Speculative RAG
Speculative RAG는 소형 전문 모델(RAG Drafter)이 검색된 문서의 서로 다른 부분집합으로 복수의 초안을 병렬 생성하고, 대형 범용 모델(RAG Verifier)이 각 초안을 평가하여 최적의 답변을 선택하는 아키텍처입니다. Speculative Decoding의 아이디어를 RAG에 적용한 것입니다.
핵심 아이디어
기존 RAG는 하나의 모델이 검색된 모든 문서를 한꺼번에 처리하여 단일 답변을 생성합니다. 이 방식은 문서 간 노이즈에 취약하고, 답변의 다양성을 확보하기 어렵습니다.
Speculative RAG는 두 모델의 역할을 분리합니다.
| 역할 | 모델 | 특징 |
|---|
| RAG Drafter | 소형 전문 모델 | 각 문서 부분집합에서 빠르게 초안 생성, 병렬 실행 |
| RAG Verifier | 대형 범용 모델 | 초안의 품질을 평가하고 최적의 답변 선택 |
이 접근 방식은 LLM의 Speculative Decoding에서 영감을 받았습니다. Speculative Decoding에서 소형 모델이 후보 토큰을 생성하고 대형 모델이 검증하듯이, Speculative RAG에서는 소형 모델이 후보 답변을 생성하고 대형 모델이 검증합니다.
동작 방식
각 Drafter는 서로 다른 문서 부분집합만을 참고하므로, 동일한 질문에 대해 다양한 관점의 초안이 생성됩니다. Verifier는 이 초안들을 종합적으로 평가하여 가장 정확하고 근거가 충실한 답변을 선택합니다.
기존 RAG와의 차이점
| 항목 | Standard RAG | Speculative RAG |
|---|
| 모델 구조 | 단일 모델 | Drafter (소형) + Verifier (대형) |
| 답변 생성 | 단일 답변 | K개 초안 병렬 생성 |
| 문서 활용 | 전체 문서를 한 번에 처리 | 부분집합으로 분할하여 독립 처리 |
| 검증 단계 | 없음 (또는 후처리) | Verifier가 각 초안을 점수화 |
| 노이즈 내성 | 문서 노이즈에 취약 | 부분집합 분할로 노이즈 영향 분산 |
| 지연 시간 | 대형 모델의 긴 문서 처리 시간 | 소형 모델 병렬 처리로 단축 |
LangGraph 구현
상태 정의
from typing import TypedDict, List, Annotated
from langchain_core.documents import Document
import operator
class DraftState(TypedDict):
question: str
subset: List[Document]
draft: str
score: float
class GraphState(TypedDict):
question: str
documents: List[Document]
drafts: Annotated[list[DraftState], operator.add]
final_answer: str
노드 함수
검색 + 분할
초안 생성 (Drafter)
검증 (Verifier)
최종 선택
import random
from langchain.chat_models import init_chat_model
from langchain_openai import OpenAIEmbeddings
from langchain_chroma import Chroma
embeddings = OpenAIEmbeddings()
vectorstore = Chroma(collection_name="documents", embedding_function=embeddings)
retriever = vectorstore.as_retriever(search_kwargs={"k": 6})
NUM_DRAFTS = 3
def retrieve(state: GraphState) -> GraphState:
"""문서를 검색합니다."""
question = state["question"]
documents = retriever.invoke(question)
return {"documents": documents}
def assign_drafts(state: GraphState) -> list:
"""검색된 문서를 부분집합으로 분할하여 병렬 초안 생성을 준비합니다."""
from langgraph.types import Send
documents = state["documents"]
question = state["question"]
# 문서를 K개의 부분집합으로 랜덤 분할
shuffled = documents.copy()
random.shuffle(shuffled)
subset_size = max(1, len(shuffled) // NUM_DRAFTS)
sends = []
for i in range(NUM_DRAFTS):
start = i * subset_size
end = start + subset_size if i < NUM_DRAFTS - 1 else len(shuffled)
subset = shuffled[start:end]
sends.append(Send("generate_draft", {
"question": question,
"subset": subset,
"draft": "",
"score": 0.0,
}))
return sends
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
# Drafter: 소형 전문 모델
drafter_llm = init_chat_model("gpt-4o-mini", temperature=0.7)
def generate_draft(state: DraftState) -> GraphState:
"""소형 모델이 문서 부분집합을 기반으로 초안을 생성합니다."""
question = state["question"]
subset = state["subset"]
prompt = ChatPromptTemplate.from_messages([
("system", (
"주어진 문서만을 참고하여 질문에 답변하세요.\n"
"문서에 근거한 사실만 작성하세요.\n\n"
"문서:\n{context}"
)),
("human", "{question}"),
])
chain = prompt | drafter_llm | StrOutputParser()
context = "\n\n".join(doc.page_content for doc in subset)
draft = chain.invoke({"context": context, "question": question})
return {"drafts": [{"question": question, "subset": subset, "draft": draft, "score": 0.0}]}
# Verifier: 대형 범용 모델
verifier_llm = init_chat_model("gpt-4o", temperature=0)
def verify_drafts(state: GraphState) -> GraphState:
"""대형 모델이 각 초안의 품질을 평가합니다."""
question = state["question"]
drafts = state["drafts"]
prompt = ChatPromptTemplate.from_messages([
("system", (
"질문에 대한 답변 초안을 1~10점으로 평가하세요.\n"
"평가 기준: 정확성, 완전성, 문서 근거 충실도\n"
"숫자만 출력하세요."
)),
("human", "질문: {question}\n\n초안: {draft}"),
])
chain = prompt | verifier_llm | StrOutputParser()
scored_drafts = []
for d in drafts:
result = chain.invoke({"question": question, "draft": d["draft"]})
try:
score = float(result.strip())
except ValueError:
score = 0.0
scored_drafts.append({**d, "score": score})
return {"drafts": scored_drafts}
def select_best(state: GraphState) -> GraphState:
"""가장 높은 점수의 초안을 최종 답변으로 선택합니다."""
drafts = state["drafts"]
best = max(drafts, key=lambda d: d["score"])
return {"final_answer": best["draft"]}
그래프 구성
from langgraph.graph import StateGraph, START, END
workflow = StateGraph(GraphState)
# 노드 추가
workflow.add_node("retrieve", retrieve)
workflow.add_node("generate_draft", generate_draft)
workflow.add_node("verify_drafts", verify_drafts)
workflow.add_node("select_best", select_best)
# 엣지 연결
workflow.add_edge(START, "retrieve")
workflow.add_conditional_edges("retrieve", assign_drafts) # Send API로 병렬 초안 생성
workflow.add_edge("generate_draft", "verify_drafts")
workflow.add_edge("verify_drafts", "select_best")
workflow.add_edge("select_best", END)
# 컴파일 및 실행
app = workflow.compile()
result = app.invoke({"question": "Speculative RAG의 핵심 원리는?", "drafts": []})
print(result["final_answer"])
Drafter 병렬 실행 시 API 호출이 동시에 발생하므로, Rate Limit에 주의하세요. NUM_DRAFTS를 조절하여 병렬도를 관리할 수 있습니다.
장단점
| 항목 | 설명 |
|---|
| 장점 | 소형 모델 병렬 처리로 지연 시간 단축 |
| 장점 | 다양한 문서 부분집합에서 다각도 답변 생성 |
| 장점 | Verifier가 환각 및 낮은 품질 초안을 필터링 |
| 장점 | 문서 노이즈의 영향을 부분집합 분할로 분산 |
| 단점 | 두 모델 운영에 따른 인프라 복잡도 증가 |
| 단점 | Drafter 수(K) 증가 시 API 비용 증가 |
| 단점 | Verifier의 평가 품질에 최종 결과가 의존 |
참고 논문
| 논문 | 학회 | 링크 |
|---|
| Speculative RAG: Enhancing Retrieval Augmented Generation through Drafting (Wang et al., 2024) | - | arXiv 2407.08223 |