import gradio as gr import spaces import subprocess import os import shutil import string import random import glob from pypdf import PdfReader from sentence_transformers import SentenceTransformer model_name = os.environ.get("MODEL", "Snowflake/snowflake-arctic-embed-m") chunk_size = int(os.environ.get("CHUNK_SIZE", 128)) default_max_characters = int(os.environ.get("DEFAULT_MAX_CHARACTERS", 258)) model = SentenceTransformer(model_name) # model.to(device="cuda") @spaces.GPU def embed(queries, chunks) -> dict[str, list[tuple[str, float]]]: query_embeddings = model.encode(queries, prompt_name="query") document_embeddings = model.encode(chunks) scores = query_embeddings @ document_embeddings.T results = {} for query, query_scores in zip(queries, scores): chunk_idxs = [i for i in range(len(chunks))] # Get a structure like {query: [(chunk_idx, score), (chunk_idx, score), ...]} results[query] = list(zip(chunk_idxs, query_scores)) return results def extract_text_from_pdf(reader): full_text = "" for idx, page in enumerate(reader.pages): text = page.extract_text() if len(text) > 0: full_text += f"---- Page {idx} ----\n" + page.extract_text() + "\n\n" return full_text.strip() def convert(filename) -> str: plain_text_filetypes = [ ".txt", ".csv", ".tsv", ".md", ".yaml", ".toml", ".json", ".json5", ".jsonc", ] # Already a plain text file that wouldn't benefit from pandoc so return the content if any(filename.endswith(ft) for ft in plain_text_filetypes): with open(filename, "r") as f: return f.read() if filename.endswith(".pdf"): return extract_text_from_pdf(PdfReader(filename)) raise ValueError(f"Unsupported file type: {filename}") def chunk_to_length(text, max_length=512): chunks = [] while len(text) > max_length: chunks.append(text[:max_length]) text = text[max_length:] chunks.append(text) return chunks @spaces.GPU def predict(query, max_characters) -> str: # Embed the query query_embedding = model.encode(query, prompt_name="query") # Initialize a list to store all chunks and their similarities across all documents all_chunks = [] # Iterate through all documents for filename, doc in docs.items(): # Calculate dot product between query and document embeddings similarities = doc["embeddings"] @ query_embedding.T # Add chunks and similarities to the all_chunks list all_chunks.extend([(filename, chunk, sim) for chunk, sim in zip(doc["chunks"], similarities)]) # Sort all chunks by similarity all_chunks.sort(key=lambda x: x[2], reverse=True) # Initialize a dictionary to store relevant chunks for each document relevant_chunks = {} # Add most relevant chunks until max_characters is reached total_chars = 0 for filename, chunk, _ in all_chunks: if total_chars + len(chunk) <= max_characters: if filename not in relevant_chunks: relevant_chunks[filename] = [] relevant_chunks[filename].append(chunk) total_chars += len(chunk) else: break return relevant_chunks docs = {} for filename in glob.glob("sources/*"): converted_doc = convert(filename) chunks = chunk_to_length(converted_doc, chunk_size) embeddings = model.encode(chunks) docs[filename] = { "chunks": chunks, "embeddings": embeddings, } gr.Interface( predict, inputs=[ gr.Textbox(label="Query asked about the documents"), gr.Number(label="Max output characters", value=default_max_characters), ], outputs=[gr.JSON(label="Relevant chunks")], ).launch()