Spaces:
Running
Running
""" dspy_utils.py | |
Utilities for building a DSPy based retrieval (augmented) generation model. | |
:author: Didier Guillevic | |
:email: [email protected] | |
:creation: 2024-12-21 | |
""" | |
import os | |
import dspy | |
from ragatouille import RAGPretrainedModel | |
import logging | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
class DSPyRagModel: | |
def __init__(self, retrieval_model: RAGPretrainedModel): | |
# Init the retrieval and language model | |
self.retrieval_model = retrieval_model | |
self.language_model = dspy.LM(model="mistral/mistral-large-latest", api_key=os.environ["MISTRAL_API_KEY"]) | |
# Set dspy retrieval and language model | |
dspy.settings.configure( | |
lm=self.language_model, | |
rm=self.retrieval_model | |
) | |
# Set dspy prediction functions | |
class BasicQA(dspy.Signature): | |
"""Answer the question given the context provided""" | |
context = dspy.InputField(desc="may contain relevant facts") | |
question = dspy.InputField() | |
answer = dspy.OutputField(desc="Answer the given question.") | |
self.predict = dspy.Predict(BasicQA, temperature=0.01) | |
self.predict_chain_of_thought = dspy.ChainOfThought(BasicQA) | |
def generate_response( | |
self, | |
question: str, | |
k: int=3, | |
method: str = 'chain_of_thought' | |
) -> tuple[str, str, str]: | |
"""Generate a response to a given question using the specified method. | |
Args: | |
question: the question to answer | |
k: number of passages to retrieve | |
method: method for generating the response: ['simple', 'chain_of_thought'] | |
Returns: | |
- the generated answer | |
- (html string): the references (origin of the snippets of text used to generate the answer) | |
- (html string): the snippets of text used to generate the answer | |
""" | |
# Retrieval | |
retrieval_results = self.retrieval_model.search(query=question, k=k) | |
passages = [res.get('content') for res in retrieval_results] | |
metadatas = [res.get('document_metadata') for res in retrieval_results] | |
# Generate response given retrieved passages | |
if method == 'simple': | |
response = self.predict(context=passages, question=question).answer | |
elif method == 'chain_of_thought': | |
response = self.predict_chain_of_thought(context=passages, question=question).answer | |
else: | |
raise ValueError(f"Unknown method: {method}. Expected ['simple', 'chain_of_thought']") | |
# Create an HTML string with the references | |
references = "<h4>References</h4>\n" + create_bulleted_list(metadatas) | |
snippets = "<h4>Snippets</h4>\n" + create_bulleted_list(passages) | |
return response, references, snippets | |
def create_bulleted_list(texts: list[str]) -> str: | |
""" | |
This function takes a list of strings and returns HTML with a bulleted list. | |
""" | |
html_items = [] | |
for item in texts: | |
html_items.append(f"<li>{item}</li>") | |
return "<ul>" + "".join(html_items) + "</ul>" | |