모델 저장과 로드
학습된 모델을 저장하고 복원하는 방법을 다룹니다. 체크포인트 관리, 학습 재개, 다른 프레임워크로의 내보내기(ONNX)까지 실무에서 필요한 패턴을 학습합니다.
state_dict 저장 (권장)
모델의 가중치(파라미터)만 저장하는 방식입니다. 가장 유연하고 권장되는 방법입니다.
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
# 저장
torch.save(model.state_dict(), 'model_weights.pt')
# 로드 (같은 구조의 모델이 필요)
model_loaded = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
model_loaded.load_state_dict(torch.load('model_weights.pt', weights_only=True))
model_loaded.eval() # 추론 모드로 전환
체크포인트 저장 (학습 재개용)
학습 중단 후 재개하려면 모델 가중치뿐 아니라 옵티마이저 상태, 에포크 번호 등도 저장해야 합니다.
# 체크포인트 저장
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
'best_val_acc': best_val_acc,
}
torch.save(checkpoint, f'checkpoint_epoch{epoch}.pt')
# 체크포인트에서 학습 재개
checkpoint = torch.load('checkpoint_epoch5.pt', weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
best_val_acc = checkpoint['best_val_acc']
# 이어서 학습
for epoch in range(start_epoch, num_epochs):
train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
# ...
전체 모델 저장
모델 구조와 가중치를 함께 저장합니다. 간편하지만 Python pickle 의존성이 있어 이식성이 낮습니다.
# 전체 모델 저장 (비권장)
torch.save(model, 'full_model.pt')
# 로드 (모델 클래스 정의가 접근 가능해야 함)
model = torch.load('full_model.pt', weights_only=False)
torch.save(model, ...) 방식은 Python pickle을 사용하므로, 모델 클래스 코드가 변경되면 로드에 실패할 수 있습니다. 프로덕션에서는 state_dict 방식을 사용하세요.
ONNX 내보내기
ONNX(Open Neural Network Exchange)는 다양한 프레임워크와 추론 엔진 간 모델을 교환하는 표준 포맷입니다.
# ONNX 내보내기
dummy_input = torch.randn(1, 784) # 예시 입력
torch.onnx.export(
model,
dummy_input,
'model.onnx',
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'}, # 가변 배치 크기 지원
'output': {0: 'batch_size'},
},
)
print("ONNX 모델 저장 완료: model.onnx")
# ONNX Runtime으로 추론
# import onnxruntime as ort
# session = ort.InferenceSession('model.onnx')
# result = session.run(None, {'input': dummy_input.numpy()})
저장 방식 비교
| 방식 | 저장 내용 | 이식성 | 용도 |
|---|
state_dict | 가중치만 | 높음 | 일반 저장/로드 |
| 체크포인트 | 가중치 + 옵티마이저 + 메타 | 높음 | 학습 재개 |
| 전체 모델 | 구조 + 가중치 | 낮음 | 빠른 프로토타입 |
| ONNX | 그래프 + 가중치 | 매우 높음 | 서빙, 다른 프레임워크 |
체크리스트
다음 문서