EssayGPTSpace / app.py
AngoHF's picture
Update app.py
1057753 verified
import gradio as gr
from transformers import (
AutoModelForCausalLM,
AutoTokenizer
)
from peft import PeftModel
import torch
import time
model_path = "Qwen/Qwen1.5-1.8B-Chat"
lora_path = "AngoHF/EssayGPT" #+ "/checkpoint-100"
if torch.cuda.is_available():
device = "cuda:0"
else:
device = "cpu"
tokenizer = AutoTokenizer.from_pretrained(
model_path,
)
config_kwargs = {"device_map": device}
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
**config_kwargs
)
model = PeftModel.from_pretrained(model, lora_path)
model = model.merge_and_unload()
model.eval()
model = torch.compile(model)
model.config.use_cache = True
MAX_MATERIALS = 4
def call(related_materials, materials, question):
query_texts = [f"材料{i + 1}\n{material}" for i, material in enumerate(materials) if i in related_materials]
query_texts.append(f"问题:{question}")
query = "\n".join(query_texts)
messages = [
{"role": "system", "content": "请你根据以下提供的材料来回答问题"},
{"role": "user", "content": query}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
print(f"Input Token Length: {len(model_inputs.input_ids[0])}")
start_time = time.time()
generated_ids = model.generate(
model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
do_sample=False,
max_length=8096
)
print(f"Inference Cost Time: {time.time() - start_time}")
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
def create_ui():
with gr.Blocks() as app:
gr.Markdown("""<center><font size=8>EssayGPT-申论大模型</center>""")
gr.Markdown(
"""<center><font size=4>1.把材料填入对应位置 2.输入问题和要求 3.选择解答问题需要的相关材料 4.点击"提问!"</center>""")
with gr.Row():
with gr.Column():
materials = []
for i in range(MAX_MATERIALS):
with gr.Tab(f"材料{i + 1}"):
materials.append(gr.Textbox(label="材料内容"))
with gr.Column():
related_materials = gr.Dropdown(
choices=list(range(1, MAX_MATERIALS + 1)), multiselect=True,
label="问题所需相关材料")
question = gr.Textbox(label="问题")
submit = gr.Button("提问!")
answer = gr.Textbox(label="回答")
build_ui({"materials": materials, "related_materials": related_materials, "question": question,
"submit": submit, "answer": answer})
return app
def build_ui(components):
def func(related_materials, question, *materials):
if not related_materials:
return "请选择问题所需相关材料"
related_materials = [i - 1 for i in related_materials]
return call(related_materials, materials, question)
components["submit"].click(func,
[components["related_materials"], components["question"], *components["materials"]],
components["answer"])
def run():
app = create_ui()
app.queue()
app.launch()
if __name__ == '__main__':
run()