[Paper review] TabNet: Attentive Interpretable Tabular Learning ๋ฐ TabNet ์‹ค์Šต

2024. 9. 13. 15:30ใ†๐Ÿงช Data Science/Paper review

 

 

์—ฐ๊ตฌ์‹ค ๋…ผ๋ฌธ ์„ธ๋ฏธ๋‚˜์—์„œ ๋“ค์€ ๋‚ด์šฉ์„ ์ •๋ฆฌํ•˜๊ณ ์ž ํ•œ๋‹ค.

๋…ผ๋ฌธ ์ œ๋ชฉ์€ ' TabNet: Attentive Interpretable Tabular Learning'์ด๋‹ค.

 

 

Paper link:

https://ojs.aaai.org/index.php/AAAI/article/view/16826

 

TabNet: Attentive Interpretable Tabular Learning | Proceedings of the AAAI Conference on Artificial Intelligence

 

ojs.aaai.org

 

 

๋…ผ๋ฌธ ๊ฐ„๋‹จ ์„ค๋ช…
Tabular ๋ฐ์ดํ„ฐ์—์„œ ์ฃผ๋กœ ๊ฒฐ์ • ํŠธ๋ฆฌ(Decision Tree) ๊ธฐ๋ฐ˜ ๋ชจ๋ธ๋“ค์ด ๋งŽ์ด ์‚ฌ์šฉ๋œ๋‹ค. ํ•˜์ง€๋งŒ ํŠธ๋ฆฌ ๊ธฐ๋ฐ˜ ๋ชจ๋ธ์€ ํ‘œํ˜„๋ ฅ์˜ ํ•œ๊ณ„๊ฐ€ ์žˆ์œผ๋ฉฐ, ๋”ฅ๋Ÿฌ๋‹์ด ๋‹ค์–‘ํ•œ ๋ฐ์ดํ„ฐ ์œ ํ˜•์„ ์ฒ˜๋ฆฌํ•˜๋Š” ๋Šฅ๋ ฅ์— ๋น„ํ•ด ๋ถ€์กฑํ•œ ์ ์ด ์กด์žฌํ•œ๋‹ค. TabNet์€ attention ๋ฉ”์ปค๋‹ˆ์ฆ˜์„ ํ™œ์šฉํ•˜์—ฌ ๊ฐ ์ž…๋ ฅ์—์„œ ์ค‘์š”ํ•œ ํŠน์ง•๋งŒ ์„ ํƒํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ํ•™์Šตํ•˜์—ฌ ์ด๋Ÿฌํ•œ ํ•œ๊ณ„๋ฅผ ๊ทน๋ณตํ•˜๊ณ ์ž ํ•œ๋‹ค. 

 

 

 

1. Introduction

 

ํ‘œ ํ˜•์‹ ๋ฐ์ดํ„ฐ๊ฐ€ ๋งŽ์ด ๋ถ„์„๋˜๋Š”๋ฐ, ์•„์ง๊นŒ์ง€ DNNs ๊ธฐ๋ฒ•๋ณด๋‹ค DT(Decision Tree) ๊ธฐ๋ฒ•์ด ๋งŽ์ด ์‚ฌ์šฉ๋œ๋‹ค. DT ๋ฐฉ์‹์ด ์ž˜ ํ†ตํ•˜๋Š” ์ด์œ ๋Š” ์‰ฌ์šด ํ•ด์„๊ณผ ๋น ๋ฅธ ํ•™์Šต ์†๋„์— ์žˆ๋‹ค. ํŠธ๋ฆฌ ๊ตฌ์กฐ๋ฅผ ๋”ฐ๋ผ๊ฐ€๋ฉฐ ์–ด๋–ค ๊ธฐ์ค€์œผ๋กœ ๋ฐ์ดํ„ฐ๋ฅผ ๋‚˜๋ˆด๋Š”์ง€ ํŒŒ์•…ํ•  ์ˆ˜ ์žˆ๋‹ค. ์—ฌ๋Ÿฌ ๊ฐœ์˜ ํŠธ๋ฆฌ๋ฅผ ๊ฒฐํ•ฉํ•ด ์˜ˆ์ธก ๊ฒฐ๊ณผ์˜ ๋ถ„์‚ฐ์„ ์ค„์ด๋Š” ๋ฐฉ์‹์ด ์•™์ƒ๋ธ”(Random Forests, XGBoost)์ด๋‹ค.

