import os import PyPDF2 import pandas as pd import warnings import re from transformers import DPRContextEncoder, DPRContextEncoderTokenizer from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer import torch import gradio as gr from typing import Union import numpy as np from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider from dotenv import load_dotenv, find_dotenv warnings.filterwarnings("ignore") # Load environment variables load_dotenv(find_dotenv()) ASTRADB_TOKEN = os.getenv("ASTRADB_TOKEN") ASTRADB_API_ENDPOINT = os.getenv("ASTRADB_API_ENDPOINT") # AstraDB connection setup using token and endpoint auth_provider = PlainTextAuthProvider(username="token", password=ASTRADB_TOKEN) cluster = Cluster([ASTRADB_API_ENDPOINT], auth_provider=auth_provider) session = cluster.connect("your_keyspace_name") # Load DPR models and tokenizers ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base") ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base") q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base") q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base") def process_pdfs(parent_dir: Union[str, list]): """Processes the PDF files and returns a dataframe with the text of each page in a different line.""" df = pd.DataFrame(columns=["title", "text"]) if type(parent_dir) == str: parent_dir = [parent_dir] for file_path in parent_dir: if ".pdf" not in file_path: # Skip non-pdf files raise Exception("only pdf files are supported") pdfFileObj = open(file_path, 'rb') pdfReader = PyPDF2.PdfReader(pdfFileObj) num_pages = len(pdfReader.pages) for i in range(num_pages): pageObj = pdfReader.pages[i] txt = pageObj.extract_text().replace("\n", "").replace("\t", "") txt = re.sub(r" +", " ", txt) # Strip extra space file_name = file_path.split("/")[-1] if len(txt) < 512: new_data = pd.DataFrame([[f"{file_name}-page-{i}", txt]], columns=["title", "text"]) df = pd.concat([df, new_data], ignore_index=True) else: while len(txt) > 512: new_data = pd.DataFrame([[f"{file_name}-page-{i}", txt[:512]]], columns=["title", "text"]) df = pd.concat([df, new_data], ignore_index=True) txt = txt[512:] pdfFileObj.close() return df def process_dataset(df): """Processes the dataframe and stores embeddings in AstraDB.""" if len(df) == 0: raise Exception("empty pdf files, or can't read text from them") for _, row in df.iterrows(): title = row['title'] text = row['text'] tokens = ctx_tokenizer(text, return_tensors="pt") embed = ctx_encoder(**tokens)[0][0].detach().numpy().tolist() query = "INSERT INTO your_table_name (title, text, embeddings) VALUES (%s, %s, %s)" session.execute(query, (title, text, embed)) return df def search(query, k=3): """Searches the query in the database and returns the k most similar.""" try: tokens = q_tokenizer(query, return_tensors="pt") query_embed = q_encoder(**tokens)[0][0].detach().numpy().tolist() # Perform vector search in AstraDB query = """ SELECT title, text, embeddings FROM your_table_name ORDER BY embeddings ANN OF %s LIMIT %s """ rows = session.execute(query, (query_embed, k)) retrieved_examples = [] for row in rows: retrieved_examples.append({ "title": row.title, "text": row.text, "embeddings": np.array(row.embeddings) }) out = f"""**title** : {retrieved_examples[0]["title"]},\ncontent: {retrieved_examples[0]["text"]}\n\n\n**similar resources:** {[example["title"] for example in retrieved_examples]} """ except Exception as e: out = f"error in search: {e}" return out def predict(query, file_paths, k=3): """Predicts the most similar files to the query.""" try: df = process_pdfs(file_paths) process_dataset(df) out = search(query, k=k) except Exception as e: out = f"error in predict: {e}" return out # Gradio interface with gr.Blocks() as demo: gr.Markdown("

PDF Search Engine

") with gr.Row(): with gr.Column(): files = gr.Files(label="Upload PDFs", type="filepath", file_count="multiple") query = gr.Text(label="query") with gr.Accordion("number of references", open=False): k = gr.Number(value=3, show_label=False, precision=0, minimum=1, container=False) button = gr.Button("search") with gr.Column(): output = gr.Markdown(label="output") button.click(predict, [query, files, k], outputs=output) demo.launch()