File size: 3,416 Bytes
fe5256f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# import required libraries
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFaceHub
from langchain_community.vectorstores import Chroma
from sentence_transformers import SentenceTransformer
from langchain_core.prompts import ChatPromptTemplate
from langchain import PromptTemplate

import streamlit as st
import sys,yaml
import chromadb
import Utilities as ut

hf_token=""
chromadbpath=""
chromadbcollname=""
embedding_model_id=""
llm_repo_id=""
#embeddings=None
#chroma_client=None


def filterdistance(distcoll):
    myemptydict={}
    if len(distcoll) < 0:myemptydict
    for distances in distcoll['distances']:
        for distance in distances:
            if distance<50: return distcoll
            else: return myemptydict
           
def get_collections(query):
    #myemptydict={}
    result=""
    initdict={}
    initdict = ut.get_tokens()
    hf_token = initdict["hf_token"]
    embedding_model_id = initdict["embedding_model"]
    chromadbpath = initdict["dataset_chroma_db"]
    chromadbcollname = initdict["dataset_chroma_db_collection_name"]
    llm_repo_id = initdict["llm_repoid"]
    
    embedding_model = SentenceTransformer(embedding_model_id)
    #print(chromadbpath)
    #print(chromadbcollname)
    chroma_client = chromadb.PersistentClient(path = chromadbpath)
    collection = chroma_client.get_collection(name = chromadbcollname)
 
    #collection = chroma_client.get_or_create_collection(name=chromadbcollname)
    query_vector = embedding_model.encode(query).tolist()
    output = collection.query(
        query_embeddings=[query_vector],
        n_results=1,
        #where={"distances": "is_less_than_1"},
        include=['documents','distances'],
        
        )
    #Filter for distances
    output = filterdistance(output)
    
    if len(output)>0: 
        template = """
        <s>[INST] <<SYS>>
        Act as a patent assistant who is helping summarize and neatly format the results for better readability. Ensure the output is gramatically correct and easily understandable 
        <</SYS>>

        {text} [/INST]
        """
        #Build the prompt template
        prompt = PromptTemplate(
            input_variables=["text"],
            template=template,
        )
        text = output
        
        llm = HuggingFaceHub(huggingfacehub_api_token=hf_token, 
                        repo_id=llm_repo_id, model_kwargs={"temperature":0.2, "max_new_tokens":50})

        result = llm.invoke(prompt.format(text=text))
        print (result)
    return result
  
    return output
    # extract and apply distance condition

st.title("BIG Patent Search")

# Main chat form
with st.form("chat_form"):
    query = st.text_input("Enter the abstract search for similar patents: ")
    #LLM_Summary = st.checkbox('Summarize results with LLM') 
    submit_button = st.form_submit_button("Send")    

if submit_button:
    st.write("Fetching results..\n")
    results =  get_collections(query)
    
    if len(results)>0: 
        #docids = results["documents"]
        response = "There are existing patents related to -    "
        substring = results.partition("[/ASSistant]")[-1]
        if len(substring)>0:
            response  = response + str(substring)
        else: 
            response = response + results.partition("[/INST]")[-1]
            
    else: response = "No results"
    
    st.write (response)