aakash0017's picture
Upload folder using huggingface_hub
7b5e0ec
import os
import getpass
import pinecone
from transformers import AutoTokenizer, AutoModel
import torch
# from tqdm import tqdm
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Pinecone
from langchain.document_loaders import TextLoader, DirectoryLoader
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv(".env"), override=True)
#
# os.environ["PINECONE_API_KEY"] = getpass.getpass("Pinecone API Key:")
# os.environ["PINECONE_ENV"] = getpass.getpass("Pinecone Environment:")
print("Downloading model")
print()
# model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("models/models--BAAI--bge-base-en-v1.5/snapshots/617ca489d9e86b49b8167676d8220688b99db36e")
tokenizer = AutoTokenizer.from_pretrained("models/models--BAAI--bge-base-en-v1.5/snapshots/617ca489d9e86b49b8167676d8220688b99db36e")
# tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
print()
print("Models downloaded")
pinecone.init(
api_key=os.getenv("PINECONE_API_KEY"), # find at app.pinecone.io
environment=os.getenv("PINECONE_ENV"), # next to api key in console
)
index_name = "ophtal-knowledge-base"
def get_bert_embeddings(sentence):
embeddings = []
input_ids = tokenizer.encode(sentence, return_tensors="pt")
with torch.no_grad():
output = model(input_ids)
embedding = output.last_hidden_state[:,0,:].numpy().tolist()
# embeddings.append((f"doc-{doc_no}-seg{i}", embedding, {"meta_data": text_input[i]}))
# embeddings.append((f"doc-{doc_no}-seg{i}", embedding, {"meta_data": text_input[i]}))
return embedding
def fetch_top_k(input_data, top_k):
# top_k=5
index = pinecone.Index(index_name)
# emb = get_bert_embeddings(input_data)
vectorstore = Pinecone(index, get_bert_embeddings, "text")
# query = index.query(
# vector=emb,
# top_k=top_k,
# include_values=True,
# pool_threads=100,
# )
query = vectorstore.similarity_search(input_data, k=top_k)
# id_list = []
# for i in query['matches']:
# # print(i)
# id_list.append(i['id'])
# fetched_data = index.fetch(id_list)
# topk_list = []
text_list = []
source_list = []
# for id_ in id_list:
# text = index.fetch(id_list)['vectors'][id_]['metadata']['text']
# source = index.fetch(id_list)['vectors'][id_]['metadata']['source']
# text_list.append(text)
# source_list.append(source)
for i in query:
text = i.page_content
source = i.metadata['source']
text_list.append(text)
source_list.append(source)
# print(text_list)
list_ = []
for i in range(top_k):
list_.append(f"Document: {i+1}.\n{text_list[i]}")
# return "hello"
# return '\n\n'.join(list_)
return list_
# if __name__ == "__main__":
# neet = "Which of the following is true regarding Mittendorf dot?\nA. Glial tissue projecting from optic disc\nB. Obliterated vessel running forward into the vitreous\nC. Associated with posterior polar cataract\nD. Commonest congenital anomaly of hyaloid system"
# text, source = fetch_top_k(neet)
# # print(text)