1x1 Convolution은 Network in network 논문에서 주로 다룬 개념입니다.
Lin, M., Chen, Q., & Yan, S. (2013). Network in network. arXiv preprint arXiv:1312.4400.
이 글에서는 MNIST 데이터셋을 Dense 층(fully-connected layer)를 사용하지 않고 학습하는 방법을 다루겠습니다.
코딩 방법에 따라 2가지로 나뉘어서 모델을 구성할 수 있습니다. 직접 돌려보니 정확도는 약 98%정도 나왔습니다.
1. 단순하게 Global Avaerage Pooling만 사용
import tensorflow as tf
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1,28, 28, 1)
x_train = x_train / 255
x_test = x_test.reshape(-1, 28, 28, 1)
x_test = x_test / 255
from tensorflow.keras.layers import Input, Conv2D, GlobalAveragePooling2D, MaxPooling2D
from tensorflow.keras.models import Model
inputs = Input(shape = (28, 28, 1))
x = Conv2D(32, (3, 3), strides = (1, 1), padding = 'same', activation = 'relu')(inputs)
x = MaxPooling2D(strides = (2, 2))(x)
x = Conv2D(64, (3, 3), strides = (1, 1), padding = 'same', activation = 'relu')(x)
x = MaxPooling2D(strides = (2, 2))(x)
x = Conv2D(64, (3, 3), strides = (1, 1), padding = 'same', activation = 'relu')(x)
x = MaxPooling2D(strides = (2, 2))(x)
x = Conv2D(10, (1, 1), activation = 'softmax')(x)
x = GlobalAveragePooling2D()(x)
model = Model(inputs = inputs, outputs = x)
model.compile(optimizer = 'adam',
loss = 'sparse_categorical_crossentropy',
metrics = ['acc'])
model.fit(x_train, y_train,
epochs = 10, batch_size = 32)
2. 적절한 Reshape 혼합
import tensorflow as tf
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1,28, 28, 1)
x_train = x_train / 255
x_test = x_test.reshape(-1, 28, 28, 1)
x_test = x_test / 255
from tensorflow.keras.layers import Input, Conv2D, Reshape
from tensorflow.keras.layers import GlobalAveragePooling2D, MaxPooling2D
from tensorflow.keras.models import Model
inputs = Input(shape = (28, 28, 1))
x = Conv2D(32, (3, 3), strides = (1, 1), padding = 'same', activation = 'relu')(inputs)
x = MaxPooling2D(strides = (2, 2))(x)
x = Conv2D(64, (3, 3), strides = (1, 1), padding = 'same', activation = 'relu')(x)
x = MaxPooling2D(strides = (2, 2))(x)
x = Conv2D(64, (3, 3), strides = (1, 1), padding = 'same', activation = 'relu')(x)
x = MaxPooling2D(strides = (2, 2))(x)
x = Conv2D(128, (1, 1), padding = 'same', activation = 'relu')(x)
x = GlobalAveragePooling2D()(x)
x = Reshape((1, 1, 128))(x)
x = Conv2D(10, (1, 1), padding = 'same', activation = 'softmax')(x)
x = Reshape((10,))(x)
model = Model(inputs = inputs, outputs = x)
model.compile(optimizer = 'adam',
loss = 'sparse_categorical_crossentropy',
metrics = ['acc'])
model.fit(x_train, y_train,
epochs = 10, batch_size = 32)
위에서 소개한 2가지 방법의 차이는 거의 없습니다. 이렇게도 할 수 있구나...를 보여주기 위해서
1x1 conv에 익숙하지 않은 분들에게 도움이 되길 바랍니다.