1x1 Convolution은 Network in network 논문에서 주로 다룬 개념입니다. 

Lin, M., Chen, Q., & Yan, S. (2013). Network in network. arXiv preprint arXiv:1312.4400.

 

이 글에서는 MNIST 데이터셋을 Dense 층(fully-connected layer)를 사용하지 않고 학습하는 방법을 다루겠습니다.

코딩 방법에 따라 2가지로 나뉘어서 모델을 구성할 수 있습니다. 직접 돌려보니 정확도는 약 98%정도 나왔습니다.


1. 단순하게 Global Avaerage Pooling만 사용

import tensorflow as tf

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train.reshape(-1,28, 28, 1)
x_train = x_train / 255
x_test = x_test.reshape(-1, 28, 28, 1)
x_test = x_test / 255

from tensorflow.keras.layers import Input, Conv2D, GlobalAveragePooling2D, MaxPooling2D
from tensorflow.keras.models import Model


inputs = Input(shape = (28, 28, 1))
x = Conv2D(32, (3, 3), strides = (1, 1), padding = 'same', activation = 'relu')(inputs)
x = MaxPooling2D(strides = (2, 2))(x)
x = Conv2D(64, (3, 3), strides = (1, 1), padding = 'same', activation = 'relu')(x)
x = MaxPooling2D(strides = (2, 2))(x)
x = Conv2D(64, (3, 3), strides = (1, 1), padding = 'same', activation = 'relu')(x)
x = MaxPooling2D(strides = (2, 2))(x)
x = Conv2D(10, (1, 1), activation = 'softmax')(x)
x = GlobalAveragePooling2D()(x)

model = Model(inputs = inputs, outputs = x)

model.compile(optimizer = 'adam', 
              loss = 'sparse_categorical_crossentropy', 
              metrics = ['acc'])
model.fit(x_train, y_train, 
          epochs = 10, batch_size = 32)

2. 적절한 Reshape 혼합

import tensorflow as tf

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train.reshape(-1,28, 28, 1)
x_train = x_train / 255
x_test = x_test.reshape(-1, 28, 28, 1)
x_test = x_test / 255

from tensorflow.keras.layers import Input, Conv2D, Reshape
from tensorflow.keras.layers import GlobalAveragePooling2D, MaxPooling2D
from tensorflow.keras.models import Model


inputs = Input(shape = (28, 28, 1))
x = Conv2D(32, (3, 3), strides = (1, 1), padding = 'same', activation = 'relu')(inputs)
x = MaxPooling2D(strides = (2, 2))(x)
x = Conv2D(64, (3, 3), strides = (1, 1), padding = 'same', activation = 'relu')(x)
x = MaxPooling2D(strides = (2, 2))(x)
x = Conv2D(64, (3, 3), strides = (1, 1), padding = 'same', activation = 'relu')(x)
x = MaxPooling2D(strides = (2, 2))(x)
x = Conv2D(128, (1, 1), padding = 'same', activation = 'relu')(x)
x = GlobalAveragePooling2D()(x)
x = Reshape((1, 1, 128))(x)
x = Conv2D(10, (1, 1), padding = 'same', activation = 'softmax')(x)
x = Reshape((10,))(x)

model = Model(inputs = inputs, outputs = x)

model.compile(optimizer = 'adam', 
              loss = 'sparse_categorical_crossentropy', 
              metrics = ['acc'])
model.fit(x_train, y_train, 
          epochs = 10, batch_size = 32)

 

위에서 소개한 2가지 방법의 차이는 거의 없습니다. 이렇게도 할 수 있구나...를 보여주기 위해서

1x1 conv에 익숙하지 않은 분들에게 도움이 되길 바랍니다.