* Tree-based learning: ๊ฐ ๋‹จ๊ณ„์—์„œ ํ†ต๊ณ„์  ์ •๋ณด ์ด๋“์ด ๊ฐ€์žฅ ํฐ ํŠน์ง•์„ ํšจ์œจ์ ์œผ๋กœ ์„ ํƒ. 

 

๋‹ค๋งŒ, DT๋Š” ํ•™์Šต์— ์žˆ์–ด ์—ฌ๋Ÿฌ ์ฒ˜๋ฆฌ๊ฐ€ ํ•„์š”ํ•œ ์ , ๋‹ค๋ฅธ ๋ฐ์ดํ„ฐ์™€ ํ•จ๊ป˜ ์ฒ˜๋ฆฌํ•˜๋Š” ๊ฒƒ์ด ์–ด๋ ค์šด ์ ์—์„œ DNNs๊ฐ€ ๊ณ ๊ธ‰ ๊ธฐ๋ฒ•์—์„œ ์žฅ์ ์„ ๊ฐ€์ง„๋‹ค.

 

 

2. Main Idea

 

2-0. TabNet overview

TabNet์€ feature ์„ค๊ณ„๋ฅผ ํ†ตํ•ด DNNs ๊ธฐ๋ณธ ๊ตฌ์„ฑ ์š”์†Œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ฒฐ์ • ํŠธ๋ฆฌ์™€ ์œ ์‚ฌํ•œ ์ถœ๋ ฅ ๊ตฌ์กฐ๋ฅผ ๊ตฌํ˜„ํ•œ๋‹ค. 

1) ๊ฐ ๋ฐ์ดํ„ฐ ์ธ์Šคํ„ด์Šค์— ๋Œ€ํ•ด ์ค‘์š”ํ•œ ํŠน์ง•๋“ค์„ ๋ฐ์ดํ„ฐ๋กœ๋ถ€ํ„ฐ ํ•™์Šต ๋ฐ ์„ ํƒ
2) ๊ฐ ๋‹จ๊ณ„์—์„œ ์„ ํƒ๋œ ํŠน์ง•์— ๋”ฐ๋ผ ๊ฒฐ์ •์ด ๋ถ€๋ถ„์ ์œผ๋กœ ์ด๋ค„์ง. ์—ฌ๋Ÿฌ ๋‹จ๊ณ„์˜ ๊ฒฐ๊ณผ๊ฐ€ ๊ฒฐํ•ฉ๋˜์–ด ์ตœ์ข… ๊ฒฐ์ •
3) ์„ ํƒ๋œ ํŠน์ง•๋“ค์„ ๋น„์„ ํ˜•์ ์œผ๋กœ ์ฒ˜๋ฆฌ ๋ฐ ํ•™์Šต
4) ๋” ๋†’์€ ์ฐจ์›๊ณผ ๋‹ค๋‹จ๊ณ„ ๊ตฌ์กฐ๋ฅผ ํ†ตํ•ด ์•™์ƒ๋ธ” ๋ฐฉ์‹๊ณผ ์œ ์‚ฌํ•œ ํšจ๊ณผ

๋ชจ๋ธ ๊ณผ์ •: Feature transformer > Attentive transformer > Mask ์ฒ˜๋ฆฌ

 

 

