이 글은 keras 공식 문서의 'Few-Shot learning with Reptile' 글을 번역한 것입니다.


Introduction

Reptile 알고리즘은 Meta-Learning을 위해 open-AI가 개발한 알고리즘입니다. 특히, 이 알고리즘은 최소한의 training으로 새로운 task를 수행하기 위한 학습을 빠르게 진행하기 위해 고안되었습니다(Few-Shot learning). Few-shot learning의 목적은 소수의 데이터만 학습한 모델이 보지 못한 새로운 데이터를 접했을 때, 이를 분류하게 하는 것입니다. 예를 들어, 희귀병 진단이나 레이블 cost가 높은 경우에 유용한 방법이라고 할 수 있습니다. 즉, 해결해야 될 문제에서 데이터가 적을때 충분히 고려해볼만한 방법입니다.

알고리즘은 이전에 보지 못한 데이터의 mini-batch로 학습된 가중치와 고정된 meta-iteration을 통해 사전 학습된 모델 가중치 사이의 차이를 최적화하는 SGD와 함께 작동합니다. 즉, n번 만큼의 횟수에서 mini-batch로 학습하여 얻은 가중치와 모델 weight initialization 차이를 최소하하면서 optimal weight를 update하게 됩니다.
(동일 클래스 이미지에 계속해서 다른 레이블로 학습하는데 어떻게 정답을 맞추는지 정말 신기합니다. 이 부분을 알려면 더 깊게 공부해봐야 알듯...)

필요 라이브러리를 임포트합니다.

import matplotlib.pyplot as plt
import numpy as np
import random
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds

하이퍼파라미터 정의

learning_rate = 0.003 # 학습률
meta_step_size = 0.25 # 메타 스템 크기

inner_batch_size = 25 # 학습 배치 크기
eval_batch_size = 25 # 테스트 배치 크기

meta_iters = 2000 # 학습 총 횟수
eval_iters = 5 # 평가 데이터셋 반복 횟수
inner_iters = 4 # 학습 데이터셋 반복 횟수

eval_interval = 1 # 평가 기간 설정
train_shots = 20 # 학습에서 각 클래스마다 몇 개 샘플을 가져올지
shots = 5 # 평가에서 각 클래스마다 몇 개 샘플을 가져올지
classes = 5 # 몇 개 클래스를 사용할지

여기서는 5-way 5-shot이라고 표현할 수 있겠네요(N-way, k-shot)
--> 5개 클래스에서 5장 뽑아서 이를 학습에 사용하겠다는 의미

 

데이터 준비

Omniglot 데이터셋은 각 character별로 20개 샘플을 가지고 있는 50가지 알파벳(즉, 50개의 다른 언어)으로부터 얻어진 1,623개 데이터로 이루어져 있습니다(서로 다른 나라의 문자에서 특정 클래스들을 모아둔 데이터셋입니다). 여기서 20개 샘플은 Amazon's Mechanical Turk으로 그려졌습니다.

few-shot learning task를 수행하기 위해 k개 샘플은 무작위로 선택된 n개 클래스로부터 그려진 것을 사용합니다. n개 숫자값은 몇 가지 예에서 새 작업을 학습하는 모델 능력을 테스트하는 데 사용할 임시 레이블 집합을 만드는 데 사용됩니다. 예를 들어, 5개 클래스를 사용한다고 하면, 새로운 클래스 레이블은 0, 1, 2, 3, 4가 될 것입니다.

핵심은 모델이 계속 반복해서 class instance를 보게 되지만, 해당 instance에 대해서는 계속해서 서로 다른 레이블을 보게 됩니다. 이러한 과정에서 모델은 단순하게 클래스를 분류하는 방법이 아니라 해당 데이터를 특정 클래스로 구별하는 방법을 배워야 합니다.

Omniglot 데이터셋은 다양한 클래스에서 만족할만한 개수의 샘플을 가져다 사용할 수 있기 때문에 이번 task에 적합합니다.

