Skip to main content

Fine-tuning

Fine-tuning은 사전학습된 모델의 전체 가중치를 커스텀 데이터에 맞게 미세 조정하는 기법입니다. Feature Extraction보다 높은 성능을 달성할 수 있으며, 실무에서 가장 많이 사용되는 접근법입니다.

Fine-tuning 전략

1
데이터셋과 DataLoader 준비
2
import torch
import timm
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder

# Augmentation 포함 학습 변환
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

train_dataset = ImageFolder('data/train', transform=train_transform)
val_dataset = ImageFolder('data/val', transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,
                          num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False,
                        num_workers=4, pin_memory=True)

num_classes = len(train_dataset.classes)
print(f"클래스: {train_dataset.classes} ({num_classes}개)")
3
모델 구성 (Discriminative Learning Rate)
4
계층별로 다른 학습률을 적용하는 방식입니다. 입력에 가까운 초기 계층은 범용적인 특징을 학습하므로 낮은 학습률, 출력에 가까운 후반 계층은 태스크 특화 특징을 학습하므로 높은 학습률을 적용합니다.
5
model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=num_classes)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# 계층별 학습률 설정
base_lr = 1e-4
param_groups = [
    # 초기 계층: 낮은 학습률
    {'params': model.conv_stem.parameters(), 'lr': base_lr * 0.1},
    {'params': model.bn1.parameters(), 'lr': base_lr * 0.1},
    # 중간 계층
    {'params': model.blocks[:4].parameters(), 'lr': base_lr * 0.5},
    {'params': model.blocks[4:].parameters(), 'lr': base_lr},
    # 분류 헤드: 높은 학습률
    {'params': model.classifier.parameters(), 'lr': base_lr * 10},
]

optimizer = torch.optim.AdamW(param_groups, weight_decay=0.01)
6
학습률 스케줄러 설정
7
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR

# 코사인 어닐링 (가장 일반적)
scheduler = CosineAnnealingLR(optimizer, T_max=30, eta_min=1e-6)

# 또는 OneCycleLR (빠른 수렴)
# scheduler = OneCycleLR(
#     optimizer, max_lr=base_lr * 10,
#     epochs=30, steps_per_epoch=len(train_loader),
# )
8
학습 루프 (Early Stopping 포함)
9
import torch.nn as nn
from copy import deepcopy

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  # Label Smoothing

best_val_acc = 0.0
best_model = None
patience = 5
patience_counter = 0
num_epochs = 30

for epoch in range(num_epochs):
    # 학습
    model.train()
    train_loss, correct, total = 0.0, 0, 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()

        # 그래디언트 클리핑
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        train_loss += loss.item() * images.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)

    scheduler.step()
    train_acc = correct / total

    # 검증
    model.eval()
    val_loss, val_correct, val_total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)
            _, preds = outputs.max(1)
            val_correct += preds.eq(labels).sum().item()
            val_total += labels.size(0)

    val_acc = val_correct / val_total

    print(f"Epoch {epoch+1}/{num_epochs} | "
          f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | "
          f"LR: {optimizer.param_groups[-1]['lr']:.6f}")

    # Early Stopping
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model = deepcopy(model.state_dict())
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

# 최적 모델 복원
model.load_state_dict(best_model)
torch.save(best_model, 'best_model.pth')
print(f"최고 검증 정확도: {best_val_acc:.4f}")

Fine-tuning 모범 사례

기법설명효과
Label Smoothing정답 확률을 1.0 대신 0.9로과적합 방지
Gradient Clipping그래디언트 크기 제한학습 안정화
Cosine Annealing학습률을 코사인 곡선으로 감소부드러운 수렴
Weight DecayL2 정규화과적합 방지
Mixed PrecisionFP16 학습메모리 절약, 속도 향상

Mixed Precision 학습

from torch.amp import autocast, GradScaler

scaler = GradScaler('cuda')

for images, labels in train_loader:
    images, labels = images.to(device), labels.to(device)

    optimizer.zero_grad()
    with autocast('cuda'):
        outputs = model(images)
        loss = criterion(outputs, labels)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
Fine-tuning 시 학습률이 너무 높으면 사전학습된 지식이 파괴됩니다. 일반적으로 사전학습에 사용된 학습률(0.1)보다 10100배 작은 값(1e-4 ~ 1e-3)으로 시작하세요.

트러블슈팅

  1. 배치 사이즈를 절반으로 줄이세요. 2) Mixed Precision(FP16)을 활성화하세요. 3) Gradient Accumulation을 사용하세요. 4) 더 작은 모델(EfficientNet-B0)로 교체하세요.
과적합 징후입니다. Weight Decay를 높이거나(0.01→0.05), 더 강한 Data Augmentation을 적용하거나, 학습 에폭을 줄이세요(Early Stopping). Label Smoothing(0.1)도 효과적입니다.