Update app.py
Browse files
app.py
CHANGED
@@ -3,17 +3,25 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
3 |
from langchain_huggingface import HuggingFaceEmbeddings
|
4 |
#from langchain.vectorstores import FAISS
|
5 |
from langchain_community.vectorstores import FAISS
|
6 |
-
import torch
|
7 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
8 |
from langchain_huggingface.llms import HuggingFacePipeline
|
9 |
from langchain.prompts import PromptTemplate
|
10 |
-
|
11 |
from langchain_core.output_parsers import StrOutputParser
|
12 |
from langchain_core.runnables import RunnablePassthrough
|
13 |
import glob
|
14 |
import gradio as gr
|
15 |
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
md_path = glob.glob( "md_files/*.md")
|
19 |
|
@@ -30,62 +38,18 @@ doc_splits = text_splitter.split_documents(docs_list)
|
|
30 |
# Create the embeddings + retriever
|
31 |
|
32 |
db = FAISS.from_documents(doc_splits,
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
# Load quantized model
|
39 |
-
|
40 |
-
model_name = 'HuggingFaceH4/zephyr-7b-beta'
|
41 |
-
|
42 |
-
bnb_config = BitsAndBytesConfig(
|
43 |
-
load_in_4bit=True,
|
44 |
-
bnb_4bit_use_double_quant=True,
|
45 |
-
bnb_4bit_quant_type="nf4",
|
46 |
-
bnb_4bit_compute_dtype=torch.bfloat16
|
47 |
-
)
|
48 |
-
|
49 |
-
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config)
|
50 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
51 |
-
|
52 |
-
# Setup the LLM chain
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
text_generation_pipeline = pipeline(
|
57 |
-
model=model,
|
58 |
-
tokenizer=tokenizer,
|
59 |
-
task="text-generation",
|
60 |
-
temperature=0.2,
|
61 |
-
do_sample=True,
|
62 |
-
repetition_penalty=1.1,
|
63 |
-
return_full_text=True,
|
64 |
-
max_new_tokens=512,
|
65 |
-
)
|
66 |
-
|
67 |
-
llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
|
68 |
-
|
69 |
-
# search in vector database
|
70 |
-
|
71 |
|
72 |
|
|
|
73 |
prompt_template = '''You are an assistant for question-answering tasks.
|
74 |
-
|
75 |
Here is the context to use to answer the question:
|
76 |
-
|
77 |
{context}
|
78 |
-
|
79 |
Think carefully about the above context.
|
80 |
-
|
81 |
Now, review the user question:
|
82 |
-
|
83 |
{question}
|
84 |
-
|
85 |
Provide an answer to this questions using only the above context.
|
86 |
-
|
87 |
Use three sentences maximum and keep the answer concise.
|
88 |
-
|
89 |
Answer:'''
|
90 |
|
91 |
prompt = PromptTemplate(
|
@@ -94,57 +58,145 @@ prompt = PromptTemplate(
|
|
94 |
)
|
95 |
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
)
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
choices=["RAG", "No RAG"],
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
)
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
)
|
144 |
-
|
145 |
-
#
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from langchain_huggingface import HuggingFaceEmbeddings
|
4 |
#from langchain.vectorstores import FAISS
|
5 |
from langchain_community.vectorstores import FAISS
|
|
|
6 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
7 |
from langchain_huggingface.llms import HuggingFacePipeline
|
8 |
from langchain.prompts import PromptTemplate
|
9 |
+
|
10 |
from langchain_core.output_parsers import StrOutputParser
|
11 |
from langchain_core.runnables import RunnablePassthrough
|
12 |
import glob
|
13 |
import gradio as gr
|
14 |
|
15 |
+
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
|
16 |
+
|
17 |
+
import os
|
18 |
+
|
19 |
+
secret_value_hf = os.getenv("hf_token")
|
20 |
+
|
21 |
+
hf_embeddings = HuggingFaceInferenceAPIEmbeddings(
|
22 |
+
api_key=secret_value_hf,
|
23 |
+
model_name="sentence-transformers/all-MiniLM-l6-v2"
|
24 |
+
)
|
25 |
|
26 |
md_path = glob.glob( "md_files/*.md")
|
27 |
|
|
|
38 |
# Create the embeddings + retriever
|
39 |
|
40 |
db = FAISS.from_documents(doc_splits,
|
41 |
+
hf_embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
|
44 |
+
# prompt
|
45 |
prompt_template = '''You are an assistant for question-answering tasks.
|
|
|
46 |
Here is the context to use to answer the question:
|
|
|
47 |
{context}
|
|
|
48 |
Think carefully about the above context.
|
|
|
49 |
Now, review the user question:
|
|
|
50 |
{question}
|
|
|
51 |
Provide an answer to this questions using only the above context.
|
|
|
52 |
Use three sentences maximum and keep the answer concise.
|
|
|
53 |
Answer:'''
|
54 |
|
55 |
prompt = PromptTemplate(
|
|
|
58 |
)
|
59 |
|
60 |
|
61 |
+
# gradio interface
|
62 |
+
|
63 |
+
def get_output(model_name:str,is_RAG:str,questions:str):
|
64 |
+
if model_name=="mistralai/Mistral-7B-Instruct-v0.2":
|
65 |
+
#repo_id = "mistralai/Mistral-7B-Instruct-v0.2"
|
66 |
+
llm = HuggingFaceEndpoint(
|
67 |
+
repo_id=model_name,
|
68 |
+
max_length=512,
|
69 |
+
temperature=0.2,
|
70 |
+
huggingfacehub_api_token=secret_value_hf,
|
71 |
+
)
|
72 |
+
llm_chain = prompt | llm | StrOutputParser()
|
73 |
+
retriever = db.as_retriever(
|
74 |
+
search_type="similarity",
|
75 |
+
search_kwargs={'k': 4}
|
76 |
+
)
|
77 |
+
|
78 |
+
rag_chain = (
|
79 |
+
{"context": retriever, "question": RunnablePassthrough()}
|
80 |
+
| llm_chain
|
81 |
+
)
|
82 |
+
if is_RAG== "RAG":
|
83 |
+
generation2=rag_chain.invoke(questions)
|
84 |
+
return generation2
|
85 |
+
else:
|
86 |
+
generation1=llm_chain.invoke({"context":"", "question": questions})
|
87 |
+
return generation1
|
88 |
+
elif model_name=="meta-llama/Llama-3.2-3B-Instruct":
|
89 |
+
llm = HuggingFaceEndpoint(
|
90 |
+
repo_id=model_name,
|
91 |
+
max_length=512,
|
92 |
+
temperature=0.2,
|
93 |
+
huggingfacehub_api_token=secret_value_hf,
|
94 |
+
)
|
95 |
+
llm_chain = prompt | llm | StrOutputParser()
|
96 |
+
retriever = db.as_retriever()
|
97 |
+
|
98 |
+
rag_chain = (
|
99 |
+
{"context": retriever, "question": RunnablePassthrough()}
|
100 |
+
| llm_chain
|
101 |
+
)
|
102 |
+
if is_RAG== "RAG":
|
103 |
+
generation2=rag_chain.invoke(questions)
|
104 |
+
return generation2
|
105 |
+
else:
|
106 |
+
generation1=llm_chain.invoke({"context":"", "question": questions})
|
107 |
+
return generation1
|
108 |
+
elif model_name=="Qwen/Qwen2.5-72B-Instruct":
|
109 |
+
llm = HuggingFaceEndpoint(
|
110 |
+
repo_id=model_name,
|
111 |
+
max_length=512,
|
112 |
+
temperature=0.2,
|
113 |
+
huggingfacehub_api_token=secret_value_hf,
|
114 |
+
)
|
115 |
+
llm_chain = prompt | llm | StrOutputParser()
|
116 |
+
retriever = db.as_retriever()
|
117 |
+
|
118 |
+
rag_chain = (
|
119 |
+
{"context": retriever, "question": RunnablePassthrough()}
|
120 |
+
| llm_chain
|
121 |
+
)
|
122 |
+
if is_RAG== "RAG":
|
123 |
+
generation2=rag_chain.invoke(questions)
|
124 |
+
return generation2
|
125 |
+
else:
|
126 |
+
generation1=llm_chain.invoke({"context":"", "question": questions})
|
127 |
+
return generation1
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
# Custom CSS to style the output area
|
132 |
+
custom_css = """
|
133 |
+
#output_area {
|
134 |
+
background-color: #1e1e1e; /* Dark background */
|
135 |
+
color: #ffffff; /* White text */
|
136 |
+
padding: 10px;
|
137 |
+
border-radius: 5px;
|
138 |
+
border: 1px solid #333333; /* Dark border */
|
139 |
+
margin-top: 10px;
|
140 |
+
}
|
141 |
+
|
142 |
+
#output_area h3 {
|
143 |
+
color: #ffcc00; /* Yellow title color */
|
144 |
+
margin-bottom: 10px;
|
145 |
+
}
|
146 |
+
"""
|
147 |
+
|
148 |
+
with gr.Blocks(title="Ask Questions on Chalcogenide Perovskites",theme=gr.themes.Ocean(),css=custom_css) as demo:
|
149 |
+
gr.Markdown("""
|
150 |
+
## Retrieval-Augmented Generation for Chalcogenide Perovskites
|
151 |
+
This space implements Retrieval-Augmented Generation (RAG) using large language models, based on Hui Haolei's work on chalcogenide perovskite papers. You can select different models and choose whether to use RAG to enhance the responses.
|
152 |
+
""")
|
153 |
+
|
154 |
+
with gr.Row():
|
155 |
+
model_name = gr.Radio(
|
156 |
+
choices=["mistralai/Mistral-7B-Instruct-v0.2", "meta-llama/Llama-3.2-3B-Instruct", "Qwen/Qwen2.5-72B-Instruct"],
|
157 |
+
value="mistralai/Mistral-7B-Instruct-v0.2",
|
158 |
+
label="Model Name",
|
159 |
+
info="Select the model you want to use."
|
160 |
+
)
|
161 |
+
|
162 |
+
with gr.Row():
|
163 |
+
rag = gr.Radio(
|
164 |
choices=["RAG", "No RAG"],
|
165 |
+
value="RAG",
|
166 |
+
label="RAG or Not",
|
167 |
+
info="Choose whether to use Retrieval-Augmented Generation."
|
168 |
+
)
|
169 |
+
|
170 |
+
with gr.Row():
|
171 |
+
question = gr.Textbox(
|
172 |
+
label="Input Question",
|
173 |
+
placeholder="Enter your question about chalcogenide perovskites here...",
|
174 |
+
lines=2 # Increase the number of lines for better input experience
|
175 |
+
)
|
176 |
+
|
177 |
+
with gr.Row():
|
178 |
+
submit_button = gr.Button("Submit")
|
179 |
+
|
180 |
+
with gr.Row():
|
181 |
+
output = gr.Textbox(label="Response",
|
182 |
+
lines=10, # Increase the number of lines for the output area
|
183 |
+
elem_id="output_area" # Assign a custom ID for styling
|
184 |
+
)
|
185 |
+
submit_button.click(
|
186 |
+
fn=get_output,
|
187 |
+
inputs=[model_name, rag, question],
|
188 |
+
outputs=output
|
189 |
+
)
|
190 |
+
|
191 |
+
gr.Examples(
|
192 |
+
examples=[
|
193 |
+
["mistralai/Mistral-7B-Instruct-v0.2", "RAG", "What is the advantage of BaZrS3?"],
|
194 |
+
["mistralai/Mistral-7B-Instruct-v0.2", "RAG", "What is the bandgap of SrHfS3?"],
|
195 |
+
["mistralai/Mistral-7B-Instruct-v0.2", "RAG", "Why is chalcogenide perovskite important?"]
|
196 |
+
],
|
197 |
+
fn=get_output,
|
198 |
+
inputs=[model_name, rag, question],
|
199 |
+
outputs=output
|
200 |
+
)
|
201 |
+
|
202 |
+
demo.launch()
|