이미지 분류 실습 — CIFAR-10 CNN
CIFAR-10 데이터셋을 사용하여 CNN 이미지 분류기를 처음부터 구현하고 학습합니다. 데이터 로딩부터 학습, 평가, 시각화까지 전체 파이프라인을 다룹니다.CIFAR-10 데이터셋
10개 클래스, 32x32 크기의 컬러 이미지 60,000장(학습 50,000 + 테스트 10,000)으로 구성됩니다.데이터 준비
Copy
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# 데이터 전처리
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2470, 0.2435, 0.2616)),
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2470, 0.2435, 0.2616)),
])
# 데이터셋 로드
train_dataset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=train_transform)
test_dataset = torchvision.datasets.CIFAR10(
root='./data', train=False, transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)
classes = ('airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
모델 정의
Copy
class CIFAR10CNN(nn.Module):
"""CIFAR-10 분류를 위한 CNN"""
def __init__(self, num_classes=10):
super().__init__()
# 특성 추출기
self.features = nn.Sequential(
# 블록 1: 3 → 32채널
nn.Conv2d(3, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), # 32×32 → 16×16
nn.Dropout2d(0.25),
# 블록 2: 32 → 64채널
nn.Conv2d(32, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), # 16×16 → 8×8
nn.Dropout2d(0.25),
# 블록 3: 64 → 128채널
nn.Conv2d(64, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(1), # 8×8 → 1×1
)
# 분류기
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(128, 256),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(256, num_classes),
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CIFAR10CNN().to(device)
# 파라미터 수 확인
total_params = sum(p.numel() for p in model.parameters())
print(f"전체 파라미터: {total_params:,}")
학습
Copy
from tqdm import tqdm
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
num_epochs = 50
best_acc = 0
for epoch in range(num_epochs):
# 학습
model.train()
train_loss, correct, total = 0, 0, 0
for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
scheduler.step()
# 평가
model.eval()
test_correct, test_total = 0, 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = outputs.max(1)
test_total += labels.size(0)
test_correct += predicted.eq(labels).sum().item()
train_acc = correct / total
test_acc = test_correct / test_total
print(f"Epoch {epoch+1}: Train Acc={train_acc:.4f}, Test Acc={test_acc:.4f}")
if test_acc > best_acc:
best_acc = test_acc
torch.save(model.state_dict(), 'cifar10_best.pt')
print(f"최고 테스트 정확도: {best_acc:.4f}")
평가 및 분석
Copy
# 클래스별 정확도
model.load_state_dict(torch.load('cifar10_best.pt', weights_only=True))
model.eval()
class_correct = [0] * 10
class_total = [0] * 10
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = outputs.max(1)
for i in range(labels.size(0)):
label = labels[i].item()
class_total[label] += 1
if predicted[i] == label:
class_correct[label] += 1
for i in range(10):
print(f"{classes[i]:>12}: {class_correct[i]/class_total[i]:.2%}")
체크리스트
- 데이터 증강(RandomHorizontalFlip, RandomCrop)의 역할을 이해한다
- Conv → BN → ReLU → Pool 패턴으로 CNN 블록을 구성할 수 있다
- 학습 스케줄러를 적용하여 학습률을 조정할 수 있다
- 클래스별 정확도를 분석하여 모델의 강점과 약점을 파악할 수 있다

