Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import random | |
from functools import cache | |
from operator import itemgetter | |
import langsmith | |
from langchain.memory import ConversationBufferWindowMemory | |
from langchain.retrievers import EnsembleRetriever | |
from langchain_community.document_transformers import LongContextReorder | |
from langchain_core.documents import Document | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnableLambda | |
from langchain_openai.chat_models import ChatOpenAI | |
from .prompt_template import generate_prompt_template | |
from .retrievers_setup import ( | |
DenseRetrieverClient, | |
SparseRetrieverClient, | |
compression_retriever_setup, | |
multi_query_retriever_setup, | |
) | |
# Helpers | |
def reorder_documents(docs: list[Document]) -> list[Document]: | |
"""Reorder documents to mitigate performance degradation with long contexts.""" | |
return LongContextReorder().transform_documents(docs) | |
def randomize_documents(documents: list[Document]) -> list[Document]: | |
"""Randomize documents to vary model recommendations.""" | |
random.shuffle(documents) | |
return documents | |
class DocumentFormatter: | |
def __init__(self, prefix: str): | |
self.prefix = prefix | |
def __call__(self, docs: list[Document]) -> str: | |
"""Format the Documents to markdown. | |
Args: | |
docs (list[Documents]): List of Langchain documents | |
Returns: | |
docs (str): | |
""" | |
return f"\n---\n".join( | |
[ | |
f"- {self.prefix} {i+1}:\n\n\t" + d.page_content | |
for i, d in enumerate(docs) | |
] | |
) | |
def create_langsmith_client(): | |
"""Create a Langsmith client.""" | |
os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
os.environ["LANGCHAIN_PROJECT"] = "talltree-ai-assistant" | |
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com" | |
langsmith_api_key = os.getenv("LANGCHAIN_API_KEY") | |
if not langsmith_api_key: | |
raise EnvironmentError("Missing environment variable: LANGCHAIN_API_KEY") | |
return langsmith.Client() | |
# Set up Runnable and Memory | |
def get_rag_chain( | |
model_name: str = "gpt-4", temperature: float = 0.2 | |
) -> tuple[ChatOpenAI, ConversationBufferWindowMemory]: | |
"""Set up runnable and chat memory | |
Args: | |
model_name (str, optional): LLM model. Defaults to "gpt-4" 30012024. | |
temperature (float, optional): Model temperature. Defaults to 0.2. | |
Returns: | |
Runnable, Memory: Chain and Memory | |
""" | |
RETRIEVER_PARAMETERS = { | |
"embeddings_model": "text-embedding-3-small", | |
"k_dense_practitioners_db": 8, | |
"k_sparse_practitioners_db": 15, | |
"weights_ensemble_practitioners_db": [0.2, 0.8], | |
"k_compression_practitioners_db": 12, | |
"k_dense_talltree": 6, | |
"k_compression_talltree": 6, | |
} | |
# Set up Langsmith to trace the chain | |
langsmith_tracing = create_langsmith_client() | |
# LLM and prompt template | |
llm = ChatOpenAI( | |
model_name=model_name, | |
temperature=temperature, | |
) | |
prompt = generate_prompt_template() | |
# Set retrievers pointing to the practitioners's dataset | |
dense_retriever_client = DenseRetrieverClient( | |
embeddings_model=RETRIEVER_PARAMETERS["embeddings_model"], | |
collection_name="practitioners_db", | |
search_type="similarity", | |
k=RETRIEVER_PARAMETERS["k_dense_practitioners_db"], | |
) # k x 4 using multiquery retriever | |
# Qdrant db as a retriever | |
practitioners_db_dense_retriever = dense_retriever_client.get_dense_retriever() | |
# Multiquery retriever using the dense retriever | |
# This retriever can be passed or not to the EnsembleRetriever. It uses GPT-3.5-turbo. | |
practitioners_db_dense_multiquery_retriever = multi_query_retriever_setup( | |
practitioners_db_dense_retriever | |
) | |
# Sparse vector retriever | |
sparse_retriever_client = SparseRetrieverClient( | |
collection_name="practitioners_db_sparse_collection", | |
vector_name="sparse_vector", | |
splade_model_id="naver/splade-cocondenser-ensembledistil", | |
k=RETRIEVER_PARAMETERS["k_sparse_practitioners_db"], | |
) | |
practitioners_db_sparse_retriever = sparse_retriever_client.get_sparse_retriever() | |
# Ensemble retriever for hyprid search (dense retriever seems to work better but the dense retriever is good for acronyms like RMT) | |
practitioners_ensemble_retriever = EnsembleRetriever( | |
retrievers=[ | |
practitioners_db_dense_retriever, | |
practitioners_db_sparse_retriever, | |
], | |
weights=RETRIEVER_PARAMETERS["weights_ensemble_practitioners_db"], | |
) | |
# Compression retriever for practitioners db | |
practitioners_db_compression_retriever = compression_retriever_setup( | |
practitioners_ensemble_retriever, | |
embeddings_model=RETRIEVER_PARAMETERS["embeddings_model"], | |
k=RETRIEVER_PARAMETERS["k_compression_practitioners_db"], | |
) | |
# Set retrievers pointing to the tall_tree_db | |
dense_retriever_client = DenseRetrieverClient( | |
embeddings_model=RETRIEVER_PARAMETERS["embeddings_model"], | |
collection_name="tall_tree_db", | |
search_type="similarity", | |
k=RETRIEVER_PARAMETERS["k_dense_talltree"], | |
) | |
tall_tree_db_dense_retriever = dense_retriever_client.get_dense_retriever() | |
# Compression retriever for tall_tree_db | |
tall_tree_db_compression_retriever = compression_retriever_setup( | |
tall_tree_db_dense_retriever, | |
embeddings_model=RETRIEVER_PARAMETERS["embeddings_model"], | |
k=RETRIEVER_PARAMETERS["k_compression_talltree"], | |
) | |
# Set conversation history window memory. It only uses the last k interactions. | |
memory = ConversationBufferWindowMemory( | |
memory_key="history", | |
return_messages=True, | |
k=6, | |
) | |
# Set up runnable using LCEL | |
setup_and_retrieval = { | |
"practitioners_db": itemgetter("message") | |
| practitioners_db_compression_retriever | |
| DocumentFormatter("Practitioner #"), | |
"tall_tree_db": itemgetter("message") | |
| tall_tree_db_dense_retriever | |
| DocumentFormatter("No."), | |
"history": RunnableLambda(memory.load_memory_variables) | itemgetter("history"), | |
"message": itemgetter("message"), | |
} | |
chain = setup_and_retrieval | prompt | llm | StrOutputParser() | |
return chain, memory | |