Carlosito16 commited on
Commit
e28c4fa
·
1 Parent(s): bc292d9

add prompt and add chain_type_kwargs on `load_retriever`

Browse files
Files changed (1) hide show
  1. app.py +26 -1
app.py CHANGED
@@ -23,9 +23,33 @@ from langchain.embeddings import HuggingFaceInstructEmbeddings
23
  from langchain import HuggingFacePipeline
24
  from langchain.chains import RetrievalQA
25
 
 
26
 
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  st.set_page_config(
30
  page_title = 'aitGPT',
31
  page_icon = '✅')
@@ -80,7 +104,8 @@ def load_llm_model():
80
 
81
  def load_retriever(llm, db):
82
  qa_retriever = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff",
83
- retriever=db.as_retriever())
 
84
 
85
  return qa_retriever
86
 
 
23
  from langchain import HuggingFacePipeline
24
  from langchain.chains import RetrievalQA
25
 
26
+ from langchain.prompts import PromptTemplate
27
 
28
 
29
 
30
+
31
+ prompt_template = """
32
+
33
+ You are the chatbot and the face of Asian Institute of Technology (AIT). Your job is to give answers to prospective and current students about the school.
34
+ Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
35
+ Always make sure to be elaborate. And try to use vibrant, positive tone to represent good branding of the school.
36
+ Never answer with any unfinished response.
37
+
38
+ {context}
39
+
40
+ Question: {question}
41
+
42
+ Always make sure to elaborate your response and use vibrant, positive tone to represent good branding of the school.
43
+ Never answer with any unfinished response.
44
+
45
+
46
+ """
47
+ PROMPT = PromptTemplate(
48
+ template=prompt_template, input_variables=["context", "question"]
49
+ )
50
+ chain_type_kwargs = {"prompt": PROMPT}
51
+
52
+
53
  st.set_page_config(
54
  page_title = 'aitGPT',
55
  page_icon = '✅')
 
104
 
105
  def load_retriever(llm, db):
106
  qa_retriever = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff",
107
+ retriever=db.as_retriever(),
108
+ chain_type_kwargs= chain_type_kwargs)
109
 
110
  return qa_retriever
111