sergey21000 commited on
Commit
a10dd76
·
verified ·
1 Parent(s): 508b4a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +262 -258
app.py CHANGED
@@ -1,259 +1,263 @@
1
- from pathlib import Path
2
- from shutil import rmtree
3
- from typing import Union, List, Dict, Tuple, Optional
4
- from tqdm import tqdm
5
-
6
- import requests
7
- import gradio as gr
8
- from llama_cpp import Llama
9
-
10
-
11
- # ================== ANNOTATIONS ========================
12
-
13
- CHAT_HISTORY = List[Tuple[Optional[str], Optional[str]]]
14
- MODEL_DICT = Dict[str, Llama]
15
-
16
-
17
- # ================== FUNCS =============================
18
-
19
- def download_file(file_url: str, file_path: Union[str, Path]) -> None:
20
- response = requests.get(file_url, stream=True)
21
- if response.status_code != 200:
22
- raise Exception(f'Файл недоступен для скачивания по ссылке: {file_url}')
23
- total_size = int(response.headers.get('content-length', 0))
24
- progress_tqdm = tqdm(desc='Loading GGUF file', total=total_size, unit='iB', unit_scale=True)
25
- progress_gradio = gr.Progress()
26
- completed_size = 0
27
- with open(file_path, 'wb') as file:
28
- for data in response.iter_content(chunk_size=4096):
29
- size = file.write(data)
30
- progress_tqdm.update(size)
31
- completed_size += size
32
- desc = f'Loading GGUF file, {completed_size/1024**3:.3f}/{total_size/1024**3:.3f} GB'
33
- progress_gradio(completed_size/total_size, desc=desc)
34
-
35
-
36
- def download_gguf_and_init_model(gguf_url: str, model_dict: MODEL_DICT) -> Tuple[MODEL_DICT, bool, str]:
37
- log = ''
38
- if not gguf_url.endswith('.gguf'):
39
- log += f'The link must be a direct link to the GGUF file\n'
40
- return model_dict, log
41
-
42
- gguf_filename = gguf_url.rsplit('/')[-1]
43
- model_path = MODELS_PATH / gguf_filename
44
- progress = gr.Progress()
45
-
46
- if not model_path.is_file():
47
- progress(0.3, desc='Шаг 1/2: Loading GGUF model file')
48
- try:
49
- download_file(gguf_url, model_path)
50
- log += f'Model file {gguf_filename} successfully loaded\n'
51
- except Exception as ex:
52
- log += f'Error loading model from link {gguf_url}, error code:\n{ex}\n'
53
- curr_model = model_dict.get('model')
54
- if curr_model is None:
55
- log += f'Model is missing from dictionary "model_dict"\n'
56
- return model_dict, load_log
57
- curr_model_filename = Path(curr_model.model_path).name
58
- log += f'Current initialized model: {curr_model_filename}\n'
59
- return model_dict, log
60
- else:
61
- log += f'Model file {gguf_filename} loaded, initializing model...\n'
62
-
63
- progress(0.7, desc='Шаг 2/2: Model initialization')
64
- model = Llama(model_path=str(model_path), n_gpu_layers=-1, verbose=True)
65
- model_dict = {'model': model}
66
- support_system_role = 'System role not supported' not in model.metadata['tokenizer.chat_template']
67
- log += f'Model {gguf_filename} initialized\n'
68
- return model_dict, support_system_role, log
69
-
70
-
71
- def user_message_to_chatbot(user_message: str, chatbot: CHAT_HISTORY) -> Tuple[str, CHAT_HISTORY]:
72
- if user_message:
73
- chatbot.append((user_message, None))
74
- return '', chatbot
75
-
76
-
77
- def bot_response_to_chatbot(
78
- chatbot: CHAT_HISTORY,
79
- model_dict: MODEL_DICT,
80
- system_prompt: str,
81
- support_system_role: bool,
82
- history_len: int,
83
- do_sample: bool,
84
- *generate_args,
85
- ):
86
-
87
- model = model_dict.get('model')
88
- user_message = chatbot[-1][0]
89
- messages = []
90
-
91
- gen_kwargs = dict(zip(GENERATE_KWARGS.keys(), generate_args))
92
- gen_kwargs['top_k'] = int(gen_kwargs['top_k'])
93
-
94
- if not do_sample:
95
- gen_kwargs['top_p'] = 0.0
96
- gen_kwargs['top_k'] = 1
97
- gen_kwargs['repeat_penalty'] = 1.0
98
-
99
- if support_system_role and system_prompt:
100
- messages.append({'role': 'system', 'content': system_prompt})
101
-
102
- if history_len != 0:
103
- for user_msg, bot_msg in chatbot[:-1][-history_len:]:
104
- print(user_msg, bot_msg)
105
- messages.append({'role': 'user', 'content': user_msg})
106
- messages.append({'role': 'assistant', 'content': bot_msg})
107
-
108
- messages.append({'role': 'user', 'content': user_message})
109
- stream_response = model.create_chat_completion(
110
- messages=messages,
111
- stream=True,
112
- **gen_kwargs,
113
- )
114
-
115
- chatbot[-1][1] = ''
116
- for chunk in stream_response:
117
- token = chunk['choices'][0]['delta'].get('content')
118
- if token is not None:
119
- chatbot[-1][1] += token
120
- yield chatbot
121
-
122
-
123
- def get_system_prompt_component(interactive: bool) -> gr.Textbox:
124
- value = '' if interactive else 'System prompt is not supported by this model'
125
- return gr.Textbox(value=value, label='System prompt', interactive=interactive)
126
-
127
-
128
- def get_generate_args(do_sample: bool) -> List[gr.component]:
129
- visible = do_sample
130
- generate_args = [
131
- gr.Slider(label='temperature', value=GENERATE_KWARGS['temperature'], minimum=0.1, maximum=3, step=0.1, visible=visible),
132
- gr.Slider(label='top_p', value=GENERATE_KWARGS['top_p'], minimum=0.1, maximum=1, step=0.1, visible=visible),
133
- gr.Slider(label='top_k', value=GENERATE_KWARGS['top_k'], minimum=1, maximum=50, step=5, visible=visible),
134
- gr.Slider(label='repeat_penalty', value=GENERATE_KWARGS['repeat_penalty'], minimum=1, maximum=5, step=0.1, visible=visible),
135
- ]
136
- return generate_args
137
-
138
-
139
- # ================== VARIABLES =============================
140
-
141
- MODELS_PATH = Path('models')
142
- MODELS_PATH.mkdir(exist_ok=True)
143
- DEFAULT_GGUF_URL = 'https://huggingface.co/bartowski/gemma-2-2b-it-GGUF/resolve/main/gemma-2-2b-it-Q8_0.gguf'
144
-
145
- start_model_dict, start_support_system_role, start_load_log = download_gguf_and_init_model(
146
- gguf_url=DEFAULT_GGUF_URL, model_dict={},
147
- )
148
-
149
- GENERATE_KWARGS = dict(
150
- temperature=0.2,
151
- top_p=0.95,
152
- top_k=40,
153
- repeat_penalty=1.0,
154
- )
155
-
156
- theme = gr.themes.Base(primary_hue='green', secondary_hue='yellow', neutral_hue='zinc').set(
157
- loader_color='rgb(0, 255, 0)',
158
- slider_color='rgb(0, 200, 0)',
159
- body_text_color_dark='rgb(0, 200, 0)',
160
- button_secondary_background_fill_dark='green',
161
- )
162
- css = '''.gradio-container {width: 60% !important}'''
163
-
164
-
165
- # ================== INTERFACE =============================
166
-
167
- with gr.Blocks(theme=theme, css=css) as interface:
168
- model_dict = gr.State(start_model_dict)
169
- support_system_role = gr.State(start_support_system_role)
170
-
171
- # ================= CHAT BOT PAGE ======================
172
- with gr.Tab('Chat bot'):
173
- with gr.Row():
174
- with gr.Column(scale=3):
175
- chatbot = gr.Chatbot(show_copy_button=True, bubble_full_width=False, height=480)
176
- user_message = gr.Textbox(label='User')
177
-
178
- with gr.Row():
179
- user_message_btn = gr.Button('Send')
180
- stop_btn = gr.Button('Stop')
181
- clear_btn = gr.Button('Clear')
182
-
183
- system_prompt = get_system_prompt_component(interactive=support_system_role.value)
184
-
185
- with gr.Column(scale=1, min_width=80):
186
- with gr.Group():
187
- gr.Markdown('Length of message history')
188
- history_len = gr.Slider(
189
- minimum=0,
190
- maximum=10,
191
- value=0,
192
- step=1,
193
- info='Number of previous messages taken into account in history',
194
- label='history_len',
195
- show_label=False,
196
- )
197
-
198
- with gr.Group():
199
- gr.Markdown('Generation parameters')
200
- do_sample = gr.Checkbox(
201
- value=False,
202
- label='do_sample',
203
- info='Activate random sampling',
204
- )
205
- generate_args = get_generate_args(do_sample.value)
206
- do_sample.change(
207
- fn=get_generate_args,
208
- inputs=do_sample,
209
- outputs=generate_args,
210
- show_progress=False,
211
- )
212
-
213
- generate_event = gr.on(
214
- triggers=[user_message.submit, user_message_btn.click],
215
- fn=user_message_to_chatbot,
216
- inputs=[user_message, chatbot],
217
- outputs=[user_message, chatbot],
218
- ).then(
219
- fn=bot_response_to_chatbot,
220
- inputs=[chatbot, model_dict, system_prompt, support_system_role, history_len, do_sample, *generate_args],
221
- outputs=[chatbot],
222
- )
223
- stop_btn.click(
224
- fn=None,
225
- inputs=None,
226
- outputs=None,
227
- cancels=generate_event,
228
- )
229
- clear_btn.click(
230
- fn=lambda: None,
231
- inputs=None,
232
- outputs=[chatbot],
233
- )
234
-
235
- # ================= LOAD MODELS PAGE ======================
236
- with gr.Tab('Load model'):
237
- gguf_url = gr.Textbox(
238
- value='',
239
- label='Link to GGUF',
240
- placeholder='URL link to the model in GGUF format',
241
- )
242
- load_model_btn = gr.Button('Downloading GGUF and initializing the model')
243
- load_log = gr.Textbox(
244
- value=start_load_log,
245
- label='Model loading status',
246
- lines=3,
247
- )
248
-
249
- load_model_btn.click(
250
- fn=download_gguf_and_init_model,
251
- inputs=[gguf_url, model_dict],
252
- outputs=[model_dict, support_system_role, load_log],
253
- ).success(
254
- fn=get_system_prompt_component,
255
- inputs=[support_system_role],
256
- outputs=[system_prompt],
257
- )
258
-
 
 
 
 
259
  interface.launch(server_name='0.0.0.0', server_port=7860)
 
