Custom data generator를 만들 때는 keras.utils.Sequence 클래스를 상속하는 것으로 시작합니다.

Sequence는 __getitem__, __len__, on_epoch_end, __iter__를 sub method로서 가지고 있습니다.

따라서, 이들을 우리의 데이터에 맞게 변형하여 사용하게 됩니다.


MNIST는 예를 들기 위해 사용했습니다.

import tensorflow as tf
from tensorflow.keras.utils import Sequence
from tensorflow.keras.utils import to_categorical
import numpy as np

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

class DataGenerator(Sequence):
    def __init__(self, X, y, batch_size, dim, n_channels, n_classes, shuffle = True):
        self.X = X
        self.y = y if y is not None else y
        self.batch_size = batch_size
        self.dim = dim
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.on_epoch_end()
        
    def on_epoch_end(self):
        self.indexes = np.arange(len(self.X))
        if self.shuffle:
            np.random.shuffle(self.indexes)
            
    def __len__(self):
        return int(np.floor(len(self.X) / self.batch_size))
    
    def __data_generation(self, X_list, y_list):
        X = np.empty((self.batch_size, *self.dim))
        y = np.empty((self.batch_size), dtype = int)
        
        if y is not None:
            # 지금 같은 경우는 MNIST를 로드해서 사용하기 때문에
            # 배열에 그냥 넣어주면 되는 식이지만,
            # custom image data를 사용하는 경우 
            # 이 부분에서 이미지를 batch_size만큼 불러오게 하면 됩니다. 
            for i, (img, label) in enumerate(zip(X_list, y_list)):
                X[i] = img
                y[i] = label
                
            return X, to_categorical(y, num_classes = self.n_classes)
        
        else:
            for i, img in enumerate(X_list):
                X[i] = img
                
            return X
        
    def __getitem__(self, index):
        indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
        X_list = [self.X[k] for k in indexes]
        
        if self.y is not None:
            y_list = [self.y[k] for k in indexes]
            X, y = self.__data_generation(X_list, y_list)
            return X, y
        else:
            y_list = None
            X = self.__data_generation(X_list, y_list)
            return X

__data_generation부분을 수정해서 사용하면 됩니다. 원래 같은 경우는 X를 정의할 때, (batch_size, *dim, n_channel)로 정의하나 편의를 위해 지웠습니다. MNIST 데이터셋은 (28, 28, 1)이 아닌 (28, 28)로 인식되기 때문. 필요에 따라 수정하시면 됩니다.


dg = DataGenerator(x_train, y_train, 4, (28, 28), 1, 10)

import matplotlib.pyplot as plt

for i, (x, y) in enumerate(dg):
    if(i <= 1):
        x_first = x[0]
        plt.title(y[0])
        plt.imshow(x_first)