2-1. Feature selection
์˜ˆ์ธก์— ์œ ์šฉํ•œ ํŠน์ง•๋“ค์„ ๋ถ€๋ถ„ ์ง‘ํ•ฉ์œผ๋กœ ์„ ํƒํ•˜๋Š” ๊ณผ์ •์„ ๋งํ•œ๋‹ค. ์ฃผ๋กœ ์‚ฌ์šฉ๋˜๋Š” ๋ฐฉ๋ฒ•์ด Forward selection, Lasso regularization์ธ๋ฐ, ์ด๋“ค์€ ์ „์ฒด ํ•™์Šต ๋ฐ์ดํ„ฐ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์ค‘์š”๋„๋ฅผ ํ‰๊ฐ€(global methods)ํ•œ๋‹ค. 

๋ฐ˜๋ฉด์—, instance-wise feature selection์€ ๊ฐ instance๋งˆ๋‹ค ๊ฐœ๋ณ„์ ์œผ๋กœ ํŠน์ง•์„ ์„ ํƒํ•˜๋Š” ๋ฐฉ์‹์ด๋‹ค.

 

TabNet์€ soft feature selection์„ ์‚ฌ์šฉํ•˜๊ณ , cotrollable sparsity(์ œ์–ด ๊ฐ€๋Šฅํ•œ ํฌ์†Œ์„ฑ)์œผ๋กœ ํŠน์ง• ์„ ํƒ๊ณผ ์ถœ๋ ฅ ๋งคํ•‘์„ end-to-end  ๋ฐฉ์‹์œผ๋กœ ํ•™์Šตํ•œ๋‹ค. ์ฆ‰, ๋ชจ๋ธ์ด ํŠน์ง•์„ ์„ ํƒํ•˜๊ณ  ์ถœ๋ ฅํ•˜๋Š” ๊ฒƒ์„ ๋™์‹œ์— ์ฒ˜๋ฆฌํ•œ๋‹ค. 

Attentive transformer > sparsemax normalization์„ ์‚ฌ์šฉํ•ด ๋งˆ์Šคํฌ๋ฅผ ๊ณ„์‚ฐํ•œ๋‹ค. ๊ฐ ํŠน์ง•์ด ์ด์ „ ๋‹จ๊ณ„์—์„œ ์–ผ๋งˆ๋‚˜ ์‚ฌ์šฉ๋˜์—ˆ๋Š”์ง€๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” ์šฐ์„ ์ˆœ์œ„ ์Šค์ผ€์ผ(prior scale) ๊ฐ’์„ ํ™œ์šฉํ•œ๋‹ค.

 

 

 

 

2-2. Feature transformer, Attentive transformer ๊ตฌ์กฐ

 


Feature transformer : ํŠน์ง•๋“ค์„ ๋ณ€ํ™˜ํ•˜์—ฌ ๋‹ค์Œ ๋‹จ๊ณ„์— ์ „๋‹ฌํ•˜๋Š” ๋ชจ๋“ˆ

Attentive transformer : ๊ฐ ๊ฒฐ์ • ๋‹จ๊ณ„์—์„œ ์ค‘์š”ํ•œ ํŠน์ง•์„ ์„ ํƒ

Feature Masking : ์„ ํƒ๋œ ํŠน์ง•๋“ค์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์ตœ์ข… ์ถœ๋ ฅ ๊ณ„์‚ฐ

 

 

2-3. Ghost Batch Normalization

๋ฐฐ์น˜ ์ •๊ทœํ™” ๊ธฐ๋ฒ•์˜ ์ผ์ข…์œผ๋กœ, ๋ฐฐ์น˜ ์ •๊ทœํ™”์˜ ํšจ๊ณผ๋ฅผ ๊ทน๋Œ€ํ™”ํ•˜๋ฉด์„œ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ์„ ํšจ์œจ์ ์œผ๋กœ ๊ด€๋ฆฌํ•˜๋Š” ๋ฐ ์ดˆ์ ์„ ๋‘๊ณ  ์žˆ๋‹ค.

์ด๋Š” Hoffer, Hubara, and Soudry (2017)์—์„œ ์ œ์•ˆ๋œ ๊ธฐ๋ฒ•์ด๋‹ค.

 

