ARWKV-7B-Preview-0.1-NoG / test_gradio.py
zhiyuan8's picture
Add files using upload-large-folder tool
b1f3dab verified
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TextIteratorStreamer
import threading
model = AutoModelForCausalLM.from_pretrained(
"RWKV-Red-Team/ARWKV-7B-Preview-0.1",
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(
"RWKV-Red-Team/ARWKV-7B-Preview-0.1"
)
device = "cuda"
def convert_history_to_messages(history):
messages = []
for user_msg, bot_msg in history:
messages.append({"role": "user", "content": user_msg})
if bot_msg is not None:
messages.append({"role": "assistant", "content": bot_msg})
return messages
def stream_chat(prompt, history):
messages = convert_history_to_messages(history)
messages.append({"role": "user", "content": prompt})
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
generation_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=4096,
do_sample=True,
temperature=1.5,
top_p=0.2,
top_k=0,
)
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield history + [(prompt, response)]
with gr.Blocks() as demo:
chatbot = gr.Chatbot(label="Chat with LLM", height=750)
msg = gr.Textbox(label="Your Message")
clear = gr.Button("Clear Chat")
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history):
prompt = history[-1][0]
history[-1][1] = ""
for updated_history in stream_chat(prompt, history[:-1]):
yield updated_history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
demo.queue().launch(server_name="0.0.0.0")