Skip to main content

Semantic Segmentation

Semantic Segmentation은 이미지의 모든 픽셀에 클래스 레이블을 할당합니다. 같은 클래스의 객체는 구분하지 않으며, 도로/건물/하늘처럼 영역 단위 분류에 적합합니다.
1
환경 준비
2
pip install segmentation-models-pytorch albumentations opencv-python-headless
3
UNet 모델 구성
4
import segmentation_models_pytorch as smp

# UNet + ResNet34 백본 (ImageNet 사전학습)
model = smp.Unet(
    encoder_name='resnet34',
    encoder_weights='imagenet',
    in_channels=3,
    classes=3,  # 클래스 수 (배경 포함)
)

print(f"파라미터 수: {sum(p.numel() for p in model.parameters()):,}")
5
DeepLabV3+ 모델 구성
6
# DeepLabV3+ (ASPP + Decoder)
model = smp.DeepLabV3Plus(
    encoder_name='resnet50',
    encoder_weights='imagenet',
    in_channels=3,
    classes=3,
)
7
데이터셋 준비
8
import torch
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
import albumentations as A

class SegmentationDataset(Dataset):
    """세그멘테이션 데이터셋입니다."""

    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = cv2.imread(self.image_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # 정규화 및 텐서 변환
        image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
        mask = torch.from_numpy(mask).long()

        return image, mask

# 증강 파이프라인
train_transform = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.2),
    A.RandomBrightnessContrast(p=0.3),
])
9
학습 루프
10
import torch.nn as nn
import torch.optim as optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Dice Loss + CrossEntropy 조합
criterion = smp.losses.DiceLoss(mode='multiclass')
ce_loss = nn.CrossEntropyLoss()

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)

num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0

    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)  # [B, C, H, W]

        loss = criterion(outputs, masks) + ce_loss(outputs, masks)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    scheduler.step()

    # 검증
    model.eval()
    val_iou = 0.0
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)
            outputs = model(images)
            preds = outputs.argmax(dim=1)
            val_iou += compute_iou(preds, masks)

    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {epoch_loss/len(train_loader):.4f} | Val mIoU: {val_iou/len(val_loader):.4f}")
11
추론과 시각화
12
import matplotlib.pyplot as plt

def predict_and_visualize(model, image_path, device):
    """이미지에 대해 세그멘테이션을 수행하고 결과를 시각화합니다."""
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    resized = cv2.resize(image, (256, 256))

    input_tensor = torch.from_numpy(resized).permute(2, 0, 1).float() / 255.0
    input_tensor = input_tensor.unsqueeze(0).to(device)

    model.eval()
    with torch.no_grad():
        output = model(input_tensor)
        pred_mask = output.argmax(dim=1).squeeze().cpu().numpy()

    # 시각화
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(resized)
    axes[0].set_title('원본')
    axes[1].imshow(pred_mask, cmap='jet')
    axes[1].set_title('예측 마스크')
    plt.tight_layout()
    plt.show()

UNet vs DeepLabV3+ 비교

비교 항목UNetDeepLabV3+
구조대칭 Encoder-Decoder + Skip ConnectionASPP + Decoder
멀티스케일Skip Connection으로 처리ASPP(Atrous Spatial Pyramid Pooling)
원래 용도의료 영상일반 장면
경계 정밀도높음매우 높음
추천 용도의료, 소규모 데이터일반 장면, 자율주행
segmentation-models-pytorch는 timm의 모든 백본을 지원합니다. resnet34, efficientnet-b4, mit_b2(SegFormer), convnext_tiny 등을 encoder_name에 지정할 수 있습니다. 소규모 데이터에는 resnet34, 고성능이 필요하면 efficientnet-b4 이상을 추천합니다.
그레이스케일 PNG 이미지로, 각 픽셀 값이 클래스 인덱스(0, 1, 2, …)를 나타냅니다. 배경은 0, 클래스1은 1, 클래스2는 2 등으로 설정합니다. RGB 컬러 마스크는 클래스 매핑 테이블로 인덱스 마스크로 변환해야 합니다.