DNNs์—์„œ Batch Normalization์€ ๋ฐฐ์น˜ ๋‹จ์œ„๋กœ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ์ •๊ทœํ™”ํ•˜์—ฌ ํ•™์Šต์„ ์•ˆ์ •์‹œํ‚ค๊ณ , ๋” ๋น ๋ฅธ ํ•™์Šต ์†๋„๋ฅผ ์ œ๊ณตํ•œ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ๋Œ€๊ทœ๋ชจ ๋ฐฐ์น˜์ผ ๊ฒฝ์šฐ, ํ†ต๊ณ„๊ฐ’(ํ‰๊ท ๊ณผ ๋ถ„์‚ฐ)์˜ ๊ณ„์‚ฐ์ด ๋„ˆ๋ฌด ๊ธ€๋กœ๋ฒŒํ•ด์ ธ, ๋ชจ๋ธ์ด ์†Œ๊ทœ๋ชจ ๋ฐ์ดํ„ฐ์—์„œ ์ž˜ ์ผ๋ฐ˜ํ™”๋˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ๋‹ค. ๋˜ํ•œ BN ํ†ต๊ณ„๊ฐ’์ด ์ง€๋‚˜์น˜๊ฒŒ ํ‰ํ™œํ™”๋˜๋ฉด ์†๋„๊ฐ€ ๋Š๋ ค์ง€๊ณ  ๋„ˆ๋ฌด ์ž‘์€ ๋ฐฐ์น˜๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ๊ณผ์ ํ•ฉ์ด ๋ฐœ์ƒํ•œ๋‹ค.

 

Ghost BN์€ ํฐ ๋ฐฐ์น˜๋ฅผ ์ž‘์€ ๊ฐ€์ƒ ๋ฐฐ์น˜๋กœ ๋‚˜๋ˆ„์–ด ๊ฐ ๋ฐฐ์น˜์—์„œ ํ†ต๊ณ„๊ฐ’์„ ๊ณ„์‚ฐํ•˜์—ฌ ๋กœ์ปฌ ํ†ต๊ณ„๋ฅผ ๋ฐ˜์˜ํ•œ๋‹ค. ์ผ๊ด€๋œ ํ•™์Šต ์†๋„๋ฅผ ์œ ์ง€ํ•˜๋„๋ก ๋„์™€์ค€๋‹ค.

TabNet์—์„œ Ghost BN์€ ์ž…๋ ฅ ํŠน์ง•๋“ค์— ๋Œ€ํ•ด์„œ๋Š” ์‚ฌ์šฉํ•˜์ง€ ์•Š์ง€๋งŒ, ๋‚˜๋จธ์ง€ ๋„คํŠธ์›Œํฌ ์ธต์—์„œ๋Š” ์‚ฌ์šฉํ•œ๋‹ค. ์ž…๋ ฅ ํŠน์ง•์— ๋Œ€ํ•ด์„œ๋Š” ์ €๋ถ„์‚ฐ ํ†ต๊ณ„์น˜๊ฐ€ ๋„์›€์ด ๋˜์ง€๋งŒ, ์ค‘๊ฐ„ ๊ณผ์ •์—์„  Ghost BN์„ ์ ์šฉํ•˜์—ฌ ํ•™์Šต์„ ์•ˆ์ •ํ™”ํ•˜๊ณ  ํ•™์Šต ์†๋„๋ฅผ ๋†’์ด๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.
> ์ง€๊ธˆ ์—ฐ๊ตฌ ์ค‘์ธ ํ”„๋กœ์ ํŠธ๋Š” Batch ์ „์— ์ „๋ถ€ standard normalization์„ ๊ฑฐ์น˜๊ณ  ์‹œ์ž‘ํ•œ๋‹ค. BN์„ ํ™œ์šฉํ•ด ๋ณผ ์ˆ˜ ์žˆ๊ฒ ๋‹ค.

 

 

