Skip to main content

세그멘테이션 프로젝트 — 의료 영상 세그멘테이션

의료 영상(예: 피부 병변, 폴립, 세포)에서 관심 영역을 픽셀 단위로 분할하는 Semantic Segmentation 프로젝트입니다. UNet 기반 모델과 segmentation-models-pytorch(smp)를 사용하여 데이터 준비부터 평가까지 전체 파이프라인을 수행합니다.
1
프로젝트 구조 설정
2
medical-segmentation/
├── data/
│   ├── images/
│   │   ├── train/         # 원본 이미지
│   │   └── val/
│   ├── masks/
│   │   ├── train/         # 바이너리 마스크
│   │   └── val/
│   └── splits.csv         # 학습/검증 분할 기록
├── outputs/
│   ├── checkpoints/
│   └── predictions/
├── train.py
├── dataset.py
└── evaluate.py
3
데이터셋 클래스 구현
4
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2

class MedicalSegDataset(Dataset):
    """의료 영상 세그멘테이션 데이터셋입니다."""

    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform

        self.images = sorted([
            f for f in os.listdir(image_dir)
            if f.endswith(('.png', '.jpg', '.tif'))
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        mask_name = img_name.replace('.jpg', '.png')  # 마스크는 PNG

        # 이미지 로드
        image = cv2.imread(os.path.join(self.image_dir, img_name))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # 마스크 로드 (0 또는 255 → 0 또는 1)
        mask = cv2.imread(os.path.join(self.mask_dir, mask_name), cv2.IMREAD_GRAYSCALE)
        mask = (mask > 127).astype(np.float32)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # 마스크에 채널 차원 추가: (H, W) → (1, H, W)
        mask = mask.unsqueeze(0) if isinstance(mask, torch.Tensor) else torch.from_numpy(mask).unsqueeze(0)

        return image, mask
5
의료 영상 전용 증강
6
의료 영상은 일반 이미지와 증강 전략이 다릅니다. 색상 변환은 조심스럽게, 기하학적 변환은 적극적으로 적용합니다.
7
def get_medical_transforms(phase='train', image_size=256):
    """의료 영상 전용 증강 파이프라인을 반환합니다."""
    if phase == 'train':
        return A.Compose([
            A.Resize(image_size, image_size),
            # 기하학적 변환 (적극 적용)
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.ShiftScaleRotate(
                shift_limit=0.1,
                scale_limit=0.2,
                rotate_limit=30,
                border_mode=cv2.BORDER_CONSTANT,
                p=0.5,
            ),
            A.ElasticTransform(alpha=120, sigma=6, p=0.3),
            # 색상 변환 (보수적 적용)
            A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3),
            A.GaussNoise(var_limit=(5, 25), p=0.2),
            # 정규화
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
            A.Resize(image_size, image_size),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])

# DataLoader 생성
IMAGE_SIZE = 256
BATCH_SIZE = 16

train_dataset = MedicalSegDataset(
    'data/images/train', 'data/masks/train',
    transform=get_medical_transforms('train', IMAGE_SIZE),
)
val_dataset = MedicalSegDataset(
    'data/images/val', 'data/masks/val',
    transform=get_medical_transforms('val', IMAGE_SIZE),
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

print(f"학습: {len(train_dataset)}장, 검증: {len(val_dataset)}장")
8
ElasticTransform은 의료 영상에서 매우 효과적인 증강입니다. 조직의 변형을 시뮬레이션하여 모델의 일반화 성능을 크게 향상시킵니다. 다만 과도한 변형은 비현실적인 패턴을 만들 수 있으므로 alpha와 sigma를 적절히 조절하세요.
9
모델 정의
10
import segmentation_models_pytorch as smp

# UNet 모델 (EfficientNet-B3 백본)
model = smp.Unet(
    encoder_name='efficientnet-b3',
    encoder_weights='imagenet',
    in_channels=3,
    classes=1,             # 바이너리 세그멘테이션
    activation=None,       # 손실 함수에서 Sigmoid 처리
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# 모델 파라미터 수 확인
total_params = sum(p.numel() for p in model.parameters()) / 1e6
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6
print(f"전체 파라미터: {total_params:.1f}M")
print(f"학습 가능 파라미터: {trainable_params:.1f}M")
11
모델백본파라미터추천 상황UNetEfficientNet-B3~13M범용, 소규모 데이터UNet++ResNet50~26M경계 정밀도 중시DeepLabV3+ResNet50~26M큰 객체, 넓은 컨텍스트FPNEfficientNet-B4~20M다중 스케일
12
손실 함수와 학습
13
의료 영상 세그멘테이션에서는 Dice Loss와 BCE Loss를 조합하면 성능이 향상됩니다.
14
import torch.nn as nn
import torch.nn.functional as F

class DiceBCELoss(nn.Module):
    """Dice Loss와 BCE Loss를 결합한 손실 함수입니다."""

    def __init__(self, dice_weight=0.5, bce_weight=0.5, smooth=1.0):
        super().__init__()
        self.dice_weight = dice_weight
        self.bce_weight = bce_weight
        self.smooth = smooth

    def forward(self, pred, target):
        pred_sigmoid = torch.sigmoid(pred)

        # Dice Loss
        pred_flat = pred_sigmoid.view(-1)
        target_flat = target.view(-1)
        intersection = (pred_flat * target_flat).sum()
        dice = (2. * intersection + self.smooth) / (
            pred_flat.sum() + target_flat.sum() + self.smooth
        )
        dice_loss = 1 - dice

        # BCE Loss
        bce_loss = F.binary_cross_entropy_with_logits(pred, target)

        return self.dice_weight * dice_loss + self.bce_weight * bce_loss

# 학습 설정
criterion = DiceBCELoss(dice_weight=0.5, bce_weight=0.5)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6)
15
# 학습 함수
def train_one_epoch(model, loader, criterion, optimizer, device):
    """한 에포크를 학습합니다."""
    model.train()
    total_loss = 0

    for images, masks in loader:
        images = images.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)

    return total_loss / len(loader.dataset)

# 검증 함수
@torch.no_grad()
def validate(model, loader, criterion, device):
    """모델을 검증하고 IoU와 Dice Score를 계산합니다."""
    model.eval()
    total_loss = 0
    total_iou = 0
    total_dice = 0
    count = 0

    for images, masks in loader:
        images = images.to(device)
        masks = masks.to(device)

        outputs = model(images)
        loss = criterion(outputs, masks)
        total_loss += loss.item() * images.size(0)

        # 예측 이진화
        preds = (torch.sigmoid(outputs) > 0.5).float()

        # 배치 내 각 이미지에 대해 메트릭 계산
        for pred, mask in zip(preds, masks):
            iou = compute_iou(pred, mask)
            dice = compute_dice(pred, mask)
            total_iou += iou
            total_dice += dice
            count += 1

    return total_loss / len(loader.dataset), total_iou / count, total_dice / count

def compute_iou(pred, target, smooth=1e-6):
    """IoU(Intersection over Union)를 계산합니다."""
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)
    intersection = (pred_flat * target_flat).sum()
    union = pred_flat.sum() + target_flat.sum() - intersection
    return ((intersection + smooth) / (union + smooth)).item()

