Skip to main content

GAN 기초 (Generative Adversarial Networks)

학습 목표

  • GAN의 생성자-판별자 적대적 학습 구조를 이해한다
  • Min-Max 목적 함수의 수학적 의미를 설명할 수 있다
  • 모드 붕괴(Mode Collapse)와 학습 불안정 문제를 안다
  • Wasserstein GAN의 개선 아이디어를 이해한다

왜 중요한가

생성적 적대 신경망(GAN, Generative Adversarial Network)은 2014년 Goodfellow et al.이 제안한 생성 모델입니다. 두 신경망이 적대적으로 경쟁하며 학습하여, VAE보다 선명하고 사실적인 데이터를 생성합니다. 이미지 생성, 스타일 변환, 데이터 증강 등에 혁신적인 성과를 이뤘습니다.

구조

구성 요소역할목표
생성자(Generator) GG노이즈 zz에서 가짜 데이터 생성판별자를 속이는 것
판별자(Discriminator) DD진짜/가짜 데이터 구분진짜와 가짜를 정확히 분류

목적 함수

Min-Max 게임

minGmaxD  V(D,G)=Expdata[logD(x)]+Ezpz[log(1D(G(z)))]\min_G \max_D \; V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]
  • 판별자 DD: VV최대화 — 진짜는 1, 가짜는 0으로 분류
  • 생성자 GG: VV최소화 — 판별자가 가짜를 진짜로 착각하도록

비포화 생성자 손실 (Non-Saturating Loss)

실전에서는 생성자의 기울기 소실을 방지하기 위해 변형된 손실을 사용합니다. LG=Ezpz[logD(G(z))]\mathcal{L}_G = -\mathbb{E}_{z \sim p_z}[\log D(G(z))]

구현

import torch
import torch.nn as nn

class Generator(nn.Module):
    """생성자: 노이즈 → 이미지"""
    def __init__(self, latent_dim=100, img_dim=784):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),
            nn.Linear(512, img_dim),
            nn.Tanh(),  # 출력 [-1, 1]
        )

    def forward(self, z):
        return self.net(z)


class Discriminator(nn.Module):
    """판별자: 이미지 → 진짜/가짜 확률"""
    def __init__(self, img_dim=784):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(img_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.net(x)

학습

import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 하이퍼파라미터
latent_dim = 100
lr = 2e-4
epochs = 100

# 데이터 ([-1, 1] 정규화)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# 모델
G = Generator(latent_dim=latent_dim).to('cuda')
D = Discriminator().to('cuda')

# 옵티마이저 (Adam, β₁=0.5 권장)
opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

criterion = nn.BCELoss()

for epoch in range(epochs):
    for real_images, _ in train_loader:
        batch_size = real_images.size(0)
        real_images = real_images.view(batch_size, -1).to('cuda')

        # 레이블
        real_labels = torch.ones(batch_size, 1).to('cuda')
        fake_labels = torch.zeros(batch_size, 1).to('cuda')

        # ── 판별자 학습 ──
        z = torch.randn(batch_size, latent_dim).to('cuda')
        fake_images = G(z).detach()  # 생성자 기울기 차단

        d_real = D(real_images)
        d_fake = D(fake_images)

        d_loss = criterion(d_real, real_labels) + criterion(d_fake, fake_labels)

        opt_D.zero_grad()
        d_loss.backward()
        opt_D.step()

        # ── 생성자 학습 ──
        z = torch.randn(batch_size, latent_dim).to('cuda')
        fake_images = G(z)
        d_fake = D(fake_images)

        g_loss = criterion(d_fake, real_labels)  # 비포화 손실

        opt_G.zero_grad()
        g_loss.backward()
        opt_G.step()

    print(f"Epoch {epoch+1}: D 손실={d_loss.item():.4f}, G 손실={g_loss.item():.4f}")

학습 문제와 해결

모드 붕괴 (Mode Collapse)

생성자가 다양성 없이 소수의 출력만 반복 생성하는 현상입니다.

학습 안정화 기법

기법설명효과
레이블 스무딩진짜 레이블을 0.9로, 가짜를 0.1로판별자 과신 방지
스펙트럴 노말라이제이션판별자 가중치의 스펙트럴 노름 제한리프시츠 연속성 보장
학습률 균형판별자와 생성자의 학습률 조절일방적 지배 방지
그래디언트 페널티기울기 크기에 패널티WGAN-GP

Wasserstein GAN (WGAN)

JS Divergence 대신 **Wasserstein 거리(Earth Mover’s Distance)**를 사용하여 학습을 안정화합니다. LWGAN=Expdata[D(x)]Ezpz[D(G(z))]\mathcal{L}_{\text{WGAN}} = \mathbb{E}_{x \sim p_{\text{data}}}[D(x)] - \mathbb{E}_{z \sim p_z}[D(G(z))]
# WGAN-GP 판별자 손실 (개념)
def wgan_gp_d_loss(D, real, fake, lambda_gp=10):
    """WGAN-GP 판별자 손실"""
    d_real = D(real).mean()
    d_fake = D(fake.detach()).mean()

    # Gradient Penalty
    alpha = torch.rand(real.size(0), 1).to(real.device)
    interpolated = (alpha * real + (1 - alpha) * fake.detach()).requires_grad_(True)
    d_interp = D(interpolated)

    gradients = torch.autograd.grad(
        outputs=d_interp, inputs=interpolated,
        grad_outputs=torch.ones_like(d_interp),
        create_graph=True,
    )[0]
    gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean()

    return d_fake - d_real + lambda_gp * gp

GAN 변형 비교

모델연도손실 함수핵심 개선
GAN (원본)2014BCE적대적 학습 제안
WGAN2017Wasserstein학습 안정성
WGAN-GP2017Wasserstein + GP그래디언트 페널티
SNGAN2018BCE + SN스펙트럴 노말라이제이션
GAN 학습은 매우 불안정합니다. 판별자가 너무 강하면 생성자의 기울기가 소실되고, 생성자가 너무 강하면 모드 붕괴가 발생합니다. 학습 곡선을 주의 깊게 모니터링하고, 생성 결과를 주기적으로 시각화해야 합니다.

참고 논문

논문학회/연도핵심 기여
Generative Adversarial Nets (Goodfellow et al.)NeurIPS 2014GAN 제안
Wasserstein GAN (Arjovsky et al.)ICML 2017Wasserstein 거리 기반 학습
Improved Training of Wasserstein GANs (Gulrajani et al.)NeurIPS 2017Gradient Penalty

체크리스트

  • 생성자와 판별자의 역할과 목표를 설명할 수 있다
  • Min-Max 목적 함수의 의미를 이해한다
  • 모드 붕괴의 원인과 해결 방법을 안다
  • WGAN이 원본 GAN의 어떤 문제를 해결하는지 안다

다음 문서