class Dataset:
    # 이 클래스는 few-shot dataset을 만듭니다.
    # 새로운 레이블을 만듬과 동시에 Omniglot 데이터셋에서 샘플을 뽑습니다.
    def __init__(self, training):
        # omniglot data를 포함한 tfrecord file을 다운로드하고 dataset으로 변환합니다.
        split = "train" if training else "test"
        ds = tfds.load("omniglot", split=split, as_supervised=True, shuffle_files=False)
        # dataset을 순환하면서 이미지와 클래스를 data dictionary에 담을겁니다.
        self.data = {}

        def extraction(image, label):
            # RGB에서 grayscale로 변환하고, resize를 수행합니다.
            image = tf.image.convert_image_dtype(image, tf.float32)
            image = tf.image.rgb_to_grayscale(image)
            image = tf.image.resize(image, [28, 28])
            return image, label

        for image, label in ds.map(extraction):
            image = image.numpy()
            label = str(label.numpy())
            # 레이블이 존재하지 않으면 data dictionary에 넣고,
            if label not in self.data:
                self.data[label] = []
            # 존재하면 해당 레이블에 붙여줍니다.
            self.data[label].append(image)
            # 레이블을 저장합니다.
            self.labels = list(self.data.keys())

    def get_mini_dataset(self, batch_size, repetitions, shots, num_classes, split=False):
        # num_classes * shots 수 만큼 이미지와 레이블을 가져올 것입니다.
        temp_labels = np.zeros(shape=(num_classes * shots))
        temp_images = np.zeros(shape=(num_classes * shots, 28, 28, 1))
        if split:
            test_labels = np.zeros(shape=(num_classes))
            test_images = np.zeros(shape=(num_classes, 28, 28, 1))
          
        # 전체 레이블 셋에서 랜덤하게 몇 개의 레이블만 가져옵니다.
        # self.labels는 omniglot에서 가져온 데이터의 전체 레이블을 담고 있습니다.
        # 전체 레이블에서 num_classes만큼 label을 랜덤하게 뽑습니다.
        label_subset = random.choices(self.labels, k=num_classes)
        for class_idx, _ in enumerate(label_subset):
            # few-shot learning mini-batch를 위한 현재 레이블값으로 enumerate index를 사용합니다.
            # temp_label에서는 shots 수만큼 class_idx를 가지게 됩니다.[0, 0, 0, 0, 0]
            # 해당되는 이미지의 레이블은 계속해서 변하는 점을 참고
            temp_labels[class_idx * shots : (class_idx + 1) * shots] = class_idx

            # 테스트를 위한 데이터셋을 만든다면, 선택된 sample과 label을 넣어줍니다.
            # 테스트 과정에서 temp_label과 test_label은 동일해야 합니다.
            if split:
                test_labels[class_idx] = class_idx
                images_to_split = random.choices(self.data[label_subset[class_idx]], 
                                                 k=shots + 1)
                test_images[class_idx] = images_to_split[-1]
                temp_images[class_idx * shots : (class_idx + 1) * shots] = images_to_split[:-1]
            else:
                # label_subset[class_idx]에 해당하는 이미지를 shots만큼 temp_images에 담습니다.
                # 이로써 temp_images는 label_subset[class_idx] 레이블에 관한 이미지를 가지고 있지만,
                # temp_labels는 이에 상관없이 class_idx 클래스를 가집니다.
                temp_images[
                    class_idx * shots : (class_idx + 1) * shots
                ] = random.choices(self.data[label_subset[class_idx]], k=shots)

        # TensorFlow Dataset을 만듭니다.
        dataset = tf.data.Dataset.from_tensor_slices(
            (temp_images.astype(np.float32), temp_labels.astype(np.int32))
        )
        dataset = dataset.shuffle(100).batch(batch_size).repeat(repetitions)

        if split:
            return dataset, test_images, test_labels
            
        return dataset


import urllib3

urllib3.disable_warnings()  # Disable SSL warnings that may happen during download.
train_dataset = Dataset(training=True)
test_dataset = Dataset(training=False)

 

모델 구성

def conv_bn(x):
    x = layers.Conv2D(filters=64, kernel_size=3, strides=2, padding="same")(x)
    x = layers.BatchNormalization()(x)
    return layers.ReLU()(x)


inputs = layers.Input(shape=(28, 28, 1))
x = conv_bn(inputs)
x = conv_bn(x)
x = conv_bn(x)
x = conv_bn(x)
x = layers.Flatten()(x)
outputs = layers.Dense(classes, activation="softmax")(x)

model = keras.Model(inputs=inputs, outputs=outputs)
model.compile()
optimizer = keras.optimizers.SGD(learning_rate=learning_rate)

 

모델 학습

