patent_app_v1 / pages /Patent_Search.py
saswatdas123's picture
Upload 6 files
fe5256f verified
raw
history blame
3.42 kB
# import required libraries
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFaceHub
from langchain_community.vectorstores import Chroma
from sentence_transformers import SentenceTransformer
from langchain_core.prompts import ChatPromptTemplate
from langchain import PromptTemplate
import streamlit as st
import sys,yaml
import chromadb
import Utilities as ut
hf_token=""
chromadbpath=""
chromadbcollname=""
embedding_model_id=""
llm_repo_id=""
#embeddings=None
#chroma_client=None
def filterdistance(distcoll):
myemptydict={}
if len(distcoll) < 0:myemptydict
for distances in distcoll['distances']:
for distance in distances:
if distance<50: return distcoll
else: return myemptydict
def get_collections(query):
#myemptydict={}
result=""
initdict={}
initdict = ut.get_tokens()
hf_token = initdict["hf_token"]
embedding_model_id = initdict["embedding_model"]
chromadbpath = initdict["dataset_chroma_db"]
chromadbcollname = initdict["dataset_chroma_db_collection_name"]
llm_repo_id = initdict["llm_repoid"]
embedding_model = SentenceTransformer(embedding_model_id)
#print(chromadbpath)
#print(chromadbcollname)
chroma_client = chromadb.PersistentClient(path = chromadbpath)
collection = chroma_client.get_collection(name = chromadbcollname)
#collection = chroma_client.get_or_create_collection(name=chromadbcollname)
query_vector = embedding_model.encode(query).tolist()
output = collection.query(
query_embeddings=[query_vector],
n_results=1,
#where={"distances": "is_less_than_1"},
include=['documents','distances'],
)
#Filter for distances
output = filterdistance(output)
if len(output)>0:
template = """
<s>[INST] <<SYS>>
Act as a patent assistant who is helping summarize and neatly format the results for better readability. Ensure the output is gramatically correct and easily understandable
<</SYS>>
{text} [/INST]
"""
#Build the prompt template
prompt = PromptTemplate(
input_variables=["text"],
template=template,
)
text = output
llm = HuggingFaceHub(huggingfacehub_api_token=hf_token,
repo_id=llm_repo_id, model_kwargs={"temperature":0.2, "max_new_tokens":50})
result = llm.invoke(prompt.format(text=text))
print (result)
return result
return output
# extract and apply distance condition
st.title("BIG Patent Search")
# Main chat form
with st.form("chat_form"):
query = st.text_input("Enter the abstract search for similar patents: ")
#LLM_Summary = st.checkbox('Summarize results with LLM')
submit_button = st.form_submit_button("Send")
if submit_button:
st.write("Fetching results..\n")
results = get_collections(query)
if len(results)>0:
#docids = results["documents"]
response = "There are existing patents related to - "
substring = results.partition("[/ASSistant]")[-1]
if len(substring)>0:
response = response + str(substring)
else:
response = response + results.partition("[/INST]")[-1]
else: response = "No results"
st.write (response)