def compute_dice(pred, target, smooth=1e-6):
    """Dice Score를 계산합니다."""
    pred_flat = pred.view(-1)
    target_flat = target.view(-1)
    intersection = (pred_flat * target_flat).sum()
    return ((2. * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)).item()
16
# 학습 루프
EPOCHS = 50
best_dice = 0

for epoch in range(EPOCHS):
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_iou, val_dice = validate(model, val_loader, criterion, device)
    scheduler.step()

    print(f"Epoch [{epoch+1}/{EPOCHS}] "
          f"Train Loss: {train_loss:.4f} | "
          f"Val Loss: {val_loss:.4f} IoU: {val_iou:.4f} Dice: {val_dice:.4f}")

    # 최고 성능 모델 저장
    if val_dice > best_dice:
        best_dice = val_dice
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'val_dice': val_dice,
            'val_iou': val_iou,
        }, 'outputs/checkpoints/best_model.pth')
        print(f"  ✓ 최고 모델 저장 (Dice: {val_dice:.4f})")
17
예측 시각화
18
import matplotlib.pyplot as plt

@torch.no_grad()
def visualize_predictions(model, dataset, device, n_samples=4):
    """예측 결과를 시각화합니다."""
    model.eval()
    fig, axes = plt.subplots(n_samples, 4, figsize=(16, 4 * n_samples))

    for i in range(n_samples):
        image, mask = dataset[i]
        input_tensor = image.unsqueeze(0).to(device)
        pred = torch.sigmoid(model(input_tensor)).squeeze().cpu().numpy()
        pred_binary = (pred > 0.5).astype(np.float32)

        # 원본 이미지 역정규화
        img_np = image.permute(1, 2, 0).numpy()
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_np = (img_np * std + mean).clip(0, 1)

        mask_np = mask.squeeze().numpy()

        # 오버레이 생성
        overlay = img_np.copy()
        overlay[pred_binary > 0.5] = [1, 0, 0]  # 예측 영역을 빨간색으로
        blended = 0.6 * img_np + 0.4 * overlay

        axes[i, 0].imshow(img_np)
        axes[i, 0].set_title('원본')
        axes[i, 1].imshow(mask_np, cmap='gray')
        axes[i, 1].set_title('정답 마스크')
        axes[i, 2].imshow(pred, cmap='hot')
        axes[i, 2].set_title(f'예측 확률맵')
        axes[i, 3].imshow(blended)
        axes[i, 3].set_title('오버레이')

        for ax in axes[i]:
            ax.axis('off')

    plt.tight_layout()
    plt.savefig('outputs/predictions/visualization.png', dpi=150)

