validation set의 data augmentation

2024. 5. 4. 23:55·ML_DL/딥러닝 공부하기
728x90
반응형

pytorch로 Dataset, DataLoader를 작성하는 경우 주의해야 할 사항

일반적인 data augmentation의 과정의 예시이다.

transform = transforms.Compose([
    transforms.Resize(),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize()
])

 

그리고는 Dataset을 정의하면서 다음과 같이 적용할 수 있다.

class CustomDataset(Dataset):
    def __init__(self, x, y=None, aug=None):
        self.x = x
        self.y = y
        self.transform = aug

    def __len__(self):
        return len(self.x)
        
    def __getitem__(self, idx):
        img = Image.open(self.x[idx])
        
        if self.transform is not None:
            img = self.transform(img)
        
        if self.y is not None:
            label = self.y[idx]
            return img, label
        else:
            return img

 


 

train data를 train, valid 데이터로 나누고, 위와 같은 코드로 두 데이터셋 모두에 동일한 transform을 적용하면,

valid data에도 augmentation이 적용된다. 따라서, valid는 test와 같은 데이터로 취급한다. (resize, normalize 만 적용)

class CustomDataset(Dataset):
    def __init__(self, x, y=None, aug=None):
        self.x = x
        self.y = y
        self.transform = aug

    def __len__(self):
        return len(self.x)
        
    def __getitem__(self, idx):
        img = Image.open(self.x[idx])
        
        if self.transform is not None:
            img = self.transform(img)
        
        if self.y is not None:
            label = self.y[idx]
            return img, label
        else:
            return img
            
train_dataset = CustomDataset(train_x, train_y, transform)
val_dataset = CustomDataset(val_x, val_y, test_transform)
저작자표시

'ML_DL > 딥러닝 공부하기' 카테고리의 다른 글

Ollama 설치 및 Llama3.1 모델 사용  (0) 2024.08.14
텍스트 임베딩 해보기  (0) 2024.08.10
머신러닝 VS 딥러닝  (0) 2024.04.17
Word Embedding  (0) 2024.01.08
Bag of Words  (0) 2024.01.07
'ML_DL/딥러닝 공부하기' 카테고리의 다른 글
  • Ollama 설치 및 Llama3.1 모델 사용
  • 텍스트 임베딩 해보기
  • 머신러닝 VS 딥러닝
  • Word Embedding
swwho
swwho
일상을 데이터화하다
  • swwho
    하루한장
    swwho
  • 전체
    오늘
    어제
    • 분류 전체보기 (188)
      • ML_DL (39)
        • MUJAKJUNG (무작정 시리즈) (18)
        • 딥러닝 공부하기 (21)
      • 데이터사이언스 (1)
        • EDA (1)
        • 데이터과학을 위한 통계 (0)
      • 데이터엔지니어링 (2)
      • 논문리뷰 (2)
        • Computer Vision (2)
      • Python 활용하기 (12)
      • 코딩테스트 (127)
        • Python (109)
        • MySQL (14)
      • Git (3)
      • MySQL 활용하기 (0)
      • 일상 이야기 (1)
  • 블로그 메뉴

    • 홈
    • 태그
  • 최근 글

  • 250x250
  • hELLO· Designed By정상우.v4.10.3
swwho
validation set의 data augmentation
상단으로

티스토리툴바