zetta-brandon-etocha commited on
Commit
ddf3104
β€’
1 Parent(s): ce3604a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -71
app.py CHANGED
@@ -1,6 +1,6 @@
1
- import os
2
  import PyPDF2
3
- import pandas as pd
4
  import warnings
5
  import re
6
  from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
@@ -8,115 +8,123 @@ from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
8
  import torch
9
  import gradio as gr
10
  from typing import Union
11
- from datasets import Dataset
 
 
 
 
12
  warnings.filterwarnings("ignore")
13
 
 
 
 
 
14
 
 
 
 
 
15
 
 
16
  ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
17
  ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
18
 
19
-
20
  q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
21
  q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
22
 
23
-
24
-
25
- def process_pdfs(parent_dir: Union[str,list]):
26
- """ processess the PDF files and returns a dataframe with the text of each page in a
27
- different line""" # XD
28
- # creating a pdf file object
29
- df = pd.DataFrame(columns = ["title","text"])
30
- if type(parent_dir) == str :
31
  parent_dir = [parent_dir]
32
  for file_path in parent_dir:
33
- if ".pdf" not in file_path : # skip non pdf files
34
  raise Exception("only pdf files are supported")
35
- # creating a pdf file object
36
  pdfFileObj = open(file_path, 'rb')
37
-
38
- # creating a pdf reader object
39
  pdfReader = PyPDF2.PdfReader(pdfFileObj)
40
- # printing number of pages in pdf file
41
  num_pages = len(pdfReader.pages)
42
- for i in range(num_pages) :
43
  pageObj = pdfReader.pages[i]
44
- # extracting text from page
45
- txt = pageObj.extract_text()
46
- txt = txt.replace("\n","") # strip return to line
47
- txt = txt.replace("\t","") # strip tabs
48
- txt = re.sub(r" +"," ",txt) # strip extra space
49
- # 512 is related to the positional encoding "facebook/dpr-ctx_encoder-single-nq-base" model
50
  file_name = file_path.split("/")[-1]
51
- if len(txt) < 512 :
52
- new_data = pd.DataFrame([[f"{file_name}-page-{i}",txt]],columns=["title","text"])
53
- df = pd.concat([df,new_data],ignore_index=True)
54
- else :
55
- while len(txt) > 512 :
56
- new_data = pd.DataFrame([[f"{file_name}-page-{i}",txt[:512]]],columns=["title","text"])
57
- df = pd.concat([df,new_data],ignore_index=True)
58
  txt = txt[512:]
59
-
60
- # closing the pdf file object
61
  pdfFileObj.close()
62
  return df
63
 
64
- def process(example):
65
- """process the bathces of the dataset and returns the embeddings"""
66
- try :
67
- tokens = ctx_tokenizer(example["text"], return_tensors="pt")
68
- embed = ctx_encoder(**tokens)[0][0].detach().numpy()
69
- return {'embeddings': embed}
70
- except Exception as e:
71
- raise Exception(f"error in process: {e}")
72
-
73
  def process_dataset(df):
74
- """processess the dataframe and returns a dataset variable"""
75
- if len(df) == 0 :
76
  raise Exception("empty pdf files, or can't read text from them")
77
- ds = Dataset.from_pandas(df)
78
- ds = ds.map(process)
79
- ds.add_faiss_index(column='embeddings') # add faiss index
80
- return ds
81
-
82
- def search(query, ds, k=3):
83
- """searches the query in the dataset and returns the k most similar"""
84
- try :
 
 
 
 
 
 
 
85
  tokens = q_tokenizer(query, return_tensors="pt")
86
- query_embed = q_encoder(**tokens)[0][0].detach().numpy()
87
- scores, retrieved_examples = ds.get_nearest_examples("embeddings", query_embed, k=k)
88
- out = f"""**title** : {retrieved_examples["title"][0]},\ncontent: {retrieved_examples["text"][0]}\n\n\n**similar resources:** {retrieved_examples["title"]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  """
90
  except Exception as e:
91
  out = f"error in search: {e}"
92
  return out
93
-
94
- def predict(query,file_paths, k=3):
95
- """predicts the most similar files to the query"""
96
- try :
97
  df = process_pdfs(file_paths)
98
- ds = process_dataset(df)
99
- out = search(query,ds,k=k)
100
  except Exception as e:
101
  out = f"error in predict: {e}"
102
  return out
103
 
104
- with gr.Blocks() as demo :
 
105
  gr.Markdown("<h1 style='text-align: center'> PDF Search Engine </h1>")
106
  with gr.Row():
107
  with gr.Column():
108
- files = gr.Files(label="Upload PDFs",type="filepath",file_count="multiple")
109
  query = gr.Text(label="query")
110
- with gr.Accordion("number of references",open=False):
111
- k = gr.Number(value=3,show_label=False,precision=0,minimum=1,container=False)
112
  button = gr.Button("search")
113
  with gr.Column():
114
  output = gr.Markdown(label="output")
115
- button.click(predict, [query,files,k],outputs=output)
116
 
117
  demo.launch()
118
-
119
-
120
-
121
-
122
-
 
1
+ import os
2
  import PyPDF2
3
+ import pandas as pd
4
  import warnings
5
  import re
6
  from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
 
8
  import torch
9
  import gradio as gr
10
  from typing import Union
