Skip to main content

Grad-CAM — 모델의 판단 근거 시각화

Grad-CAM(Gradient-weighted Class Activation Mapping)은 모델이 이미지의 어느 부분에 주목하여 판단했는지를 히트맵(Heatmap)으로 시각화하는 XAI(Explainable AI) 기법입니다.
1
설치
2
pip install pytorch-grad-cam
3
CNN 모델에 적용
4
import torch
import timm
import cv2
import numpy as np
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# 모델 로드
model = timm.create_model('resnet50', pretrained=True)
model.eval()

# 타겟 레이어 지정 (마지막 Conv 레이어)
target_layers = [model.layer4[-1]]

# Grad-CAM 생성기
cam = GradCAM(model=model, target_layers=target_layers)

# 이미지 준비
image = cv2.imread('image.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_resized = cv2.resize(image, (224, 224))
rgb_img = image_resized / 255.0

# 텐서 변환
input_tensor = torch.from_numpy(image_resized).permute(2, 0, 1).float() / 255.0
input_tensor = input_tensor.unsqueeze(0)

# ImageNet 정규화
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
input_tensor = (input_tensor - mean) / std

# 히트맵 생성 (특정 클래스 기준)
targets = [ClassifierOutputTarget(281)]  # 예: 고양이 클래스
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
grayscale_cam = grayscale_cam[0, :]

# 원본 이미지 위에 히트맵 오버레이
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

import matplotlib.pyplot as plt
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.imshow(image_resized)
plt.title('원본')
plt.subplot(1, 3, 2)
plt.imshow(grayscale_cam, cmap='jet')
plt.title('Grad-CAM')
plt.subplot(1, 3, 3)
plt.imshow(visualization)
plt.title('오버레이')
plt.tight_layout()
plt.savefig('gradcam_result.png', dpi=150)
5
ViT 모델에 적용
6
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.reshape_transforms import vit_reshape_transform

# ViT 모델 로드
model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.eval()

# ViT의 마지막 블록의 LayerNorm을 타겟으로
target_layers = [model.blocks[-1].norm1]

# ViT용 reshape transform 필요
cam = GradCAM(
    model=model,
    target_layers=target_layers,
    reshape_transform=vit_reshape_transform,
)

grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
visualization = show_cam_on_image(rgb_img, grayscale_cam[0, :], use_rgb=True)
7
다양한 CAM 변형
8
from pytorch_grad_cam import (
    GradCAM,
    GradCAMPlusPlus,
    EigenCAM,
    LayerCAM,
)

# 변형별 비교
cam_methods = {
    'Grad-CAM': GradCAM,
    'Grad-CAM++': GradCAMPlusPlus,
    'EigenCAM': EigenCAM,
    'LayerCAM': LayerCAM,
}

fig, axes = plt.subplots(1, len(cam_methods), figsize=(16, 4))
for ax, (name, CamClass) in zip(axes, cam_methods.items()):
    cam = CamClass(model=model, target_layers=target_layers)
    heatmap = cam(input_tensor=input_tensor, targets=targets)[0, :]
    vis = show_cam_on_image(rgb_img, heatmap, use_rgb=True)
    ax.imshow(vis)
    ax.set_title(name)
    ax.axis('off')
plt.tight_layout()
plt.savefig('cam_comparison.png', dpi=150)

CAM 변형 비교

방법특징추천 용도
Grad-CAM가장 기본, 안정적범용
Grad-CAM++다중 객체에 강함여러 객체가 있는 이미지
EigenCAM그래디언트 불필요빠른 시각화
LayerCAM초기 레이어에서도 효과적세밀한 특징 분석

실무 활용 사례

활용설명
모델 디버깅모델이 잘못된 영역에 주목하는지 확인
신뢰성 검증의료/안전 분야에서 모델 판단 근거 제시
데이터 품질 점검배경 편향(Spurious Correlation) 발견
모델 비교서로 다른 모델의 주목 영역 비교
Grad-CAM은 모델의 “주목 영역”을 보여주지만, 이것이 반드시 “올바른 근거”를 의미하지는 않습니다. 모델이 올바른 영역에 주목하더라도 잘못된 예측을 할 수 있고, 그 반대도 가능합니다.
가능하지만 복잡합니다. 탐지 모델은 특정 객체의 활성화만 분리해야 합니다. pytorch-grad-cam의 ObjectDetectionTarget을 사용하거나, 탐지된 영역을 크롭하여 분류 모델에 적용하는 방법이 있습니다.
Grad-CAM은 모델 내부 그래디언트를 직접 사용하여 빠르고 직관적입니다. SHAP과 LIME은 모델을 블랙박스로 취급하여 범용적이지만 계산이 느립니다. 이미지 모델에서는 Grad-CAM이 가장 널리 사용됩니다.