Didier Guillevic commited on
Commit
1c18375
·
1 Parent(s): 3d4db34

Initial commit

Browse files
Files changed (5) hide show
  1. app.py +118 -0
  2. colbert_utils.py +44 -0
  3. dspy_utils.py +87 -0
  4. pdf_utils.py +180 -0
  5. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ app.py
2
+
3
+ Question / answer over a collection of PDF documents using late interaction
4
+ ColBERT model for retrieval and DSPy+Mistral for answer generation.
5
+
6
+ :author: Didier Guillevic
7
+ :date: 2024-12-22
8
+ """
9
+
10
+ import gradio as gr
11
+
12
+ import logging
13
+ logger = logging.getLogger(__name__)
14
+ logging.basicConfig(level=logging.INFO)
15
+
16
+ import os
17
+ import pdf_utils # utilities for pdf processing
18
+ import colbert_utils # utilities for to build a ColBERT retrieval model
19
+ import dspy_utils # utilities for building a DSPy based retrieval generation model
20
+
21
+ from tqdm.notebook import tqdm
22
+ import warnings
23
+ warnings.filterwarnings('ignore')
24
+
25
+
26
+ def generate_response(question: str) -> list[str, str, str]:
27
+ """Generate a response to a given question using the RAG model.
28
+
29
+ """
30
+ global dspy_rag_model
31
+
32
+ if dspy_rag_model is None:
33
+ return "RAG model not built. Please build the model first."
34
+
35
+ # Generate response
36
+ responses, references, snippets = dspy_rag_model.generate_response(
37
+ question=question, k=5, method='chain_of_thought')
38
+
39
+ return responses, references, snippets
40
+
41
+
42
+ with gr.Blocks() as demo:
43
+ gr.Markdown("""
44
+ # Retrieval (ColBERT) + Generation (DSPy & Mistral)
45
+ Note: building the retrieval model might take a few minutes.
46
+ """)
47
+
48
+ # Input files and build status
49
+ with gr.Row():
50
+ upload_files = gr.File(
51
+ label="Upload PDF files to index", file_count="multiple",
52
+ value=["OECD_Engaging_with_HNW_individuals_tax_compliance_(2009).pdf",],
53
+ scale=5)
54
+ build_status = gr.Textbox(label="Build status", placeholder="", scale=2)
55
+
56
+ # button
57
+ build_button = gr.Button("Build retrieval generation model", variant='primary')
58
+
59
+ # Question to answer
60
+ question = gr.Textbox(
61
+ label="Question to answer",
62
+ placeholder="How do tax administrations address aggressive tax planning by HNWIs?"
63
+ )
64
+ response = gr.Textbox(
65
+ label="Response",
66
+ placeholder=""
67
+ )
68
+ with gr.Accordion("References & snippets", open=False):
69
+ references = gr.HTML(label="References")
70
+ snippets = gr.HTML(label="Snippets")
71
+
72
+ # button
73
+ response_button = gr.Button("Submit", variant='primary')
74
+
75
+ # Example questions given default provided PDF file
76
+ with gr.Accordion("Sample questions", open=False):
77
+ gr.Examples(
78
+ [
79
+ ["What are the tax risks associated with high net worth individuals (HNWIs)?",],
80
+ ["How do tax administrations address aggressive tax planning by HNWIs?",],
81
+ ["How can tax administrations engage with HNWIs to improve tax compliance?",],
82
+ ["What are the benefits of establishing dedicated HNWI units within tax administrations?",],
83
+ ["How can international cooperation help address offshore tax risks associated with HNWIs?",],
84
+ ],
85
+ inputs=[question,],
86
+ outputs=[response, references, snippets],
87
+ fn=generate_response,
88
+ cache_examples=False,
89
+ label="Sample questions"
90
+ )
91
+
92
+ # Documentation
93
+ with gr.Accordion("Documentation", open=False):
94
+ gr.Markdown("""
95
+ - What
96
+ - Retrieval augmented generation (RAG) model based on ColBERT and DSPy.
97
+ - Retrieval base model: 'antoinelouis/colbert-xm' (multilingual model)
98
+ - Generation framework: DSPy and Mistral.
99
+ - How
100
+ - Upload PDF files to index.
101
+ - Build the retrieval augmented model (might take a few minutes)
102
+ - Ask a question to generate a response.
103
+ """)
104
+
105
+ # Click actions
106
+ build_button.click(
107
+ fn=build_rag_model,
108
+ inputs=[upload_files],
109
+ outputs=[build_status]
110
+ )
111
+ response_button.click(
112
+ fn=generate_response,
113
+ inputs=[question],
114
+ outputs=[response, references, snippets]
115
+ )
116
+
117
+
118
+ demo.launch(show_api=False)
colbert_utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ colbert_utils.py
2
+
3
+ Utilities for building (and using) a ColBERT (retrieval) model.
4
+
5
+ :author: Didier Guillevic
6
+ :email: [email protected]
7
+ :creation: 2024-12-21
8
+ """
9
+
10
+ import logging
11
+ logger = logging.getLogger(__name__)
12
+ logging.basicConfig(level=logging.INFO)
13
+
14
+ from ragatouille import RAGPretrainedModel
15
+
16
+
17
+ def build_colbert_model(
18
+ documents: list[str],
19
+ metadatas: list[dict[str, str]],
20
+ pretrained_model: str='antoinelouis/colbert-xm',
21
+ index_name: str='colbert_index'
22
+ ) -> RAGPretrainedModel:
23
+ """Build a ColBERT model for retrieval.
24
+
25
+ Args:
26
+ documents: list of documents to index
27
+ metadatas: list of metadata for each document
28
+ index_name: name of the index built with given documents
29
+ pretrined_model: name of the pretrained model to use
30
+
31
+ Returns:
32
+ the ColBERT retrieval model built witt the given documents.
33
+ """
34
+ model = RAGPretrainedModel.from_pretrained(pretrained_model)
35
+ model.index(
36
+ collection=documents,
37
+ #document_ids=document_ids, # no unique IDs at the moment
38
+ document_metadatas=metadatas,
39
+ index_name=index_name,
40
+ max_document_length=180,
41
+ split_documents=True,
42
+ use_faiss=False # cannot get it to work...
43
+ )
44
+ return model
dspy_utils.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ dspy_utils.py
2
+
3
+ Utilities for building a DSPy based retrieval (augmented) generation model.
4
+
5
+ :author: Didier Guillevic
6
+ :email: [email protected]
7
+ :creation: 2024-12-21
8
+ """
9
+
10
+ import os
11
+ import dspy
12
+ from ragatouille import RAGPretrainedModel
13
+
14
+ import logging
15
+ logger = logging.getLogger(__name__)
16
+ logging.basicConfig(level=logging.INFO)
17
+
18
+
19
+ class DSPyRagModel:
20
+ def __init__(self, retrieval_model: RAGPretrainedModel):
21
+
22
+ # Init the retrieval and language model
23
+ self.retrieval_model = retrieval_model
24
+ self.language_model = dspy.LM(model="mistral/mistral-large-latest", api_key=os.environ["MISTRAL_API_KEY"])
25
+
26
+ # Set dspy retrieval and language model
27
+ dspy.settings.configure(
28
+ lm=self.language_model,
29
+ rm=self.retrieval_model
30
+ )
31
+
32
+ # Set dspy prediction functions
33
+ class BasicQA(dspy.Signature):
34
+ """Answer the question given the context provided"""
35
+ context = dspy.InputField(desc="may contain relevant facts")
36
+ question = dspy.InputField()
37
+ answer = dspy.OutputField(desc="Answer the given question.")
38
+
39
+ self.predict = dspy.Predict(BasicQA, temperature=0.01)
40
+ self.predict_chain_of_thought = dspy.ChainOfThought(BasicQA)
41
+
42
+ def generate_response(
43
+ self,
44
+ question: str,
45
+ k: int=3,
46
+ method: str = 'chain_of_thought'
47
+ ) -> tuple[str, str, str]:
48
+ """Generate a response to a given question using the specified method.
49
+
50
+ Args:
51
+ question: the question to answer
52
+ k: number of passages to retrieve
53
+ method: method for generating the response: ['simple', 'chain_of_thought']
54
+
55
+ Returns:
56
+ - the generated answer
57
+ - (html string): the references (origin of the snippets of text used to generate the answer)
58
+ - (html string): the snippets of text used to generate the answer
59
+ """
60
+ # Retrieval
61
+ retrieval_results = self.retrieval_model.search(query=question, k=k)
62
+ passages = [res.get('content') for res in retrieval_results]
63
+ metadatas = [res.get('document_metadata') for res in retrieval_results]
64
+
65
+ # Generate response given retrieved passages
66
+ if method == 'simple':
67
+ response = self.predict(context=passages, question=question).answer
68
+ elif method == 'chain_of_thought':
69
+ response = self.predict_chain_of_thought(context=passages, question=question).answer
70
+ else:
71
+ raise ValueError(f"Unknown method: {method}. Expected ['simple', 'chain_of_thought']")
72
+
73
+ # Create an HTML string with the references
74
+ references = "<h4>References</h4>\n" + create_bulleted_list(metadatas)
75
+ snippets = "<h4>Snippets</h4>\n" + create_bulleted_list(passages)
76
+
77
+ return response, references, snippets
78
+
79
+
80
+ def create_bulleted_list(texts: list[str]) -> str:
81
+ """
82
+ This function takes a list of strings and returns HTML with a bulleted list.
83
+ """
84
+ html_items = []
85
+ for item in texts:
86
+ html_items.append(f"<li>{item}</li>")
87
+ return "<ul>" + "".join(html_items) + "</ul>"
pdf_utils.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ pdf_utils.py
2
+
3
+ Utilities for working with PDFs
4
+
5
+ :author: Didier Guillevic
6
+ :email: [email protected]
7
+ :creation: 2024-12-21
8
+ """
9
+
10
+ import pypdf
11
+ import os
12
+ import datetime
13
+ import pytz
14
+
15
+ import logging
16
+ logger = logging.getLogger(__name__)
17
+ logging.basicConfig(level=logging.INFO)
18
+
19
+
20
+ def validate_pdf(file_path: str) -> bool:
21
+ """Validate that file exists AND is a PDF file)
22
+ """
23
+ if not os.path.exists(file_path):
24
+ logger.error(f"File not found at path: {file_path}")
25
+ return False
26
+ if not file_path.lower().endswith('.pdf'):
27
+ logger.error("File is not a PDF")
28
+ return False
29
+ return True
30
+
31
+
32
+ def get_text_from_pdf(
33
+ file_path: str,
34
+ max_chars: int = 100_000_000
35
+ ) -> str:
36
+ """Extract the text from a given PDF file.
37
+
38
+ Args:
39
+ file_path: path to the PDF file
40
+ mac_chars: max length (in chars) to be read from the file
41
+
42
+ Returns:
43
+ the extracted text.
44
+ """
45
+ if not validate_pdf(file_path):
46
+ return None
47
+
48
+ try:
49
+ with open(file_path, 'rb') as file:
50
+ # Create PDF reader object
51
+ pdf_reader = pypdf.PdfReader(file)
52
+
53
+ # Get total number of pages
54
+ num_pages = len(pdf_reader.pages)
55
+ print(f"Processing PDF with {num_pages} pages...")
56
+
57
+ extracted_text = []
58
+ total_chars = 0
59
+
60
+ # Iterate through all pages
61
+ for page_num in range(num_pages):
62
+ # Extract text from page
63
+ page = pdf_reader.pages[page_num]
64
+ text = page.extract_text()
65
+
66
+ # Check if adding this page's text would exceed the limit
67
+ if total_chars + len(text) > max_chars:
68
+ # Only add text up to the limit
69
+ remaining_chars = max_chars - total_chars
70
+ extracted_text.append(text[:remaining_chars])
71
+ print(f"Reached {max_chars} character limit at page {page_num + 1}")
72
+ break
73
+
74
+ extracted_text.append(text)
75
+ total_chars += len(text)
76
+ print(f"Processed page {page_num + 1}/{num_pages}")
77
+
78
+ final_text = '\n'.join(extracted_text)
79
+ print(f"\nExtraction complete! Total characters: {len(final_text)}")
80
+ return final_text
81
+
82
+ except pypdf.PdfReadError:
83
+ print("Error: Invalid or corrupted PDF file")
84
+ return None
85
+ except Exception as e:
86
+ print(f"An unexpected error occurred: {str(e)}")
87
+ return None
88
+
89
+
90
+ def get_pdf_metadata(file_path: str) -> dict:
91
+ """Get the metadata of a given PDF file.
92
+
93
+ Args:
94
+ file_path: path to a PDF file
95
+
96
+ Returns:
97
+ dictionary woth the metadata information
98
+ """
99
+ if not validate_pdf(file_path):
100
+ return None
101
+
102
+ try:
103
+ with open(file_path, 'rb') as file:
104
+ pdf_reader = pypdf.PdfReader(file)
105
+ metadata = {
106
+ 'num_pages': len(pdf_reader.pages),
107
+ 'metadata': pdf_reader.metadata
108
+ }
109
+ return metadata
110
+ except Exception as e:
111
+ print(f"Error extracting metadata: {str(e)}")
112
+ return None
113
+
114
+
115
+ def get_datetime_from_pdf_metadata(metadata: dict, key: str) -> str:
116
+ """Extract a datetime string from the metadata of a PDF file.
117
+
118
+ Args:
119
+ metadata: dictionary with the metadata information
120
+ key: key to extract the datetime from
121
+
122
+ Returns:
123
+ the datetime string or None if not found
124
+ """
125
+ if key not in metadata:
126
+ return None
127
+
128
+ # Extract the datetime string from data time string used in PDF metadata
129
+ # e.g. "D:20210714143000+02'00'" -> "2021-07-14 14:30:00"
130
+ pdf_date_string = metadata[key]
131
+
132
+ # Remove the 'D:' prefix and the single quotes around the timezone offset
133
+ date_string = pdf_date_string[2:]
134
+ date_string = date_string.replace("'", "")
135
+
136
+ # Parse the date and time components
137
+ date_part = date_string[:8]
138
+ time_part = date_string[8:14]
139
+ offset_part = date_string[14:]
140
+
141
+ # Create a datetime object
142
+ dt = datetime.datetime.strptime(date_part + time_part, "%Y%m%d%H%M%S")
143
+
144
+ # Handle the timezone offset
145
+ offset_hours = int(offset_part[1:3])
146
+ offset_minutes = int(offset_part[3:5])
147
+ offset = offset_hours * 60 + offset_minutes
148
+ if offset_part.startswith('+'):
149
+ offset = -offset
150
+
151
+ # Create a timezone object
152
+ timezone = pytz.FixedOffset(offset)
153
+
154
+ # Create a timezone-aware datetime object
155
+ dt = timezone.localize(dt)
156
+
157
+ return dt.strftime("%Y-%m-%d %H:%M:%S")
158
+
159
+
160
+ def get_metadata_info(pdf_path: str) -> dict:
161
+ """Build a dictionary with basic and additional information about a PDF file.
162
+
163
+ Args:
164
+ pdf_path: path to the PDF file
165
+
166
+ Returns:
167
+ dictionary with the metadata information
168
+ """
169
+ # basic information about the file
170
+ metadata_info = {}
171
+ metadata_info['file_name'] = os.path.basename(pdf_path)
172
+
173
+ # additional information about the file
174
+ pdf_metadata = get_pdf_metadata(pdf_path)
175
+ if pdf_metadata:
176
+ metadata_info['num_pages'] = pdf_metadata['num_pages']
177
+ metadata_info['creation_date'] = get_datetime_from_pdf_metadata(pdf_metadata['metadata'], '/CreationDate')
178
+ metadata_info['modification_date'] = get_datetime_from_pdf_metadata(pdf_metadata['metadata'], '/ModDate')
179
+
180
+ return metadata_info
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ RAGatouille
3
+ dspy-ai
4
+ mistralai
5
+ litellm
6
+ pypdf
7
+ pytz