# 검증 세트 시각화
checkpoint = torch.load('outputs/checkpoints/best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
visualize_predictions(model, val_dataset, device)
19
오류 분석
20
@torch.no_grad()
def error_analysis(model, loader, device, threshold=0.5):
    """오류 유형별 분석을 수행합니다."""
    model.eval()
    results = {'dice_scores': [], 'iou_scores': [], 'sizes': []}

    for images, masks in loader:
        images = images.to(device)
        preds = torch.sigmoid(model(images)).cpu()
        preds_binary = (preds > threshold).float()

        for pred, mask in zip(preds_binary, masks):
            dice = compute_dice(pred, mask)
            iou = compute_iou(pred, mask)
            mask_ratio = mask.sum().item() / mask.numel()

            results['dice_scores'].append(dice)
            results['iou_scores'].append(iou)
            results['sizes'].append(mask_ratio)

    # 마스크 크기별 성능 분석
    dice_arr = np.array(results['dice_scores'])
    size_arr = np.array(results['sizes'])

    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # Dice Score 분포
    axes[0].hist(dice_arr, bins=30, edgecolor='black')
    axes[0].set_title(f'Dice Score 분포 (평균: {dice_arr.mean():.4f})')
    axes[0].set_xlabel('Dice Score')
    axes[0].axvline(dice_arr.mean(), color='red', linestyle='--')

    # 마스크 크기 vs 성능
    axes[1].scatter(size_arr, dice_arr, alpha=0.5, s=10)
    axes[1].set_title('마스크 크기 vs Dice Score')
    axes[1].set_xlabel('마스크 비율')
    axes[1].set_ylabel('Dice Score')

    plt.tight_layout()
    plt.savefig('outputs/predictions/error_analysis.png', dpi=150)

    # 성능 구간별 통계
    print(f"전체 평균 Dice: {dice_arr.mean():.4f}")
    print(f"작은 마스크 (< 5%): Dice = {dice_arr[size_arr < 0.05].mean():.4f}")
    print(f"중간 마스크 (5-20%): Dice = {dice_arr[(size_arr >= 0.05) & (size_arr < 0.2)].mean():.4f}")
    print(f"큰 마스크 (>= 20%): Dice = {dice_arr[size_arr >= 0.2].mean():.4f}")

error_analysis(model, val_loader, device)
21
모델 내보내기
22
# ONNX 변환
model.eval()
dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE).to(device)

torch.onnx.export(
    model, dummy_input, 'outputs/segmentation_model.onnx',
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}},
    opset_version=17,
)

# ONNX Runtime 추론 검증
import onnxruntime as ort

session = ort.InferenceSession('outputs/segmentation_model.onnx')
test_input = np.random.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE).astype(np.float32)
output = session.run(None, {'input': test_input})
print(f"ONNX 출력 형태: {output[0].shape}")  # (1, 1, 256, 256)

프로젝트 결과 요약 예시

항목
모델UNet (EfficientNet-B3 백본)
입력 크기256 x 256
학습 데이터500장
손실 함수DiceBCELoss (Dice 0.5 + BCE 0.5)
증강Flip, Rotate90, ElasticTransform, ShiftScaleRotate
검증 Dice Score0.891
검증 IoU0.823
추론 시간6.5ms (ONNX Runtime, GPU)
의료 영상 세그멘테이션 모델을 임상에 적용하려면 반드시 전문 의료진의 검증이 필요합니다. 모델 예측만으로 진단을 내리면 안 되며, 보조 도구로서의 역할을 명확히 해야 합니다. 또한 데이터 프라이버시(HIPAA, 개인정보보호법)를 반드시 준수하세요.
(1) 증강을 적극 활용하세요. 특히 ElasticTransform, GridDistortion이 의료 영상에서 효과적입니다. (2) 사전학습된 백본(ImageNet)의 가중치를 최대한 활용하고, 디코더 부분만 먼저 학습한 후 전체를 미세 조정하세요. (3) 유사 도메인의 공개 데이터셋으로 사전학습 후 타겟 데이터에 Fine-tuning하는 전략도 고려하세요.
모델의 classes 파라미터를 클래스 수로 변경하고, 마스크를 원핫 인코딩 또는 클래스 인덱스 형식으로 변환합니다. 손실 함수는 CrossEntropyLoss(클래스 인덱스) 또는 클래스별 Dice Loss의 평균을 사용합니다. 각 클래스의 데이터 비율이 다르면 클래스 가중치를 설정하세요.