pytorch로 Dataset, DataLoader를 작성하는 경우 주의해야 할 사항

일반적인 data augmentation의 과정의 예시이다.

transform = transforms.Compose([
    transforms.Resize(),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize()
])

 

그리고는 Dataset을 정의하면서 다음과 같이 적용할 수 있다.

class CustomDataset(Dataset):
    def __init__(self, x, y=None, aug=None):
        self.x = x
        self.y = y
        self.transform = aug

    def __len__(self):
        return len(self.x)
        
    def __getitem__(self, idx):
        img = Image.open(self.x[idx])
        
        if self.transform is not None:
            img = self.transform(img)
        
        if self.y is not None:
            label = self.y[idx]
            return img, label
        else:
            return img

 


 

train data를 train, valid 데이터로 나누고, 위와 같은 코드로 두 데이터셋 모두에 동일한 transform을 적용하면,

valid data에도 augmentation이 적용된다. 따라서, valid는 test와 같은 데이터로 취급한다. (resize, normalize 만 적용)

class CustomDataset(Dataset):
    def __init__(self, x, y=None, aug=None):
        self.x = x
        self.y = y
        self.transform = aug

    def __len__(self):
        return len(self.x)
        
    def __getitem__(self, idx):
        img = Image.open(self.x[idx])
        
        if self.transform is not None:
            img = self.transform(img)
        
        if self.y is not None:
            label = self.y[idx]
            return img, label
        else:
            return img
            
train_dataset = CustomDataset(train_x, train_y, transform)
val_dataset = CustomDataset(val_x, val_y, test_transform)

+ Recent posts