File size: 7,456 Bytes
9e608cc
 
 
 
 
25f188a
a00ee3f
9e608cc
 
 
 
637af8c
 
 
9e608cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
950c25c
 
 
 
 
 
 
62a81e7
950c25c
9e608cc
 
 
950c25c
9e608cc
 
 
 
 
 
 
 
 
 
 
950c25c
 
9e608cc
950c25c
 
 
 
9e608cc
 
 
 
 
950c25c
9e608cc
 
 
 
950c25c
9e608cc
950c25c
9e608cc
 
 
 
950c25c
9e608cc
 
 
 
 
 
950c25c
 
 
 
 
 
 
9e608cc
 
 
 
 
 
950c25c
9e608cc
 
 
 
950c25c
 
 
 
 
 
 
 
e75173f
9e608cc
 
a00ee3f
 
 
 
 
 
 
 
 
 
 
 
 
6b22bf5
a00ee3f
 
 
 
 
 
9e608cc
 
 
 
950c25c
 
9e608cc
 
 
 
0754ecf
 
 
9e608cc
 
a00ee3f
 
 
 
 
 
 
 
 
 
 
 
17b76bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a00ee3f
 
9e608cc
 
c99ac11
9e608cc
 
 
 
 
 
 
 
950c25c
9e608cc
 
 
 
 
 
 
 
c99ac11
9e608cc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import gradio as gr
import time
import requests
import json
import os
from urllib3.util.retry import Retry
from requests.adapters import HTTPAdapter

API_URL = os.getenv("API_URL")
API_KEY = os.getenv("API_KEY")

print(f"API_URL: {API_URL}")
print(f"API_KEY: {API_KEY}")

url = f"{API_URL}/v1/chat/completions"

# The headers for the HTTP request
headers = {
    "accept": "application/json",
    "Content-Type": "application/json",
    "Authorization": f"Bearer {API_KEY}",
}


def is_valid_json(data):
    try:
        parsed_data = json.loads(data)
        return True, parsed_data
    except ValueError as e:
        return False, str(e)


with gr.Blocks() as demo:

    markup = gr.Markdown(
        """
                         # Mistral 7B Instruct v0.2
                         This is a demo of the Mistral 7B Instruct quantized model in GGUF (Q2) hosted on K8s cluster.

                         The original models can be found [MaziyarPanahi/Mistral-7B-Instruct-v0.2-GGUF](https://huggingface.co./MaziyarPanahi/Mistral-7B-Instruct-v0.2-GGUF)"""
    )
    chatbot = gr.Chatbot(height=500)
    msg = gr.Textbox(lines=1, label="User Message")
    clear = gr.Button("Clear")
    with gr.Row():

        with gr.Column(scale=2):
            system_prompt_input = gr.Textbox(
                label="System Prompt",
                placeholder="Type system prompt here...",
                value="You are a helpful assistant.",
            )
            temperature_input = gr.Slider(
                label="Temperature", minimum=0.0, maximum=1.0, value=0.9, step=0.01
            )
            max_new_tokens_input = gr.Slider(
                label="Max New Tokens", minimum=0, maximum=1024, value=256, step=1
            )

        with gr.Column(scale=2):
            top_p_input = gr.Slider(
                label="Top P", minimum=0.0, maximum=1.0, value=0.95, step=0.01
            )
            top_k_input = gr.Slider(
                label="Top K", minimum=1, maximum=100, value=50, step=1
            )
            repetition_penalty_input = gr.Slider(
                label="Repetition Penalty",
                minimum=1.0,
                maximum=2.0,
                value=1.1,
                step=0.01,
            )

    def update_globals(
        system_prompt, temperature, max_new_tokens, top_p, top_k, repetition_penalty
    ):
        global global_system_prompt, global_temperature, global_max_new_tokens, global_top_p, global_repetition_penalty, global_top_k
        global_system_prompt = system_prompt
        global_temperature = temperature
        global_max_new_tokens = max_new_tokens
        global_top_p = top_p
        global_top_k = top_k
        global_repetition_penalty = repetition_penalty

    def user(user_message, history):
        return "", history + [[user_message, None]]

    def bot(
        history,
        system_prompt,
        temperature,
        max_new_tokens,
        top_p,
        top_k,
        repetition_penalty,
    ):
        print(f"History in bot: {history}")
        print(f"System Prompt: {system_prompt}")
        print(f"Temperature: {temperature}")
        print(f"Max New Tokens: {max_new_tokens}")
        print(f"Top P: {top_p}")
        print(f"Top K: {top_k}")
        print(f"Repetition Penalty: {repetition_penalty}")

        history_messages = [{"content": h[0], "role": "user"} for h in history if h[0]]
        history[-1][1] = ""
        sys_msg = [
            {
                "content": (
                    system_prompt if system_prompt else "You are a helpful assistant."
                ),
                "role": "system",
            }
        ]
        history_messages = sys_msg + history_messages
        print(history_messages)

        # Create a session object
        session = requests.Session()

        # Define the retry strategy
        retries = Retry(
            total=5,  # Total number of retries to allow
            backoff_factor=1,  # A backoff factor to apply between attempts
            status_forcelist=[
                500,
                502,
                503,
                504,
            ],  # A set of HTTP status codes that we should force a retry on
            allowed_methods=[
                "HEAD",
                "GET",
                "OPTIONS",
                "POST",
            ],  # HTTP methods to retry on
        )
        data = {
            "messages": history_messages,
            "stream": True,
            "temprature": temperature,
            "top_k": top_k,
            "top_p": top_p,
            "seed": 42,
            "repeat_penalty": repetition_penalty,
            "chat_format": "mistral-instruct",
            "max_tokens": max_new_tokens,
            # "response_format": {
            #     "type": "json_object",
            # },
        }

        # Mount it for http usage
        session.mount("http://", HTTPAdapter(max_retries=retries))

        # Making the POST request with increased timeout and retry logic
        try:
            response = session.post(
                url,
                headers=headers,
                data=json.dumps(data),
                stream=True,
                timeout=(10, 30),
            )
            if response.status_code == 200:
                for line in response.iter_lines():
                    if line:
                        for line in response.iter_lines():
                            # Filter out keep-alive new lines
                            if line:
                                data = line.decode("utf-8").lstrip("data: ")
                                # Check if the examples are valid
                                valid_check = is_valid_json(data)
                                if valid_check[0]:
                                    try:
                                        # Attempt to parse the JSON dataa
                                        # json_data = json.loads(data)
                                        json_data = valid_check[1]

                                        delta_content = (
                                            json_data.get("choices", [{}])[0]
                                            .get("delta", {})
                                            .get("content", "")
                                        )

                                        if delta_content:  # Ensure there's content to print
                                            history[-1][1] += delta_content
                                            time.sleep(0.05)
                                            yield history
                                    except json.JSONDecodeError as e:
                                        print(f"Error decoding JSON: {e} date: {data}")
        except requests.exceptions.RequestException as e:
            print(f"An error occurred: {e}")

    msg.submit(
        user, [msg, chatbot], [msg, chatbot], queue=True, concurrency_limit=10
    ).then(
        bot,
        inputs=[
            chatbot,
            system_prompt_input,
            temperature_input,
            max_new_tokens_input,
            top_p_input,
            top_k_input,
            repetition_penalty_input,
        ],
        outputs=chatbot,
    )

    clear.click(lambda: None, None, chatbot, queue=False)


demo.queue(default_concurrency_limit=20, max_size=20, api_open=False)
if __name__ == "__main__":
    demo.launch(show_api=False, share=False)