두 가지 방법으로 정의하여 사용할 수 있습니다.
예시로는 이 글이 작성된 2019/10/27일을 기준으로 가장 최신의 활성화 함수라고 말할 수 있는 Mish Activation을 사용하였습니다.
[해당 논문]: https://arxiv.org/abs/1908.08681
케라스의 Activation 함수에 그대로 넣어서 사용하기
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Flatten, Activation
from tensorflow.keras.models import Model
(x_train, y_train), (x_test, y_test) = mnist.load_data()
def mish(x):
return x * K.tanh(K.softplus(x))
inputs = Input(shape = (28, 28))
x = Flatten()(inputs)
x = Dense(50)(x)
x = Activation(mish)(x)
x = Dense(30)(x)
x = Activation(mish)(x)
x = Dense(10, activation = 'softmax')(x)
model = Model(inputs = inputs, outputs = x)
model.compile(optimizer = 'adam',
loss = 'sparse_categorical_crossentropy')
model.fit(x_train, y_train)
위의 코드와 같이 직접 activation 함수에 쓰일 연산을 정의하여 인자로 넘겨줄 수 있습니다.
함수를 등록하여 케라스의 특징인 문자열 형태로 제공하여 쓰기
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Flatten, Activation
from tensorflow.keras.models import Model
from tensorflow.keras.utils import get_custom_objects
(x_train, y_train), (x_test, y_test) = mnist.load_data()
class Mish(Activation):
def __init__(self, activation, **kwargs):
super(Mish, self).__init__(activation, **kwargs)
self.__name__ = 'Mish'
def mish(x):
return x * K.tanh(K.softplus(x))
get_custom_objects().update({'mish': Mish(mish)})
inputs = Input(shape = (28, 28))
x = Flatten()(inputs)
x = Dense(50)(x)
x = Activation('mish')(x)
x = Dense(30)(x)
x = Activation('mish')(x)
x = Dense(10, activation = 'softmax')(x)
model = Model(inputs = inputs, outputs = x)
model.compile(optimizer = 'adam',
loss = 'sparse_categorical_crossentropy')
model.fit(x_train, y_train)
위와 같이 클래스로 정의한 뒤, get_custom_objects를 사용해 등록하여 사용할 수 있습니다.
경우에 따라 1번과 2번 중에 편리한 것이 있을 수 있으니 선택하여 사용하시면 될 것 같네요.
'# Machine Learning > Keras Implementation' 카테고리의 다른 글
keras custom generator - 2 (0) | 2020.01.31 |
---|---|
Keras, 1x1 Convolution만 사용해서 MNIST 학습시키기 (0) | 2019.11.05 |
keras Custom generator - 1 (1) | 2019.07.29 |
TTA(test time augmentation) with 케라스 (2) | 2019.07.01 |
Keras callback함수 쓰기 (0) | 2018.12.23 |