분류 평가 지표
학습 목표
- Accuracy, Precision, Recall, F1-Score의 의미와 계산법을 설명할 수 있다
- Confusion Matrix를 해석하고 시각화할 수 있다
- 클래스 불균형 상황에서 적절한 지표를 선택할 수 있다
- torchmetrics를 사용하여 평가 지표를 계산할 수 있다
왜 중요한가
모델 학습이 끝났다고 해서 작업이 완료된 것이 아닙니다. 적절한 평가 지표로 모델의 성능을 정량적으로 측정하고, 어떤 클래스에서 실수가 많은지 분석해야 합니다. 특히 클래스 불균형이 있는 실무 데이터에서는 Accuracy만으로는 모델의 실제 성능을 파악할 수 없습니다.
핵심 지표
Confusion Matrix (혼동 행렬)
이진 분류에서 4가지 예측 결과를 정리한 행렬입니다.
| 예측: Positive | 예측: Negative |
|---|
| 실제: Positive | TP (True Positive) | FN (False Negative) |
| 실제: Negative | FP (False Positive) | TN (True Negative) |
지표 정의
| 지표 | 수식 | 의미 |
|---|
| Accuracy | (TP+TN) / 전체 | 전체 중 맞춘 비율 |
| Precision | TP / (TP+FP) | Positive 예측 중 실제 Positive 비율 |
| Recall (Sensitivity) | TP / (TP+FN) | 실제 Positive 중 맞춘 비율 |
| F1-Score | 2 * P * R / (P+R) | Precision과 Recall의 조화 평균 |
지표 선택 가이드
| 상황 | 중요한 지표 | 이유 |
|---|
| 클래스 균형 | Accuracy | 직관적, 충분 |
| 클래스 불균형 | F1-Score, Recall | Accuracy는 편향 |
| 스팸 필터링 | Precision | 정상 메일 오분류 방지 |
| 질병 진단 | Recall | 실제 환자 놓치면 안 됨 |
| 불량 검출 | Recall + F1 | 불량 놓치지 않는 것이 중요 |
99% 양품, 1% 불량인 데이터에서 모든 이미지를 “양품”으로 예측해도 Accuracy는 99%입니다. 이런 경우 F1-Score나 Recall을 반드시 확인하세요.
torchmetrics 활용
import torch
import torchmetrics
# 지표 정의
accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=5)
f1 = torchmetrics.F1Score(task='multiclass', num_classes=5, average='macro')
precision = torchmetrics.Precision(task='multiclass', num_classes=5, average='macro')
recall = torchmetrics.Recall(task='multiclass', num_classes=5, average='macro')
conf_matrix = torchmetrics.ConfusionMatrix(task='multiclass', num_classes=5)
# 예측 결과 누적
preds = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2])
targets = torch.tensor([0, 1, 2, 3, 0, 0, 2, 2])
print(f"Accuracy: {accuracy(preds, targets):.4f}")
print(f"F1-Score (macro): {f1(preds, targets):.4f}")
print(f"Precision (macro): {precision(preds, targets):.4f}")
print(f"Recall (macro): {recall(preds, targets):.4f}")
Confusion Matrix 시각화
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
def plot_confusion_matrix(cm, class_names):
"""Confusion Matrix를 히트맵으로 시각화합니다."""
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(
cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names,
yticklabels=class_names,
ax=ax
)
ax.set_xlabel('예측')
ax.set_ylabel('실제')
ax.set_title('Confusion Matrix')
plt.tight_layout()
plt.show()
# 사용 예시
cm = conf_matrix(preds, targets).numpy()
class_names = ['클래스0', '클래스1', '클래스2', '클래스3', '클래스4']
plot_confusion_matrix(cm, class_names)
학습 루프에서 평가
def evaluate_model(model, val_loader, device, num_classes):
"""모델을 종합적으로 평가합니다."""
model.eval()
metrics = torchmetrics.MetricCollection({
'accuracy': torchmetrics.Accuracy(task='multiclass', num_classes=num_classes),
'f1_macro': torchmetrics.F1Score(task='multiclass', num_classes=num_classes, average='macro'),
'f1_per_class': torchmetrics.F1Score(task='multiclass', num_classes=num_classes, average=None),
}).to(device)
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
preds = outputs.argmax(dim=1)
metrics.update(preds, labels)
results = metrics.compute()
print(f"Accuracy: {results['accuracy']:.4f}")
print(f"F1 (macro): {results['f1_macro']:.4f}")
print(f"F1 (per class): {results['f1_per_class']}")
metrics.reset()
return results
Multi-class 평균 방식
| 방식 | 설명 | 추천 상황 |
|---|
| macro | 클래스별 지표의 단순 평균 | 클래스 불균형 시 (추천) |
| micro | 전체 TP/FP/FN으로 계산 | 클래스 균형 시 |
| weighted | 클래스별 샘플 수로 가중 평균 | 클래스 중요도가 다를 때 |
모델이 예측한 상위 5개 클래스 안에 정답이 포함되면 맞춘 것으로 간주합니다. ImageNet처럼 1,000개 클래스가 있는 대규모 분류에서 주로 사용되며, 실무의 소규모 분류에서는 Top-1 Accuracy를 사용합니다.
ROC-AUC는 이진 분류에서 분류 임계값(Threshold)에 관계없이 모델의 판별 능력을 평가합니다. 의료 진단처럼 임계값 설정이 중요한 도메인에서 유용합니다.
F1-Score가 높더라도 특정 클래스에서 성능이 극단적으로 낮을 수 있습니다. 반드시 클래스별 F1(per-class F1)과 Confusion Matrix를 함께 확인하여 약점을 파악하세요.
체크리스트
다음 문서