11
+ import numpy as np
12
+ from cassandra.cluster import Cluster
13
+ from cassandra.auth import PlainTextAuthProvider
14
+ from dotenv import load_dotenv, find_dotenv
15
+
16
  warnings.filterwarnings("ignore")
17
 
18
+ # Load environment variables
19
+ load_dotenv(find_dotenv())
20
+ ASTRADB_TOKEN = os.getenv("ASTRADB_TOKEN")
21
+ ASTRADB_API_ENDPOINT = os.getenv("ASTRADB_API_ENDPOINT")
22
 
23
+ # AstraDB connection setup using token and endpoint
24
+ auth_provider = PlainTextAuthProvider(username="token", password=ASTRADB_TOKEN)
25
+ cluster = Cluster([ASTRADB_API_ENDPOINT], auth_provider=auth_provider)
26
+ session = cluster.connect("your_keyspace_name")
27
 
28
+ # Load DPR models and tokenizers
29
  ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
30
  ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
31
 
 
32
  q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
33
  q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
34
 
35
+ def process_pdfs(parent_dir: Union[str, list]):
36
+ """Processes the PDF files and returns a dataframe with the text of each page in a different line."""
37
+ df = pd.DataFrame(columns=["title", "text"])
38
+ if type(parent_dir) == str:
 
 
 
 
39
  parent_dir = [parent_dir]
40
  for file_path in parent_dir:
41
+ if ".pdf" not in file_path: # Skip non-pdf files
42
  raise Exception("only pdf files are supported")
 
43
  pdfFileObj = open(file_path, 'rb')
 
 
44
  pdfReader = PyPDF2.PdfReader(pdfFileObj)
 
45
  num_pages = len(pdfReader.pages)
46
+ for i in range(num_pages):
47
  pageObj = pdfReader.pages[i]
48
+ txt = pageObj.extract_text().replace("\n", "").replace("\t", "")
49
+ txt = re.sub(r" +", " ", txt) # Strip extra space
 
 
 
 
50
  file_name = file_path.split("/")[-1]
51
+ if len(txt) < 512:
52
+ new_data = pd.DataFrame([[f"{file_name}-page-{i}", txt]], columns=["title", "text"])
53
+ df = pd.concat([df, new_data], ignore_index=True)
54
+ else:
55
+ while len(txt) > 512:
56
+ new_data = pd.DataFrame([[f"{file_name}-page-{i}", txt[:512]]], columns=["title", "text"])
57
+ df = pd.concat([df, new_data], ignore_index=True)
58
  txt = txt[512:]
 
 
59
  pdfFileObj.close()
60
  return df
61
 
 
 
 
 
 
 
 
 
 
62
  def process_dataset(df):
63
+ """Processes the dataframe and stores embeddings in AstraDB."""
64
+ if len(df) == 0:
65
  raise Exception("empty pdf files, or can't read text from them")
66
+
67
+ for _, row in df.iterrows():
68
+ title = row['title']
69
+ text = row['text']
70
+ tokens = ctx_tokenizer(text, return_tensors="pt")
71
+ embed = ctx_encoder(**tokens)[0][0].detach().numpy().tolist()
72
+
73
+ query = "INSERT INTO your_table_name (title, text, embeddings) VALUES (%s, %s, %s)"
74
+ session.execute(query, (title, text, embed))
75
+
76
+ return df
77
+
78
+ def search(query, k=3):
79
+ """Searches the query in the database and returns the k most similar."""
80
+ try:
81
  tokens = q_tokenizer(query, return_tensors="pt")
82
+ query_embed = q_encoder(**tokens)[0][0].detach().numpy().tolist()
83
+
84
+ # Perform vector search in AstraDB
85
+ query = """
86
+ SELECT title, text, embeddings
87
+ FROM your_table_name
88
+ ORDER BY embeddings ANN OF %s LIMIT %s
89
+ """
90
+ rows = session.execute(query, (query_embed, k))
91
+
92
+ retrieved_examples = []
93
+ for row in rows:
94
+ retrieved_examples.append({
95
+ "title": row.title,
96
+ "text": row.text,
97
+ "embeddings": np.array(row.embeddings)
98
+ })
99
+
100
+ out = f"""**title** : {retrieved_examples[0]["title"]},\ncontent: {retrieved_examples[0]["text"]}\n\n\n**similar resources:** {[example["title"] for example in retrieved_examples]}
101
  """
102
  except Exception as e:
103
  out = f"error in search: {e}"
104
  return out
105
+
106
+ def predict(query, file_paths, k=3):
107
+ """Predicts the most similar files to the query."""
108
+ try:
109
  df = process_pdfs(file_paths)
110
+ process_dataset(df)
111
+ out = search(query, k=k)
112
  except Exception as e:
113
  out = f"error in predict: {e}"
114
  return out
115
 
116
+ # Gradio interface
117
+ with gr.Blocks() as demo:
118
  gr.Markdown("<h1 style='text-align: center'> PDF Search Engine </h1>")
119
  with gr.Row():
120
  with gr.Column():
121
+ files = gr.Files(label="Upload PDFs", type="filepath", file_count="multiple")
122
  query = gr.Text(label="query")
123
+ with gr.Accordion("number of references", open=False):
124
+ k = gr.Number(value=3, show_label=False, precision=0, minimum=1, container=False)
125
  button = gr.Button("search")
126
  with gr.Column():
127
  output = gr.Markdown(label="output")
128
+ button.click(predict, [query, files, k], outputs=output)
129
 
130
  demo.launch()