File size: 3,047 Bytes
231c5cd
 
 
b14da39
231c5cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b14da39
 
231c5cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b14da39
231c5cd
 
 
 
 
 
 
 
 
 
 
b14da39
231c5cd
 
b14da39
231c5cd
 
 
 
 
 
 
 
b14da39
231c5cd
b14da39
231c5cd
 
 
b14da39
231c5cd
 
 
 
 
 
 
 
b14da39
231c5cd
 
 
 
 
 
 
 
 
 
 
 
 
 
b14da39
 
 
 
 
 
231c5cd
b14da39
0fc087b
b14da39
9ef0b1a
b14da39
 
 
 
 
 
 
 
 
 
 
 
7fa14a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import sys
from threading import Thread

import gradio as gr
import spaces
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
    BitsAndBytesConfig,
)


import torch


MODEL = "microsoft/Phi-3.5-mini-instruct"

if torch.cuda.is_available():
    device = "cuda"
elif sys.platform == "darwin" and torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"


# TODO understand this
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(
    MODEL,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    quantization_config=quantization_config,
)


@spaces.GPU()
def stream_chat(
    message: str,
    history: list,
    system_prompt: str,
    temperature: float = 0.8,
    max_new_tokens: int = 1024,
    top_p: float = 1.0,
    top_k: int = 20,
    penalty: float = 1.2,
):
    print(f"message: {message}")
    print(f"history: {history}")

    conversation = [{"role": "system", "content": system_prompt}]
    for prompt, answer in history:
        conversation.extend(
            [
                {"role": "user", "content": prompt},
                {"role": "assistant", "content": answer},
            ]
        )

    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(
        conversation, add_generation_prompt=True, return_tensors="pt"
    ).to(model.device)

    streamer = TextIteratorStreamer(
        tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
    )

    generate_kwargs = dict(
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        do_sample=False if temperature == 0 else True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        eos_token_id=[128001, 128008, 128009],
        streamer=streamer,
    )

    with torch.no_grad():
        thread = Thread(target=model.generate, kwargs=generate_kwargs)
        thread.start()

    buffer = ""
    for new_text in streamer:
        buffer += new_text
        yield buffer


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    stream_chat,
    additional_inputs=[
        gr.Textbox(value="You are an ARM Assembly language decoder. You receive a line of Arm assembly and respond with a description of what the instruction does.", label="System message"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)


if __name__ == "__main__":
    demo.launch()