holyhigh666 commited on
Commit
5b68ef2
·
verified ·
1 Parent(s): e81ef56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -102
app.py CHANGED
@@ -3,17 +3,25 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
  #from langchain.vectorstores import FAISS
5
  from langchain_community.vectorstores import FAISS
6
- import torch
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
8
  from langchain_huggingface.llms import HuggingFacePipeline
9
  from langchain.prompts import PromptTemplate
10
- from transformers import pipeline
11
  from langchain_core.output_parsers import StrOutputParser
12
  from langchain_core.runnables import RunnablePassthrough
13
  import glob
14
  import gradio as gr
15
 
16
- # Prepare the data
 
 
 
 
 
 
 
 
 
17
 
18
  md_path = glob.glob( "md_files/*.md")
19
 
@@ -30,62 +38,18 @@ doc_splits = text_splitter.split_documents(docs_list)
30
  # Create the embeddings + retriever
31
 
32
  db = FAISS.from_documents(doc_splits,
33
- HuggingFaceEmbeddings(model_name='BAAI/bge-base-en-v1.5'))
34
-
35
-
36
-
37
-
38
- # Load quantized model
39
-
40
- model_name = 'HuggingFaceH4/zephyr-7b-beta'
41
-
42
- bnb_config = BitsAndBytesConfig(
43
- load_in_4bit=True,
44
- bnb_4bit_use_double_quant=True,
45
- bnb_4bit_quant_type="nf4",
46
- bnb_4bit_compute_dtype=torch.bfloat16
47
- )
48
-
49
- model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config)
50
- tokenizer = AutoTokenizer.from_pretrained(model_name)
51
-
52
- # Setup the LLM chain
53
-
54
-
55
-
56
- text_generation_pipeline = pipeline(
57
- model=model,
58
- tokenizer=tokenizer,
59
- task="text-generation",
60
- temperature=0.2,
61
- do_sample=True,
62
- repetition_penalty=1.1,
63
- return_full_text=True,
64
- max_new_tokens=512,
65
- )
66
-
67
- llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
68
-
69
- # search in vector database
70
-
71
 
72
 
 
73
  prompt_template = '''You are an assistant for question-answering tasks.
74
-
75
  Here is the context to use to answer the question:
76
-
77
  {context}
78
-
79
  Think carefully about the above context.
80
-
81
  Now, review the user question:
82
-
83
  {question}
84
-
85
  Provide an answer to this questions using only the above context.
86
-
87
  Use three sentences maximum and keep the answer concise.
88
-
89
  Answer:'''
90
 
