결론부터 말하자면, Gradient Accumulation 방법은 GPU memory issue를 보완하기 위한 방법입니다.

배치 크기는 성능에 영향을 주는 중요한 하이퍼파라미터 중 하나인데요.
이 때문에 배치 크기(Batch Size)와 성능(Performance) 관계는 이전부터 많은 방법(다양한 조건에서의 실험, 수학적 풀이, 최적화 관점 등)으로 연구되어 왔습니다.

대표적인 논문만 간단히 보자면,

위 논문뿐만 아니라 관련 논문을 보면 아래와 같이 결론을 얻을 수 있습니다.

    ① LB(Large Batch)는 학습 속도도 빠르고, 성능도 좋을 수 있다
    ② LB 일수록 더 큰 학습률(Learning Rate)를 사용하면 좋다 (하지만 큰 학습률을 사용하면 수렴이 안되는 문제가?)
    ③ 하지만 LB보다 SB(Small Batch)가 안정적으로 학습할 수 있어 일반화(Generalization)에 강하다

Gradient Accumulation 방법이 배치 크기와 관련이 있다보니 간단하게 배치 크기 관련 이야기를 해보았는데, 이번 글은 어떤 이유에서든 "비록 가지고 있는 GPU memory가 제한적이지만 일단 LB를 사용해서 학습시키고 싶다."라는 생각입니다.

Gradient Accumulation?

큰 배치 크기를 활용할때 가장 큰 문제는 '성능이 좋아질 수 있다', '안좋아질 수 있다'가 아닐겁니다. 아마도 가장 중요하게 고려하는 것은 지금 보유하고 있는 GPU를 가지고 '1,024, 2,048, ... 처럼 큰 배치 크기를 돌려볼 수 있느냐'입니다.

누구나 다 겪는 문제인데요. 만약 이와 같은 문제에 직면했다면, 이를 해결하기 위한 방법 중 하나로 Gradient Accumulation 방법을 생각해볼 수 있습니다. 가장 좋은 방법은 돈을 들이는 것...

일반적인 학습 방법과 Gradient Accumulation 방법의 차이는 아래 그림에서 바로 확인할 수 있습니다.

General Training vs Training with GA

일반적인 학습 방법이 미니 배치를 통해 gradient를 구한 후, Update를 즉시 진행하는 방법이라면,

Gradient Accumulation 방법을 포함한 학습은 미니 배치를 통해 구해진 gradient를 n-step동안 Global Gradients에 누적시킨 후, 한번에 업데이트하는 방법입니다.

매우 간단합니다. 즉, 핵심 아이디어는 256 mini-batch를 통해 gradient를 업데이트하는 것과 64 mini-batch를 4번(64 * 4 = 256) 누적하여 업데이트하는 것이 비슷한 결과를 가져다 줄 것이라는 생각입니다.

이 예와 같다면, 우리가 보유하고 있는 GPU memory 한계상 64 mini-batch 밖에 사용할 수 없다면, 이 학습 방법을 통해 4번 누적하여 업데이트할 경우 256 mini-batch를 사용하여 학습하는 것과 다름없게 되는 셈이죠.

하지만 이 방법은 memory issue를 중점적으로 해결하고자 하는 방법이지 속도, 성능과는 아직 explicit하게 증명된 바가 없습니다.

아래의 간단한 구현을 통해 위에서 글로서 살펴본 것보다 좀 더 쉽게 이해할 수 있을 거에요.


전체 코드는 깃허브_저장소에 있습니다.

Setup

from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Conv2D, MaxPooling2D, Flatten
from tensorflow.keras.models import Model
import tensorflow as tf

print(tf.__version__)

MNIST Dataset Laod

(x_train, y_train), (x_test, y_test) = mnist.load_data()

Make TF Dataset

def make_datasets(x, y):
    # (28, 28) -> (28, 28, 1)
    def _new_axis(x, y):
        y = tf.one_hot(y, depth = 10)
        
        return x[..., tf.newaxis], y
            
    ds = tf.data.Dataset.from_tensor_slices((x, y))
    ds = ds.map(_new_axis, num_parallel_calls = tf.data.experimental.AUTOTUNE)
    ds = ds.shuffle(100).batch(32) # 배치 크기 조절하세요
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
    
    return ds
    
ds = make_datasets(x_train, y_train)

Make Models

# rescaling, 1 / 255
preprocessing_layer = tf.keras.models.Sequential([
        tf.keras.layers.experimental.preprocessing.Rescaling(1./255),
    ])

# simple CNN model
def get_model():
    inputs = Input(shape = (28, 28, 1))
    preprocessing_inputs = preprocessing_layer(inputs)
    
    x = Conv2D(filters = 32, kernel_size = (3, 3), activation='relu')(preprocessing_inputs)
    x = MaxPooling2D((2, 2))(x)
    x = Conv2D(filters = 64, kernel_size = (3, 3), activation='relu')(x)
    x = MaxPooling2D((2, 2))(x)
    x = Conv2D(filters = 64, kernel_size =(3, 3), activation='relu')(x)
    
    x = Flatten()(x)
    x = Dense(64, activation = 'relu')(x)
    outputs = Dense(10, activation = 'softmax')(x)
    
    model = Model(inputs = inputs, outputs = outputs)
    
    return model

model = get_model()
model.summary()

Training with Gradient Accumulation

epochs = 10
num_accum = 4 # 누적 횟수

loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits = True)
optimizer = tf.keras.optimizers.Adam()
train_acc_metric = tf.keras.metrics.CategoricalAccuracy()
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x)
        loss_value = loss_fn(y, logits)
    gradients = tape.gradient(loss_value, model.trainable_weights)
    
    # update metrics    
    train_acc_metric.update_state(y, logits)
    
    return gradients, loss_value

def train():
    for epoch in range(epochs):
        print(f'################ Start of epoch: {epoch} ################')
        # 누적 gradient를 담기 위한 zeros_like 선언
        accumulation_gradients = [tf.zeros_like(ele) for ele in model.trainable_weights]
        
        for step, (batch_x_train, batch_y_train) in enumerate(ds):
            gradients, loss_value = train_step(batch_x_train, batch_y_train)
            
            if step % num_accum == 0:
                accumulation_gradients = [grad / num_accum for grad in accumulation_gradients]
                optimizer.apply_gradients(zip(gradients, model.trainable_weights))

                # zero-like init
                accumulation_gradients = [tf.zeros_like(ele) for ele in model.trainable_weights]
            else:
                accumulation_gradients = [(accum_grad + grad) for accum_grad, grad in zip(accumulation_gradients, gradients)]

            if step % 100 == 0:
                print(f"Loss at Step: {step} : {loss_value:.4f}")
            
        train_acc = train_acc_metric.result()
        print(f'Accuracy : {(train_acc * 100):.4f}%')
        train_acc_metric.reset_states()
        
# start training
train()