File size: 2,222 Bytes
c1a12af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TextIteratorStreamer
import threading


model = AutoModelForCausalLM.from_pretrained(
    "RWKV-Red-Team/ARWKV-7B-Preview-0.1",
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(
    "RWKV-Red-Team/ARWKV-7B-Preview-0.1"
)
device = "cuda"


def convert_history_to_messages(history):
    messages = []
    for user_msg, bot_msg in history:
        messages.append({"role": "user", "content": user_msg})
        if bot_msg is not None:
            messages.append({"role": "assistant", "content": bot_msg})
    return messages


def stream_chat(prompt, history):

    messages = convert_history_to_messages(history)
    messages.append({"role": "user", "content": prompt})

    text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(device)

    streamer = TextIteratorStreamer(
        tokenizer, skip_prompt=True, skip_special_tokens=True
    )

    generation_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=4096,
        do_sample=True,
        temperature=1.5,
        top_p=0.2,
        top_k=0,
    )
    thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    response = ""
    for new_text in streamer:
        response += new_text
        yield history + [(prompt, response)]


with gr.Blocks() as demo:
    chatbot = gr.Chatbot(label="Chat with LLM", height=750)
    msg = gr.Textbox(label="Your Message")
    clear = gr.Button("Clear Chat")

    def user(user_message, history):
        return "", history + [[user_message, None]]

    def bot(history):
        prompt = history[-1][0]
        history[-1][1] = ""
        for updated_history in stream_chat(prompt, history[:-1]):
            yield updated_history

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, chatbot, chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)

demo.queue().launch(server_name="0.0.0.0")