File size: 4,046 Bytes
f9e7dbf
746ca46
 
f9e7dbf
746ca46
 
49fba9e
746ca46
f9e7dbf
746ca46
d1789cc
 
 
 
9da61be
 
 
 
cc040f7
d1789cc
 
cc040f7
 
d1789cc
 
 
746ca46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c539b4
 
f9e7dbf
0ec2418
49fba9e
 
 
 
 
 
 
746ca46
49fba9e
746ca46
 
5c539b4
 
 
746ca46
5c539b4
746ca46
5c539b4
 
746ca46
5c539b4
746ca46
49fba9e
 
 
 
 
 
 
 
 
 
 
 
f9e7dbf
 
 
 
 
 
49fba9e
f9e7dbf
49fba9e
 
 
 
f9e7dbf
49fba9e
746ca46
 
 
 
81f7d00
 
 
 
746ca46
 
 
f9e7dbf
 
 
 
 
 
 
 
 
 
746ca46
 
 
 
 
 
f9e7dbf
 
 
49fba9e
 
746ca46
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import spaces
import os
import threading

import gradio as gr
from huggingface_hub import snapshot_download

from vptq.app_utils import get_chat_loop_generator

models = [
    {
        "name": "VPTQ-community/Meta-Llama-3.1-8B-Instruct-v12-k65536-4096-woft",
        "bits": "2.3 bits"
    },
    {
        "name": "VPTQ-community/Meta-Llama-3.1-8B-Instruct-v8-k65536-256-woft",
        "bits": "3 bits"
    },
    {
        "name": "VPTQ-community/Meta-Llama-3.1-8B-Instruct-v8-k65536-4096-woft",
        "bits": "3.5 bits"
    },
    {
        "name": "VPTQ-community/Meta-Llama-3.1-70B-Instruct-v8-k32768-0-woft",
        "bits": "1.85 bits"
    },
]

def initialize_history():
    """
    Initializes the GPU utilization and memory usage history.
    """
    for _ in range(100):
        gpu_info = get_gpu_info()
        gpu_util_history.append(round(gpu_info.get('gpu_util', 0), 1))
        mem_usage_history.append(round(gpu_info.get('mem_percent', 0), 1))


model_choices = [f"{model['name']} ({model['bits']})" for model in models]
display_to_model = {f"{model['name']} ({model['bits']})": model['name'] for model in models}


def download_model(model):
    print(f"Downloading {model['name']}...")
    snapshot_download(repo_id=model['name'])


def download_models_in_background():
    print('Downloading models for the first time...')
    for model in models:
        download_model(model)


download_thread = threading.Thread(target=download_models_in_background)
download_thread.start()

loaded_model = None
loaded_model_name = None

@spaces.GPU
def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    selected_model_display_label,
):
    model_name = display_to_model[selected_model_display_label]

    global loaded_model
    global loaded_model_name

    # Check if the model is already loaded
    if model_name is not loaded_model_name:
        # Load and store the model in the cache
        loaded_model = get_chat_loop_generator(model_name)
        loaded_model_name = model_name

    chat_completion = loaded_model 

    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ""

    for message in chat_completion(
            messages,
            max_tokens=max_tokens,
            stream=True,
            temperature=temperature,
            top_p=top_p,
    ):
        token = message

        response += token
        yield response


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
# enable_gpu_info()
with gr.Blocks(fill_height=True) as demo:
    # with gr.Row():
    #   def update_chart():
    #       return _update_charts(chart_height=200)
    #       gpu_chart = gr.Plot(update_chart, every=0.1)  # update every 0.1 seconds

    with gr.Column():
        chat_interface = gr.ChatInterface(
            respond,
            additional_inputs=[
                gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
                gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
                gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
                gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=0.95,
                    step=0.05,
                    label="Top-p (nucleus sampling)",
                ),
                gr.Dropdown(
                    choices=model_choices,
                    value=model_choices[0],
                    label="Select Model",
                ),
            ],
        )

if __name__ == "__main__":
    share = os.getenv("SHARE_LINK", None) in ["1", "true", "True"]
    demo.launch(share=share)