2022. 4. 19. 15:49ใ๐งช Data Science/ML, DL
์ด๋ฒ ํฌ์คํ ์์ ๊ฐ๋จํ๊ฒ Keras๋ฅผ ์ด์ฉํ์ฌ CNN๋ชจ๋ธ์ ๋ง๋ค๊ณ ํ์ต, ์์ธกํ๋ค.
CNN์ ํ์ต๊ณผ์ ์ ์ ์ฒด์ ์ผ๋ก ๋ฐ๋ผ๊ฐ ๋ณด์.
CNN์ ๊ฐ๋ ์ด ์ต์์ง ์๋ค๋ฉด ์ด์ ํฌ์คํ ์ ๋ณด๊ณ ์ค์.
[์ด์ ํฌ์คํ : https://mengu.tistory.com/23]
MNIST ๋ฐ์ดํฐ์
MNIST ๋ฐ์ดํฐ์ ์ ์๊ธ์จ ๋ฐ์ดํฐ ์ ์ด๋ค.
ํด๋น ํฌ์คํ ์์ , ์ด๋ฏธ์ง๋ฅผ ๋ฐํ์ผ๋ก ์๊ธ์จ๋ฅผ 0~10๊น์ง ๋ถ๋ฅํ๋ ๋ชจ๋ธ์ ๋ง๋ค ๊ฒ์ด๋ค.
์ฐจ๊ทผ์ฐจ๊ทผ ๊ฐ๋ณด์.
1. ๋ฐ์ดํฐ์ ๋ก๋
(x_train_all, y_train_all), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
keras๋ฅผ ์ด์ฉํด ๋ฐ์ดํฐ์ ์ ๋ถ๋ฌ์จ๋ค. ๋ฐ์ดํฐ์ ์๊น์๋ฅผ ํ์ธํด๋ณด์.
x_train_all[0]
28x28 ํฌ๊ธฐ์ด๋ฉฐ, 0~255๊น์ง์ ๊ฐ์ด ๋ถํฌํด์๋ค.
y_train_all[0]
ํ๊น ๋ฐ์ดํฐ๋ 0~9๊น์ง ์กด์ฌํ๋ฉฐ, 1์ฐจ์ ์ ์๋ค.
2. ํ๋ จ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ํ๋ จ ์ธํธ์ ๊ฒ์ฆ ์ธํธ๋ก ๋๋๊ธฐ
from sklearn.model_selection import train_test_split
x_train, x_val, y_train, y_val = train_test_split(x_train_all, y_train_all,
stratify=y_train_all, test_size=0.2, random_state=42)
ํ์ต ๋ฐ์ดํฐ์ ๊ฒ์ฆ ๋ฐ์ดํฐ๋ฅผ 80%/20% ๋น์จ๋ก ๋๋ ์ค๋ค.
3. ํ๊น์ ์-ํซ ์ธ์ฝ๋ฉ์ผ๋ก ๋ณํ
y_train_encoded = tf.keras.utils.to_categorical(y_train)
y_val_encoded = tf.keras.utils.to_categorical(y_val)
์ถ๋ ฅ ๊ฐ์ softmax์ ๊ฐ์ ์ถ๋ ฅ ํจ์๋ฅผ ์ด์ฉํ 10์ฐจ์ ๋ฐฐ์ด์ด ๋ ๊ฒ์ด๋ค. ์ด์ ๋ง๊ฒ ํ๊น๋ 10์ฐจ์์ ๋ฐฐ์ด๋ก ์ธ์ฝ๋ฉํด์ค๋ค.
y_train_encoded[0]
4. ์ ๋ ฅ ๋ฐ์ดํฐ ์ค๋น
x_train = x_train.reshape(-1, 28, 28, 1)
x_val = x_val.reshape(-1, 28, 28, 1)
x_train.shape
train์ shape ํด์: 48000๊ฐ์ ์ํ์ด ์กด์ฌํ๊ณ , 28x28 ํฌ๊ธฐ์, ์ฑ๋์ด gray ํ๋๋ฟ์ธ ์ด๋ฏธ์ง ๋ฐ์ดํฐ
5. ์ ๋ ฅ ๋ฐ์ดํฐ ํ์คํ ์ ์ฒ๋ฆฌ
x_train = x_train / 255
x_val = x_val / 255
์ด๋ฏธ์ง ๋ฐ์ดํฐ๋ 0~255 ์ฌ์ด์ ์ ์๋ก ํฝ์ ๊ฐ๋๋ฅผ ํํํ๋ค. ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ 255๋ก ๋๋์ด 0~1 ์ฌ์ด์ ๊ฐ์ผ๋ก ์กฐ์ ํ๋ค.
x_train[0]
6. ๋ชจ๋ธ๋ง / ๋ชจ๋ธ ๊ตฌ์กฐ ํ์ธ
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
conv1 = tf.keras.Sequential()
conv1.add(Conv2D(10, (3,3), activation='relu', padding='same', input_shape=(28,28,1)))
conv1.add(MaxPooling2D((2,2)))
conv1.add(Flatten())
conv1.add(Dense(100, activation='relu'))
conv1.add(Dense(10, activation='softmax'))
conv1.summary()
7. ๋ชจ๋ธ ํ๋ จ
conv1.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
history = conv1.fit(x_train, y_train_encoded, epochs=20, validation_data=(x_val, y_val_encoded))
8. ์์ค ๊ทธ๋ํ์ ์ ํ๋ ๊ทธ๋ํ ํ์ธํ๊ธฐ
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train_loss', 'val_loss'])
plt.show()
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train_accuracy', 'val_accuracy'])
plt.show()
์ ์ loss๊ฐ ์ค์ด๋ค๊ณ accuracy๊ฐ ์ฆ๊ฐํ๋ ๊ฒ์ ํ์ธํ ์ ์๋ค.
CNN ๋ชจ๋ธ๋ง์ด ์ ๋์์ผ๋ฉฐ, ํ์ต/์์ธก ๋ํ ์ ๋์์ ์ ์ ์๋ค.
์ด๋ฒ ํฌ์คํ ์ ์ฌ๊ธฐ๊น์ง.
'๐งช Data Science > ML, DL' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[๊ฐํํ์ต] REINFORCE ์๊ณ ๋ฆฌ์ฆ : ์ฝ๋ ๊ตฌํ (1) | 2024.06.02 |
---|---|
[๊ฐํํ์ต] REINFORCE ์๊ณ ๋ฆฌ์ฆ : ๊ฐ๋ ๋ฐ ์์ (0) | 2024.05.27 |
[ML] ์ฐจ์ ์ถ์ (1) - ์ ์, PCA, ์์ ์ฝ๋ (1) | 2024.02.26 |
[์ถ์ฒ ์๊ณ ๋ฆฌ์ฆ] ALS ๊ฐ๋ , Basic ํ๊ฒ feat. ์ฝ๋ X (0) | 2022.05.23 |
[CNN basic] ํฉ์ฑ๊ณฑ ์ธต, ํ๋ง ์ธต (0) | 2022.04.19 |