세그멘테이션 프로젝트 — 의료 영상 세그멘테이션
의료 영상(예: 피부 병변, 폴립, 세포)에서 관심 영역을 픽셀 단위로 분할하는 Semantic Segmentation 프로젝트입니다. UNet 기반 모델과 segmentation-models-pytorch(smp)를 사용하여 데이터 준비부터 평가까지 전체 파이프라인을 수행합니다.medical-segmentation/
├── data/
│ ├── images/
│ │ ├── train/ # 원본 이미지
│ │ └── val/
│ ├── masks/
│ │ ├── train/ # 바이너리 마스크
│ │ └── val/
│ └── splits.csv # 학습/검증 분할 기록
├── outputs/
│ ├── checkpoints/
│ └── predictions/
├── train.py
├── dataset.py
└── evaluate.py
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
class MedicalSegDataset(Dataset):
"""의료 영상 세그멘테이션 데이터셋입니다."""
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.images = sorted([
f for f in os.listdir(image_dir)
if f.endswith(('.png', '.jpg', '.tif'))
])
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_name = self.images[idx]
mask_name = img_name.replace('.jpg', '.png') # 마스크는 PNG
# 이미지 로드
image = cv2.imread(os.path.join(self.image_dir, img_name))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 마스크 로드 (0 또는 255 → 0 또는 1)
mask = cv2.imread(os.path.join(self.mask_dir, mask_name), cv2.IMREAD_GRAYSCALE)
mask = (mask > 127).astype(np.float32)
if self.transform:
augmented = self.transform(image=image, mask=mask)
image = augmented['image']
mask = augmented['mask']
# 마스크에 채널 차원 추가: (H, W) → (1, H, W)
mask = mask.unsqueeze(0) if isinstance(mask, torch.Tensor) else torch.from_numpy(mask).unsqueeze(0)
return image, mask
def get_medical_transforms(phase='train', image_size=256):
"""의료 영상 전용 증강 파이프라인을 반환합니다."""
if phase == 'train':
return A.Compose([
A.Resize(image_size, image_size),
# 기하학적 변환 (적극 적용)
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.RandomRotate90(p=0.5),
A.ShiftScaleRotate(
shift_limit=0.1,
scale_limit=0.2,
rotate_limit=30,
border_mode=cv2.BORDER_CONSTANT,
p=0.5,
),
A.ElasticTransform(alpha=120, sigma=6, p=0.3),
# 색상 변환 (보수적 적용)
A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3),
A.GaussNoise(var_limit=(5, 25), p=0.2),
# 정규화
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
])
else:
return A.Compose([
A.Resize(image_size, image_size),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
])
# DataLoader 생성
IMAGE_SIZE = 256
BATCH_SIZE = 16
train_dataset = MedicalSegDataset(
'data/images/train', 'data/masks/train',
transform=get_medical_transforms('train', IMAGE_SIZE),
)
val_dataset = MedicalSegDataset(
'data/images/val', 'data/masks/val',
transform=get_medical_transforms('val', IMAGE_SIZE),
)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
print(f"학습: {len(train_dataset)}장, 검증: {len(val_dataset)}장")
ElasticTransform은 의료 영상에서 매우 효과적인 증강입니다. 조직의 변형을 시뮬레이션하여 모델의 일반화 성능을 크게 향상시킵니다. 다만 과도한 변형은 비현실적인 패턴을 만들 수 있으므로 alpha와 sigma를 적절히 조절하세요.
import segmentation_models_pytorch as smp
# UNet 모델 (EfficientNet-B3 백본)
model = smp.Unet(
encoder_name='efficientnet-b3',
encoder_weights='imagenet',
in_channels=3,
classes=1, # 바이너리 세그멘테이션
activation=None, # 손실 함수에서 Sigmoid 처리
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# 모델 파라미터 수 확인
total_params = sum(p.numel() for p in model.parameters()) / 1e6
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6
print(f"전체 파라미터: {total_params:.1f}M")
print(f"학습 가능 파라미터: {trainable_params:.1f}M")
import torch.nn as nn
import torch.nn.functional as F
class DiceBCELoss(nn.Module):
"""Dice Loss와 BCE Loss를 결합한 손실 함수입니다."""
def __init__(self, dice_weight=0.5, bce_weight=0.5, smooth=1.0):
super().__init__()
self.dice_weight = dice_weight
self.bce_weight = bce_weight
self.smooth = smooth
def forward(self, pred, target):
pred_sigmoid = torch.sigmoid(pred)
# Dice Loss
pred_flat = pred_sigmoid.view(-1)
target_flat = target.view(-1)
intersection = (pred_flat * target_flat).sum()
dice = (2. * intersection + self.smooth) / (
pred_flat.sum() + target_flat.sum() + self.smooth
)
dice_loss = 1 - dice
# BCE Loss
bce_loss = F.binary_cross_entropy_with_logits(pred, target)
return self.dice_weight * dice_loss + self.bce_weight * bce_loss
# 학습 설정
criterion = DiceBCELoss(dice_weight=0.5, bce_weight=0.5)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6)
# 학습 함수
def train_one_epoch(model, loader, criterion, optimizer, device):
"""한 에포크를 학습합니다."""
model.train()
total_loss = 0
for images, masks in loader:
images = images.to(device)
masks = masks.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
total_loss += loss.item() * images.size(0)
return total_loss / len(loader.dataset)
# 검증 함수
@torch.no_grad()
def validate(model, loader, criterion, device):
"""모델을 검증하고 IoU와 Dice Score를 계산합니다."""
model.eval()
total_loss = 0
total_iou = 0
total_dice = 0
count = 0
for images, masks in loader:
images = images.to(device)
masks = masks.to(device)
outputs = model(images)
loss = criterion(outputs, masks)
total_loss += loss.item() * images.size(0)
# 예측 이진화
preds = (torch.sigmoid(outputs) > 0.5).float()
# 배치 내 각 이미지에 대해 메트릭 계산
for pred, mask in zip(preds, masks):
iou = compute_iou(pred, mask)
dice = compute_dice(pred, mask)
total_iou += iou
total_dice += dice
count += 1
return total_loss / len(loader.dataset), total_iou / count, total_dice / count
def compute_iou(pred, target, smooth=1e-6):
"""IoU(Intersection over Union)를 계산합니다."""
pred_flat = pred.view(-1)
target_flat = target.view(-1)
intersection = (pred_flat * target_flat).sum()
union = pred_flat.sum() + target_flat.sum() - intersection
return ((intersection + smooth) / (union + smooth)).item()
def compute_dice(pred, target, smooth=1e-6):
"""Dice Score를 계산합니다."""
pred_flat = pred.view(-1)
target_flat = target.view(-1)
intersection = (pred_flat * target_flat).sum()
return ((2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)).item()
# 학습 루프
EPOCHS = 50
best_dice = 0
for epoch in range(EPOCHS):
train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
val_loss, val_iou, val_dice = validate(model, val_loader, criterion, device)
scheduler.step()
print(f"Epoch [{epoch+1}/{EPOCHS}] "
f"Train Loss: {train_loss:.4f} | "
f"Val Loss: {val_loss:.4f} IoU: {val_iou:.4f} Dice: {val_dice:.4f}")
# 최고 성능 모델 저장
if val_dice > best_dice:
best_dice = val_dice
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'val_dice': val_dice,
'val_iou': val_iou,
}, 'outputs/checkpoints/best_model.pth')
print(f" ✓ 최고 모델 저장 (Dice: {val_dice:.4f})")
import matplotlib.pyplot as plt
@torch.no_grad()
def visualize_predictions(model, dataset, device, n_samples=4):
"""예측 결과를 시각화합니다."""
model.eval()
fig, axes = plt.subplots(n_samples, 4, figsize=(16, 4 * n_samples))
for i in range(n_samples):
image, mask = dataset[i]
input_tensor = image.unsqueeze(0).to(device)
pred = torch.sigmoid(model(input_tensor)).squeeze().cpu().numpy()
pred_binary = (pred > 0.5).astype(np.float32)
# 원본 이미지 역정규화
img_np = image.permute(1, 2, 0).numpy()
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img_np = (img_np * std + mean).clip(0, 1)
mask_np = mask.squeeze().numpy()
# 오버레이 생성
overlay = img_np.copy()
overlay[pred_binary > 0.5] = [1, 0, 0] # 예측 영역을 빨간색으로
blended = 0.6 * img_np + 0.4 * overlay
axes[i, 0].imshow(img_np)
axes[i, 0].set_title('원본')
axes[i, 1].imshow(mask_np, cmap='gray')
axes[i, 1].set_title('정답 마스크')
axes[i, 2].imshow(pred, cmap='hot')
axes[i, 2].set_title(f'예측 확률맵')
axes[i, 3].imshow(blended)
axes[i, 3].set_title('오버레이')
for ax in axes[i]:
ax.axis('off')
plt.tight_layout()
plt.savefig('outputs/predictions/visualization.png', dpi=150)
# 검증 세트 시각화
checkpoint = torch.load('outputs/checkpoints/best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
visualize_predictions(model, val_dataset, device)
@torch.no_grad()
def error_analysis(model, loader, device, threshold=0.5):
"""오류 유형별 분석을 수행합니다."""
model.eval()
results = {'dice_scores': [], 'iou_scores': [], 'sizes': []}
for images, masks in loader:
images = images.to(device)
preds = torch.sigmoid(model(images)).cpu()
preds_binary = (preds > threshold).float()
for pred, mask in zip(preds_binary, masks):
dice = compute_dice(pred, mask)
iou = compute_iou(pred, mask)
mask_ratio = mask.sum().item() / mask.numel()
results['dice_scores'].append(dice)
results['iou_scores'].append(iou)
results['sizes'].append(mask_ratio)
# 마스크 크기별 성능 분석
dice_arr = np.array(results['dice_scores'])
size_arr = np.array(results['sizes'])
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# Dice Score 분포
axes[0].hist(dice_arr, bins=30, edgecolor='black')
axes[0].set_title(f'Dice Score 분포 (평균: {dice_arr.mean():.4f})')
axes[0].set_xlabel('Dice Score')
axes[0].axvline(dice_arr.mean(), color='red', linestyle='--')
# 마스크 크기 vs 성능
axes[1].scatter(size_arr, dice_arr, alpha=0.5, s=10)
axes[1].set_title('마스크 크기 vs Dice Score')
axes[1].set_xlabel('마스크 비율')
axes[1].set_ylabel('Dice Score')
plt.tight_layout()
plt.savefig('outputs/predictions/error_analysis.png', dpi=150)
# 성능 구간별 통계
print(f"전체 평균 Dice: {dice_arr.mean():.4f}")
print(f"작은 마스크 (< 5%): Dice = {dice_arr[size_arr < 0.05].mean():.4f}")
print(f"중간 마스크 (5-20%): Dice = {dice_arr[(size_arr >= 0.05) & (size_arr < 0.2)].mean():.4f}")
print(f"큰 마스크 (>= 20%): Dice = {dice_arr[size_arr >= 0.2].mean():.4f}")
error_analysis(model, val_loader, device)
# ONNX 변환
model.eval()
dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE).to(device)
torch.onnx.export(
model, dummy_input, 'outputs/segmentation_model.onnx',
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}},
opset_version=17,
)
# ONNX Runtime 추론 검증
import onnxruntime as ort
session = ort.InferenceSession('outputs/segmentation_model.onnx')
test_input = np.random.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE).astype(np.float32)
output = session.run(None, {'input': test_input})
print(f"ONNX 출력 형태: {output[0].shape}") # (1, 1, 256, 256)
프로젝트 결과 요약 예시
| 항목 | 값 |
|---|---|
| 모델 | UNet (EfficientNet-B3 백본) |
| 입력 크기 | 256 x 256 |
| 학습 데이터 | 500장 |
| 손실 함수 | DiceBCELoss (Dice 0.5 + BCE 0.5) |
| 증강 | Flip, Rotate90, ElasticTransform, ShiftScaleRotate |
| 검증 Dice Score | 0.891 |
| 검증 IoU | 0.823 |
| 추론 시간 | 6.5ms (ONNX Runtime, GPU) |
데이터가 300장 이하로 적을 때 성능을 높이려면?
데이터가 300장 이하로 적을 때 성능을 높이려면?
(1) 증강을 적극 활용하세요. 특히 ElasticTransform, GridDistortion이 의료 영상에서 효과적입니다. (2) 사전학습된 백본(ImageNet)의 가중치를 최대한 활용하고, 디코더 부분만 먼저 학습한 후 전체를 미세 조정하세요. (3) 유사 도메인의 공개 데이터셋으로 사전학습 후 타겟 데이터에 Fine-tuning하는 전략도 고려하세요.
다중 클래스 세그멘테이션으로 확장하려면?
다중 클래스 세그멘테이션으로 확장하려면?
모델의 classes 파라미터를 클래스 수로 변경하고, 마스크를 원핫 인코딩 또는 클래스 인덱스 형식으로 변환합니다. 손실 함수는 CrossEntropyLoss(클래스 인덱스) 또는 클래스별 Dice Loss의 평균을 사용합니다. 각 클래스의 데이터 비율이 다르면 클래스 가중치를 설정하세요.

