이글은 다음 문서를 참조합니다.

www.tensorflow.org/guide/keras/train_and_evaluate

(번역은 자력 + 파파고 + 구글 번역기를 사용하였으니, 부자연스럽더라도 양해바랍니다.)

 


Using callbacks

케라스에서 callback은 학습(epoch 시작, 끝 시점, batch의 끝시점 등)하는 동안 서로 다른 지점에서 호출되어지고, 다음과 같은 기능을 위해 사용되어집니다.

  • 학습하는 동안 다른 지점에서 검증하기
  • 일정 구간이나 특정 accuracy threshold를 초과할 때 모델에서 checkpointing
  • 학습이 안정적일때 상단 레이어에서 fine-tuning
  • email을 보내거나 학습 종료시에 알림보내기 또는 특정 성능이 일정 수준을 초과했을 때
  • Etc.

Callback은 fit함수를 통해 사용할 수 있습니다.

model = get_compiled_model()

callbacks = [
    keras.callbacks.EarlyStopping(
        # Stop training when `val_loss` is no longer improving
        monitor='val_loss',
        # "no longer improving" being defined as "no better than 1e-2 less"
        min_delta=1e-2,
        # "no longer improving" being further defined as "for at least 2 epochs"
        patience=2,
        verbose=1)
]
model.fit(x_train, y_train,
          epochs=20,
          batch_size=64,
          callbacks=callbacks,
          validation_split=0.2)

 

Many built-in callbacks are available

  • ModelCheckpoint : 주기적으로 모델을 저장하기
  • EarlyStopping : 학습하는 동안 validation이 더이상 향상되지 않을 경우 학습을 중단하기
  • TensorBoard : 각종 파라미터에 대해 TensorBoard로 사용할 수 있음
  • CSVLogger : loss와 metrics를 csv형태로 저장하기
  • etc.

 

Writing your own callback

keras.callbacks.Callback을 통해 customizing할 수 있습니다. 

콜백은 클래스 속성 self.model을 통해 관련 모델에 액세스할 수 있습니다.

class LossHistory(keras.callbacks.Callback):

    def on_train_begin(self, logs):
        self.losses = []

    def on_batch_end(self, batch, logs):
        self.losses.append(logs.get('loss'))

 

Checkpointing models

상대적으로 큰 데이터셋을 학습할 때, 빈번하게 모델의 체크포인트를 저장하는 것은 매우 중요합니다.

ModelCheckpoint callback을 통해 쉽게 구현할 수 있습니다.

model = get_compiled_model()

callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath='mymodel_{epoch}.h5',
        # Path where to save the model
        # The two parameters below mean that we will overwrite
        # the current checkpoint if and only if
        # the `val_loss` score has improved.
        save_best_only=True,
        monitor='val_loss',
        verbose=1)
]
model.fit(x_train, y_train,
          epochs=3,
          batch_size=64,
          callbacks=callbacks,
          validation_split=0.2)

모델을 저장하고 복원하기 위한 콜백도 직접 작성할 수 있습니다.

 

Using learning rate schedules

딥러닝 모델을 학습할 때 공통적인 패턴은 점진적으로 learning rate를 감소 시키는 것입니다. 이를 "learning rate decay"라고 부릅니다.

learning decay schedule은 정적(현재 epoch 또는 현재 batch index의 함수로서 미리 고정) 또는 동적(특히 검증 손실)일 수 있다.

 

Passing a schedule to an optimizer

우리는 optimizer에서 learning_rate인자를 통해 schedule object를 전달하면 static learning rate schedule을 쉽게 사용할 수 있습니다.

initial_learning_rate = 0.1
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=100000,
    decay_rate=0.96,
    staircase=True)

optimizer = keras.optimizers.RMSprop(learning_rate=lr_schedule)

내장된 함수는 다음과 같습니다 : ExponentialDecay, PiecewiseConstantDecay, PolynomialDecay and InverseTimeDecay

 

Using callbacks to implement a dynamic leraning rate schedule

동적 learning rate schedule(예를 들어, validation loss가 더 이상 감소하지 않을 때 lr을 감소시켜주는 경우)는 optimizer가 validation metrics에 허용되지 않는다면 이러한 스케줄을 행할 수 없습니다.

그러나 callback은 validation metrics를 포함하여 모든 평가지표에 접근할 수 있습니다! 그래서 optimizer의 현재 lr을 수정하는 callback을 사용하는 패턴을 사용할 수 있습니다. 사실, ReduceLROnPlateau callback함수로 내장되어 있습니다.

 

Visualizing loss and metrics during training

학습하는 동안 모델을 시각적으로 유지하는 가장 좋은 방법은 TensorBoard를 사용하는 것입니다. 

  • 학습, 평가에 대한 loss 및 metrics를 실시간으로 plot할 수 있습니다.
  • (optionally) layer activations에 대한 히스토그램을 시각화 할 수 있습니다.
  • (optionally) Embedding layers에 의해 학습된 임베딩 공간을 3D 시각화 할 수 있습니다.

pip를 통해 텐서플로우를 설치했다면 다음과 같이 사용할 수 있습니다.

tensorboard --logdir=/full_path_to_your_logs

 

Using the TensorBoard callback

Keras 모델로 TensorBoard를 사용할 수 있는 가장 쉬운 방법은 TensorBoard Callback함수를 사용하는 것입니다.

tensorboard_cbk = keras.callbacks.TensorBoard(log_dir='/full_path_to_your_logs')
model.fit(dataset, epochs=10, callbacks=[tensorboard_cbk])

TensorBoard는 유용한 옵션들이 많습니다.

keras.callbacks.TensorBoard(
  log_dir='/full_path_to_your_logs',
  histogram_freq=0,  # How often to log histogram visualizations
  embeddings_freq=0,  # How often to log embedding visualizations
  update_freq='epoch')  # How often to write logs (default: once per epoch)