2024. 9. 13. 15:30ใ๐งช Data Science/Paper review
์ฐ๊ตฌ์ค ๋ ผ๋ฌธ ์ธ๋ฏธ๋์์ ๋ค์ ๋ด์ฉ์ ์ ๋ฆฌํ๊ณ ์ ํ๋ค.
๋ ผ๋ฌธ ์ ๋ชฉ์ ' TabNet: Attentive Interpretable Tabular Learning'์ด๋ค.
Paper link:
https://ojs.aaai.org/index.php/AAAI/article/view/16826
๋
ผ๋ฌธ ๊ฐ๋จ ์ค๋ช
Tabular ๋ฐ์ดํฐ์์ ์ฃผ๋ก ๊ฒฐ์ ํธ๋ฆฌ(Decision Tree) ๊ธฐ๋ฐ ๋ชจ๋ธ๋ค์ด ๋ง์ด ์ฌ์ฉ๋๋ค. ํ์ง๋ง ํธ๋ฆฌ ๊ธฐ๋ฐ ๋ชจ๋ธ์ ํํ๋ ฅ์ ํ๊ณ๊ฐ ์์ผ๋ฉฐ, ๋ฅ๋ฌ๋์ด ๋ค์ํ ๋ฐ์ดํฐ ์ ํ์ ์ฒ๋ฆฌํ๋ ๋ฅ๋ ฅ์ ๋นํด ๋ถ์กฑํ ์ ์ด ์กด์ฌํ๋ค. TabNet์ attention ๋ฉ์ปค๋์ฆ์ ํ์ฉํ์ฌ ๊ฐ ์
๋ ฅ์์ ์ค์ํ ํน์ง๋ง ์ ํํ๋ ๋ฐฉ์์ผ๋ก ํ์ตํ์ฌ ์ด๋ฌํ ํ๊ณ๋ฅผ ๊ทน๋ณตํ๊ณ ์ ํ๋ค.
1. Introduction
ํ ํ์ ๋ฐ์ดํฐ๊ฐ ๋ง์ด ๋ถ์๋๋๋ฐ, ์์ง๊น์ง DNNs ๊ธฐ๋ฒ๋ณด๋ค DT(Decision Tree) ๊ธฐ๋ฒ์ด ๋ง์ด ์ฌ์ฉ๋๋ค. DT ๋ฐฉ์์ด ์ ํตํ๋ ์ด์ ๋ ์ฌ์ด ํด์๊ณผ ๋น ๋ฅธ ํ์ต ์๋์ ์๋ค. ํธ๋ฆฌ ๊ตฌ์กฐ๋ฅผ ๋ฐ๋ผ๊ฐ๋ฉฐ ์ด๋ค ๊ธฐ์ค์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ๋๋ด๋์ง ํ์
ํ ์ ์๋ค. ์ฌ๋ฌ ๊ฐ์ ํธ๋ฆฌ๋ฅผ ๊ฒฐํฉํด ์์ธก ๊ฒฐ๊ณผ์ ๋ถ์ฐ์ ์ค์ด๋ ๋ฐฉ์์ด ์์๋ธ(Random Forests, XGBoost)์ด๋ค.
* Tree-based learning: ๊ฐ ๋จ๊ณ์์ ํต๊ณ์ ์ ๋ณด ์ด๋์ด ๊ฐ์ฅ ํฐ ํน์ง์ ํจ์จ์ ์ผ๋ก ์ ํ.
๋ค๋ง, DT๋ ํ์ต์ ์์ด ์ฌ๋ฌ ์ฒ๋ฆฌ๊ฐ ํ์ํ ์ , ๋ค๋ฅธ ๋ฐ์ดํฐ์ ํจ๊ป ์ฒ๋ฆฌํ๋ ๊ฒ์ด ์ด๋ ค์ด ์ ์์ DNNs๊ฐ ๊ณ ๊ธ ๊ธฐ๋ฒ์์ ์ฅ์ ์ ๊ฐ์ง๋ค.
2. Main Idea
2-0. TabNet overview
TabNet์ feature ์ค๊ณ๋ฅผ ํตํด DNNs ๊ธฐ๋ณธ ๊ตฌ์ฑ ์์๋ฅผ ์ฌ์ฉํ์ฌ ๊ฒฐ์ ํธ๋ฆฌ์ ์ ์ฌํ ์ถ๋ ฅ ๊ตฌ์กฐ๋ฅผ ๊ตฌํํ๋ค.
1) ๊ฐ ๋ฐ์ดํฐ ์ธ์คํด์ค์ ๋ํด ์ค์ํ ํน์ง๋ค์ ๋ฐ์ดํฐ๋ก๋ถํฐ ํ์ต ๋ฐ ์ ํ
2) ๊ฐ ๋จ๊ณ์์ ์ ํ๋ ํน์ง์ ๋ฐ๋ผ ๊ฒฐ์ ์ด ๋ถ๋ถ์ ์ผ๋ก ์ด๋ค์ง. ์ฌ๋ฌ ๋จ๊ณ์ ๊ฒฐ๊ณผ๊ฐ ๊ฒฐํฉ๋์ด ์ต์ข
๊ฒฐ์
3) ์ ํ๋ ํน์ง๋ค์ ๋น์ ํ์ ์ผ๋ก ์ฒ๋ฆฌ ๋ฐ ํ์ต
4) ๋ ๋์ ์ฐจ์๊ณผ ๋ค๋จ๊ณ ๊ตฌ์กฐ๋ฅผ ํตํด ์์๋ธ ๋ฐฉ์๊ณผ ์ ์ฌํ ํจ๊ณผ
๋ชจ๋ธ ๊ณผ์ : Feature transformer > Attentive transformer > Mask ์ฒ๋ฆฌ
2-1. Feature selection
์์ธก์ ์ ์ฉํ ํน์ง๋ค์ ๋ถ๋ถ ์งํฉ์ผ๋ก ์ ํํ๋ ๊ณผ์ ์ ๋งํ๋ค. ์ฃผ๋ก ์ฌ์ฉ๋๋ ๋ฐฉ๋ฒ์ด Forward selection, Lasso regularization์ธ๋ฐ, ์ด๋ค์ ์ ์ฒด ํ์ต ๋ฐ์ดํฐ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ค์๋๋ฅผ ํ๊ฐ(global methods)ํ๋ค.
๋ฐ๋ฉด์, instance-wise feature selection์ ๊ฐ instance๋ง๋ค ๊ฐ๋ณ์ ์ผ๋ก ํน์ง์ ์ ํํ๋ ๋ฐฉ์์ด๋ค.
TabNet์ soft feature selection์ ์ฌ์ฉํ๊ณ , cotrollable sparsity(์ ์ด ๊ฐ๋ฅํ ํฌ์์ฑ)์ผ๋ก ํน์ง ์ ํ๊ณผ ์ถ๋ ฅ ๋งคํ์ end-to-end ๋ฐฉ์์ผ๋ก ํ์ตํ๋ค. ์ฆ, ๋ชจ๋ธ์ด ํน์ง์ ์ ํํ๊ณ ์ถ๋ ฅํ๋ ๊ฒ์ ๋์์ ์ฒ๋ฆฌํ๋ค.
Attentive transformer > sparsemax normalization์ ์ฌ์ฉํด ๋ง์คํฌ๋ฅผ ๊ณ์ฐํ๋ค. ๊ฐ ํน์ง์ด ์ด์ ๋จ๊ณ์์ ์ผ๋ง๋ ์ฌ์ฉ๋์๋์ง๋ฅผ ๋ํ๋ด๋ ์ฐ์ ์์ ์ค์ผ์ผ(prior scale) ๊ฐ์ ํ์ฉํ๋ค.
2-2. Feature transformer, Attentive transformer ๊ตฌ์กฐ
Feature transformer : ํน์ง๋ค์ ๋ณํํ์ฌ ๋ค์ ๋จ๊ณ์ ์ ๋ฌํ๋ ๋ชจ๋
Attentive transformer : ๊ฐ ๊ฒฐ์ ๋จ๊ณ์์ ์ค์ํ ํน์ง์ ์ ํ
Feature Masking : ์ ํ๋ ํน์ง๋ค์ ๊ธฐ๋ฐ์ผ๋ก ์ต์ข ์ถ๋ ฅ ๊ณ์ฐ
2-3. Ghost Batch Normalization
๋ฐฐ์น ์ ๊ทํ ๊ธฐ๋ฒ์ ์ผ์ข ์ผ๋ก, ๋ฐฐ์น ์ ๊ทํ์ ํจ๊ณผ๋ฅผ ๊ทน๋ํํ๋ฉด์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ์ ํจ์จ์ ์ผ๋ก ๊ด๋ฆฌํ๋ ๋ฐ ์ด์ ์ ๋๊ณ ์๋ค.
์ด๋ Hoffer, Hubara, and Soudry (2017)์์ ์ ์๋ ๊ธฐ๋ฒ์ด๋ค.
DNNs์์ Batch Normalization์ ๋ฐฐ์น ๋จ์๋ก ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ ๊ทํํ์ฌ ํ์ต์ ์์ ์ํค๊ณ , ๋ ๋น ๋ฅธ ํ์ต ์๋๋ฅผ ์ ๊ณตํ๋ค. ๊ทธ๋ฌ๋ ๋๊ท๋ชจ ๋ฐฐ์น์ผ ๊ฒฝ์ฐ, ํต๊ณ๊ฐ(ํ๊ท ๊ณผ ๋ถ์ฐ)์ ๊ณ์ฐ์ด ๋๋ฌด ๊ธ๋ก๋ฒํด์ ธ, ๋ชจ๋ธ์ด ์๊ท๋ชจ ๋ฐ์ดํฐ์์ ์ ์ผ๋ฐํ๋์ง ์์ ์ ์๋ค. ๋ํ BN ํต๊ณ๊ฐ์ด ์ง๋์น๊ฒ ํํํ๋๋ฉด ์๋๊ฐ ๋๋ ค์ง๊ณ ๋๋ฌด ์์ ๋ฐฐ์น๋ฅผ ์ฌ์ฉํ๋ฉด ๊ณผ์ ํฉ์ด ๋ฐ์ํ๋ค.
Ghost BN์ ํฐ ๋ฐฐ์น๋ฅผ ์์ ๊ฐ์ ๋ฐฐ์น๋ก ๋๋์ด ๊ฐ ๋ฐฐ์น์์ ํต๊ณ๊ฐ์ ๊ณ์ฐํ์ฌ ๋ก์ปฌ ํต๊ณ๋ฅผ ๋ฐ์ํ๋ค. ์ผ๊ด๋ ํ์ต ์๋๋ฅผ ์ ์งํ๋๋ก ๋์์ค๋ค.
TabNet์์ Ghost BN์ ์
๋ ฅ ํน์ง๋ค์ ๋ํด์๋ ์ฌ์ฉํ์ง ์์ง๋ง, ๋๋จธ์ง ๋คํธ์ํฌ ์ธต์์๋ ์ฌ์ฉํ๋ค. ์
๋ ฅ ํน์ง์ ๋ํด์๋ ์ ๋ถ์ฐ ํต๊ณ์น๊ฐ ๋์์ด ๋์ง๋ง, ์ค๊ฐ ๊ณผ์ ์์ Ghost BN์ ์ ์ฉํ์ฌ ํ์ต์ ์์ ํํ๊ณ ํ์ต ์๋๋ฅผ ๋์ด๊ธฐ ๋๋ฌธ์ด๋ค.
> ์ง๊ธ ์ฐ๊ตฌ ์ค์ธ ํ๋ก์ ํธ๋ Batch ์ ์ ์ ๋ถ standard normalization์ ๊ฑฐ์น๊ณ ์์ํ๋ค. BN์ ํ์ฉํด ๋ณผ ์ ์๊ฒ ๋ค.
3. Experiment
- ๋ฐ์ดํฐ์
:
- 6๊ฐ์ ํ ํ์(tabular) ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ์์ผ๋ฉฐ, ๊ฐ ๋ฐ์ดํฐ์ ์ 10,000๊ฐ์ ํ๋ จ ์ํ๋ก ๊ตฌ์ฑ๋๋ค.
- ๋ฐ์ดํฐ์ ์ ํน์ ํน์ง๋ค์ด ์ถ๋ ฅ์ ๊ฒฐ์ ์ ์ธ ์ญํ ์ ํ๋๋ก ์ค๊ณ๋์๋ค.
- ์ผ๋ฐ์ ์ธ ํน์ง ์ ํ:
- Syn1-Syn3 ๋ฐ์ดํฐ์ ์์๋ ์ค์ํ ํน์ง์ด ๋ชจ๋ ์ธ์คํด์ค์ ๋ํด ๋์ผํ๋ค. ์๋ฅผ ๋ค์ด, Syn2๋ ํน์ง X3-X6์ ์์กดํ๋ค.
- ์ด๋ฌํ ๊ฒฝ์ฐ, ์ ์ญ(feature-wise) ํน์ง ์ ํ(global feature selection) ๋ฐฉ์์ด ๋์ ์ฑ๋ฅ์ ๋ณด์ธ๋ค.
- ์ธ์คํด์ค ์์กด์ ํน์ง ์ ํ:
- Syn4-Syn6 ๋ฐ์ดํฐ์ ์์๋ ์ค์ํ ํน์ง์ด ์ธ์คํด์ค์ ๋ฐ๋ผ ๋ฌ๋ผ์ง๋ค. ์๋ฅผ ๋ค์ด, Syn4์์๋ ์ถ๋ ฅ์ด X1-X2 ๋๋ X3-X6์ ์์กดํ๋ฉฐ, ์ด๋ X11์ ๊ฐ์ ๋ฐ๋ผ ๋ฌ๋ผ์ง๋ค.
- ์ด ๊ฒฝ์ฐ, ์ ์ญ ํน์ง ์ ํ์ ๋นํจ์จ์ ์ด๋ฉฐ, ์ธ์คํด์ค๋ณ ํน์ง ์ ํ์ด ํ์ํ๋ค.
์ฑ๋ฅ ๋น๊ต
- TabNet์ ์ฑ๋ฅ:
- TabNet์ ๋ค์๊ณผ ๊ฐ์ ๋ฐฉ๋ฒ๋ค๊ณผ ๋น๊ต๋๋ค:
- Tree Ensembles (Geurts et al. 2006)
- LASSO Regularization
- L2X (Chen et al. 2018)
- INVASE (Yoon et al. 2019)
- TabNet์ Syn1-Syn3 ๋ฐ์ดํฐ์ ์์ ์ ์ญ ํน์ง ์ ํ์ ๊ทผ์ ํ ์ฑ๋ฅ์ ๋ณด์ธ๋ค. ์ด๋ TabNet์ด ๊ธ๋ก๋ฒ ์ค์ ํน์ง์ ํจ๊ณผ์ ์ผ๋ก ์๋ณํ ์ ์์์ ์๋ฏธํ๋ค.
- Syn4-Syn6 ๋ฐ์ดํฐ์ ์์๋ ์ธ์คํด์ค๋ณ๋ก ์ค๋ณต๋ ํน์ง์ ์ ๊ฑฐํจ์ผ๋ก์จ ์ ์ญ ํน์ง ์ ํ์ ์ฑ๋ฅ์ ๊ฐ์ ํ๋ค.
- TabNet์ ๋จ์ผ ์ํคํ ์ฒ๋ก, Syn1-Syn3์ ๊ฒฝ์ฐ 26,000๊ฐ์ ํ๋ผ๋ฏธํฐ๋ฅผ ๊ฐ์ง๋ฉฐ, Syn4-Syn6์ ๊ฒฝ์ฐ 31,000๊ฐ์ ํ๋ผ๋ฏธํฐ๋ฅผ ๊ฐ์ง๋ค. ๋ฐ๋ฉด, ๋ค๋ฅธ ๋ฐฉ๋ฒ๋ค์ 43,000๊ฐ์ ํ๋ผ๋ฏธํฐ๋ฅผ ์ฌ์ฉํ๋ ์์ธก ๋ชจ๋ธ์ ์ฌ์ฉํ๋ฉฐ, INVASE๋ 101,000๊ฐ์ ํ๋ผ๋ฏธํฐ๋ฅผ ๊ฐ์ง๋ค(์กํฐ-๋นํ๊ฐ ํ๋ ์์ํฌ์ ๋ ๋ชจ๋ธ ํฌํจ).
- TabNet์ ๋ค์๊ณผ ๊ฐ์ ๋ฐฉ๋ฒ๋ค๊ณผ ๋น๊ต๋๋ค:
4. TabNet Pytorch ์ค์ต ๋ฐ ์์
TabNet์ ํ๋ Kaggle์์ ๋ง์ด ์ฌ์ฉ๋๋ ๋ชจ๋ธ์ด๋ค.
์ค์ TabNet์ Classification ๋ฌธ์ ์์ ์ฌ์ฉํ๋ ์์๋ฅผ pytorch ์ฝ๋๋ฅผ ํตํด ๋ณด์ฌ์ฃผ๊ฒ ๋ค.
4-1. ์ฃผ์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ Import ํ๊ณ , titanic ๋ฐ์ดํฐ๋ฅผ ๋ถ๋ฌ์์ ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ๋ฅผ ๊ฐ๋ณ๊ฒ ํด ์ค๋ค.
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score
from pytorch_tabnet.tab_model import TabNetClassifier
# ๋ฐ์ดํฐ ๋ค์ด๋ก๋
url_train = 'https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv'
df = pd.read_csv(url_train)
# ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ
df['Age'].fillna(df['Age'].median(), inplace=True)
df['Embarked'].fillna(df['Embarked'].mode()[0], inplace=True)
df['Fare'].fillna(df['Fare'].median(), inplace=True)
df['Sex'] = LabelEncoder().fit_transform(df['Sex'])
df['Embarked'] = LabelEncoder().fit_transform(df['Embarked'])
4-2. ๋ฐ์ดํฐ๋ฅผ X(input), Y(target)๋ก ๋๋๋ค. ๋ชจ๋ธ์ด ์์ธกํ๋ ๊ฒ์ ์ด ์ฌ๋์ ๋ฅ๋ ฅ์ ๋ฐํ์ผ๋ก ์์กดํ๋ ์์กดํ์ง ์์๋, ์ฆ ์์กด์ฌ๋ถ์ด๋ค. ๋ฐ์ดํฐ ๋ถํ ํ์, ๋ชจ๋ input ๋ฐ์ดํฐ์ ๋ํด์ scale์ ๋ง์ถฐ์ฃผ๊ธฐ ์ํด ํ์คํ๋ฅผ ์งํํ๋ค.
# ํน์ง๊ณผ ๋ ์ด๋ธ ๋ถ๋ฆฌ
X = df[['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked']].values
y = df['Survived'].values
# ๋ฐ์ดํฐ ๋ถํ
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
# ๋ฐ์ดํฐ ํ์คํ
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
4-3. TabNet ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ์ค์ ํ๊ณ ํ์ต ๋ฐ ํ๊ฐํ๋ค.
# TabNet ๋ชจ๋ธ ์์ฑ
clf = TabNetClassifier(
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=2e-2),
scheduler_fn=torch.optim.lr_scheduler.StepLR,
scheduler_params=dict(step_size=10, gamma=0.9),
mask_type='sparsemax', # 'sparsemax' or 'entmax'
)
# ๋ชจ๋ธ ํ์ต
clf.fit(
X_train, y_train,
eval_set=[(X_test, y_test)],
batch_size=1024,
virtual_batch_size=128,
num_workers=0,
max_epochs=100,
patience=10,
)
# ์์ธก ๋ฐ ํ๊ฐ
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.4f}")
Accuracy: 0.9200
Attention ๊ธฐ๋ฒ์ ์ฌ์ฉํ๋ค๊ณ ํ์ง๋ง, ์ฌ์ค ์ต๊ทผ ์ฐ๋ฆฌ๊ฐ ์๊ณ ์๋ Attention ๊ธฐ๋ฒ์ด ์ฌ์ฉ๋ ๊ฒ์ ์๋๋ค. ์ ๋น์์ Attention์ ์๋ฏธ๊ฐ ์กฐ๊ธ ํผ๋๋์ด ์ฌ์ฉ๋ ๊ฒ์ด ์๋๊ฐ ์๊ฐ๋๋ค.
๋จธ์ ๋ฌ๋(xgboost)์ผ๋ก VAEP๋ฅผ ๋ฝ์์ ๋ ์ ๋์๋๋ฐ, ์ต๊ทผ DNNs๋ก ํ์ ๋๋ ์ ๋๋ก ๋์ค์ง ์์๋ค. ์ด์ ๋ํ ๋์์ ๋๋ฃ ์ฐ๊ตฌ์์ด ์ฐพ๋ค๊ฐ ๋ฐํํ๋ค๊ณ ํ์๋ค.
๊ต์๋๊ป์ Ghost Normalization์ ํด๊ฒฐ์ฑ ์ผ๋ก ์ผ๋ ์ฐ์ ๊ฐ ๊ถ๊ธํ๋ฉฐ ์ด์ฉ๋ฉด ์ ๋ฐ ์์ด๋์ด๊ฐ ๋์ค์ ์ฐ๊ตฌํ ๋ ์ข์ ๋ณดํฌ์ด ๋ ๊ฒ์ด๋ผ๊ณ ํ์ จ๋ค.