728x90
반응형
pytorch로 Dataset, DataLoader를 작성하는 경우 주의해야 할 사항
일반적인 data augmentation의 과정의 예시이다.
transform = transforms.Compose([
transforms.Resize(),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize()
])
그리고는 Dataset을 정의하면서 다음과 같이 적용할 수 있다.
class CustomDataset(Dataset):
def __init__(self, x, y=None, aug=None):
self.x = x
self.y = y
self.transform = aug
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
img = Image.open(self.x[idx])
if self.transform is not None:
img = self.transform(img)
if self.y is not None:
label = self.y[idx]
return img, label
else:
return img
train data를 train, valid 데이터로 나누고, 위와 같은 코드로 두 데이터셋 모두에 동일한 transform을 적용하면,
valid data에도 augmentation이 적용된다. 따라서, valid는 test와 같은 데이터로 취급한다. (resize, normalize 만 적용)
class CustomDataset(Dataset):
def __init__(self, x, y=None, aug=None):
self.x = x
self.y = y
self.transform = aug
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
img = Image.open(self.x[idx])
if self.transform is not None:
img = self.transform(img)
if self.y is not None:
label = self.y[idx]
return img, label
else:
return img
train_dataset = CustomDataset(train_x, train_y, transform)
val_dataset = CustomDataset(val_x, val_y, test_transform)
'ML_DL > 딥러닝 공부하기' 카테고리의 다른 글
Ollama 설치 및 Llama3.1 모델 사용 (0) | 2024.08.14 |
---|---|
텍스트 임베딩 해보기 (0) | 2024.08.10 |
머신러닝 VS 딥러닝 (0) | 2024.04.17 |
Word Embedding (0) | 2024.01.08 |
Bag of Words (0) | 2024.01.07 |