import torch
import torch.nn as nn
class DCGANGenerator(nn.Module):
"""DCGAN 생성자 (64x64 이미지)"""
def __init__(self, latent_dim=100, channels=3, feature_maps=64):
super().__init__()
fm = feature_maps
self.net = nn.Sequential(
# (latent_dim,) → (fm*8, 4, 4)
nn.ConvTranspose2d(latent_dim, fm * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(fm * 8),
nn.ReLU(True),
# (fm*8, 4, 4) → (fm*4, 8, 8)
nn.ConvTranspose2d(fm * 8, fm * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(fm * 4),
nn.ReLU(True),
# (fm*4, 8, 8) → (fm*2, 16, 16)
nn.ConvTranspose2d(fm * 4, fm * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(fm * 2),
nn.ReLU(True),
# (fm*2, 16, 16) → (fm, 32, 32)
nn.ConvTranspose2d(fm * 2, fm, 4, 2, 1, bias=False),
nn.BatchNorm2d(fm),
nn.ReLU(True),
# (fm, 32, 32) → (channels, 64, 64)
nn.ConvTranspose2d(fm, channels, 4, 2, 1, bias=False),
nn.Tanh(),
)
def forward(self, z):
z = z.view(z.size(0), -1, 1, 1) # (배치, latent_dim, 1, 1)
return self.net(z)
class DCGANDiscriminator(nn.Module):
"""DCGAN 판별자 (64x64 이미지)"""
def __init__(self, channels=3, feature_maps=64):
super().__init__()
fm = feature_maps
self.net = nn.Sequential(
# (channels, 64, 64) → (fm, 32, 32)
nn.Conv2d(channels, fm, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# (fm, 32, 32) → (fm*2, 16, 16)
nn.Conv2d(fm, fm * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(fm * 2),
nn.LeakyReLU(0.2, inplace=True),
# (fm*2, 16, 16) → (fm*4, 8, 8)
nn.Conv2d(fm * 2, fm * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(fm * 4),
nn.LeakyReLU(0.2, inplace=True),
# (fm*4, 8, 8) → (fm*8, 4, 4)
nn.Conv2d(fm * 4, fm * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(fm * 8),
nn.LeakyReLU(0.2, inplace=True),
# (fm*8, 4, 4) → (1, 1, 1)
nn.Conv2d(fm * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid(),
)
def forward(self, x):
return self.net(x).view(-1, 1)