이미지 제네레이터와 활용하고 싶은 데이터를 포함한 데이터 제네레이터의 구현 코드입니다.

이미지는 이미지데이터 제네레이터를 통해 불러오며, 활용하고 싶은 데이터인 color는 직접 인덱스를 통해 배치 크기만큼 부르는 것을 볼 수 있습니다.

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, df, batch_size = 32, target_size = (112, 112), shuffle = True):
        self.len_df = len(df)
        self.batch_size = batch_size
        self.target_size = target_size
        self.shuffle = shuffle
        self.class_col = ['black', 'blue', 'brown', 'green', 'red', 'white', 
             'dress', 'shirt', 'pants', 'shorts', 'shoes']
        self.generator = ImageDataGenerator(rescale = 1./255)
        self.df_generator = self.generator.flow_from_dataframe(dataframe=df, 
                                                          directory='',
                                                            x_col = 'image',
                                                            y_col = self.class_col,
                                                            target_size = self.target_size,
                                                            color_mode='rgb',
                                                            class_mode='other',
                                                            batch_size=self.batch_size,
                                                            seed=42)
        self.colors_df = df['color']
        self.on_epoch_end()
        
    def __len__(self):
        return int(np.floor(self.len_df) / self.batch_size)
    
    def on_epoch_end(self):
        self.indexes = np.arange(self.len_df)
        if self.shuffle:
            np.random.shuffle(self.indexes)
        
    def __getitem__(self, index):
        indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
        colors = self.__data_generation(indexes)
        
        images, labels = self.df_generator.__getitem__(index)
        
        # return multi-input and output
        return [images, colors], labels
    
    def __data_generation(self, indexes):
        colors = self.colors_df[indexes].to_numpy()
        # 또는
        # colors = np.array([self.colors_df[k] for k in indexes])
        
        return colors

 

1 - https://hwiyong.tistory.com/241