이글은 다음 문서를 참조합니다.

www.tensorflow.org/guide/keras/custom_callback

(번역은 자력 + 파파고 + 구글 번역기를 사용하였으니, 부자연스럽더라도 양해바랍니다.)


Keras model을 변경하고 읽거나 학습, 평가, 추론하는 동안에 custom callbacks는 Keras model을 customize하는데에 있어서 강력한 도구입니다. 그 예로 tf.keras.callbacks가 있습니다. 또한, Tensorboard는 학습 진전과 결과를 도출하고 시각화할 수 있으며, tf.keras.callbacks.ModelCheckpoint는 자동적으로 학습 또는 그 외의 행동에 대한 결과를 자동으로 저장해줍니다. 이

이번 가이드에서는 이러한 것들이 무엇을 하며 언제 불리는지에 대해 다루고 어떻게 build하는지 설명합니다. 

 

Introduction to Keras callbacks


케라스에서 Callback은 학습(batch/epoch start and ends), 평가, 추론의 다양한 단계에서 호출할 수 있는 메소드의 집합으로서 기능적으로 구체적인 정보들에 대해 접근할 수 있습니다. 학습하는 동안 모델의 statistics(accuracy, recall etc.)와 내부 상태를 관찰하는데 매우 유용합니다. 

tf.keras.Model.fit(), tf.keras.Model.evaluate(), tf.keras.Model.predict()callbacks list인자를 주어 사용할 수 있습니다. 

먼저 tensorflow를 import하고 간단한 모델을 만들어 봅니다.

# Define the Keras model to add callbacks to
def get_model():
  model = tf.keras.Sequential()
  model.add(tf.keras.layers.Dense(1, activation = 'linear', input_dim = 784))
  model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.1), loss='mean_squared_error', metrics=['mae'])
  return model
  
  # Load example MNIST data and pre-process it
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

이제 batch start, end를 추적할 수 있는 간단한 custom callback을 작성합니다. 이 코드는 각 batch에 대한 정보를 나타냅니다.

import datetime

