Project
LoGO 해외로고 프로젝트 - RAG 4.
- -
LoGO 해외로고, 해외진출을 희망하는 대한민국 기업을 위한 정보 검색 서비스에서 RAG 부분
네번째 게시물이다.
https://github.com/khw11044/KT_BIGPRO_RAG
위 깃헙링크에서 코드를 따라하면 되겠다.
해당 게시물에서는 빈 프로젝트 폴더에서 시작해서 하나하나 코딩을 해본다.
네번째 게시물은 RAG pipeline에서 검색창의 검색어 기반 게시물을 만들어 낸다.
1. API Key를 로드한다.
# API KEY를 환경변수로 관리하기 위한 설정 파일
from dotenv import load_dotenv
# API KEY 정보로드
load_dotenv()
2. 지금까지 필요했던 라이브러리 다 불러오기
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
# retriever
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain.retrievers import EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
import pickle
from langchain.chains import LLMChain, create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain.schema import HumanMessage
from utils.redis_utils import save_message_to_redis, get_messages_from_redis
from utils.prompt import *
from utils.config import config, metadata_field_info
from utils.update import convert_file_to_documents
3. 주어진 검색어를 통해 여러 게시물 제목을 만들고 그 제목을 바탕으로 게시물을 만들자.
아래 코드를 작성해본다.
# title_generator_prompt와 llm을 연결하고 리트리버랑 연결
def init_title_chain(self):
question_answer_chain = create_stuff_documents_chain(self.llm, title_generator_prompt)
rag_title_chain = create_retrieval_chain(self.retriever, question_answer_chain)
return rag_title_chain
# post_generator_prompt와 llm을 연결하고 리트리버랑 연결
def init_post_chain(self):
question_answer_chain = create_stuff_documents_chain(self.llm, post_generator_prompt)
rag_text_chain = create_retrieval_chain(self.mq_ensemble_retriever, question_answer_chain)
return rag_text_chain
# 사용자 쿼리에 대한 게시물 제목 생성
def title_generation(self, question: str):
response = self.title_chain.invoke({'input': question})
return response
# 생성된 게시물 제목에 대한 게시물 생성
def post_generation(self, question: str):
response = self.post_chain.invoke({'input': question})
return response
llm은 title 프롬프트와 연결되고, 리트리버랑 연결이 된다. 그럼, 우리가 가진 데이터 기반 게시물 제목이 잘 만들어 질 것이다.
(근데 반대로, 굳이 리트리버랑 연결 할 필요가 있나 싶다. 그냥 llm과 프롬프트만 연결해도 되지 않나?)
post chain 역시 title chain과 같은 구조의 코드이다. 그리고 이런 체인을 미리 만들어 두었다가 question일 들어올때 마다 invoke를 해준다. 질문이 들어올때마다 체인을 걸어주지 않고 미리 걸었다가 질문에 대한 답변을 만들어 낸다.
최종 코드 및 확인
class Ragpipeline:
def __init__(self):
# chatGPT API를 통해 llm 모델 로드
self.llm = ChatOpenAI(
model=config["llm_predictor"]["model_name"], # chatgpt 모델 이름
temperature=config["llm_predictor"]["temperature"], # 창의성 0~1
)
# 초기화 리스트들
self.vector_store = self.init_vectorDB()
self.retriever = self.init_retriever()
self.bm25_retriever = self.init_bm25_retriever()
self.ensemble_retriever = self.init_ensemble_retriever()
self.mq_ensemble_retriever = self.init_mq_ensemble_retriever()
self.chain = self.init_chat_chain()
self.title_chain = self.init_title_chain()
self.post_chain = self.init_post_chain()
self.session_histories = {}
self.current_user_email = None
self.current_session_id = None
def init_vectorDB(self, persist_directory=config["chroma"]["persist_dir"]):
"""vectorDB 설정"""
embeddings = OpenAIEmbeddings(model=config["embed_model"]["model_name"])
vector_store = Chroma(
persist_directory=persist_directory,
embedding_function=embeddings,
collection_name = 'india',
collection_metadata = {'hnsw:space': 'cosine'},
)
return vector_store
# --1. 리트리버 ---------------------------------------------------------------------------------------------------------------------------------
def init_retriever(self):
# base retriever 3
retriever = self.vector_store.as_retriever(
search_type="mmr",
search_kwargs={'fetch_k': 5, "k": 2, 'lambda_mult': 0.4},
)
return retriever
def init_bm25_retriever(self):
all_docs = pickle.load(open(config["pkl_path"], 'rb'))
bm25_retriever = BM25Retriever.from_documents(all_docs)
bm25_retriever.k = 1
return bm25_retriever
def init_ensemble_retriever(self):
ensemble_retriever = EnsembleRetriever(
retrievers=[self.bm25_retriever, self.retriever],
weights=[0.4, 0.6],
search_type=config["ensemble_search_type"], # mmr
)
return ensemble_retriever
def init_mq_ensemble_retriever(self):
mq_ensemble_retriever = MultiQueryRetriever.from_llm(
llm=self.llm,
retriever=self.ensemble_retriever
)
return mq_ensemble_retriever
# --2. 생성 chain 초기화 ---------------------------------------------------------------------------------------------------------------------------------
def init_chat_chain(self):
history_aware_retriever = create_history_aware_retriever(self.llm, self.mq_ensemble_retriever, contextualize_q_prompt) # self.mq_ensemble_retriever
question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt)
rag_chat_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
return rag_chat_chain
def init_title_chain(self):
question_answer_chain = create_stuff_documents_chain(self.llm, title_generator_prompt)
rag_title_chain = create_retrieval_chain(self.retriever, question_answer_chain)
return rag_title_chain
def init_post_chain(self):
question_answer_chain = create_stuff_documents_chain(self.llm, post_generator_prompt)
rag_text_chain = create_retrieval_chain(self.mq_ensemble_retriever, question_answer_chain)
return rag_text_chain
# --3. 생성 결과 출력 ---------------------------------------------------------------------------------------------------------------------------------
def chat_generation(self, question: str) -> dict:
def get_session_history(session_id=None, user_email=None):
session_id = session_id if session_id else self.current_session_id
user_email = user_email if user_email else self.current_user_email
if session_id not in self.session_histories:
self.session_histories[session_id] = ChatMessageHistory()
# Redis에서 세션 히스토리 불러오기
history_messages = get_messages_from_redis(user_email, session_id)
for message in history_messages:
self.session_histories[session_id].add_message(HumanMessage(content=message))
return self.session_histories[session_id]
conversational_rag_chain = RunnableWithMessageHistory(
self.chain, # 실행할 Runnable 객체
get_session_history, # 세션 기록을 가져오는 함수
input_messages_key="input", # 입력 메시지의 키
history_messages_key="chat_history", # 기록 메시지의 키
output_messages_key="answer" # 출력 메시지의 키
)
response = conversational_rag_chain.invoke(
{"input": question},
config={"configurable": {"session_id": self.current_session_id}}
)
# Redis에 세션 히스토리 저장
save_message_to_redis(self.current_user_email, self.current_session_id, question)
save_message_to_redis(self.current_user_email, self.current_session_id, response["answer"])
return response
def title_generation(self, question: str):
response = self.title_chain.invoke({'input': question})
return response
def post_generation(self, question: str):
response = self.post_chain.invoke({'input': question})
return response
1. 게시물 제목 생성
pipeline = Ragpipeline()
question = '인도 통관'
titles = pipeline.title_generation(question) # 제목 개수 정할 수 있음
print(titles)
print(titles['answer'])
2. 생성한 게시물 제목을 바탕으로 게시물 생성하기
title = titles['answer'].split('\n')[0].split('.')[-1]
post = pipeline.post_generation(title) # 제목 개수 정할 수 있음
print(post)
print(post['answer'])
다음 5번째 게시물을 마지막으로 한다. 5번째 게시물은 RAGAS를 이용하여 실험을 진행해본다.
'Project' 카테고리의 다른 글
LoGO 해외로고 프로젝트 - RAG 5: RAGAS로 성능 평가하기 (0) | 2024.08.06 |
---|---|
Chroma DB 폴더 및 파일 구조 (0) | 2024.08.01 |
LoGO 해외로고 프로젝트 - RAG 3. (0) | 2024.07.31 |
LoGO 해외로고 프로젝트 - RAG 2. (0) | 2024.07.31 |
LoGO 해외로고 프로젝트 - RAG 1. (0) | 2024.07.31 |
Contents
소중한 공감 감사합니다