Docs_QA_ColBERT_DSPy / dspy_utils.py
Didier Guillevic
Initial commit
1c18375
raw
history blame
3.18 kB
""" dspy_utils.py
Utilities for building a DSPy based retrieval (augmented) generation model.
:author: Didier Guillevic
:email: [email protected]
:creation: 2024-12-21
"""
import os
import dspy
from ragatouille import RAGPretrainedModel
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class DSPyRagModel:
def __init__(self, retrieval_model: RAGPretrainedModel):
# Init the retrieval and language model
self.retrieval_model = retrieval_model
self.language_model = dspy.LM(model="mistral/mistral-large-latest", api_key=os.environ["MISTRAL_API_KEY"])
# Set dspy retrieval and language model
dspy.settings.configure(
lm=self.language_model,
rm=self.retrieval_model
)
# Set dspy prediction functions
class BasicQA(dspy.Signature):
"""Answer the question given the context provided"""
context = dspy.InputField(desc="may contain relevant facts")
question = dspy.InputField()
answer = dspy.OutputField(desc="Answer the given question.")
self.predict = dspy.Predict(BasicQA, temperature=0.01)
self.predict_chain_of_thought = dspy.ChainOfThought(BasicQA)
def generate_response(
self,
question: str,
k: int=3,
method: str = 'chain_of_thought'
) -> tuple[str, str, str]:
"""Generate a response to a given question using the specified method.
Args:
question: the question to answer
k: number of passages to retrieve
method: method for generating the response: ['simple', 'chain_of_thought']
Returns:
- the generated answer
- (html string): the references (origin of the snippets of text used to generate the answer)
- (html string): the snippets of text used to generate the answer
"""
# Retrieval
retrieval_results = self.retrieval_model.search(query=question, k=k)
passages = [res.get('content') for res in retrieval_results]
metadatas = [res.get('document_metadata') for res in retrieval_results]
# Generate response given retrieved passages
if method == 'simple':
response = self.predict(context=passages, question=question).answer
elif method == 'chain_of_thought':
response = self.predict_chain_of_thought(context=passages, question=question).answer
else:
raise ValueError(f"Unknown method: {method}. Expected ['simple', 'chain_of_thought']")
# Create an HTML string with the references
references = "<h4>References</h4>\n" + create_bulleted_list(metadatas)
snippets = "<h4>Snippets</h4>\n" + create_bulleted_list(passages)
return response, references, snippets
def create_bulleted_list(texts: list[str]) -> str:
"""
This function takes a list of strings and returns HTML with a bulleted list.
"""
html_items = []
for item in texts:
html_items.append(f"<li>{item}</li>")
return "<ul>" + "".join(html_items) + "</ul>"