""" dspy_utils.py Utilities for building a DSPy based retrieval (augmented) generation model. :author: Didier Guillevic :email: didier@guillevic.net :creation: 2024-12-21 """ import os import dspy from ragatouille import RAGPretrainedModel import logging logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) class DSPyRagModel: def __init__(self, retrieval_model: RAGPretrainedModel): # Init the retrieval and language model self.retrieval_model = retrieval_model self.language_model = dspy.LM(model="mistral/mistral-large-latest", api_key=os.environ["MISTRAL_API_KEY"]) # Set dspy retrieval and language model dspy.settings.configure( lm=self.language_model, rm=self.retrieval_model ) # Set dspy prediction functions class BasicQA(dspy.Signature): """Answer the question given the context provided""" context = dspy.InputField(desc="may contain relevant facts") question = dspy.InputField() answer = dspy.OutputField(desc="Answer the given question.") self.predict = dspy.Predict(BasicQA, temperature=0.01) self.predict_chain_of_thought = dspy.ChainOfThought(BasicQA) def generate_response( self, question: str, k: int=3, method: str = 'chain_of_thought' ) -> tuple[str, str, str]: """Generate a response to a given question using the specified method. Args: question: the question to answer k: number of passages to retrieve method: method for generating the response: ['simple', 'chain_of_thought'] Returns: - the generated answer - (html string): the references (origin of the snippets of text used to generate the answer) - (html string): the snippets of text used to generate the answer """ # Retrieval retrieval_results = self.retrieval_model.search(query=question, k=k) passages = [res.get('content') for res in retrieval_results] metadatas = [res.get('document_metadata') for res in retrieval_results] # Generate response given retrieved passages if method == 'simple': response = self.predict(context=passages, question=question).answer elif method == 'chain_of_thought': response = self.predict_chain_of_thought(context=passages, question=question).answer else: raise ValueError(f"Unknown method: {method}. Expected ['simple', 'chain_of_thought']") # Create an HTML string with the references references = "

References

\n" + create_bulleted_list(metadatas) snippets = "

Snippets

\n" + create_bulleted_list(passages) return response, references, snippets def create_bulleted_list(texts: list[str]) -> str: """ This function takes a list of strings and returns HTML with a bulleted list. """ html_items = [] for item in texts: html_items.append(f"
  • {item}
  • ") return ""