1
+ from pathlib import Path
2
+ from shutil import rmtree
3
+ from typing import Union, List, Dict, Tuple, Optional
4
+ from tqdm import tqdm
5
+
6
+ import requests
7
+ import gradio as gr
8
+ from llama_cpp import Llama
9
+
10
+
11
+ # ================== ANNOTATIONS ========================
12
+
13
+ CHAT_HISTORY = List[Tuple[Optional[str], Optional[str]]]
14
+ MODEL_DICT = Dict[str, Llama]
15
+
16
+
17
+ # ================== FUNCS =============================
18
+
19
+ def download_file(file_url: str, file_path: Union[str, Path]) -> None:
20
+ response = requests.get(file_url, stream=True)
21
+ if response.status_code != 200:
22
+ raise Exception(f'Файл недоступен для скачивания по ссылке: {file_url}')
23
+ total_size = int(response.headers.get('content-length', 0))
24
+ progress_tqdm = tqdm(desc='Loading GGUF file', total=total_size, unit='iB', unit_scale=True)
25
+ progress_gradio = gr.Progress()
26
+ completed_size = 0
27
+ with open(file_path, 'wb') as file:
28
+ for data in response.iter_content(chunk_size=4096):
29
+ size = file.write(data)
30
+ progress_tqdm.update(size)
31
+ completed_size += size
32
+ desc = f'Loading GGUF file, {completed_size/1024**3:.3f}/{total_size/1024**3:.3f} GB'
33
+ progress_gradio(completed_size/total_size, desc=desc)
34
+
35
+
36
+ def download_gguf_and_init_model(gguf_url: str, model_dict: MODEL_DICT) -> Tuple[MODEL_DICT, bool, str]:
37
+ log = ''
38
+ if not gguf_url.endswith('.gguf'):
39
+ log += f'The link must be a direct link to the GGUF file\n'
40
+ return model_dict, log
41
+
42
+ gguf_filename = gguf_url.rsplit('/')[-1]
43
+ model_path = MODELS_PATH / gguf_filename
44
+ progress = gr.Progress()
45
+
46
+ if not model_path.is_file():
47
+ progress(0.3, desc='Шаг 1/2: Loading GGUF model file')
48
+ try:
49
+ download_file(gguf_url, model_path)
50
+ log += f'Model file {gguf_filename} successfully loaded\n'
51
+ except Exception as ex:
52
+ log += f'Error loading model from link {gguf_url}, error code:\n{ex}\n'
53
+ curr_model = model_dict.get('model')
54
+ if curr_model is None:
55
+ log += f'Model is missing from dictionary "model_dict"\n'
56
+ return model_dict, load_log
57
+ curr_model_filename = Path(curr_model.model_path).name
58
+ log += f'Current initialized model: {curr_model_filename}\n'
59
+ return model_dict, log
60
+ else:
61
+ log += f'Model file {gguf_filename} loaded, initializing model...\n'
62
+
63
+ progress(0.7, desc='Шаг 2/2: Model initialization')
64
+ model = Llama(model_path=str(model_path), n_gpu_layers=-1, verbose=True)
65
+ model_dict = {'model': model}
66
+ support_system_role = 'System role not supported' not in model.metadata['tokenizer.chat_template']
67
+ log += f'Model {gguf_filename} initialized\n'
68
+ return model_dict, support_system_role, log
69
+
70
+
71
+ def user_message_to_chatbot(user_message: str, chatbot: CHAT_HISTORY) -> Tuple[str, CHAT_HISTORY]:
72
+ if user_message:
73
+ chatbot.append((user_message, None))
74
+ return '', chatbot
75
+
76
+
77
+ def bot_response_to_chatbot(
78
+ chatbot: CHAT_HISTORY,
79
+ model_dict: MODEL_DICT,
80
+ system_prompt: str,
81
+ support_system_role: bool,
82
+ history_len: int,
83
+ do_sample: bool,
84
+ *generate_args,
85
+ ):
86
+
87
+ model = model_dict.get('model')
88
+ user_message = chatbot[-1][0]
89
+ messages = []
90
+
91
+ gen_kwargs = dict(zip(GENERATE_KWARGS.keys(), generate_args))
92
+ gen_kwargs['top_k'] = int(gen_kwargs['top_k'])
93
+
94
+ if not do_sample:
95
+ gen_kwargs['top_p'] = 0.0
96
+ gen_kwargs['top_k'] = 1
97
+ gen_kwargs['repeat_penalty'] = 1.0
98
+
99
+ if support_system_role and system_prompt:
100
+ messages.append({'role': 'system', 'content': system_prompt})
101
+
102
+ if history_len != 0:
103
+ for user_msg, bot_msg in chatbot[:-1][-history_len:]:
104
+ print(user_msg, bot_msg)
105
+ messages.append({'role': 'user', 'content': user_msg})
106
+ messages.append({'role': 'assistant', 'content': bot_msg})
107
+
108
+ messages.append({'role': 'user', 'content': user_message})
109
+ stream_response = model.create_chat_completion(
110
+ messages=messages,
111
+ stream=True,
112
+ **gen_kwargs,
113
+ )
114
+
115
+ chatbot[-1][1] = ''
116
+ for chunk in stream_response:
117
+ token = chunk['choices'][0]['delta'].get('content')
118
+ if token is not None:
119
+ chatbot[-1][1] += token
120
+ yield chatbot
121
+
122
+
123
+ def get_system_prompt_component(interactive: bool) -> gr.Textbox:
124
+ value = '' if interactive else 'System prompt is not supported by this model'
125
+ return gr.Textbox(value=value, label='System prompt', interactive=interactive)
126
+
127
+
128
+ def get_generate_args(do_sample: bool) -> List[gr.component]:
129
+ visible = do_sample
130
+ generate_args = [
131
+ gr.Slider(label='temperature', value=GENERATE_KWARGS['temperature'], minimum=0.1, maximum=3, step=0.1, visible=visible),
132
+ gr.Slider(label='top_p', value=GENERATE_KWARGS['top_p'], minimum=0.1, maximum=1, step=0.1, visible=visible),
133
+ gr.Slider(label='top_k', value=GENERATE_KWARGS['top_k'], minimum=1, maximum=50, step=5, visible=visible),
134
+ gr.Slider(label='repeat_penalty', value=GENERATE_KWARGS['repeat_penalty'], minimum=1, maximum=5, step=0.1, visible=visible),
135
+ ]
136
+ return generate_args
137
+
138
+
139
+ # ================== VARIABLES =============================
140
+
141
+ MODELS_PATH = Path('models')
142
+ MODELS_PATH.mkdir(exist_ok=True)
143
+ DEFAULT_GGUF_URL = 'https://huggingface.co/bartowski/gemma-2-2b-it-GGUF/resolve/main/gemma-2-2b-it-Q8_0.gguf'
144
+
145
+ start_model_dict, start_support_system_role, start_load_log = download_gguf_and_init_model(
146
+ gguf_url=DEFAULT_GGUF_URL, model_dict={},
147
+ )
148
+
149
+ GENERATE_KWARGS = dict(
150
+ temperature=0.2,
151
+ top_p=0.95,
152
+ top_k=40,
153
+ repeat_penalty=1.0,
154
+ )
155
+
156
+ theme = gr.themes.Base(primary_hue='green', secondary_hue='yellow', neutral_hue='zinc').set(
157
+ loader_color='rgb(0, 255, 0)',
158
+ slider_color='rgb(0, 200, 0)',
159
+ body_text_color_dark='rgb(0, 200, 0)',
160
+ button_secondary_background_fill_dark='green',
161
+ )
162
+ css = '''.gradio-container {width: 60% !important}'''
163
+
164
+
165
+ # ================== INTERFACE =============================
166
+
167
+ with gr.Blocks(theme=theme, css=css) as interface:
168
+ model_dict = gr.State(start_model_dict)
169
+ support_system_role = gr.State(start_support_system_role)
170
+
171
+ # ================= CHAT BOT PAGE ======================
172
+ with gr.Tab('Chat bot'):
173
+ with gr.Row():
174
+ with gr.Column(scale=3):
175
+ chatbot = gr.Chatbot(show_copy_button=True, bubble_full_width=False, height=480)
176
+ user_message = gr.Textbox(label='User')
177
+
178
+ with gr.Row():
179
+ user_message_btn = gr.Button('Send')
180
+ stop_btn = gr.Button('Stop')
181
+ clear_btn = gr.Button('Clear')
182
+
183
+ system_prompt = get_system_prompt_component(interactive=support_system_role.value)
184
+
185
+ with gr.Column(scale=1, min_width=80):
186
+ with gr.Group():
187
+ gr.Markdown('Length of message history')
188
+ history_len = gr.Slider(
189
+ minimum=0,
190
+ maximum=10,
191
+ value=0,
192
+ step=1,
193
+ info='Number of previous messages taken into account in history',
194
+ label='history_len',
195
+ show_label=False,
196
+ )
197
+
198
+ with gr.Group():
199
+ gr.Markdown('Generation parameters')
200
+ do_sample = gr.Checkbox(
201
+ value=False,
202
+ label='do_sample',
203
+ info='Activate random sampling',
204
+ )
205
+ generate_args = get_generate_args(do_sample.value)
206
+ do_sample.change(
207
+ fn=get_generate_args,
208
+ inputs=do_sample,
209
+ outputs=generate_args,
210
+ show_progress=False,
211
+ )
212
+
213
+ generate_event = gr.on(
214
+ triggers=[user_message.submit, user_message_btn.click],
215
+ fn=user_message_to_chatbot,
216
+ inputs=[user_message, chatbot],
217
+ outputs=[user_message, chatbot],
218
+ ).then(
219
+ fn=bot_response_to_chatbot,
220
+ inputs=[chatbot, model_dict, system_prompt, support_system_role, history_len, do_sample, *generate_args],
221
+ outputs=[chatbot],
222
+ )
223
+ stop_btn.click(
224
+ fn=None,
225
+ inputs=None,
226
+ outputs=None,
227
+ cancels=generate_event,
228
+ )
229
+ clear_btn.click(
230
+ fn=lambda: None,
231
+ inputs=None,
232
+ outputs=[chatbot],
233
+ )
234
+
235
+ # ================= LOAD MODELS PAGE ======================
236
+ with gr.Tab('Load model'):
237
+ gguf_url = gr.Textbox(
238
+ value='',
239
+ label='Link to GGUF',
240
+ placeholder='URL link to the model in GGUF format',
241
+ )
242
+ load_model_btn = gr.Button('Downloading GGUF and initializing the model')
243
+ load_log = gr.Textbox(
244
+ value=start_load_log,
245
+ label='Model loading status',
246
+ lines=3,
247
+ )
248
+
249
+ load_model_btn.click(
250
+ fn=download_gguf_and_init_model,
251
+ inputs=[gguf_url, model_dict],
252
+ outputs=[model_dict, support_system_role, load_log],
253
+ ).success(
254
+ fn=get_system_prompt_component,
255
+ inputs=[support_system_role],
256
+ outputs=[system_prompt],
257
+ )
258
+
259
+ gr.HTML("""<h3 style='text-align: center'>
260
+ <a href="https://github.com/sergey21000/gradio-llamacpp-chatbot" target='_blank'>GitHub Page</a></h3>
261
+ """)
262
+
263
  interface.launch(server_name='0.0.0.0', server_port=7860)