Skip to main content

모델 저장과 로드

학습된 모델을 저장하고 복원하는 방법을 다룹니다. 체크포인트 관리, 학습 재개, 다른 프레임워크로의 내보내기(ONNX)까지 실무에서 필요한 패턴을 학습합니다.

state_dict 저장 (권장)

모델의 가중치(파라미터)만 저장하는 방식입니다. 가장 유연하고 권장되는 방법입니다.
import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
)

# 저장
torch.save(model.state_dict(), 'model_weights.pt')

# 로드 (같은 구조의 모델이 필요)
model_loaded = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
)
model_loaded.load_state_dict(torch.load('model_weights.pt', weights_only=True))
model_loaded.eval()  # 추론 모드로 전환

체크포인트 저장 (학습 재개용)

학습 중단 후 재개하려면 모델 가중치뿐 아니라 옵티마이저 상태, 에포크 번호 등도 저장해야 합니다.
# 체크포인트 저장
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss.item(),
    'best_val_acc': best_val_acc,
}
torch.save(checkpoint, f'checkpoint_epoch{epoch}.pt')

# 체크포인트에서 학습 재개
checkpoint = torch.load('checkpoint_epoch5.pt', weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
best_val_acc = checkpoint['best_val_acc']

# 이어서 학습
for epoch in range(start_epoch, num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    # ...

전체 모델 저장

모델 구조와 가중치를 함께 저장합니다. 간편하지만 Python pickle 의존성이 있어 이식성이 낮습니다.
# 전체 모델 저장 (비권장)
torch.save(model, 'full_model.pt')

# 로드 (모델 클래스 정의가 접근 가능해야 함)
model = torch.load('full_model.pt', weights_only=False)
torch.save(model, ...) 방식은 Python pickle을 사용하므로, 모델 클래스 코드가 변경되면 로드에 실패할 수 있습니다. 프로덕션에서는 state_dict 방식을 사용하세요.

ONNX 내보내기

ONNX(Open Neural Network Exchange)는 다양한 프레임워크와 추론 엔진 간 모델을 교환하는 표준 포맷입니다.
# ONNX 내보내기
dummy_input = torch.randn(1, 784)  # 예시 입력
torch.onnx.export(
    model,
    dummy_input,
    'model.onnx',
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},   # 가변 배치 크기 지원
        'output': {0: 'batch_size'},
    },
)
print("ONNX 모델 저장 완료: model.onnx")

# ONNX Runtime으로 추론
# import onnxruntime as ort
# session = ort.InferenceSession('model.onnx')
# result = session.run(None, {'input': dummy_input.numpy()})

저장 방식 비교

방식저장 내용이식성용도
state_dict가중치만높음일반 저장/로드
체크포인트가중치 + 옵티마이저 + 메타높음학습 재개
전체 모델구조 + 가중치낮음빠른 프로토타입
ONNX그래프 + 가중치매우 높음서빙, 다른 프레임워크

체크리스트

  • state_dict로 모델을 저장하고 로드할 수 있다
  • 체크포인트에 옵티마이저 상태를 포함하여 학습을 재개할 수 있다
  • ONNX 포맷으로 모델을 내보낼 수 있다
  • 각 저장 방식의 장단점을 이해한다

다음 문서