91
  prompt = PromptTemplate(
@@ -94,57 +58,145 @@ prompt = PromptTemplate(
94
  )
95
 
96
 
97
- llm_chain = prompt | llm | StrOutputParser()
98
-
99
-
100
-
101
-
102
- retriever = db.as_retriever()
103
-
104
- rag_chain = (
105
- {"context": retriever, "question": RunnablePassthrough()}
106
- | llm_chain
107
- )
108
-
109
- #question = "what is advantage of chalcogenide perovskite?"
110
-
111
-
112
-
113
-
114
-
115
- def get_output(is_RAG:str,questions:str):
116
- if is_RAG== "RAG":
117
- generation2=rag_chain.invoke(questions)
118
- return generation2.content
119
- else:
120
- generation1=llm_chain.invoke({"context":"", "question": questions})
121
- return generation1.content
122
-
123
- demo = gr.Interface(
124
- fn=get_output,
125
- inputs=[
126
- gr.Radio(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  choices=["RAG", "No RAG"],
128
- type="value",
129
- value="RAG", # Set default value to "Model 1"
130
- label="RAG or not"
131
- ),
132
- gr.Textbox(label="Input Questions",info="input questions on chalcogenide perovskites")
133
- ],
134
- outputs="markdown",
135
- title="RAG using llm zephyr-7b-beta, embedding model BAAI/bge-base-en-v1.5, based on chalcogenide perovskite papers",
136
- description="""
137
- ## ask a question to get answer on chalcogenide perovskite; or click on the examples below.
138
- """,
139
- examples=[["RAG","what is advantage of BaZrS3?"],
140
- ["RAG","what is bandgap of SrHfS3?"],
141
- ["RAG","why is chalcogenide perovskite important?"]
142
- ]
143
- )
144
-
145
- # Launch the Gradio app
146
- if __name__ == "__main__":
147
- demo.launch(share=False)
148
-
149
-
150
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
  #from langchain.vectorstores import FAISS
5
  from langchain_community.vectorstores import FAISS
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
7
  from langchain_huggingface.llms import HuggingFacePipeline
8
  from langchain.prompts import PromptTemplate
9
+
10
  from langchain_core.output_parsers import StrOutputParser
11
  from langchain_core.runnables import RunnablePassthrough
12
  import glob
13
  import gradio as gr
14
 
15
+ from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
16
+
17
+ import os
18
+
19
+ secret_value_hf = os.getenv("hf_token")
20
+
21
+ hf_embeddings = HuggingFaceInferenceAPIEmbeddings(
22
+ api_key=secret_value_hf,
23
+ model_name="sentence-transformers/all-MiniLM-l6-v2"
24
+ )
25
 
26
  md_path = glob.glob( "md_files/*.md")
27
 
 
38
  # Create the embeddings + retriever
39
 
40
  db = FAISS.from_documents(doc_splits,
41
+ hf_embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
 
44
+ # prompt
45
  prompt_template = '''You are an assistant for question-answering tasks.
 
46
  Here is the context to use to answer the question:
 
47
  {context}
 
48
  Think carefully about the above context.
 
49
  Now, review the user question:
 
50
  {question}
 
51
  Provide an answer to this questions using only the above context.
 
52
  Use three sentences maximum and keep the answer concise.
 
53
  Answer:'''
54
 
55
  prompt = PromptTemplate(
 
58
  )
59
 
60
 
61
+ # gradio interface
62
+
63
+ def get_output(model_name:str,is_RAG:str,questions:str):
64
+ if model_name=="mistralai/Mistral-7B-Instruct-v0.2":
65
+ #repo_id = "mistralai/Mistral-7B-Instruct-v0.2"
66
+ llm = HuggingFaceEndpoint(
67
+ repo_id=model_name,
68
+ max_length=512,
69
+ temperature=0.2,
70
+ huggingfacehub_api_token=secret_value_hf,
71
+ )
72
+ llm_chain = prompt | llm | StrOutputParser()
73
+ retriever = db.as_retriever(
74
+ search_type="similarity",
75
+ search_kwargs={'k': 4}
76
+ )
77
+
78
+ rag_chain = (
79
+ {"context": retriever, "question": RunnablePassthrough()}
80
+ | llm_chain
81
+ )
82
+ if is_RAG== "RAG":
83
+ generation2=rag_chain.invoke(questions)
84
+ return generation2
85
+ else:
86
+ generation1=llm_chain.invoke({"context":"", "question": questions})
87
+ return generation1
88
+ elif model_name=="meta-llama/Llama-3.2-3B-Instruct":
89
+ llm = HuggingFaceEndpoint(
90
+ repo_id=model_name,
91
+ max_length=512,
92
+ temperature=0.2,
93
+ huggingfacehub_api_token=secret_value_hf,
94
+ )
95
+ llm_chain = prompt | llm | StrOutputParser()
96
+ retriever = db.as_retriever()
97
+
98
+ rag_chain = (
99
+ {"context": retriever, "question": RunnablePassthrough()}
100
+ | llm_chain
101
+ )
102
+ if is_RAG== "RAG":
103
+ generation2=rag_chain.invoke(questions)
104
+ return generation2
105
+ else:
106
+ generation1=llm_chain.invoke({"context":"", "question": questions})
107
+ return generation1
108
+ elif model_name=="Qwen/Qwen2.5-72B-Instruct":
109
+ llm = HuggingFaceEndpoint(
110
+ repo_id=model_name,
111
+ max_length=512,
112
+ temperature=0.2,
113
+ huggingfacehub_api_token=secret_value_hf,
114
+ )
115
+ llm_chain = prompt | llm | StrOutputParser()
116
+ retriever = db.as_retriever()
117
+
118
+ rag_chain = (
119
+ {"context": retriever, "question": RunnablePassthrough()}
120
+ | llm_chain
121
+ )
122
+ if is_RAG== "RAG":
123
+ generation2=rag_chain.invoke(questions)
124
+ return generation2
125
+ else:
126
+ generation1=llm_chain.invoke({"context":"", "question": questions})
127
+ return generation1
128
+
129
+
130
+
131
+ # Custom CSS to style the output area
132
+ custom_css = """
133
+ #output_area {
134
+ background-color: #1e1e1e; /* Dark background */
135
+ color: #ffffff; /* White text */
136
+ padding: 10px;
137
+ border-radius: 5px;
138
+ border: 1px solid #333333; /* Dark border */
139
+ margin-top: 10px;
140
+ }
141
+
142
+ #output_area h3 {
143
+ color: #ffcc00; /* Yellow title color */
144
+ margin-bottom: 10px;
145
+ }
146
+ """
147
+
148
+ with gr.Blocks(title="Ask Questions on Chalcogenide Perovskites",theme=gr.themes.Ocean(),css=custom_css) as demo:
149
+ gr.Markdown("""
150
+ ## Retrieval-Augmented Generation for Chalcogenide Perovskites
151
+ This space implements Retrieval-Augmented Generation (RAG) using large language models, based on Hui Haolei's work on chalcogenide perovskite papers. You can select different models and choose whether to use RAG to enhance the responses.
152
+ """)
153
+
154
+ with gr.Row():
155
+ model_name = gr.Radio(
156
+ choices=["mistralai/Mistral-7B-Instruct-v0.2", "meta-llama/Llama-3.2-3B-Instruct", "Qwen/Qwen2.5-72B-Instruct"],
157
+ value="mistralai/Mistral-7B-Instruct-v0.2",
158
+ label="Model Name",
159
+ info="Select the model you want to use."
160
+ )
161
+
162
+ with gr.Row():
163
+ rag = gr.Radio(
164
  choices=["RAG", "No RAG"],
165
+ value="RAG",
166
+ label="RAG or Not",
167
+ info="Choose whether to use Retrieval-Augmented Generation."
168
+ )
169
+
170
+ with gr.Row():
171
+ question = gr.Textbox(
172
+ label="Input Question",
173
+ placeholder="Enter your question about chalcogenide perovskites here...",
174
+ lines=2 # Increase the number of lines for better input experience
175
+ )
176
+
177
+ with gr.Row():
178
+ submit_button = gr.Button("Submit")
179
+
180
+ with gr.Row():
181
+ output = gr.Textbox(label="Response",
182
+ lines=10, # Increase the number of lines for the output area
183
+ elem_id="output_area" # Assign a custom ID for styling
184
+ )
185
+ submit_button.click(
186
+ fn=get_output,
187
+ inputs=[model_name, rag, question],
188
+ outputs=output
189
+ )
190
+
191
+ gr.Examples(
192
+ examples=[
193
+ ["mistralai/Mistral-7B-Instruct-v0.2", "RAG", "What is the advantage of BaZrS3?"],
194
+ ["mistralai/Mistral-7B-Instruct-v0.2", "RAG", "What is the bandgap of SrHfS3?"],
195
+ ["mistralai/Mistral-7B-Instruct-v0.2", "RAG", "Why is chalcogenide perovskite important?"]
196
+ ],
197
+ fn=get_output,
198
+ inputs=[model_name, rag, question],
199
+ outputs=output
200
+ )
201
+
202
+ demo.launch()