edwardjiang commited on
Commit
a763090
·
1 Parent(s): f086839

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -319
app.py CHANGED
@@ -1,321 +1,10 @@
1
- import gradio as gr
2
- import argparse
3
- import torch
4
- import transformers
5
- from distutils.util import strtobool
6
- from tokenizers import pre_tokenizers
7
 
8
- from transformers.generation.utils import logger
9
- import mdtex2html
10
- import warnings
 
 
 
11
 
12
-
13
- logger.setLevel("ERROR")
14
- warnings.filterwarnings("ignore")
15
-
16
-
17
- warnings.filterwarnings("ignore")
18
-
19
-
20
- def _strtobool(x):
21
- return bool(strtobool(x))
22
-
23
-
24
- QA_SPECIAL_TOKENS = {
25
- "Question": "<|prompter|>",
26
- "Answer": "<|assistant|>",
27
- "System": "<|system|>",
28
- "StartPrefix": "<|prefix_begin|>",
29
- "EndPrefix": "<|prefix_end|>",
30
- "InnerThought": "<|inner_thoughts|>",
31
- "EndOfThought": "<eot>"
32
- }
33
-
34
-
35
- def format_pairs(pairs, eos_token, add_initial_reply_token=False):
36
- conversations = [
37
- "{}{}{}".format(
38
- QA_SPECIAL_TOKENS["Question" if i % 2 == 0 else "Answer"], pairs[i], eos_token)
39
- for i in range(len(pairs))
40
- ]
41
- if add_initial_reply_token:
42
- conversations.append(QA_SPECIAL_TOKENS["Answer"])
43
- return conversations
44
-
45
-
46
- def format_system_prefix(prefix, eos_token):
47
- return "{}{}{}".format(
48
- QA_SPECIAL_TOKENS["System"],
49
- prefix,
50
- eos_token,
51
- )
52
-
53
-
54
- def get_specific_model(
55
- model_name, seq2seqmodel=False, without_head=False, cache_dir=".cache", quantization=False, **kwargs
56
- ):
57
- # encoder-decoder support for Flan-T5 like models
58
- # for now, we can use an argument but in the future,
59
- # we can automate this
60
-
61
- model = transformers.LlamaForCausalLM.from_pretrained(model_name, **kwargs)
62
-
63
- return model
64
-
65
-
66
- parser = argparse.ArgumentParser()
67
- parser.add_argument("--model_path", type=str, required=True)
68
- parser.add_argument("--max_new_tokens", type=int, default=200)
69
- parser.add_argument("--top_k", type=int, default=40)
70
- parser.add_argument("--do_sample", type=_strtobool, default=True)
71
- # parser.add_argument("--system_prefix", type=str, default=None)
72
- parser.add_argument("--per-digit-tokens", action="store_true")
73
-
74
-
75
- args = parser.parse_args()
76
-
77
- # # 开放问答
78
- # system_prefix = \
79
- # "<|system|>"'''你是一个人工智能助手,名字叫EduChat。
80
- # - EduChat是一个由华东师范大学开发的对话式语言模型。
81
- # EduChat的工具
82
- # - Web search: Disable.
83
- # - Calculators: Disable.
84
- # EduChat的能力
85
- # - Inner Thought: Disable.
86
- # 对话主题
87
- # - General: Enable.
88
- # - Psychology: Disable.
89
- # - Socrates: Disable.'''"</s>"
90
-
91
- # # 启发式教学
92
- # system_prefix = \
93
- # "<|system|>"'''你是一个人工智能助手,名字叫EduChat。
94
- # - EduChat是一个由华东师范大学开发的对话式语言模型。
95
- # EduChat的工具
96
- # - Web search: Disable.
97
- # - Calculators: Disable.
98
- # EduChat的能力
99
- # - Inner Thought: Disable.
100
- # 对话主题
101
- # - General: Disable.
102
- # - Psychology: Disable.
103
- # - Socrates: Enable.'''"</s>"
104
-
105
- # 情感支持
106
- system_prefix = \
107
- "<|system|>"'''你是一个人工智能助手,名字叫EduChat。
108
- - EduChat是一个由华东师范大学开发的对话式语言模型。
109
- EduChat的工具
110
- - Web search: Disable.
111
- - Calculators: Disable.
112
- EduChat的能力
113
- - Inner Thought: Disable.
114
- 对话主题
115
- - General: Disable.
116
- - Psychology: Enable.
117
- - Socrates: Disable.'''"</s>"
118
-
119
- # # 情感支持(with InnerThought)
120
- # system_prefix = \
121
- # "<|system|>"'''你是一个人工智能助手,名字叫EduChat。
122
- # - EduChat是一个由华东师范大学开发的对话式语言模型。
123
- # EduChat的工具
124
- # - Web search: Disable.
125
- # - Calculators: Disable.
126
- # EduChat的能力
127
- # - Inner Thought: Enable.
128
- # 对话主题
129
- # - General: Disable.
130
- # - Psychology: Enable.
131
- # - Socrates: Disable.'''"</s>"
132
-
133
-
134
- print('Loading model...')
135
-
136
- model = get_specific_model("models/ecnu-icalk/educhat-sft-002-7b")
137
-
138
- model.half().cuda()
139
- model.gradient_checkpointing_enable() # reduce number of stored activations
140
-
141
- print('Loading tokenizer...')
142
- tokenizer = transformers.LlamaTokenizer.from_pretrained(args.model_path)
143
-
144
- tokenizer.add_special_tokens(
145
- {
146
- "pad_token": "</s>",
147
- "eos_token": "</s>",
148
- "sep_token": "<s>",
149
- }
150
- )
151
- additional_special_tokens = (
152
- []
153
- if "additional_special_tokens" not in tokenizer.special_tokens_map
154
- else tokenizer.special_tokens_map["additional_special_tokens"]
155
- )
156
- additional_special_tokens = list(
157
- set(additional_special_tokens + list(QA_SPECIAL_TOKENS.values())))
158
-
159
- print("additional_special_tokens:", additional_special_tokens)
160
-
161
- tokenizer.add_special_tokens(
162
- {"additional_special_tokens": additional_special_tokens})
163
-
164
- if args.per_digit_tokens:
165
- tokenizer._tokenizer.pre_processor = pre_tokenizers.Digits(True)
166
-
167
- human_token_id = tokenizer.additional_special_tokens_ids[
168
- tokenizer.additional_special_tokens.index(QA_SPECIAL_TOKENS["Question"])
169
- ]
170
-
171
- print('Type "quit" to exit')
172
- print("Press Control + C to restart conversation (spam to exit)")
173
-
174
- conversation_history = []
175
-
176
-
177
- """Override Chatbot.postprocess"""
178
-
179
-
180
- def postprocess(self, y):
181
- if y is None:
182
- return []
183
- for i, (message, response) in enumerate(y):
184
- y[i] = (
185
- None if message is None else mdtex2html.convert((message)),
186
- None if response is None else mdtex2html.convert(response),
187
- )
188
- return y
189
-
190
-
191
- gr.Chatbot.postprocess = postprocess
192
-
193
-
194
- def parse_text(text):
195
- """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
196
- lines = text.split("\n")
197
- lines = [line for line in lines if line != ""]
198
- count = 0
199
- for i, line in enumerate(lines):
200
- if "```" in line:
201
- count += 1
202
- items = line.split('`')
203
- if count % 2 == 1:
204
- lines[i] = f'<pre><code class="language-{items[-1]}">'
205
- else:
206
- lines[i] = f'<br></code></pre>'
207
- else:
208
- if i > 0:
209
- if count % 2 == 1:
210
- line = line.replace("`", "\`")
211
- line = line.replace("<", "&lt;")
212
- line = line.replace(">", "&gt;")
213
- line = line.replace(" ", "&nbsp;")
214
- line = line.replace("*", "&ast;")
215
- line = line.replace("_", "&lowbar;")
216
- line = line.replace("-", "&#45;")
217
- line = line.replace(".", "&#46;")
218
- line = line.replace("!", "&#33;")
219
- line = line.replace("(", "&#40;")
220
- line = line.replace(")", "&#41;")
221
- line = line.replace("$", "&#36;")
222
- lines[i] = "<br>"+line
223
- text = "".join(lines)
224
- return text
225
-
226
-
227
- def predict(input, chatbot, max_length, top_p, temperature, history):
228
- query = parse_text(input)
229
- chatbot.append((query, ""))
230
- conversation_history = []
231
- for i, (old_query, response) in enumerate(history):
232
- conversation_history.append(old_query)
233
- conversation_history.append(response)
234
-
235
- conversation_history.append(query)
236
-
237
- query_str = "".join(format_pairs(conversation_history,
238
- tokenizer.eos_token, add_initial_reply_token=True))
239
-
240
- if system_prefix:
241
- query_str = system_prefix + query_str
242
- print("query:", query_str)
243
-
244
- batch = tokenizer.encode(
245
- query_str,
246
- return_tensors="pt",
247
- )
248
-
249
- with torch.cuda.amp.autocast():
250
- out = model.generate(
251
- input_ids=batch.to(model.device),
252
- # The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
253
- max_new_tokens=args.max_new_tokens,
254
- do_sample=args.do_sample,
255
- max_length=max_length,
256
- top_k=args.top_k,
257
- top_p=top_p,
258
- temperature=temperature,
259
- eos_token_id=tokenizer.eos_token_id,
260
- pad_token_id=tokenizer.eos_token_id,
261
- )
262
-
263
- if out[0][-1] == tokenizer.eos_token_id:
264
- response = out[0][:-1]
265
- else:
266
- response = out[0]
267
-
268
- response = tokenizer.decode(out[0]).split(QA_SPECIAL_TOKENS["Answer"])[-1]
269
-
270
- conversation_history.append(response)
271
-
272
- with open("./educhat_query_record.txt", 'a+') as f:
273
- f.write(str(conversation_history) + '\n')
274
-
275
- chatbot[-1] = (query, parse_text(response))
276
- history = history + [(query, response)]
277
- print(f"chatbot is {chatbot}")
278
- print(f"history is {history}")
279
-
280
- return chatbot, history
281
-
282
-
283
- def reset_user_input():
284
- return gr.update(value='')
285
-
286
-
287
- def reset_state():
288
- return [], []
289
-
290
-
291
- with gr.Blocks() as demo:
292
- gr.HTML("""<h1 align="center">欢迎使用 EduChat 人工智能助手!</h1>""")
293
-
294
- chatbot = gr.Chatbot()
295
- with gr.Row():
296
- with gr.Column(scale=4):
297
- with gr.Column(scale=12):
298
- user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
299
- container=False)
300
- with gr.Column(min_width=32, scale=1):
301
- submitBtn = gr.Button("Submit", variant="primary")
302
- with gr.Column(scale=1):
303
- emptyBtn = gr.Button("Clear History")
304
- max_length = gr.Slider(
305
- 0, 2048, value=2048, step=1.0, label="Maximum length", interactive=True)
306
- top_p = gr.Slider(0, 1, value=0.2, step=0.01,
307
- label="Top P", interactive=True)
308
- temperature = gr.Slider(
309
- 0, 1, value=1, step=0.01, label="Temperature", interactive=True)
310
-
311
- history = gr.State([]) # (message, bot_message)
312
-
313
- submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
314
- show_progress=True)
315
- submitBtn.click(reset_user_input, [], [user_input])
316
-
317
- emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
318
-
319
- demo.queue().launch(inbrowser=True, share=True)
320
-
321
- #gr.Interface.load("models/ecnu-icalk/educhat-sft-002-7b").launch()
 
1
+ import os
 
 
 
 
 
2
 
3
+ os.system("pwd")
4
+ #os.system("nvidia-smi")
5
+ os.system(
6
+ f"git clone https://github.com/icalk-nlp/EduChat.git")
7
+ #os.system(f"cd /home/user/app/bitsandbytes && CUDA_VERSION=113 make cuda11x && python setup.py install")
8
+ #os.system(f"cd EduChat")
9
 
10
+ os.system(f"python launch.py")