Skip to main content

분류 평가 지표

학습 목표

  • Accuracy, Precision, Recall, F1-Score의 의미와 계산법을 설명할 수 있다
  • Confusion Matrix를 해석하고 시각화할 수 있다
  • 클래스 불균형 상황에서 적절한 지표를 선택할 수 있다
  • torchmetrics를 사용하여 평가 지표를 계산할 수 있다

왜 중요한가

모델 학습이 끝났다고 해서 작업이 완료된 것이 아닙니다. 적절한 평가 지표로 모델의 성능을 정량적으로 측정하고, 어떤 클래스에서 실수가 많은지 분석해야 합니다. 특히 클래스 불균형이 있는 실무 데이터에서는 Accuracy만으로는 모델의 실제 성능을 파악할 수 없습니다.

핵심 지표

Confusion Matrix (혼동 행렬)

이진 분류에서 4가지 예측 결과를 정리한 행렬입니다.
예측: Positive예측: Negative
실제: PositiveTP (True Positive)FN (False Negative)
실제: NegativeFP (False Positive)TN (True Negative)

지표 정의

지표수식의미
Accuracy(TP+TN) / 전체전체 중 맞춘 비율
PrecisionTP / (TP+FP)Positive 예측 중 실제 Positive 비율
Recall (Sensitivity)TP / (TP+FN)실제 Positive 중 맞춘 비율
F1-Score2 * P * R / (P+R)Precision과 Recall의 조화 평균

지표 선택 가이드

상황중요한 지표이유
클래스 균형Accuracy직관적, 충분
클래스 불균형F1-Score, RecallAccuracy는 편향
스팸 필터링Precision정상 메일 오분류 방지
질병 진단Recall실제 환자 놓치면 안 됨
불량 검출Recall + F1불량 놓치지 않는 것이 중요
99% 양품, 1% 불량인 데이터에서 모든 이미지를 “양품”으로 예측해도 Accuracy는 99%입니다. 이런 경우 F1-Score나 Recall을 반드시 확인하세요.

구현

torchmetrics 활용

import torch
import torchmetrics

# 지표 정의
accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=5)
f1 = torchmetrics.F1Score(task='multiclass', num_classes=5, average='macro')
precision = torchmetrics.Precision(task='multiclass', num_classes=5, average='macro')
recall = torchmetrics.Recall(task='multiclass', num_classes=5, average='macro')
conf_matrix = torchmetrics.ConfusionMatrix(task='multiclass', num_classes=5)

# 예측 결과 누적
preds = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2])
targets = torch.tensor([0, 1, 2, 3, 0, 0, 2, 2])

print(f"Accuracy: {accuracy(preds, targets):.4f}")
print(f"F1-Score (macro): {f1(preds, targets):.4f}")
print(f"Precision (macro): {precision(preds, targets):.4f}")
print(f"Recall (macro): {recall(preds, targets):.4f}")

Confusion Matrix 시각화

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def plot_confusion_matrix(cm, class_names):
    """Confusion Matrix를 히트맵으로 시각화합니다."""
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(
        cm, annot=True, fmt='d', cmap='Blues',
        xticklabels=class_names,
        yticklabels=class_names,
        ax=ax
    )
    ax.set_xlabel('예측')
    ax.set_ylabel('실제')
    ax.set_title('Confusion Matrix')
    plt.tight_layout()
    plt.show()

# 사용 예시
cm = conf_matrix(preds, targets).numpy()
class_names = ['클래스0', '클래스1', '클래스2', '클래스3', '클래스4']
plot_confusion_matrix(cm, class_names)

학습 루프에서 평가

def evaluate_model(model, val_loader, device, num_classes):
    """모델을 종합적으로 평가합니다."""
    model.eval()
    metrics = torchmetrics.MetricCollection({
        'accuracy': torchmetrics.Accuracy(task='multiclass', num_classes=num_classes),
        'f1_macro': torchmetrics.F1Score(task='multiclass', num_classes=num_classes, average='macro'),
        'f1_per_class': torchmetrics.F1Score(task='multiclass', num_classes=num_classes, average=None),
    }).to(device)

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = outputs.argmax(dim=1)
            metrics.update(preds, labels)

    results = metrics.compute()
    print(f"Accuracy: {results['accuracy']:.4f}")
    print(f"F1 (macro): {results['f1_macro']:.4f}")
    print(f"F1 (per class): {results['f1_per_class']}")
    metrics.reset()
    return results

Multi-class 평균 방식

방식설명추천 상황
macro클래스별 지표의 단순 평균클래스 불균형 시 (추천)
micro전체 TP/FP/FN으로 계산클래스 균형 시
weighted클래스별 샘플 수로 가중 평균클래스 중요도가 다를 때
모델이 예측한 상위 5개 클래스 안에 정답이 포함되면 맞춘 것으로 간주합니다. ImageNet처럼 1,000개 클래스가 있는 대규모 분류에서 주로 사용되며, 실무의 소규모 분류에서는 Top-1 Accuracy를 사용합니다.
ROC-AUC는 이진 분류에서 분류 임계값(Threshold)에 관계없이 모델의 판별 능력을 평가합니다. 의료 진단처럼 임계값 설정이 중요한 도메인에서 유용합니다.
F1-Score가 높더라도 특정 클래스에서 성능이 극단적으로 낮을 수 있습니다. 반드시 클래스별 F1(per-class F1)과 Confusion Matrix를 함께 확인하여 약점을 파악하세요.

체크리스트

  • TP, FP, FN, TN의 의미를 이해했다
  • Accuracy, Precision, Recall, F1의 계산법을 안다
  • 클래스 불균형 시 Accuracy의 한계를 이해했다
  • Confusion Matrix를 해석할 수 있다
  • torchmetrics로 평가 지표를 계산할 수 있다
  • macro/micro/weighted 평균의 차이를 안다

다음 문서