Linly-ChatFlow / app.py
wmpscc's picture
Update app.py
6f161b0
raw
history blame
2.78 kB
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
examples = ["Python和JavaScript编程语言的主要区别是什么?", "影响消费者行为的主要因素是什么?", "请用pytorch实现一个带ReLU激活函数的全连接层的代码",
"请用C++编程语言实现“给你两个字符串haystack和needle,在haystack字符串中找出needle字符串的第一个匹配项的下标(下标从 0 开始)。如果needle不是haystack的一部分,则返回-1。",
"如何使用ssh -L,请用具体例子说明", "应对压力最有效的方法是什么?"]
def init_model():
model = AutoModelForCausalLM.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", device_map="cuda:0",
torch_dtype=torch.bfloat16, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", use_fast=False, trust_remote_code=True)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=30.)
return model, tokenizer, streamer
def process(message, history):
input_prompt = ""
for interaction in history:
input_prompt = f"{input_prompt} User: {str(interaction[0]).strip(' ')} Bot: {str(interaction[1]).strip(' ')}"
input_prompt = f"{input_prompt} ### Instruction:{message.strip()} ### Response:"
inputs = tokenizer(input_prompt, return_tensors="pt").to("cuda:0")
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=2048, do_sample=True,
top_k=20, top_p=0.84, temperature=1.0, repetition_penalty=1.15, eos_token_id=2,
bos_token_id=1, pad_token_id=0)
try:
t = Thread(target=model.generate, kwargs=generation_kwargs)
t.start()
response = ""
for text in streamer:
response += text
yield response
print('-log:', response)
except Exception as e:
print('-error:', str(e))
return "Error: 遇到错误,请开启新的会话重新尝试~"
if __name__ == '__main__':
model, tokenizer, streamer = init_model()
demo = gr.ChatInterface(
process,
chatbot=gr.Chatbot(height=600, show_label=True, label="Linly"),
textbox=gr.Textbox(placeholder="Input", container=True, scale=7, lines=3, show_label=False),
title="Linly ChatFlow",
description="",
theme="soft",
examples=examples,
cache_examples=True,
retry_btn="Retry",
undo_btn="Delete Previous",
clear_btn="Clear",
)
demo.queue().launch()