import torchvision
import torchvision.transforms as transforms
# 이미지 변환 파이프라인
transform = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(), # PIL → Tensor, [0, 1] 정규화
transforms.Normalize((0.5,), (0.5,)), # [-1, 1] 정규화
])
# MNIST 다운로드 및 로드
train_dataset = torchvision.datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transform,
)
test_dataset = torchvision.datasets.MNIST(
root='./data',
train=False,
transform=transform,
)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# 첫 배치 확인
images, labels = next(iter(train_loader))
print(f"이미지 배치: {images.shape}") # (64, 1, 32, 32)
print(f"레이블 배치: {labels.shape}") # (64,)