Skip to main content

Dataset과 DataLoader

효율적인 데이터 로딩은 학습 성능에 직접적인 영향을 미칩니다. PyTorch는 Dataset으로 데이터 접근 방식을 정의하고, DataLoader로 배치 처리, 셔플링, 병렬 로딩을 수행합니다.

Dataset 인터페이스

torch.utils.data.Dataset을 상속하여 커스텀 데이터셋을 만듭니다. 두 메서드를 반드시 구현해야 합니다.
import torch
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        """데이터셋 크기 반환"""
        return len(self.data)

    def __getitem__(self, idx):
        """인덱스로 샘플 하나를 반환"""
        sample = self.data[idx]
        label = self.labels[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample, label

# 사용 예시
X = torch.randn(1000, 10)     # 1000개 샘플, 10차원
y = torch.randint(0, 3, (1000,))  # 3 클래스

dataset = CustomDataset(X, y)
print(f"데이터셋 크기: {len(dataset)}")
sample, label = dataset[0]
print(f"샘플 shape: {sample.shape}, 레이블: {label}")

DataLoader

Dataset을 감싸서 미니배치, 셔플링, 병렬 로딩 등을 자동 처리합니다.
dataloader = DataLoader(
    dataset,
    batch_size=32,       # 미니배치 크기
    shuffle=True,        # 에포크마다 데이터 섞기
    num_workers=4,       # 병렬 데이터 로딩 (CPU 코어 수에 맞게)
    drop_last=True,      # 마지막 불완전 배치 제거
    pin_memory=True,     # GPU 전송 속도 향상
)

# 배치 순회
for batch_idx, (batch_data, batch_labels) in enumerate(dataloader):
    print(f"배치 {batch_idx}: data={batch_data.shape}, labels={batch_labels.shape}")
    if batch_idx >= 2:
        break
# 배치 0: data=torch.Size([32, 10]), labels=torch.Size([32])

내장 데이터셋 (torchvision)

import torchvision
import torchvision.transforms as transforms

# 이미지 변환 파이프라인
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),                          # PIL → Tensor, [0, 1] 정규화
    transforms.Normalize((0.5,), (0.5,)),           # [-1, 1] 정규화
])

# MNIST 다운로드 및 로드
train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform,
)

test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    transform=transform,
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 첫 배치 확인
images, labels = next(iter(train_loader))
print(f"이미지 배치: {images.shape}")   # (64, 1, 32, 32)
print(f"레이블 배치: {labels.shape}")    # (64,)

이미지 데이터 전처리 (transforms)

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),   # 좌우 반전
    transforms.RandomRotation(10),             # ±10도 회전
    transforms.ColorJitter(brightness=0.2),    # 밝기 변화
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),  # ImageNet 통계
])

# 검증/테스트에는 증강 없이 정규화만
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

학습/검증 데이터 분할

from torch.utils.data import random_split

full_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)

# 80% 학습, 20% 검증
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

print(f"학습: {len(train_dataset)}, 검증: {len(val_dataset)}")

CSV 파일에서 커스텀 데이터셋

import pandas as pd

class CSVDataset(Dataset):
    """CSV 파일을 읽어 텐서로 변환하는 데이터셋"""
    def __init__(self, csv_path, target_column):
        df = pd.read_csv(csv_path)
        self.features = torch.FloatTensor(df.drop(columns=[target_column]).values)
        self.labels = torch.LongTensor(df[target_column].values)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

# 사용
# dataset = CSVDataset('data/train.csv', target_column='label')
# loader = DataLoader(dataset, batch_size=32, shuffle=True)

체크리스트

  • Dataset__len____getitem__을 구현할 수 있다
  • DataLoader의 주요 파라미터(batch_size, shuffle, num_workers)를 이해한다
  • torchvision.transforms로 이미지 전처리 파이프라인을 구성할 수 있다
  • random_split으로 학습/검증 세트를 분할할 수 있다

다음 문서