Skip to main content

Vision Transformer — 이미지를 위한 트랜스포머

Vision Transformer(ViT)는 NLP에서 성공한 Transformer 아키텍처를 이미지 분류에 적용한 모델입니다. 이미지를 패치(Patch) 단위로 분할하여 시퀀스로 변환한 뒤, Self-Attention으로 전역적 관계를 학습합니다.

핵심 아이디어

CNN은 로컬 영역(커널)을 순차적으로 확장하여 전역 정보를 파악합니다. 반면 ViT는 첫 번째 레이어부터 이미지 전체의 패치 간 관계를 직접 학습합니다. “An Image is Worth 16x16 Words”라는 논문 제목처럼, 16x16 픽셀 패치를 하나의 토큰으로 취급합니다.

동작 방식

패치 임베딩 (Patch Embedding)

224x224 이미지를 16x16 패치로 분할하면 14x14 = 196개의 패치가 생성됩니다. 각 패치는 선형 변환(Linear Projection)으로 D차원 벡터로 변환됩니다.
import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    """이미지를 패치로 분할하고 임베딩합니다."""
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2  # 196
        # 패치 분할 + 선형 변환을 Conv2d로 구현
        self.proj = nn.Conv2d(in_channels, embed_dim,
                              kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x: [B, 3, 224, 224]
        x = self.proj(x)       # [B, 768, 14, 14]
        x = x.flatten(2)       # [B, 768, 196]
        x = x.transpose(1, 2)  # [B, 196, 768]
        return x

CLS 토큰과 위치 임베딩

class ViTEmbedding(nn.Module):
    """CLS 토큰과 위치 임베딩을 추가합니다."""
    def __init__(self, img_size=224, patch_size=16, embed_dim=768):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        # CLS 토큰: 이미지 전체를 대표하는 학습 가능 벡터
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        # 위치 임베딩: 패치의 공간적 위치 정보
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)  # [B, 196, 768]

        # CLS 토큰 추가
        cls_tokens = self.cls_token.expand(B, -1, -1)  # [B, 1, 768]
        x = torch.cat([cls_tokens, x], dim=1)  # [B, 197, 768]

        # 위치 임베딩 추가
        x = x + self.pos_embed  # [B, 197, 768]
        return x

구현

timm으로 ViT 사용

import timm

# ViT-Base/16 (16x16 패치, 12개 Transformer 블록)
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)

# 모델 구조 확인
print(f"패치 임베딩: {model.patch_embed}")
print(f"Transformer 블록 수: {len(model.blocks)}")
print(f"파라미터 수: {sum(p.numel() for p in model.parameters()):,}")

# 추론
dummy = torch.randn(1, 3, 224, 224)
output = model(dummy)
print(f"출력: {output.shape}")  # [1, 10]

Attention Map 시각화

def get_attention_map(model, image_tensor):
    """ViT의 마지막 레이어 Attention Map을 추출합니다."""
    model.eval()
    attention_maps = []

    # 어텐션 가중치를 저장하는 훅 등록
    def hook_fn(module, input, output):
        # output: (attn_output, attn_weights)
        attention_maps.append(output[1].detach())

    # 마지막 블록의 Attention에 훅 등록
    hook = model.blocks[-1].attn.register_forward_hook(hook_fn)

    with torch.no_grad():
        _ = model(image_tensor.unsqueeze(0))

    hook.remove()

    # CLS 토큰이 다른 패치에 주목하는 정도
    attn = attention_maps[0]  # [1, heads, 197, 197]
    attn = attn.mean(dim=1)   # 헤드 평균 [1, 197, 197]
    cls_attn = attn[0, 0, 1:]  # CLS → 패치 [196]
    cls_attn = cls_attn.reshape(14, 14)  # 2D 맵으로 변환
    return cls_attn.numpy()

관련 모델 비교

모델핵심 개선장점
ViT-B/16순수 Transformer대규모 데이터에서 최고 성능
DeiTKnowledge DistillationImageNet만으로 효과적 학습
Swin Transformer윈도우 기반 어텐션, 계층 구조효율적, 탐지/세그멘테이션에도 활용
BEiTMasked Image Modeling자기지도 사전학습
MAEMasked Autoencoder높은 마스킹 비율로 효율적 사전학습

Swin Transformer

Swin Transformer는 ViT의 계산 비용 문제를 해결합니다. 이미지 전체 대신 로컬 윈도우 내에서 Self-Attention을 수행하고, Shifted Window로 윈도우 간 정보를 교환합니다.
# Swin Transformer 사용
swin = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True, num_classes=10)

# 계층적 특징맵 추출 (탐지/세그멘테이션 백본으로 활용)
swin_backbone = timm.create_model(
    'swin_tiny_patch4_window7_224',
    pretrained=True,
    features_only=True,
)
features = swin_backbone(torch.randn(1, 3, 224, 224))
for i, f in enumerate(features):
    print(f"Stage {i}: {f.shape}")
# Stage 0: [1, 96, 56, 56]
# Stage 1: [1, 192, 28, 28]
# Stage 2: [1, 384, 14, 14]
# Stage 3: [1, 768, 7, 7]

CNN vs ViT 선택 기준

기준CNN 추천ViT 추천
데이터 규모수천 장 이하수만 장 이상
GPU 자원제한적충분
추론 속도중요덜 중요
태스크분류, 탐지분류, 멀티모달
해석 가능성Grad-CAMAttention Map
아닙니다. 소규모 데이터에서는 여전히 CNN이 유리하며, 엣지 배포에서는 CNN이 더 효율적입니다. 실무에서는 ConvNeXt처럼 Transformer의 장점을 흡수한 CNN이 좋은 절충안이 됩니다.
패치 크기가 작을수록 더 세밀한 특징을 포착하지만 시퀀스 길이가 길어져 계산 비용이 증가합니다. 16x16이 표준이며, 속도가 중요하면 32x32를 사용할 수 있습니다.

참고 논문

논문학회/연도링크
An Image is Worth 16x16 Words (ViT)ICLR 2021arXiv:2010.11929
Training data-efficient image transformers (DeiT)ICML 2021arXiv:2012.12877
Swin TransformerICCV 2021arXiv:2103.14030
Masked Autoencoders (MAE)CVPR 2022arXiv:2111.06377