from langchain.chat_models import init_chat_model
from langchain_openai import OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
llm = init_chat_model("gpt-4o-mini", temperature=0)
embeddings = OpenAIEmbeddings()
vectorstore = Chroma(collection_name="documents", embedding_function=embeddings)
retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
def rewrite_query(state: GraphState) -> GraphState:
"""사용자 질문을 검색에 적합한 형태로 재작성합니다."""
question = state["question"]
prompt = ChatPromptTemplate.from_messages([
("system", "주어진 질문을 벡터 검색에 최적화된 형태로 재작성하세요. 재작성된 질문만 출력하세요."),
("human", "{question}"),
])
chain = prompt | llm | StrOutputParser()
rewritten = chain.invoke({"question": question})
return {"rewritten_query": rewritten}
def retrieve(state: GraphState) -> GraphState:
"""재작성된 쿼리로 문서를 검색합니다."""
query = state["rewritten_query"]
documents = retriever.invoke(query)
return {"documents": documents}
def rerank(state: GraphState) -> GraphState:
"""검색된 문서를 관련도 순으로 재순위화합니다."""
question = state["question"]
documents = state["documents"]
prompt = ChatPromptTemplate.from_messages([
("system", (
"질문과 문서 목록이 주어집니다. "
"각 문서의 관련도를 0~10 점수로 평가하세요.\n"
"형식: 문서번호:점수 (한 줄에 하나씩)\n\n"
"질문: {question}"
)),
("human", "{documents}"),
])
docs_text = "\n---\n".join(
f"[문서 {i}] {doc.page_content}" for i, doc in enumerate(documents)
)
chain = prompt | llm | StrOutputParser()
scores_text = chain.invoke({"question": question, "documents": docs_text})
# 점수 파싱 및 상위 문서 선택
scored_docs = []
for line in scores_text.strip().split("\n"):
try:
idx, score = line.split(":")
idx = int(idx.strip().replace("문서", "").strip())
score = float(score.strip())
if idx < len(documents):
scored_docs.append((score, documents[idx]))
except (ValueError, IndexError):
continue
scored_docs.sort(key=lambda x: x[0], reverse=True)
top_docs = [doc for _, doc in scored_docs[:4]]
return {"reranked_documents": top_docs}
def generate(state: GraphState) -> GraphState:
"""재순위화된 문서를 기반으로 답변을 생성합니다."""
question = state["question"]
documents = state["reranked_documents"]
prompt = ChatPromptTemplate.from_messages([
("system", "다음 컨텍스트를 참고하여 질문에 답변하세요.\n\n{context}"),
("human", "{question}"),
])
chain = prompt | llm | StrOutputParser()
context = "\n\n".join(doc.page_content for doc in documents)
generation = chain.invoke({"context": context, "question": question})
return {"generation": generation}