import torch
import torch.nn as nn
import torchvision.transforms as T
class SimCLR(nn.Module):
"""SimCLR 프레임워크"""
def __init__(self, base_encoder, projection_dim=128):
super().__init__()
self.encoder = base_encoder # 예: ResNet-50 (fc 제거)
encoder_dim = 2048 # ResNet-50 출력 차원
# 프로젝션 헤드 (학습 시에만 사용)
self.projector = nn.Sequential(
nn.Linear(encoder_dim, encoder_dim),
nn.ReLU(),
nn.Linear(encoder_dim, projection_dim),
)
def forward(self, x):
h = self.encoder(x) # 표현 (다운스트림 사용)
z = self.projector(h) # 프로젝션 (학습 시에만 사용)
return h, z
# SimCLR 데이터 증강 파이프라인
class SimCLRAugmentation:
"""하나의 이미지에서 두 개의 증강 뷰 생성"""
def __init__(self, size=224):
self.transform = T.Compose([
T.RandomResizedCrop(size, scale=(0.2, 1.0)),
T.RandomHorizontalFlip(),
T.RandomApply([T.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
T.RandomGrayscale(p=0.2),
T.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def __call__(self, x):
return self.transform(x), self.transform(x)