3. Experiment

  • ๋ฐ์ดํ„ฐ์…‹:
    • 6๊ฐœ์˜ ํ‘œ ํ˜•์‹(tabular) ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉํ•˜์˜€์œผ๋ฉฐ, ๊ฐ ๋ฐ์ดํ„ฐ์…‹์€ 10,000๊ฐœ์˜ ํ›ˆ๋ จ ์ƒ˜ํ”Œ๋กœ ๊ตฌ์„ฑ๋œ๋‹ค.
    • ๋ฐ์ดํ„ฐ์…‹์€ ํŠน์ • ํŠน์ง•๋“ค์ด ์ถœ๋ ฅ์— ๊ฒฐ์ •์ ์ธ ์—ญํ• ์„ ํ•˜๋„๋ก ์„ค๊ณ„๋˜์—ˆ๋‹ค.
  • ์ผ๋ฐ˜์ ์ธ ํŠน์ง• ์„ ํƒ:
    • Syn1-Syn3 ๋ฐ์ดํ„ฐ์…‹์—์„œ๋Š” ์ค‘์š”ํ•œ ํŠน์ง•์ด ๋ชจ๋“  ์ธ์Šคํ„ด์Šค์— ๋Œ€ํ•ด ๋™์ผํ•˜๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, Syn2๋Š” ํŠน์ง• X3-X6์— ์˜์กดํ•œ๋‹ค.
    • ์ด๋Ÿฌํ•œ ๊ฒฝ์šฐ, ์ „์—ญ(feature-wise) ํŠน์ง• ์„ ํƒ(global feature selection) ๋ฐฉ์‹์ด ๋†’์€ ์„ฑ๋Šฅ์„ ๋ณด์ธ๋‹ค.
  • ์ธ์Šคํ„ด์Šค ์˜์กด์  ํŠน์ง• ์„ ํƒ:
    • Syn4-Syn6 ๋ฐ์ดํ„ฐ์…‹์—์„œ๋Š” ์ค‘์š”ํ•œ ํŠน์ง•์ด ์ธ์Šคํ„ด์Šค์— ๋”ฐ๋ผ ๋‹ฌ๋ผ์ง„๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, Syn4์—์„œ๋Š” ์ถœ๋ ฅ์ด X1-X2 ๋˜๋Š” X3-X6์— ์˜์กดํ•˜๋ฉฐ, ์ด๋Š” X11์˜ ๊ฐ’์— ๋”ฐ๋ผ ๋‹ฌ๋ผ์ง„๋‹ค.
    • ์ด ๊ฒฝ์šฐ, ์ „์—ญ ํŠน์ง• ์„ ํƒ์€ ๋น„ํšจ์œจ์ ์ด๋ฉฐ, ์ธ์Šคํ„ด์Šค๋ณ„ ํŠน์ง• ์„ ํƒ์ด ํ•„์š”ํ•˜๋‹ค.

