Skip to main content

학습 루프

학습 루프(Training Loop)는 딥러닝 모델을 학습시키는 핵심 코드입니다. PyTorch에서는 학습 루프를 직접 작성하므로, 모든 과정을 세밀하게 제어할 수 있습니다.

기본 학습 루프 구조

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    """한 에포크 학습"""
    model.train()  # 학습 모드 (Dropout, BN 활성화)
    total_loss = 0
    correct = 0
    total = 0

    for batch_data, batch_labels in dataloader:
        # 1. 데이터를 디바이스로 이동
        batch_data = batch_data.to(device)
        batch_labels = batch_labels.to(device)

        # 2. 기울기 초기화
        optimizer.zero_grad()

        # 3. 순전파
        outputs = model(batch_data)

        # 4. 손실 계산
        loss = criterion(outputs, batch_labels)

        # 5. 역전파
        loss.backward()

        # 6. 가중치 업데이트
        optimizer.step()

        # 메트릭 기록
        total_loss += loss.item() * batch_data.size(0)
        _, predicted = outputs.max(1)
        total += batch_labels.size(0)
        correct += predicted.eq(batch_labels).sum().item()

    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy

검증 루프

@torch.no_grad()
def evaluate(model, dataloader, criterion, device):
    """검증/테스트 평가"""
    model.eval()  # 평가 모드 (Dropout 비활성화)
    total_loss = 0
    correct = 0
    total = 0

    for batch_data, batch_labels in dataloader:
        batch_data = batch_data.to(device)
        batch_labels = batch_labels.to(device)

        outputs = model(batch_data)
        loss = criterion(outputs, batch_labels)

        total_loss += loss.item() * batch_data.size(0)
        _, predicted = outputs.max(1)
        total += batch_labels.size(0)
        correct += predicted.eq(batch_labels).sum().item()

    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy

전체 학습 파이프라인

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

# 1. 디바이스 설정
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 2. 데이터 준비
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

train_dataset = torchvision.datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 3. 모델, 손실 함수, 옵티마이저
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(256, 10),
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 4. 학습 실행
num_epochs = 10
best_val_acc = 0

for epoch in range(num_epochs):
    # 학습
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)

    # 검증
    val_loss, val_acc = evaluate(model, test_loader, criterion, device)

    # 결과 출력
    print(f"Epoch [{epoch+1}/{num_epochs}] "
          f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")

    # 최적 모델 저장
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_model.pt')
        print(f"  → 최적 모델 저장 (Val Acc: {val_acc:.4f})")

tqdm 진행률 표시

def train_one_epoch_tqdm(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    pbar = tqdm(dataloader, desc="Training")
    for batch_data, batch_labels in pbar:
        batch_data = batch_data.to(device)
        batch_labels = batch_labels.to(device)

        optimizer.zero_grad()
        outputs = model(batch_data)
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch_data.size(0)
        _, predicted = outputs.max(1)
        total += batch_labels.size(0)
        correct += predicted.eq(batch_labels).sum().item()

        # 진행률 바에 실시간 메트릭 표시
        pbar.set_postfix({
            'loss': f'{total_loss/total:.4f}',
            'acc': f'{correct/total:.4f}'
        })

    return total_loss / total, correct / total

학습 이력 기록

history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = evaluate(model, test_loader, criterion, device)

    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)

# 시각화 (matplotlib)
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(history['train_loss'], label='Train')
ax1.plot(history['val_loss'], label='Val')
ax1.set_title('Loss')
ax1.legend()

ax2.plot(history['train_acc'], label='Train')
ax2.plot(history['val_acc'], label='Val')
ax2.set_title('Accuracy')
ax2.legend()
plt.savefig('training_history.png')

체크리스트

  • train/eval 루프를 분리하여 작성할 수 있다
  • model.train()model.eval()을 적절히 호출한다
  • torch.no_grad()로 평가 시 메모리를 절약한다
  • 학습 이력을 기록하고 과적합 여부를 판단할 수 있다

다음 문서