training = []
testing = []
for meta_iter in range(meta_iters):
    frac_done = meta_iter / meta_iters
    cur_meta_step_size = (1 - frac_done) * meta_step_size
    # 현재 모델 weights를 저장합니다.
    old_vars = model.get_weights()
    # 전체 데이터셋에서 샘플을 가져옵니다.
    mini_dataset = train_dataset.get_mini_dataset(
        inner_batch_size, inner_iters, train_shots, classes
    )

    for images, labels in mini_dataset:
        with tf.GradientTape() as tape:
            preds = model(images)
            loss = keras.losses.sparse_categorical_crossentropy(labels, preds)

        grads = tape.gradient(loss, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

    # 위의 loss로 업데이트된 weight를 가져옵니다.
    new_vars = model.get_weights()
    # meta step with SGD
    for var in range(len(new_vars)):
        new_vars[var] = old_vars[var] + (
            (new_vars[var] - old_vars[var]) * cur_meta_step_size
        )

    # meta-learning 단계를 수행했다면, 모델 가중치를 다시 업데이트해줍니다.
    model.set_weights(new_vars)

    # Evaluation loop
    if meta_iter % eval_interval == 0:
        accuracies = []
        for dataset in (train_dataset, test_dataset):
            # 전체 데이터셋에서 mini dataset을 가져옵니다.
            train_set, test_images, test_labels = dataset.get_mini_dataset(
                eval_batch_size, eval_iters, shots, classes, split=True
            )
            old_vars = model.get_weights()
            # Train on the samples and get the resulting accuracies.
            for images, labels in train_set:
                with tf.GradientTape() as tape:
                    preds = model(images)
                    loss = keras.losses.sparse_categorical_crossentropy(labels, preds)

                grads = tape.gradient(loss, model.trainable_weights)
                optimizer.apply_gradients(zip(grads, model.trainable_weights))
            
            test_preds = model.predict(test_images)
            test_preds = tf.argmax(test_preds).numpy()
            num_correct = (test_preds == test_labels).sum()

            # 평가 acc를 얻었다면 다시 weight를 초기화해야합니다.
            model.set_weights(old_vars)
            accuracies.append(num_correct / classes)

        training.append(accuracies[0])
        testing.append(accuracies[1])

        if meta_iter % 100 == 0:
            print(
                "batch %d: train=%f test=%f" % (meta_iter, accuracies[0], accuracies[1])
            )

 

결과 시각화

# First, some preprocessing to smooth the training and testing arrays for display.
window_length = 100
train_s = np.r_[
    training[window_length - 1 : 0 : -1], training, training[-1:-window_length:-1]
]
test_s = np.r_[
    testing[window_length - 1 : 0 : -1], testing, testing[-1:-window_length:-1]
]
w = np.hamming(window_length)
train_y = np.convolve(w / w.sum(), train_s, mode="valid")
test_y = np.convolve(w / w.sum(), test_s, mode="valid")

# 학습 acc를 그립니다.
x = np.arange(0, len(test_y), 1)
plt.plot(x, test_y, x, train_y)
plt.legend(["test", "train"])
plt.grid()

train_set, test_images, test_labels = dataset.get_mini_dataset(
    eval_batch_size, eval_iters, shots, classes, split=True
)

for images, labels in train_set:
    with tf.GradientTape() as tape:
        preds = model(images)
        loss = keras.losses.sparse_categorical_crossentropy(labels, preds)

    grads = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))

test_preds = model.predict(test_images)
test_preds = tf.argmax(test_preds).numpy()

_, axarr = plt.subplots(nrows=1, ncols=5, figsize=(20, 20))

for i, ax in zip(range(5), axarr):
    temp_image = np.stack((test_images[i, :, :, 0],) * 3, axis=2)
    temp_image *= 255
    temp_image = np.clip(temp_image, 0, 255).astype("uint8")
    ax.set_title(
        "Label : {}, Prediction : {}".format(int(test_labels[i]), test_preds[i])
    )
    ax.imshow(temp_image, cmap="gray")
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)
plt.show()

 

Reference

Few-shot learning은 카카오 기술 블로그에도 설명되어 있습니다. >> 클릭

talkingaboutme.tistory.com/entry/DL-Meta-Learning-Learning-to-Learn-Fast

www.borealisai.com/en/blog/tutorial-2-few-shot-learning-and-meta-learning-i/

'# Machine Learning > TensorFlow doc 정리' 카테고리의 다른 글

tf.data tutorial 번역 (5)  (0) 2020.02.28
tf.data tutorial 번역 (4)  (1) 2020.02.27
tf.data tutorial 번역 (3)  (0) 2020.02.26
tf.data tutorial 번역 (2)  (0) 2020.02.18
tf.data tutorial 번역 (1)  (0) 2020.02.14