์„ฑ๋Šฅ ๋น„๊ต

  • TabNet์˜ ์„ฑ๋Šฅ:
    • TabNet์€ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๋ฐฉ๋ฒ•๋“ค๊ณผ ๋น„๊ต๋œ๋‹ค:
      • Tree Ensembles (Geurts et al. 2006)
      • LASSO Regularization
      • L2X (Chen et al. 2018)
      • INVASE (Yoon et al. 2019)
    • TabNet์€ Syn1-Syn3 ๋ฐ์ดํ„ฐ์…‹์—์„œ ์ „์—ญ ํŠน์ง• ์„ ํƒ์— ๊ทผ์ ‘ํ•œ ์„ฑ๋Šฅ์„ ๋ณด์ธ๋‹ค. ์ด๋Š” TabNet์ด ๊ธ€๋กœ๋ฒŒ ์ค‘์š” ํŠน์ง•์„ ํšจ๊ณผ์ ์œผ๋กœ ์‹๋ณ„ํ•  ์ˆ˜ ์žˆ์Œ์„ ์˜๋ฏธํ•œ๋‹ค.
    • Syn4-Syn6 ๋ฐ์ดํ„ฐ์…‹์—์„œ๋Š” ์ธ์Šคํ„ด์Šค๋ณ„๋กœ ์ค‘๋ณต๋œ ํŠน์ง•์„ ์ œ๊ฑฐํ•จ์œผ๋กœ์จ ์ „์—ญ ํŠน์ง• ์„ ํƒ์˜ ์„ฑ๋Šฅ์„ ๊ฐœ์„ ํ•œ๋‹ค.
    • TabNet์€ ๋‹จ์ผ ์•„ํ‚คํ…์ฒ˜๋กœ, Syn1-Syn3์˜ ๊ฒฝ์šฐ 26,000๊ฐœ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๊ฐ€์ง€๋ฉฐ, Syn4-Syn6์˜ ๊ฒฝ์šฐ 31,000๊ฐœ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๊ฐ€์ง„๋‹ค. ๋ฐ˜๋ฉด, ๋‹ค๋ฅธ ๋ฐฉ๋ฒ•๋“ค์€ 43,000๊ฐœ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ์˜ˆ์ธก ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜๋ฉฐ, INVASE๋Š” 101,000๊ฐœ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๊ฐ€์ง„๋‹ค(์•กํ„ฐ-๋น„ํ‰๊ฐ€ ํ”„๋ ˆ์ž„์›Œํฌ์˜ ๋‘ ๋ชจ๋ธ ํฌํ•จ).

 

 

 

4. TabNet Pytorch ์‹ค์Šต ๋ฐ ์˜์˜

 

TabNet์€ ํ•œ๋•Œ Kaggle์—์„œ ๋งŽ์ด ์‚ฌ์šฉ๋๋˜ ๋ชจ๋ธ์ด๋‹ค.

์‹ค์ œ TabNet์„ Classification ๋ฌธ์ œ์—์„œ ์‚ฌ์šฉํ•˜๋Š” ์˜ˆ์‹œ๋ฅผ pytorch ์ฝ”๋“œ๋ฅผ ํ†ตํ•ด ๋ณด์—ฌ์ฃผ๊ฒ ๋‹ค.

 

 

4-1. ์ฃผ์š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ Import ํ•˜๊ณ , titanic ๋ฐ์ดํ„ฐ๋ฅผ ๋ถˆ๋Ÿฌ์™€์„œ ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ๋ฅผ ๊ฐ€๋ณ๊ฒŒ ํ•ด ์ค€๋‹ค.

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score
from pytorch_tabnet.tab_model import TabNetClassifier

# ๋ฐ์ดํ„ฐ ๋‹ค์šด๋กœ๋“œ
url_train = 'https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv'
df = pd.read_csv(url_train)

# ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ
df['Age'].fillna(df['Age'].median(), inplace=True)
df['Embarked'].fillna(df['Embarked'].mode()[0], inplace=True)
df['Fare'].fillna(df['Fare'].median(), inplace=True)
df['Sex'] = LabelEncoder().fit_transform(df['Sex'])
df['Embarked'] = LabelEncoder().fit_transform(df['Embarked'])

 

 

 

4-2. ๋ฐ์ดํ„ฐ๋ฅผ X(input), Y(target)๋กœ ๋‚˜๋ˆˆ๋‹ค. ๋ชจ๋ธ์ด ์˜ˆ์ธกํ•˜๋Š” ๊ฒƒ์€ ์ด ์‚ฌ๋žŒ์˜ ๋Šฅ๋ ฅ์„ ๋ฐ”ํƒ•์œผ๋กœ ์ƒ์กดํ–ˆ๋ƒ ์ƒ์กดํ•˜์ง€ ์•Š์•˜๋ƒ, ์ฆ‰ ์ƒ์กด์—ฌ๋ถ€์ด๋‹ค. ๋ฐ์ดํ„ฐ ๋ถ„ํ•  ํ›„์—, ๋ชจ๋“  input ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•ด์„œ scale์„ ๋งž์ถฐ์ฃผ๊ธฐ ์œ„ํ•ด ํ‘œ์ค€ํ™”๋ฅผ ์ง„ํ–‰ํ•œ๋‹ค.

