Spaces:
Sleeping
Sleeping
Update app.py
#1
by
zetta-brandon-etocha
- opened
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
import os
|
2 |
import PyPDF2
|
3 |
-
import pandas as pd
|
4 |
import warnings
|
5 |
import re
|
6 |
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
|
@@ -8,115 +8,123 @@ from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
|
|
8 |
import torch
|
9 |
import gradio as gr
|
10 |
from typing import Union
|
11 |
-
|
|
|
|
|
|
|
|
|
12 |
warnings.filterwarnings("ignore")
|
13 |
|
|
|
|
|
|
|
|
|
14 |
|
|
|
|
|
|
|
|
|
15 |
|
|
|
16 |
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
17 |
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
18 |
|
19 |
-
|
20 |
q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
|
21 |
q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
different line""" # XD
|
28 |
-
# creating a pdf file object
|
29 |
-
df = pd.DataFrame(columns = ["title","text"])
|
30 |
-
if type(parent_dir) == str :
|
31 |
parent_dir = [parent_dir]
|
32 |
for file_path in parent_dir:
|
33 |
-
if ".pdf" not in file_path
|
34 |
raise Exception("only pdf files are supported")
|
35 |
-
# creating a pdf file object
|
36 |
pdfFileObj = open(file_path, 'rb')
|
37 |
-
|
38 |
-
# creating a pdf reader object
|
39 |
pdfReader = PyPDF2.PdfReader(pdfFileObj)
|
40 |
-
# printing number of pages in pdf file
|
41 |
num_pages = len(pdfReader.pages)
|
42 |
-
for i in range(num_pages)
|
43 |
pageObj = pdfReader.pages[i]
|
44 |
-
|
45 |
-
txt =
|
46 |
-
txt = txt.replace("\n","") # strip return to line
|
47 |
-
txt = txt.replace("\t","") # strip tabs
|
48 |
-
txt = re.sub(r" +"," ",txt) # strip extra space
|
49 |
-
# 512 is related to the positional encoding "facebook/dpr-ctx_encoder-single-nq-base" model
|
50 |
file_name = file_path.split("/")[-1]
|
51 |
-
if len(txt) < 512
|
52 |
-
new_data = pd.DataFrame([[f"{file_name}-page-{i}",txt]],columns=["title","text"])
|
53 |
-
df = pd.concat([df,new_data],ignore_index=True)
|
54 |
-
else
|
55 |
-
while len(txt) > 512
|
56 |
-
new_data = pd.DataFrame([[f"{file_name}-page-{i}",txt[:512]]],columns=["title","text"])
|
57 |
-
df = pd.concat([df,new_data],ignore_index=True)
|
58 |
txt = txt[512:]
|
59 |
-
|
60 |
-
# closing the pdf file object
|
61 |
pdfFileObj.close()
|
62 |
return df
|
63 |
|
64 |
-
def process(example):
|
65 |
-
"""process the bathces of the dataset and returns the embeddings"""
|
66 |
-
try :
|
67 |
-
tokens = ctx_tokenizer(example["text"], return_tensors="pt")
|
68 |
-
embed = ctx_encoder(**tokens)[0][0].detach().numpy()
|
69 |
-
return {'embeddings': embed}
|
70 |
-
except Exception as e:
|
71 |
-
raise Exception(f"error in process: {e}")
|
72 |
-
|
73 |
def process_dataset(df):
|
74 |
-
"""
|
75 |
-
if len(df) == 0
|
76 |
raise Exception("empty pdf files, or can't read text from them")
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
tokens = q_tokenizer(query, return_tensors="pt")
|
86 |
-
query_embed = q_encoder(**tokens)[0][0].detach().numpy()
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
"""
|
90 |
except Exception as e:
|
91 |
out = f"error in search: {e}"
|
92 |
return out
|
93 |
-
|
94 |
-
def predict(query,file_paths, k=3):
|
95 |
-
"""
|
96 |
-
try
|
97 |
df = process_pdfs(file_paths)
|
98 |
-
|
99 |
-
out =
|
100 |
except Exception as e:
|
101 |
out = f"error in predict: {e}"
|
102 |
return out
|
103 |
|
104 |
-
|
|
|
105 |
gr.Markdown("<h1 style='text-align: center'> PDF Search Engine </h1>")
|
106 |
with gr.Row():
|
107 |
with gr.Column():
|
108 |
-
files = gr.Files(label="Upload PDFs",type="filepath",file_count="multiple")
|
109 |
query = gr.Text(label="query")
|
110 |
-
with gr.Accordion("number of references",open=False):
|
111 |
-
k = gr.Number(value=3,show_label=False,precision=0,minimum=1,container=False)
|
112 |
button = gr.Button("search")
|
113 |
with gr.Column():
|
114 |
output = gr.Markdown(label="output")
|
115 |
-
button.click(predict, [query,files,k],outputs=output)
|
116 |
|
117 |
demo.launch()
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
1 |
+
import os
|
2 |
import PyPDF2
|
3 |
+
import pandas as pd
|
4 |
import warnings
|
5 |
import re
|
6 |
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
|
|
|
8 |
import torch
|
9 |
import gradio as gr
|
10 |
from typing import Union
|
11 |
+
import numpy as np
|
12 |
+
from cassandra.cluster import Cluster
|
13 |
+
from cassandra.auth import PlainTextAuthProvider
|
14 |
+
from dotenv import load_dotenv, find_dotenv
|
15 |
+
|
16 |
warnings.filterwarnings("ignore")
|
17 |
|
18 |
+
# Load environment variables
|
19 |
+
load_dotenv(find_dotenv())
|
20 |
+
ASTRADB_TOKEN = os.getenv("ASTRADB_TOKEN")
|
21 |
+
ASTRADB_API_ENDPOINT = os.getenv("ASTRADB_API_ENDPOINT")
|
22 |
|
23 |
+
# AstraDB connection setup using token and endpoint
|
24 |
+
auth_provider = PlainTextAuthProvider(username="token", password=ASTRADB_TOKEN)
|
25 |
+
cluster = Cluster([ASTRADB_API_ENDPOINT], auth_provider=auth_provider)
|
26 |
+
session = cluster.connect("your_keyspace_name")
|
27 |
|
28 |
+
# Load DPR models and tokenizers
|
29 |
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
30 |
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
31 |
|
|
|
32 |
q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
|
33 |
q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
|
34 |
|
35 |
+
def process_pdfs(parent_dir: Union[str, list]):
|
36 |
+
"""Processes the PDF files and returns a dataframe with the text of each page in a different line."""
|
37 |
+
df = pd.DataFrame(columns=["title", "text"])
|
38 |
+
if type(parent_dir) == str:
|
|
|
|
|
|
|
|
|
39 |
parent_dir = [parent_dir]
|
40 |
for file_path in parent_dir:
|
41 |
+
if ".pdf" not in file_path: # Skip non-pdf files
|
42 |
raise Exception("only pdf files are supported")
|
|
|
43 |
pdfFileObj = open(file_path, 'rb')
|
|
|
|
|
44 |
pdfReader = PyPDF2.PdfReader(pdfFileObj)
|
|
|
45 |
num_pages = len(pdfReader.pages)
|
46 |
+
for i in range(num_pages):
|
47 |
pageObj = pdfReader.pages[i]
|
48 |
+
txt = pageObj.extract_text().replace("\n", "").replace("\t", "")
|
49 |
+
txt = re.sub(r" +", " ", txt) # Strip extra space
|
|
|
|
|
|
|
|
|
50 |
file_name = file_path.split("/")[-1]
|
51 |
+
if len(txt) < 512:
|
52 |
+
new_data = pd.DataFrame([[f"{file_name}-page-{i}", txt]], columns=["title", "text"])
|
53 |
+
df = pd.concat([df, new_data], ignore_index=True)
|
54 |
+
else:
|
55 |
+
while len(txt) > 512:
|
56 |
+
new_data = pd.DataFrame([[f"{file_name}-page-{i}", txt[:512]]], columns=["title", "text"])
|
57 |
+
df = pd.concat([df, new_data], ignore_index=True)
|
58 |
txt = txt[512:]
|
|
|
|
|
59 |
pdfFileObj.close()
|
60 |
return df
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
def process_dataset(df):
|
63 |
+
"""Processes the dataframe and stores embeddings in AstraDB."""
|
64 |
+
if len(df) == 0:
|
65 |
raise Exception("empty pdf files, or can't read text from them")
|
66 |
+
|
67 |
+
for _, row in df.iterrows():
|
68 |
+
title = row['title']
|
69 |
+
text = row['text']
|
70 |
+
tokens = ctx_tokenizer(text, return_tensors="pt")
|
71 |
+
embed = ctx_encoder(**tokens)[0][0].detach().numpy().tolist()
|
72 |
+
|
73 |
+
query = "INSERT INTO your_table_name (title, text, embeddings) VALUES (%s, %s, %s)"
|
74 |
+
session.execute(query, (title, text, embed))
|
75 |
+
|
76 |
+
return df
|
77 |
+
|
78 |
+
def search(query, k=3):
|
79 |
+
"""Searches the query in the database and returns the k most similar."""
|
80 |
+
try:
|
81 |
tokens = q_tokenizer(query, return_tensors="pt")
|
82 |
+
query_embed = q_encoder(**tokens)[0][0].detach().numpy().tolist()
|
83 |
+
|
84 |
+
# Perform vector search in AstraDB
|
85 |
+
query = """
|
86 |
+
SELECT title, text, embeddings
|
87 |
+
FROM your_table_name
|
88 |
+
ORDER BY embeddings ANN OF %s LIMIT %s
|
89 |
+
"""
|
90 |
+
rows = session.execute(query, (query_embed, k))
|
91 |
+
|
92 |
+
retrieved_examples = []
|
93 |
+
for row in rows:
|
94 |
+
retrieved_examples.append({
|
95 |
+
"title": row.title,
|
96 |
+
"text": row.text,
|
97 |
+
"embeddings": np.array(row.embeddings)
|
98 |
+
})
|
99 |
+
|
100 |
+
out = f"""**title** : {retrieved_examples[0]["title"]},\ncontent: {retrieved_examples[0]["text"]}\n\n\n**similar resources:** {[example["title"] for example in retrieved_examples]}
|
101 |
"""
|
102 |
except Exception as e:
|
103 |
out = f"error in search: {e}"
|
104 |
return out
|
105 |
+
|
106 |
+
def predict(query, file_paths, k=3):
|
107 |
+
"""Predicts the most similar files to the query."""
|
108 |
+
try:
|
109 |
df = process_pdfs(file_paths)
|
110 |
+
process_dataset(df)
|
111 |
+
out = search(query, k=k)
|
112 |
except Exception as e:
|
113 |
out = f"error in predict: {e}"
|
114 |
return out
|
115 |
|
116 |
+
# Gradio interface
|
117 |
+
with gr.Blocks() as demo:
|
118 |
gr.Markdown("<h1 style='text-align: center'> PDF Search Engine </h1>")
|
119 |
with gr.Row():
|
120 |
with gr.Column():
|
121 |
+
files = gr.Files(label="Upload PDFs", type="filepath", file_count="multiple")
|
122 |
query = gr.Text(label="query")
|
123 |
+
with gr.Accordion("number of references", open=False):
|
124 |
+
k = gr.Number(value=3, show_label=False, precision=0, minimum=1, container=False)
|
125 |
button = gr.Button("search")
|
126 |
with gr.Column():
|
127 |
output = gr.Markdown(label="output")
|
128 |
+
button.click(predict, [query, files, k], outputs=output)
|
129 |
|
130 |
demo.launch()
|
|
|
|
|
|
|
|
|
|