Skip to main content

dataclass와 slots

학습 목표

  • @dataclass로 보일러플레이트 코드를 줄일 수 있다
  • field(), frozen, slots 옵션을 활용할 수 있다
  • __slots__로 메모리를 최적화할 수 있다

왜 중요한가

데이터를 담는 클래스에서 __init__, __repr__, __eq__ 등을 반복 작성하는 것은 비효율적입니다. @dataclass는 이 보일러플레이트를 자동 생성합니다. ML에서 실험 설정, 하이퍼파라미터, 결과 기록 등 데이터 중심 객체에 필수적인 도구입니다.

@dataclass 기본

from dataclasses import dataclass

@dataclass
class Point:
    x: float
    y: float

# 위 코드가 자동 생성하는 것:
# __init__, __repr__, __eq__

p1 = Point(3.0, 4.0)
p2 = Point(3.0, 4.0)

print(p1)          # Point(x=3.0, y=4.0)
print(p1 == p2)    # True
print(p1.x)        # 3.0

기본값과 field()

from dataclasses import dataclass, field

@dataclass
class TrainConfig:
    model_name: str
    epochs: int = 10
    learning_rate: float = 0.001
    batch_size: int = 32
    tags: list = field(default_factory=list)      # 가변 객체 기본값
    metadata: dict = field(default_factory=dict)

config = TrainConfig(model_name="bert-base")
print(config)
# TrainConfig(model_name='bert-base', epochs=10, learning_rate=0.001, ...)

config.tags.append("experiment-1")
@dataclass에서 가변 객체 기본값은 반드시 field(default_factory=...)를 사용하세요. tags: list = []는 에러가 발생합니다.

field() 고급 옵션

from dataclasses import dataclass, field

@dataclass
class Experiment:
    name: str
    score: float
    _id: str = field(init=False, repr=False)           # __init__과 __repr__에서 제외
    description: str = field(default="", compare=False)  # 비교에서 제외

    def __post_init__(self):
        """__init__ 이후 실행되는 후처리"""
        import uuid
        self._id = str(uuid.uuid4())[:8]

frozen (불변 dataclass)

@dataclass(frozen=True)
class Coordinate:
    latitude: float
    longitude: float

coord = Coordinate(37.5, 127.0)
# coord.latitude = 38.0  # FrozenInstanceError!

# frozen이면 해시 가능 -> 딕셔너리 키, 집합 요소 가능
locations = {
    Coordinate(37.5, 127.0): "서울",
    Coordinate(35.1, 129.0): "부산",
}

slots

__slots__는 인스턴스가 가질 수 있는 속성을 제한하여 메모리를 절약합니다.
# 일반 클래스: __dict__로 속성 저장 (유연하지만 메모리 사용 많음)
class PointNormal:
    def __init__(self, x, y):
        self.x = x
        self.y = y

# __slots__ 클래스: 고정 속성만 허용 (메모리 절약)
class PointSlots:
    __slots__ = ("x", "y")

    def __init__(self, x, y):
        self.x = x
        self.y = y

p1 = PointNormal(1, 2)
p2 = PointSlots(1, 2)

import sys
print(sys.getsizeof(p1.__dict__))  # ~100 bytes
# p2는 __dict__가 없음 -> 메모리 절약

# p2.z = 3  # AttributeError! __slots__에 없는 속성 추가 불가

dataclass + slots (Python 3.10+)

@dataclass(slots=True)
class DataPoint:
    x: float
    y: float
    label: str

# __slots__가 자동 설정됨
dp = DataPoint(1.0, 2.0, "A")
# dp.extra = "test"  # AttributeError!

AI/ML에서의 활용

from dataclasses import dataclass, field, asdict

@dataclass
class ModelConfig:
    """모델 설정"""
    name: str
    input_dim: int
    hidden_dim: int = 256
    output_dim: int = 10
    dropout: float = 0.1
    activation: str = "relu"

@dataclass
class TrainingConfig:
    """학습 설정"""
    model: ModelConfig
    epochs: int = 100
    lr: float = 1e-3
    batch_size: int = 32
    device: str = "cuda"
    checkpoint_dir: str = "./checkpoints"

    def to_dict(self):
        return asdict(self)

# 설정 조합
config = TrainingConfig(
    model=ModelConfig(name="mlp", input_dim=784),
    epochs=50,
    lr=0.01,
)

print(config.model.name)      # "mlp"
print(config.to_dict())       # 중첩 딕셔너리로 변환

# JSON 저장
import json
with open("config.json", "w") as f:
    json.dump(config.to_dict(), f, indent=2)
가변이 필요하면 @dataclass, 불변이면 @dataclass(frozen=True) 또는 NamedTuple을 사용합니다. @dataclass가 더 많은 기능(기본값, post_init, field)을 제공하므로 새 코드에서는 @dataclass를 권장합니다.
@dataclass는 타입 검증을 하지 않고, Pydantic BaseModel은 런타임에 타입을 검증합니다. 내부 데이터 구조에는 @dataclass, API 입출력 검증에는 Pydantic을 사용합니다.

체크리스트

  • @dataclass로 데이터 클래스를 정의할 수 있다
  • field()default_factory, init, repr, compare 옵션을 활용할 수 있다
  • frozen=True로 불변 데이터 클래스를 만들 수 있다
  • __slots__의 메모리 절약 원리를 설명할 수 있다

다음 문서