Skip to main content

실험 추적 — MLflow + PyTorch

딥러닝 실험에서는 하이퍼파라미터, 메트릭, 모델 체크포인트를 체계적으로 관리해야 합니다. MLflow는 실험 추적, 모델 레지스트리, 배포를 하나의 플랫폼에서 제공합니다.
1

MLflow 기본 설정

import mlflow
import mlflow.pytorch

# 추적 서버 설정 (원격 서버 사용 시)
# mlflow.set_tracking_uri("http://192.168.50.248:5000")

# 실험 생성 또는 선택
mlflow.set_experiment("cifar10-cnn")
2

학습 + 추적 통합

import torch
import torch.nn as nn

# 하이퍼파라미터
config = {
    "lr": 0.001,
    "batch_size": 128,
    "epochs": 50,
    "hidden_dim": 256,
    "dropout": 0.3,
    "optimizer": "AdamW",
    "weight_decay": 0.01,
}

with mlflow.start_run(run_name="resnet18-baseline"):
    # 하이퍼파라미터 기록
    mlflow.log_params(config)

    model = build_model(config)
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=config["lr"],
                                  weight_decay=config["weight_decay"])

    for epoch in range(config["epochs"]):
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)

        # 메트릭 기록
        mlflow.log_metrics({
            "train_loss": train_loss,
            "train_acc": train_acc,
            "val_loss": val_loss,
            "val_acc": val_acc,
        }, step=epoch)

    # 최종 모델 저장
    mlflow.pytorch.log_model(model, "model")

    # 추가 아티팩트 (학습 곡선 이미지 등)
    # mlflow.log_artifact("training_history.png")

    print(f"Run ID: {mlflow.active_run().info.run_id}")
3

실험 비교 및 모델 로드

# 실험 결과 조회
experiment = mlflow.get_experiment_by_name("cifar10-cnn")
runs = mlflow.search_runs(experiment_ids=[experiment.experiment_id])
print(runs[["params.lr", "metrics.val_acc"]].sort_values("metrics.val_acc", ascending=False))

# 최적 모델 로드
best_run = runs.sort_values("metrics.val_acc", ascending=False).iloc[0]
model = mlflow.pytorch.load_model(f"runs:/{best_run.run_id}/model")

자동 로깅

# PyTorch Lightning 사용 시 자동 로깅
mlflow.pytorch.autolog()

# 이후 학습 코드는 자동으로 기록됨

MLflow 핵심 개념

개념설명예시
Experiment실험 그룹”cifar10-cnn”
Run하나의 실험 실행”resnet18-lr0.001”
Parameter입력 하이퍼파라미터lr=0.001, epochs=50
Metric성능 수치val_acc=0.9234
Artifact산출물model, 그래프, 로그 파일
Model Registry모델 버전 관리v1 → v2 → v3

체크리스트

  • MLflow의 Experiment/Run/Parameter/Metric 개념을 이해한다
  • 학습 루프에 MLflow 추적을 통합할 수 있다
  • 실험 결과를 비교하고 최적 모델을 로드할 수 있다
  • mlflow.pytorch.log_model로 모델을 저장할 수 있다

다음 문서