# ํŠน์ง•๊ณผ ๋ ˆ์ด๋ธ” ๋ถ„๋ฆฌ
X = df[['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked']].values
y = df['Survived'].values

# ๋ฐ์ดํ„ฐ ๋ถ„ํ• 
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

# ๋ฐ์ดํ„ฐ ํ‘œ์ค€ํ™”
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

 

 

 

4-3. TabNet ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์„ค์ •ํ•˜๊ณ  ํ•™์Šต ๋ฐ ํ‰๊ฐ€ํ•œ๋‹ค.

# TabNet ๋ชจ๋ธ ์ƒ์„ฑ
clf = TabNetClassifier(
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=2e-2),
    scheduler_fn=torch.optim.lr_scheduler.StepLR,
    scheduler_params=dict(step_size=10, gamma=0.9),
    mask_type='sparsemax',  # 'sparsemax' or 'entmax'
)

# ๋ชจ๋ธ ํ•™์Šต
clf.fit(
    X_train, y_train,
    eval_set=[(X_test, y_test)],
    batch_size=1024,
    virtual_batch_size=128,
    num_workers=0,
    max_epochs=100,
    patience=10,
)

# ์˜ˆ์ธก ๋ฐ ํ‰๊ฐ€
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.4f}")

 

 

Accuracy: 0.9200

 

 

 

Attention ๊ธฐ๋ฒ•์„ ์‚ฌ์šฉํ•œ๋‹ค๊ณ  ํ–ˆ์ง€๋งŒ, ์‚ฌ์‹ค ์ตœ๊ทผ ์šฐ๋ฆฌ๊ฐ€ ์•Œ๊ณ  ์žˆ๋Š” Attention ๊ธฐ๋ฒ•์ด ์‚ฌ์šฉ๋œ ๊ฒƒ์€ ์•„๋‹ˆ๋‹ค. ์ € ๋‹น์‹œ์— Attention์˜ ์˜๋ฏธ๊ฐ€ ์กฐ๊ธˆ ํ˜ผ๋™๋˜์–ด ์‚ฌ์šฉ๋œ ๊ฒƒ์ด ์•„๋‹Œ๊ฐ€ ์ƒ๊ฐ๋œ๋‹ค.

๋จธ์‹ ๋Ÿฌ๋‹(xgboost)์œผ๋กœ VAEP๋ฅผ ๋ฝ‘์•˜์„ ๋•Œ ์ž˜ ๋˜์—ˆ๋Š”๋ฐ, ์ตœ๊ทผ DNNs๋กœ ํ–ˆ์„ ๋•Œ๋Š” ์ œ๋Œ€๋กœ ๋‚˜์˜ค์ง€ ์•Š์•˜๋‹ค. ์ด์— ๋Œ€ํ•œ ๋Œ€์•ˆ์„ ๋™๋ฃŒ ์—ฐ๊ตฌ์ƒ์ด ์ฐพ๋‹ค๊ฐ€ ๋ฐœํ‘œํ–ˆ๋‹ค๊ณ  ํ•˜์˜€๋‹ค.

 

๊ต์ˆ˜๋‹˜๊ป˜์„  Ghost Normalization์„ ํ•ด๊ฒฐ์ฑ…์œผ๋กœ ์ผ๋˜ ์—ฐ์œ ๊ฐ€ ๊ถ๊ธˆํ•˜๋ฉฐ ์–ด์ฉŒ๋ฉด ์ €๋Ÿฐ ์•„์ด๋””์–ด๊ฐ€ ๋‚˜์ค‘์— ์—ฐ๊ตฌํ•  ๋•Œ ์ข‹์€ ๋ณดํƒฌ์ด ๋  ๊ฒƒ์ด๋ผ๊ณ  ํ•˜์…จ๋‹ค.