class MyCustomCallback(tf.keras.callbacks.Callback):

  def on_train_batch_begin(self, batch, logs=None):
    print('Training: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))

  def on_train_batch_end(self, batch, logs=None):
    print('Training: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))

  def on_test_batch_begin(self, batch, logs=None):
    print('Evaluating: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))

  def on_test_batch_end(self, batch, logs=None):
    print('Evaluating: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))

tf.keras.Model.fit()의 인자로 전달해주면 각 단계에 대한 정보를 제공합니다.

model = get_model()
_ = model.fit(x_train, y_train,
          batch_size=64,
          epochs=1,
          steps_per_epoch=5,
          verbose=0,
          callbacks=[MyCustomCallback()])

 

Model methods that take callbacks


위와 같은 기능들은 다음 tf.keras.Model methods에서도 사용할 수 있습니다.

  • fit(), fit_generator() : 고정된 epochs만큼 학습합니다.
  • evaluate(), evaluate_generator() : loss, metrics values를 도출합니다.
  • predict(), predict_generator() : input data or data generator에 대한 추론 결과를 도출합니다.
_ = model.evaluate(x_test, y_test, batch_size=128, verbose=0, steps=5,
          callbacks=[MyCustomCallback()])

 

An overview of callback methods

Common methods for training/testing/predicting


학습, 평가, 추론시에 callback 함수는 다음과 같은 methods를 override합니다.

  • on_(train|test|predict)_begin(self, logs = None)
  • on_(train|test|predict)_end(self, logs = None)
  • on_(train|test|predict)_batch_begin(self, batch, logs = None)

: 이 method에서 log는 현재 batch number와 size를 dict형태로 가지고 있습니다.

: ex) logs["size"], logs["batch"] ( 아래 예시 참고 )

  • on_(train|test|predict)_batch_end(self, batch, logs = None)

: logs는 merics result를 담고 있습니다. ( 아래 예시 참고 )

 

Training specific methods


학습시에 추가로 제공됩니다.

  • on_epoch_begin(self, epoch, logs = None)
  • on_epoch_end(self, epoch, logs = None)

 

Usage of logs dict


logs dict는 loss value, epoch or batch의 metrics를 포함합니다.

class LossAndErrorPrintingCallback(tf.keras.callbacks.Callback):

  def on_train_batch_end(self, batch, logs=None):
    print('For batch {}, loss is {:7.2f}.'.format(batch, logs['loss']))

  def on_test_batch_end(self, batch, logs=None):
    print('For batch {}, loss is {:7.2f}.'.format(batch, logs['loss']))

  def on_epoch_end(self, epoch, logs=None):
    print('The average loss for epoch {} is {:7.2f} and mean absolute error is {:7.2f}.'.format(epoch, logs['loss'], logs['mae']))

model = get_model()
_ = model.fit(x_train, y_train,
          batch_size=64,
          steps_per_epoch=5,
          epochs=3,
          verbose=0,
          callbacks=[LossAndErrorPrintingCallback()])

이와 같이, evaluate()에서도 동일하게 사용가능합니다.

 

Examples of Keras callback applications

Early stopping at minimum loss


다음 예제는 model.stop_training(boolean)을 이용하여 최소 loss에 도달했을때 학습을 중단시킵니다. 

학습을 중단하기 전에 사용자가 얼마나 기다려야 하는지에 대한 정보를 patience로 제공받게 됩니다.

(patience가 10이라면 최소 손실로부터 10epoch동안 변화가 없을 시 earlystop)

import numpy as np

class EarlyStoppingAtMinLoss(tf.keras.callbacks.Callback):
  """Stop training when the loss is at its min, i.e. the loss stops decreasing.

  Arguments:
      patience: Number of epochs to wait after min has been hit. After this
      number of no improvement, training stops.
  """

  def __init__(self, patience=0):
    super(EarlyStoppingAtMinLoss, self).__init__()

    self.patience = patience

    # best_weights to store the weights at which the minimum loss occurs.
    self.best_weights = None

  def on_train_begin(self, logs=None):
    # The number of epoch it has waited when loss is no longer minimum.
    self.wait = 0
    # The epoch the training stops at.
    self.stopped_epoch = 0
    # Initialize the best as infinity.
    self.best = np.Inf

  def on_epoch_end(self, epoch, logs=None):
    current = logs.get('loss')
    if np.less(current, self.best):
      self.best = current
      self.wait = 0
      # Record the best weights if current results is better (less).
      self.best_weights = self.model.get_weights()
    else:
      self.wait += 1
      if self.wait >= self.patience:
        self.stopped_epoch = epoch
        self.model.stop_training = True
        print('Restoring model weights from the end of the best epoch.')
        self.model.set_weights(self.best_weights)

  def on_train_end(self, logs=None):
    if self.stopped_epoch > 0:
      print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
model = get_model()
_ = model.fit(x_train, y_train,
          batch_size=64,
          steps_per_epoch=5,
          epochs=30,
          verbose=0,
          callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()])

 

Learning Late scheduling


모델 학습동안에 일반적으로 행하는 것은 epochs에 따라 learning rate를 decay시켜주는 것입니다. Keras backend는 get_value api를 통해 이를 접근합니다. 이 예제에서 learning rate가 custom Callback에 의해 어떻게 동적으로 변화하는지 보겠습니다.

tf.keras.callbacks.LearningRateScheduler는 보다 일반적인 구현을 제공합니다.

class LearningRateScheduler(tf.keras.callbacks.Callback):
  """Learning rate scheduler which sets the learning rate according to schedule.

  Arguments:
      schedule: a function that takes an epoch index
          (integer, indexed from 0) and current learning rate
          as inputs and returns a new learning rate as output (float).
  """

  def __init__(self, schedule):
    super(LearningRateScheduler, self).__init__()
    self.schedule = schedule

  def on_epoch_begin(self, epoch, logs=None):
    if not hasattr(self.model.optimizer, 'lr'):
      raise ValueError('Optimizer must have a "lr" attribute.')
    # Get the current learning rate from model's optimizer.
    lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))
    # Call schedule function to get the scheduled learning rate.
    scheduled_lr = self.schedule(epoch, lr)
    # Set the value back to the optimizer before this epoch starts
    tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
    print('\nEpoch %05d: Learning rate is %6.4f.' % (epoch, scheduled_lr))

keras.backend.get_value와 set_value에 주목 + self.model.optimizer.lr이 현재 모델의 learning rate에 대한 정보를 제공합니다.

tf.keras.callbacks.Callback은 기본적으로 model에 대한 정보를 가지고 있습니다. self.으로 접근.

LR_SCHEDULE = [
    # (epoch to start, learning rate) tuples
    (3, 0.05), (6, 0.01), (9, 0.005), (12, 0.001)
]

def lr_schedule(epoch, lr):
  """Helper function to retrieve the scheduled learning rate based on epoch."""
  if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
    return lr
  for i in range(len(LR_SCHEDULE)):
    if epoch == LR_SCHEDULE[i][0]:
      return LR_SCHEDULE[i][1]
  return lr

model = get_model()
_ = model.fit(x_train, y_train,
          batch_size=64,
          steps_per_epoch=5,
          epochs=15,
          verbose=0,
          callbacks=[LossAndErrorPrintingCallback(), LearningRateScheduler(lr_schedule)])