Spaces:
Running
Running
Didier Guillevic
commited on
Commit
·
1c18375
1
Parent(s):
3d4db34
Initial commit
Browse files- app.py +118 -0
- colbert_utils.py +44 -0
- dspy_utils.py +87 -0
- pdf_utils.py +180 -0
- requirements.txt +7 -0
app.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" app.py
|
2 |
+
|
3 |
+
Question / answer over a collection of PDF documents using late interaction
|
4 |
+
ColBERT model for retrieval and DSPy+Mistral for answer generation.
|
5 |
+
|
6 |
+
:author: Didier Guillevic
|
7 |
+
:date: 2024-12-22
|
8 |
+
"""
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
import logging
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
logging.basicConfig(level=logging.INFO)
|
15 |
+
|
16 |
+
import os
|
17 |
+
import pdf_utils # utilities for pdf processing
|
18 |
+
import colbert_utils # utilities for to build a ColBERT retrieval model
|
19 |
+
import dspy_utils # utilities for building a DSPy based retrieval generation model
|
20 |
+
|
21 |
+
from tqdm.notebook import tqdm
|
22 |
+
import warnings
|
23 |
+
warnings.filterwarnings('ignore')
|
24 |
+
|
25 |
+
|
26 |
+
def generate_response(question: str) -> list[str, str, str]:
|
27 |
+
"""Generate a response to a given question using the RAG model.
|
28 |
+
|
29 |
+
"""
|
30 |
+
global dspy_rag_model
|
31 |
+
|
32 |
+
if dspy_rag_model is None:
|
33 |
+
return "RAG model not built. Please build the model first."
|
34 |
+
|
35 |
+
# Generate response
|
36 |
+
responses, references, snippets = dspy_rag_model.generate_response(
|
37 |
+
question=question, k=5, method='chain_of_thought')
|
38 |
+
|
39 |
+
return responses, references, snippets
|
40 |
+
|
41 |
+
|
42 |
+
with gr.Blocks() as demo:
|
43 |
+
gr.Markdown("""
|
44 |
+
# Retrieval (ColBERT) + Generation (DSPy & Mistral)
|
45 |
+
Note: building the retrieval model might take a few minutes.
|
46 |
+
""")
|
47 |
+
|
48 |
+
# Input files and build status
|
49 |
+
with gr.Row():
|
50 |
+
upload_files = gr.File(
|
51 |
+
label="Upload PDF files to index", file_count="multiple",
|
52 |
+
value=["OECD_Engaging_with_HNW_individuals_tax_compliance_(2009).pdf",],
|
53 |
+
scale=5)
|
54 |
+
build_status = gr.Textbox(label="Build status", placeholder="", scale=2)
|
55 |
+
|
56 |
+
# button
|
57 |
+
build_button = gr.Button("Build retrieval generation model", variant='primary')
|
58 |
+
|
59 |
+
# Question to answer
|
60 |
+
question = gr.Textbox(
|
61 |
+
label="Question to answer",
|
62 |
+
placeholder="How do tax administrations address aggressive tax planning by HNWIs?"
|
63 |
+
)
|
64 |
+
response = gr.Textbox(
|
65 |
+
label="Response",
|
66 |
+
placeholder=""
|
67 |
+
)
|
68 |
+
with gr.Accordion("References & snippets", open=False):
|
69 |
+
references = gr.HTML(label="References")
|
70 |
+
snippets = gr.HTML(label="Snippets")
|
71 |
+
|
72 |
+
# button
|
73 |
+
response_button = gr.Button("Submit", variant='primary')
|
74 |
+
|
75 |
+
# Example questions given default provided PDF file
|
76 |
+
with gr.Accordion("Sample questions", open=False):
|
77 |
+
gr.Examples(
|
78 |
+
[
|
79 |
+
["What are the tax risks associated with high net worth individuals (HNWIs)?",],
|
80 |
+
["How do tax administrations address aggressive tax planning by HNWIs?",],
|
81 |
+
["How can tax administrations engage with HNWIs to improve tax compliance?",],
|
82 |
+
["What are the benefits of establishing dedicated HNWI units within tax administrations?",],
|
83 |
+
["How can international cooperation help address offshore tax risks associated with HNWIs?",],
|
84 |
+
],
|
85 |
+
inputs=[question,],
|
86 |
+
outputs=[response, references, snippets],
|
87 |
+
fn=generate_response,
|
88 |
+
cache_examples=False,
|
89 |
+
label="Sample questions"
|
90 |
+
)
|
91 |
+
|
92 |
+
# Documentation
|
93 |
+
with gr.Accordion("Documentation", open=False):
|
94 |
+
gr.Markdown("""
|
95 |
+
- What
|
96 |
+
- Retrieval augmented generation (RAG) model based on ColBERT and DSPy.
|
97 |
+
- Retrieval base model: 'antoinelouis/colbert-xm' (multilingual model)
|
98 |
+
- Generation framework: DSPy and Mistral.
|
99 |
+
- How
|
100 |
+
- Upload PDF files to index.
|
101 |
+
- Build the retrieval augmented model (might take a few minutes)
|
102 |
+
- Ask a question to generate a response.
|
103 |
+
""")
|
104 |
+
|
105 |
+
# Click actions
|
106 |
+
build_button.click(
|
107 |
+
fn=build_rag_model,
|
108 |
+
inputs=[upload_files],
|
109 |
+
outputs=[build_status]
|
110 |
+
)
|
111 |
+
response_button.click(
|
112 |
+
fn=generate_response,
|
113 |
+
inputs=[question],
|
114 |
+
outputs=[response, references, snippets]
|
115 |
+
)
|
116 |
+
|
117 |
+
|
118 |
+
demo.launch(show_api=False)
|
colbert_utils.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" colbert_utils.py
|
2 |
+
|
3 |
+
Utilities for building (and using) a ColBERT (retrieval) model.
|
4 |
+
|
5 |
+
:author: Didier Guillevic
|
6 |
+
:email: [email protected]
|
7 |
+
:creation: 2024-12-21
|
8 |
+
"""
|
9 |
+
|
10 |
+
import logging
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
logging.basicConfig(level=logging.INFO)
|
13 |
+
|
14 |
+
from ragatouille import RAGPretrainedModel
|
15 |
+
|
16 |
+
|
17 |
+
def build_colbert_model(
|
18 |
+
documents: list[str],
|
19 |
+
metadatas: list[dict[str, str]],
|
20 |
+
pretrained_model: str='antoinelouis/colbert-xm',
|
21 |
+
index_name: str='colbert_index'
|
22 |
+
) -> RAGPretrainedModel:
|
23 |
+
"""Build a ColBERT model for retrieval.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
documents: list of documents to index
|
27 |
+
metadatas: list of metadata for each document
|
28 |
+
index_name: name of the index built with given documents
|
29 |
+
pretrined_model: name of the pretrained model to use
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
the ColBERT retrieval model built witt the given documents.
|
33 |
+
"""
|
34 |
+
model = RAGPretrainedModel.from_pretrained(pretrained_model)
|
35 |
+
model.index(
|
36 |
+
collection=documents,
|
37 |
+
#document_ids=document_ids, # no unique IDs at the moment
|
38 |
+
document_metadatas=metadatas,
|
39 |
+
index_name=index_name,
|
40 |
+
max_document_length=180,
|
41 |
+
split_documents=True,
|
42 |
+
use_faiss=False # cannot get it to work...
|
43 |
+
)
|
44 |
+
return model
|
dspy_utils.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" dspy_utils.py
|
2 |
+
|
3 |
+
Utilities for building a DSPy based retrieval (augmented) generation model.
|
4 |
+
|
5 |
+
:author: Didier Guillevic
|
6 |
+
:email: [email protected]
|
7 |
+
:creation: 2024-12-21
|
8 |
+
"""
|
9 |
+
|
10 |
+
import os
|
11 |
+
import dspy
|
12 |
+
from ragatouille import RAGPretrainedModel
|
13 |
+
|
14 |
+
import logging
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
logging.basicConfig(level=logging.INFO)
|
17 |
+
|
18 |
+
|
19 |
+
class DSPyRagModel:
|
20 |
+
def __init__(self, retrieval_model: RAGPretrainedModel):
|
21 |
+
|
22 |
+
# Init the retrieval and language model
|
23 |
+
self.retrieval_model = retrieval_model
|
24 |
+
self.language_model = dspy.LM(model="mistral/mistral-large-latest", api_key=os.environ["MISTRAL_API_KEY"])
|
25 |
+
|
26 |
+
# Set dspy retrieval and language model
|
27 |
+
dspy.settings.configure(
|
28 |
+
lm=self.language_model,
|
29 |
+
rm=self.retrieval_model
|
30 |
+
)
|
31 |
+
|
32 |
+
# Set dspy prediction functions
|
33 |
+
class BasicQA(dspy.Signature):
|
34 |
+
"""Answer the question given the context provided"""
|
35 |
+
context = dspy.InputField(desc="may contain relevant facts")
|
36 |
+
question = dspy.InputField()
|
37 |
+
answer = dspy.OutputField(desc="Answer the given question.")
|
38 |
+
|
39 |
+
self.predict = dspy.Predict(BasicQA, temperature=0.01)
|
40 |
+
self.predict_chain_of_thought = dspy.ChainOfThought(BasicQA)
|
41 |
+
|
42 |
+
def generate_response(
|
43 |
+
self,
|
44 |
+
question: str,
|
45 |
+
k: int=3,
|
46 |
+
method: str = 'chain_of_thought'
|
47 |
+
) -> tuple[str, str, str]:
|
48 |
+
"""Generate a response to a given question using the specified method.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
question: the question to answer
|
52 |
+
k: number of passages to retrieve
|
53 |
+
method: method for generating the response: ['simple', 'chain_of_thought']
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
- the generated answer
|
57 |
+
- (html string): the references (origin of the snippets of text used to generate the answer)
|
58 |
+
- (html string): the snippets of text used to generate the answer
|
59 |
+
"""
|
60 |
+
# Retrieval
|
61 |
+
retrieval_results = self.retrieval_model.search(query=question, k=k)
|
62 |
+
passages = [res.get('content') for res in retrieval_results]
|
63 |
+
metadatas = [res.get('document_metadata') for res in retrieval_results]
|
64 |
+
|
65 |
+
# Generate response given retrieved passages
|
66 |
+
if method == 'simple':
|
67 |
+
response = self.predict(context=passages, question=question).answer
|
68 |
+
elif method == 'chain_of_thought':
|
69 |
+
response = self.predict_chain_of_thought(context=passages, question=question).answer
|
70 |
+
else:
|
71 |
+
raise ValueError(f"Unknown method: {method}. Expected ['simple', 'chain_of_thought']")
|
72 |
+
|
73 |
+
# Create an HTML string with the references
|
74 |
+
references = "<h4>References</h4>\n" + create_bulleted_list(metadatas)
|
75 |
+
snippets = "<h4>Snippets</h4>\n" + create_bulleted_list(passages)
|
76 |
+
|
77 |
+
return response, references, snippets
|
78 |
+
|
79 |
+
|
80 |
+
def create_bulleted_list(texts: list[str]) -> str:
|
81 |
+
"""
|
82 |
+
This function takes a list of strings and returns HTML with a bulleted list.
|
83 |
+
"""
|
84 |
+
html_items = []
|
85 |
+
for item in texts:
|
86 |
+
html_items.append(f"<li>{item}</li>")
|
87 |
+
return "<ul>" + "".join(html_items) + "</ul>"
|
pdf_utils.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" pdf_utils.py
|
2 |
+
|
3 |
+
Utilities for working with PDFs
|
4 |
+
|
5 |
+
:author: Didier Guillevic
|
6 |
+
:email: [email protected]
|
7 |
+
:creation: 2024-12-21
|
8 |
+
"""
|
9 |
+
|
10 |
+
import pypdf
|
11 |
+
import os
|
12 |
+
import datetime
|
13 |
+
import pytz
|
14 |
+
|
15 |
+
import logging
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
logging.basicConfig(level=logging.INFO)
|
18 |
+
|
19 |
+
|
20 |
+
def validate_pdf(file_path: str) -> bool:
|
21 |
+
"""Validate that file exists AND is a PDF file)
|
22 |
+
"""
|
23 |
+
if not os.path.exists(file_path):
|
24 |
+
logger.error(f"File not found at path: {file_path}")
|
25 |
+
return False
|
26 |
+
if not file_path.lower().endswith('.pdf'):
|
27 |
+
logger.error("File is not a PDF")
|
28 |
+
return False
|
29 |
+
return True
|
30 |
+
|
31 |
+
|
32 |
+
def get_text_from_pdf(
|
33 |
+
file_path: str,
|
34 |
+
max_chars: int = 100_000_000
|
35 |
+
) -> str:
|
36 |
+
"""Extract the text from a given PDF file.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
file_path: path to the PDF file
|
40 |
+
mac_chars: max length (in chars) to be read from the file
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
the extracted text.
|
44 |
+
"""
|
45 |
+
if not validate_pdf(file_path):
|
46 |
+
return None
|
47 |
+
|
48 |
+
try:
|
49 |
+
with open(file_path, 'rb') as file:
|
50 |
+
# Create PDF reader object
|
51 |
+
pdf_reader = pypdf.PdfReader(file)
|
52 |
+
|
53 |
+
# Get total number of pages
|
54 |
+
num_pages = len(pdf_reader.pages)
|
55 |
+
print(f"Processing PDF with {num_pages} pages...")
|
56 |
+
|
57 |
+
extracted_text = []
|
58 |
+
total_chars = 0
|
59 |
+
|
60 |
+
# Iterate through all pages
|
61 |
+
for page_num in range(num_pages):
|
62 |
+
# Extract text from page
|
63 |
+
page = pdf_reader.pages[page_num]
|
64 |
+
text = page.extract_text()
|
65 |
+
|
66 |
+
# Check if adding this page's text would exceed the limit
|
67 |
+
if total_chars + len(text) > max_chars:
|
68 |
+
# Only add text up to the limit
|
69 |
+
remaining_chars = max_chars - total_chars
|
70 |
+
extracted_text.append(text[:remaining_chars])
|
71 |
+
print(f"Reached {max_chars} character limit at page {page_num + 1}")
|
72 |
+
break
|
73 |
+
|
74 |
+
extracted_text.append(text)
|
75 |
+
total_chars += len(text)
|
76 |
+
print(f"Processed page {page_num + 1}/{num_pages}")
|
77 |
+
|
78 |
+
final_text = '\n'.join(extracted_text)
|
79 |
+
print(f"\nExtraction complete! Total characters: {len(final_text)}")
|
80 |
+
return final_text
|
81 |
+
|
82 |
+
except pypdf.PdfReadError:
|
83 |
+
print("Error: Invalid or corrupted PDF file")
|
84 |
+
return None
|
85 |
+
except Exception as e:
|
86 |
+
print(f"An unexpected error occurred: {str(e)}")
|
87 |
+
return None
|
88 |
+
|
89 |
+
|
90 |
+
def get_pdf_metadata(file_path: str) -> dict:
|
91 |
+
"""Get the metadata of a given PDF file.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
file_path: path to a PDF file
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
dictionary woth the metadata information
|
98 |
+
"""
|
99 |
+
if not validate_pdf(file_path):
|
100 |
+
return None
|
101 |
+
|
102 |
+
try:
|
103 |
+
with open(file_path, 'rb') as file:
|
104 |
+
pdf_reader = pypdf.PdfReader(file)
|
105 |
+
metadata = {
|
106 |
+
'num_pages': len(pdf_reader.pages),
|
107 |
+
'metadata': pdf_reader.metadata
|
108 |
+
}
|
109 |
+
return metadata
|
110 |
+
except Exception as e:
|
111 |
+
print(f"Error extracting metadata: {str(e)}")
|
112 |
+
return None
|
113 |
+
|
114 |
+
|
115 |
+
def get_datetime_from_pdf_metadata(metadata: dict, key: str) -> str:
|
116 |
+
"""Extract a datetime string from the metadata of a PDF file.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
metadata: dictionary with the metadata information
|
120 |
+
key: key to extract the datetime from
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
the datetime string or None if not found
|
124 |
+
"""
|
125 |
+
if key not in metadata:
|
126 |
+
return None
|
127 |
+
|
128 |
+
# Extract the datetime string from data time string used in PDF metadata
|
129 |
+
# e.g. "D:20210714143000+02'00'" -> "2021-07-14 14:30:00"
|
130 |
+
pdf_date_string = metadata[key]
|
131 |
+
|
132 |
+
# Remove the 'D:' prefix and the single quotes around the timezone offset
|
133 |
+
date_string = pdf_date_string[2:]
|
134 |
+
date_string = date_string.replace("'", "")
|
135 |
+
|
136 |
+
# Parse the date and time components
|
137 |
+
date_part = date_string[:8]
|
138 |
+
time_part = date_string[8:14]
|
139 |
+
offset_part = date_string[14:]
|
140 |
+
|
141 |
+
# Create a datetime object
|
142 |
+
dt = datetime.datetime.strptime(date_part + time_part, "%Y%m%d%H%M%S")
|
143 |
+
|
144 |
+
# Handle the timezone offset
|
145 |
+
offset_hours = int(offset_part[1:3])
|
146 |
+
offset_minutes = int(offset_part[3:5])
|
147 |
+
offset = offset_hours * 60 + offset_minutes
|
148 |
+
if offset_part.startswith('+'):
|
149 |
+
offset = -offset
|
150 |
+
|
151 |
+
# Create a timezone object
|
152 |
+
timezone = pytz.FixedOffset(offset)
|
153 |
+
|
154 |
+
# Create a timezone-aware datetime object
|
155 |
+
dt = timezone.localize(dt)
|
156 |
+
|
157 |
+
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
158 |
+
|
159 |
+
|
160 |
+
def get_metadata_info(pdf_path: str) -> dict:
|
161 |
+
"""Build a dictionary with basic and additional information about a PDF file.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
pdf_path: path to the PDF file
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
dictionary with the metadata information
|
168 |
+
"""
|
169 |
+
# basic information about the file
|
170 |
+
metadata_info = {}
|
171 |
+
metadata_info['file_name'] = os.path.basename(pdf_path)
|
172 |
+
|
173 |
+
# additional information about the file
|
174 |
+
pdf_metadata = get_pdf_metadata(pdf_path)
|
175 |
+
if pdf_metadata:
|
176 |
+
metadata_info['num_pages'] = pdf_metadata['num_pages']
|
177 |
+
metadata_info['creation_date'] = get_datetime_from_pdf_metadata(pdf_metadata['metadata'], '/CreationDate')
|
178 |
+
metadata_info['modification_date'] = get_datetime_from_pdf_metadata(pdf_metadata['metadata'], '/ModDate')
|
179 |
+
|
180 |
+
return metadata_info
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
RAGatouille
|
3 |
+
dspy-ai
|
4 |
+
mistralai
|
5 |
+
litellm
|
6 |
+
pypdf
|
7 |
+
pytz
|