Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer | |
) | |
from peft import PeftModel | |
import torch | |
import time | |
model_path = "Qwen/Qwen1.5-1.8B-Chat" | |
lora_path = "AngoHF/EssayGPT" #+ "/checkpoint-100" | |
if torch.cuda.is_available(): | |
device = "cuda:0" | |
else: | |
device = "cpu" | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_path, | |
) | |
config_kwargs = {"device_map": device} | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype=torch.bfloat16, | |
**config_kwargs | |
) | |
model = PeftModel.from_pretrained(model, lora_path) | |
model = model.merge_and_unload() | |
model.eval() | |
model = torch.compile(model) | |
model.config.use_cache = True | |
MAX_MATERIALS = 4 | |
def call(related_materials, materials, question): | |
query_texts = [f"材料{i + 1}\n{material}" for i, material in enumerate(materials) if i in related_materials] | |
query_texts.append(f"问题:{question}") | |
query = "\n".join(query_texts) | |
messages = [ | |
{"role": "system", "content": "请你根据以下提供的材料来回答问题"}, | |
{"role": "user", "content": query} | |
] | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
model_inputs = tokenizer([text], return_tensors="pt").to(device) | |
print(f"Input Token Length: {len(model_inputs.input_ids[0])}") | |
start_time = time.time() | |
generated_ids = model.generate( | |
model_inputs.input_ids, | |
attention_mask=model_inputs.attention_mask, | |
do_sample=False, | |
max_length=8096 | |
) | |
print(f"Inference Cost Time: {time.time() - start_time}") | |
generated_ids = [ | |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
] | |
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return response | |
def create_ui(): | |
with gr.Blocks() as app: | |
gr.Markdown("""<center><font size=8>EssayGPT-申论大模型</center>""") | |
gr.Markdown( | |
"""<center><font size=4>1.把材料填入对应位置 2.输入问题和要求 3.选择解答问题需要的相关材料 4.点击"提问!"</center>""") | |
with gr.Row(): | |
with gr.Column(): | |
materials = [] | |
for i in range(MAX_MATERIALS): | |
with gr.Tab(f"材料{i + 1}"): | |
materials.append(gr.Textbox(label="材料内容")) | |
with gr.Column(): | |
related_materials = gr.Dropdown( | |
choices=list(range(1, MAX_MATERIALS + 1)), multiselect=True, | |
label="问题所需相关材料") | |
question = gr.Textbox(label="问题") | |
submit = gr.Button("提问!") | |
answer = gr.Textbox(label="回答") | |
build_ui({"materials": materials, "related_materials": related_materials, "question": question, | |
"submit": submit, "answer": answer}) | |
return app | |
def build_ui(components): | |
def func(related_materials, question, *materials): | |
if not related_materials: | |
return "请选择问题所需相关材料" | |
related_materials = [i - 1 for i in related_materials] | |
return call(related_materials, materials, question) | |
components["submit"].click(func, | |
[components["related_materials"], components["question"], *components["materials"]], | |
components["answer"]) | |
def run(): | |
app = create_ui() | |
app.queue() | |
app.launch() | |
if __name__ == '__main__': | |
run() | |