import os import getpass import pinecone from transformers import AutoTokenizer, AutoModel import torch # from tqdm import tqdm from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import Pinecone from langchain.document_loaders import TextLoader, DirectoryLoader from dotenv import load_dotenv, find_dotenv load_dotenv(find_dotenv(".env"), override=True) # # os.environ["PINECONE_API_KEY"] = getpass.getpass("Pinecone API Key:") # os.environ["PINECONE_ENV"] = getpass.getpass("Pinecone Environment:") print("Downloading model") print() # model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") model = AutoModel.from_pretrained("models/models--BAAI--bge-base-en-v1.5/snapshots/617ca489d9e86b49b8167676d8220688b99db36e") tokenizer = AutoTokenizer.from_pretrained("models/models--BAAI--bge-base-en-v1.5/snapshots/617ca489d9e86b49b8167676d8220688b99db36e") # tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") print() print("Models downloaded") pinecone.init( api_key=os.getenv("PINECONE_API_KEY"), # find at app.pinecone.io environment=os.getenv("PINECONE_ENV"), # next to api key in console ) index_name = "ophtal-knowledge-base" def get_bert_embeddings(sentence): embeddings = [] input_ids = tokenizer.encode(sentence, return_tensors="pt") with torch.no_grad(): output = model(input_ids) embedding = output.last_hidden_state[:,0,:].numpy().tolist() # embeddings.append((f"doc-{doc_no}-seg{i}", embedding, {"meta_data": text_input[i]})) # embeddings.append((f"doc-{doc_no}-seg{i}", embedding, {"meta_data": text_input[i]})) return embedding def fetch_top_k(input_data, top_k): # top_k=5 index = pinecone.Index(index_name) # emb = get_bert_embeddings(input_data) vectorstore = Pinecone(index, get_bert_embeddings, "text") # query = index.query( # vector=emb, # top_k=top_k, # include_values=True, # pool_threads=100, # ) query = vectorstore.similarity_search(input_data, k=top_k) # id_list = [] # for i in query['matches']: # # print(i) # id_list.append(i['id']) # fetched_data = index.fetch(id_list) # topk_list = [] text_list = [] source_list = [] # for id_ in id_list: # text = index.fetch(id_list)['vectors'][id_]['metadata']['text'] # source = index.fetch(id_list)['vectors'][id_]['metadata']['source'] # text_list.append(text) # source_list.append(source) for i in query: text = i.page_content source = i.metadata['source'] text_list.append(text) source_list.append(source) # print(text_list) list_ = [] for i in range(top_k): list_.append(f"Document: {i+1}.\n{text_list[i]}") # return "hello" # return '\n\n'.join(list_) return list_ # if __name__ == "__main__": # neet = "Which of the following is true regarding Mittendorf dot?\nA. Glial tissue projecting from optic disc\nB. Obliterated vessel running forward into the vitreous\nC. Associated with posterior polar cataract\nD. Commonest congenital anomaly of hyaloid system" # text, source = fetch_top_k(neet) # # print(text)