Custom data generator를 만들 때는 keras.utils.Sequence 클래스를 상속하는 것으로 시작합니다.
Sequence는 __getitem__, __len__, on_epoch_end, __iter__를 sub method로서 가지고 있습니다.
따라서, 이들을 우리의 데이터에 맞게 변형하여 사용하게 됩니다.
MNIST는 예를 들기 위해 사용했습니다.
import tensorflow as tf
from tensorflow.keras.utils import Sequence
from tensorflow.keras.utils import to_categorical
import numpy as np
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
class DataGenerator(Sequence):
def __init__(self, X, y, batch_size, dim, n_channels, n_classes, shuffle = True):
self.X = X
self.y = y if y is not None else y
self.batch_size = batch_size
self.dim = dim
self.n_channels = n_channels
self.n_classes = n_classes
self.shuffle = shuffle
self.on_epoch_end()
def on_epoch_end(self):
self.indexes = np.arange(len(self.X))
if self.shuffle:
np.random.shuffle(self.indexes)
def __len__(self):
return int(np.floor(len(self.X) / self.batch_size))
def __data_generation(self, X_list, y_list):
X = np.empty((self.batch_size, *self.dim))
y = np.empty((self.batch_size), dtype = int)
if y is not None:
# 지금 같은 경우는 MNIST를 로드해서 사용하기 때문에
# 배열에 그냥 넣어주면 되는 식이지만,
# custom image data를 사용하는 경우
# 이 부분에서 이미지를 batch_size만큼 불러오게 하면 됩니다.
for i, (img, label) in enumerate(zip(X_list, y_list)):
X[i] = img
y[i] = label
return X, to_categorical(y, num_classes = self.n_classes)
else:
for i, img in enumerate(X_list):
X[i] = img
return X
def __getitem__(self, index):
indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
X_list = [self.X[k] for k in indexes]
if self.y is not None:
y_list = [self.y[k] for k in indexes]
X, y = self.__data_generation(X_list, y_list)
return X, y
else:
y_list = None
X = self.__data_generation(X_list, y_list)
return X
__data_generation부분을 수정해서 사용하면 됩니다. 원래 같은 경우는 X를 정의할 때, (batch_size, *dim, n_channel)로 정의하나 편의를 위해 지웠습니다. MNIST 데이터셋은 (28, 28, 1)이 아닌 (28, 28)로 인식되기 때문. 필요에 따라 수정하시면 됩니다.
dg = DataGenerator(x_train, y_train, 4, (28, 28), 1, 10)
import matplotlib.pyplot as plt
for i, (x, y) in enumerate(dg):
if(i <= 1):
x_first = x[0]
plt.title(y[0])
plt.imshow(x_first)
'# Machine Learning > Keras Implementation' 카테고리의 다른 글
Keras, 1x1 Convolution만 사용해서 MNIST 학습시키기 (0) | 2019.11.05 |
---|---|
Keras Custom Activation 사용해보기 (0) | 2019.10.27 |
TTA(test time augmentation) with 케라스 (2) | 2019.07.01 |
Keras callback함수 쓰기 (0) | 2018.12.23 |
Keras ImageDataGenerator flow 사용 (0) | 2018.12.20 |