BlinkDL commited on
Commit
5a6a14d
·
verified ·
1 Parent(s): 4bd612e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -314
app.py CHANGED
@@ -1,59 +1,38 @@
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- import os, gc, copy, torch, re
 
 
 
3
  from datetime import datetime
 
4
  from huggingface_hub import hf_hub_download
5
  from pynvml import *
6
  nvmlInit()
7
  gpu_h = nvmlDeviceGetHandleByIndex(0)
8
- ctx_limit = 1536
9
- gen_limit = 500
10
- gen_limit_long = 800
11
- title = "RWKV-x060-World-7B-v3-20241112-ctx4096"
12
 
13
- os.environ["RWKV_JIT_ON"] = '1'
14
- os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
15
-
16
- from rwkv.model import RWKV
17
-
18
- model_path = hf_hub_download(repo_id="BlinkDL/rwkv-6-world", filename=f"{title}.pth")
19
- model = RWKV(model=model_path, strategy='cuda fp16i8 *8 -> cuda fp16')
20
- # model_path = '/mnt/e/RWKV-Runner/models/rwkv-final-v6-2.1-7b' # conda activate torch2; cd /mnt/program/_RWKV_/_ref_/_gradio_/RWKV-Gradio-2; python app_tab.py
21
- # model = RWKV(model=model_path, strategy='cuda fp16i8 *8 -> cuda fp16')
22
 
 
23
  from rwkv.utils import PIPELINE, PIPELINE_ARGS
24
- pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
25
-
26
- args = model.args
27
- eng_name = 'rwkv6-world-v3-7b-eng_QA-20241114-ctx2048'
28
- eng_file = hf_hub_download(repo_id="BlinkDL/rwkv-6-misc", filename=f"states/{eng_name}.pth")
29
- state_eng_raw = torch.load(eng_file)
30
- state_eng = [None] * args.n_layer * 3
31
 
32
- chn_name = 'rwkv6-world-v3-7b-chn_问答QA-20241114-ctx2048'
33
- chn_file = hf_hub_download(repo_id="BlinkDL/rwkv-6-misc", filename=f"states/{chn_name}.pth")
34
- state_chn_raw = torch.load(chn_file)
35
- state_chn = [None] * args.n_layer * 3
36
 
37
- wyw_name = 'rwkv6-world-v3-7b-chn_文言文QA-20241114-ctx2048'
38
- wyw_file = hf_hub_download(repo_id="BlinkDL/rwkv-6-misc", filename=f"states/{wyw_name}.pth")
39
- state_wyw_raw = torch.load(wyw_file)
40
- state_wyw = [None] * args.n_layer * 3
41
 
42
- for i in range(args.n_layer):
43
- dd = model.strategy[i]
44
- dev = dd.device
45
- atype = dd.atype
46
- state_eng[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
47
- state_eng[i*3+1] = state_eng_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
48
- state_eng[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
49
-
50
- state_chn[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
51
- state_chn[i*3+1] = state_chn_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
52
- state_chn[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
53
-
54
- state_wyw[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
55
- state_wyw[i*3+1] = state_wyw_raw[f'blocks.{i}.att.time_state'].transpose(1,2).to(dtype=torch.float, device=dev).requires_grad_(False).contiguous()
56
- state_wyw[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
57
 
58
  def generate_prompt(instruction, input=""):
59
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
@@ -61,22 +40,20 @@ def generate_prompt(instruction, input=""):
61
  if input:
62
  return f"""Instruction: {instruction}\n\nInput: {input}\n\nResponse:"""
63
  else:
64
- return f"""User: hi\n\nAssistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\nUser: {instruction}\n\nAssistant:"""
65
 
66
  def qa_prompt(instruction):
67
  instruction = instruction.strip().replace('\r\n','\n')
68
  instruction = re.sub(r'\n+', '\n', instruction)
69
  return f"User: {instruction}\n\nAssistant:"""
70
 
71
- penalty_decay = 0.996
72
-
73
  def evaluate(
74
  ctx,
75
- token_count=gen_limit,
76
  temperature=1.0,
77
- top_p=0.3,
78
- presencePenalty = 0.3,
79
- countPenalty = 0.3,
80
  ):
81
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
82
  alpha_frequency = countPenalty,
@@ -90,186 +67,39 @@ def evaluate(
90
  occurrence = {}
91
  state = None
92
  for i in range(int(token_count)):
93
- out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
94
- for n in occurrence:
95
- out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
96
-
97
- token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
98
- if token in args.token_stop:
99
- break
100
- all_tokens += [token]
101
- for xxx in occurrence:
102
- occurrence[xxx] *= penalty_decay
103
- if token not in occurrence:
104
- occurrence[token] = 1
105
- else:
106
- occurrence[token] += 1
107
-
108
- tmp = pipeline.decode(all_tokens[out_last:])
109
- if '\ufffd' not in tmp:
110
- out_str += tmp
111
- yield out_str.strip()
112
- out_last = i + 1
113
-
114
- gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
115
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
116
- print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
117
- del out
118
- del state
119
- gc.collect()
120
- torch.cuda.empty_cache()
121
- yield out_str.strip()
122
-
123
- def evaluate_eng(
124
- ctx,
125
- token_count=gen_limit,
126
- temperature=1.0,
127
- top_p=0.3,
128
- presencePenalty=0.3,
129
- countPenalty=0.3,
130
- ):
131
- args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
132
- alpha_frequency = countPenalty,
133
- alpha_presence = presencePenalty,
134
- token_ban = [], # ban the generation of some tokens
135
- token_stop = [0]) # stop generation whenever you see any token here
136
- ctx = qa_prompt(ctx)
137
- all_tokens = []
138
- out_last = 0
139
- out_str = ''
140
- occurrence = {}
141
- state = copy.deepcopy(state_eng)
142
- for i in range(int(token_count)):
143
- out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
144
- for n in occurrence:
145
- out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
146
-
147
- token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
148
- if token in args.token_stop:
149
- break
150
- all_tokens += [token]
151
- for xxx in occurrence:
152
- occurrence[xxx] *= penalty_decay
153
- if token not in occurrence:
154
- occurrence[token] = 1
155
- else:
156
- occurrence[token] += 1
157
-
158
- tmp = pipeline.decode(all_tokens[out_last:])
159
- if '\ufffd' not in tmp:
160
- out_str += tmp
161
- yield out_str.strip()
162
- out_last = i + 1
163
- if '\n\n' in out_str:
164
- break
165
-
166
- gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
167
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
168
- print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
169
- del out
170
- del state
171
- gc.collect()
172
- torch.cuda.empty_cache()
173
- yield out_str.strip()
174
-
175
- def evaluate_chn(
176
- ctx,
177
- token_count=gen_limit,
178
- temperature=1.0,
179
- top_p=0.3,
180
- presencePenalty=0.3,
181
- countPenalty=0.3,
182
- ):
183
- args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
184
- alpha_frequency = countPenalty,
185
- alpha_presence = presencePenalty,
186
- token_ban = [], # ban the generation of some tokens
187
- token_stop = [0]) # stop generation whenever you see any token here
188
- ctx = qa_prompt(ctx)
189
- all_tokens = []
190
- out_last = 0
191
- out_str = ''
192
- occurrence = {}
193
- state = copy.deepcopy(state_chn)
194
- for i in range(int(token_count)):
195
- out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
196
- for n in occurrence:
197
- out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
198
-
199
- token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
200
- if token in args.token_stop:
201
- break
202
- all_tokens += [token]
203
- for xxx in occurrence:
204
- occurrence[xxx] *= penalty_decay
205
- if token not in occurrence:
206
- occurrence[token] = 1
207
- else:
208
- occurrence[token] += 1
209
-
210
- tmp = pipeline.decode(all_tokens[out_last:])
211
- if '\ufffd' not in tmp:
212
- out_str += tmp
213
- yield out_str.strip()
214
- out_last = i + 1
215
- if '\n\n' in out_str:
216
- break
217
 
218
- gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
219
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
220
- print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
221
- del out
222
- del state
223
- gc.collect()
224
- torch.cuda.empty_cache()
225
- yield out_str.strip()
226
-
227
- def evaluate_wyw(
228
- ctx,
229
- token_count=gen_limit,
230
- temperature=1.0,
231
- top_p=0.3,
232
- presencePenalty=0.3,
233
- countPenalty=0.3,
234
- ):
235
- args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
236
- alpha_frequency = countPenalty,
237
- alpha_presence = presencePenalty,
238
- token_ban = [], # ban the generation of some tokens
239
- token_stop = [0]) # stop generation whenever you see any token here
240
- ctx = qa_prompt(ctx)
241
- all_tokens = []
242
- out_last = 0
243
- out_str = ''
244
- occurrence = {}
245
- state = copy.deepcopy(state_wyw)
246
- for i in range(int(token_count)):
247
- out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
248
  for n in occurrence:
249
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
250
 
251
- token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
252
  if token in args.token_stop:
253
  break
254
  all_tokens += [token]
255
  for xxx in occurrence:
256
  occurrence[xxx] *= penalty_decay
 
 
 
 
 
 
 
257
  if token not in occurrence:
258
- occurrence[token] = 1
259
  else:
260
- occurrence[token] += 1
261
-
262
- tmp = pipeline.decode(all_tokens[out_last:])
263
  if '\ufffd' not in tmp:
264
  out_str += tmp
265
  yield out_str.strip()
266
  out_last = i + 1
267
- if '\n\n' in out_str:
268
- break
269
 
270
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
271
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
272
- print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
273
  del out
274
  del state
275
  gc.collect()
@@ -277,132 +107,44 @@ def evaluate_wyw(
277
  yield out_str.strip()
278
 
279
  examples = [
280
- ["Assistant: How can we craft an engaging story featuring vampires on Mars? Let's think step by step and provide an expert response.", gen_limit, 1, 0.3, 0.5, 0.5],
281
- ["Assistant: How can we persuade Elon Musk to follow you on Twitter? Let's think step by step and provide an expert response.", gen_limit, 1, 0.3, 0.5, 0.5],
 
 
282
  [generate_prompt("東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。"), gen_limit, 1, 0.3, 0.5, 0.5],
283
  [generate_prompt("Write a story using the following information.", "A man named Alex chops a tree down."), gen_limit, 1, 0.3, 0.5, 0.5],
284
  ["A few light taps upon the pane made her turn to the window. It had begun to snow again.", gen_limit, 1, 0.3, 0.5, 0.5],
285
  ['''Edward: I am Edward Elric from Fullmetal Alchemist.\n\nUser: Hello Edward. What have you been up to recently?\n\nEdward:''', gen_limit, 1, 0.3, 0.5, 0.5],
286
- [generate_prompt("Write a simple webpage. When a user clicks the button, it shows a random joke from a list of 4 jokes."), 500, 1, 0.3, 0.5, 0.5],
287
- ['''Japanese: 春の初め、桜の花が満開になる頃、小さな町の片隅にある古びた神社の境内は、特別な雰囲気に包まれていた。\n\nEnglish:''', gen_limit, 1, 0.3, 0.5, 0.5],
288
  ["En una pequeña aldea escondida entre las montañas de Andalucía, donde las calles aún conservaban el eco de antiguas leyendas, vivía un joven llamado Alejandro.", gen_limit, 1, 0.3, 0.5, 0.5],
289
  ["Dans le cœur battant de Paris, sous le ciel teinté d'un crépuscule d'or et de pourpre, se tenait une petite librairie oubliée par le temps.", gen_limit, 1, 0.3, 0.5, 0.5],
290
  ["في تطور مذهل وغير مسبوق، أعلنت السلطات المحلية في العاصمة عن اكتشاف أثري قد يغير مجرى التاريخ كما نعرفه.", gen_limit, 1, 0.3, 0.5, 0.5],
291
  ['''“当然可以,大宇宙不会因为这五公斤就不坍缩了。”关一帆说,他还有一个没说出来的想法:也许大宇宙真的会因为相差一个原子的质量而由封闭转为开放。大��然的精巧有时超出想象,比如生命的诞生,就需要各项宇宙参数在几亿亿分之一精度上的精确配合。但程心仍然可以留下她的生态球,因为在那无数文明创造的无数小宇宙中,肯定有相当一部分不响应回归运动的号召,所以,大宇宙最终被夺走的质量至少有几亿吨,甚至可能是几亿亿亿吨。\n但愿大宇宙能够忽略这个误差。\n程心和关一帆进入了飞船,智子最后也进来了。她早就不再穿那身华丽的和服了,她现在身着迷彩服,再次成为一名轻捷精悍的战士,她的身上佩带着许多武器和生存装备,最引人注目的是那把插在背后的武士刀。\n“放心,我在,你们就在!”智子对两位人类朋友说。\n聚变发动机启动了,推进器发出幽幽的蓝光,''', gen_limit, 1, 0.3, 0.5, 0.5],
292
  ]
293
 
294
- examples_eng = [
295
- ["How can I craft an engaging story featuring vampires on Mars?", gen_limit_long, 1, 0.2, 0.3, 0.3],
296
- ["Compare the business models of Apple and Google.", gen_limit_long, 1, 0.2, 0.3, 0.3],
297
- ["In JSON format, list the top 5 tourist attractions in Paris.", gen_limit_long, 1, 0.2, 0.3, 0.3],
298
- ["Write an outline for a fantasy novel where dreams can alter reality.", gen_limit_long, 1, 0.2, 0.3, 0.3],
299
- ["Can fish get thirsty?", gen_limit_long, 1, 0.2, 0.3, 0.3],
300
- ["Write a Bash script to check disk usage and send alerts if it's too high.", gen_limit_long, 1, 0.2, 0.3, 0.3],
301
- ["Write a simple website in HTML. When a user clicks the button, it shows a random joke from a list of 4 jokes.", gen_limit_long, 1, 0.2, 0.3, 0.3],
302
- ]
303
-
304
- examples_chn = [
305
- ["怎样写一个在火星上的吸血鬼的有趣故事?", gen_limit_long, 1, 0.2, 0.3, 0.3],
306
- ["比较苹果和谷歌的商业模式。", gen_limit_long, 1, 0.2, 0.3, 0.3],
307
- ["鱼会口渴吗?", gen_limit_long, 1, 0.2, 0.3, 0.3],
308
- ["以 JSON 格式解释冰箱是如何工作的。", gen_limit_long, 1, 0.2, 0.3, 0.3],
309
- ["编写一个Bash脚本来检查磁盘使用情况,如果使用量过高则发送警报。", gen_limit_long, 1, 0.2, 0.3, 0.3],
310
- ["用HTML编写一个简单的网站。当用户点击按钮时,从4个笑话的列表中随机显示一个笑话。", gen_limit_long, 1, 0.2, 0.3, 0.3],
311
- ]
312
-
313
- examples_wyw = [
314
- ["我和前男友分手了", gen_limit_long, 1, 0.2, 0.3, 0.3],
315
- ["量子计算机的原理", gen_limit_long, 1, 0.2, 0.3, 0.3],
316
- ["李白和杜甫的结拜故事", gen_limit_long, 1, 0.2, 0.3, 0.3],
317
- ["林黛玉和伏地魔的关系是什么?", gen_limit_long, 1, 0.2, 0.3, 0.3],
318
- ["我被同事陷害了,帮我写一篇文言文骂他", gen_limit_long, 1, 0.2, 0.3, 0.3],
319
- ]
320
-
321
- ##########################################################################
322
-
323
- with gr.Blocks(title=title) as demo:
324
- gr.HTML(f"<div style=\"text-align: center;\">\n<h1>{title}</h1>\n</div>")
325
 
326
  with gr.Tab("=== Base Model (Raw Generation) ==="):
327
- gr.Markdown(f"This is [RWKV-6](https://huggingface.co/BlinkDL/rwkv-6-world) base model. Supports 100+ world languages and code. RWKV is a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM), and we have [300+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). Demo limited to ctxlen {ctx_limit}.")
328
  with gr.Row():
329
  with gr.Column():
330
- prompt = gr.Textbox(lines=2, label="Raw Input", value="Assistant: How can we craft an engaging story featuring vampires on Mars? Let's think step by step and provide an expert response.")
331
  token_count = gr.Slider(10, gen_limit, label="Max Tokens", step=10, value=gen_limit)
332
  temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
333
  top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.3)
334
  presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.5)
335
  count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.5)
336
- with gr.Column():
337
- with gr.Row():
338
- submit = gr.Button("Submit", variant="primary")
339
- clear = gr.Button("Clear", variant="secondary")
340
- output = gr.Textbox(label="Output", lines=30)
341
- data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples, samples_per_page=50, label="Examples", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
342
- submit.click(evaluate, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
343
- clear.click(lambda: None, [], [output])
344
- data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
345
-
346
- with gr.Tab("=== English Q/A ==="):
347
- gr.Markdown(f"This is [RWKV-6](https://huggingface.co/BlinkDL/rwkv-6-world) state-tuned to [English Q/A](https://huggingface.co/BlinkDL/rwkv-6-misc/tree/main/states). RWKV is a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM), and we have [300+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). Demo limited to ctxlen {ctx_limit}.")
348
- with gr.Row():
349
- with gr.Column():
350
- prompt = gr.Textbox(lines=2, label="Prompt", value="How can I craft an engaging story featuring vampires on Mars?")
351
- token_count = gr.Slider(10, gen_limit_long, label="Max Tokens", step=10, value=gen_limit_long)
352
- temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
353
- top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.2)
354
- presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.3)
355
- count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.3)
356
  with gr.Column():
357
  with gr.Row():
358
  submit = gr.Button("Submit", variant="primary")
359
  clear = gr.Button("Clear", variant="secondary")
360
  output = gr.Textbox(label="Output", lines=20, max_lines=100)
361
- data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples_eng, samples_per_page=50, label="Examples", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
362
- submit.click(evaluate_eng, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
363
  clear.click(lambda: None, [], [output])
364
  data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
365
 
366
- with gr.Tab("=== Chinese Q/A ==="):
367
- gr.Markdown(f"This is [RWKV-6](https://huggingface.co/BlinkDL/rwkv-6-world) state-tuned to [Chinese Q/A](https://huggingface.co/BlinkDL/rwkv-6-misc/tree/main/states). RWKV is a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM), and we have [300+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). Demo limited to ctxlen {ctx_limit}.")
368
- with gr.Row():
369
- with gr.Column():
370
- prompt = gr.Textbox(lines=2, label="Prompt", value="怎样写一个在火星上的吸血鬼的有趣故事?")
371
- token_count = gr.Slider(10, gen_limit_long, label="Max Tokens", step=10, value=gen_limit_long)
372
- temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
373
- top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.2)
374
- presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.3)
375
- count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.3)
376
- with gr.Column():
377
- with gr.Row():
378
- submit = gr.Button("Submit", variant="primary")
379
- clear = gr.Button("Clear", variant="secondary")
380
- output = gr.Textbox(label="Output", lines=20, max_lines=100)
381
- data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples_chn, samples_per_page=50, label="Examples", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
382
- submit.click(evaluate_chn, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
383
- clear.click(lambda: None, [], [output])
384
- data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
385
-
386
- with gr.Tab("=== WenYanWen Q/A ==="):
387
- gr.Markdown(f"This is [RWKV-6](https://huggingface.co/BlinkDL/rwkv-6-world) state-tuned to [WenYanWen 文言文 Q/A](https://huggingface.co/BlinkDL/rwkv-6-misc/tree/main/states). RWKV is a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM), and we have [300+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). Demo limited to ctxlen {ctx_limit}.")
388
- with gr.Row():
389
- with gr.Column():
390
- prompt = gr.Textbox(lines=2, label="Prompt", value="我和前男友分手了")
391
- token_count = gr.Slider(10, gen_limit_long, label="Max Tokens", step=10, value=gen_limit_long)
392
- temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
393
- top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.2)
394
- presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.3)
395
- count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.3)
396
- with gr.Column():
397
- with gr.Row():
398
- submit = gr.Button("Submit", variant="primary")
399
- clear = gr.Button("Clear", variant="secondary")
400
- output = gr.Textbox(label="Output", lines=20, max_lines=100)
401
- data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples_wyw, samples_per_page=50, label="Examples", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
402
- submit.click(evaluate_wyw, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
403
- clear.click(lambda: None, [], [output])
404
- data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
405
-
406
-
407
  demo.queue(concurrency_count=1, max_size=10)
408
- demo.launch(share=False)
 
1
+ import os, copy
2
+ os.environ["RWKV_V7_ON"] = '1'
3
+ os.environ["RWKV_JIT_ON"] = '1'
4
+ os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
5
+
6
+ from rwkv.model import RWKV
7
+
8
+ import gc, re
9
  import gradio as gr
10
+ import base64
11
+ from io import BytesIO
12
+ import torch
13
+ import torch.nn.functional as F
14
  from datetime import datetime
15
+ from transformers import CLIPImageProcessor
16
  from huggingface_hub import hf_hub_download
17
  from pynvml import *
18
  nvmlInit()
19
  gpu_h = nvmlDeviceGetHandleByIndex(0)
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
21
 
22
+ ctx_limit = 4000
23
+ gen_limit = 1000
 
 
 
 
 
 
 
24
 
25
+ ########################## text rwkv ################################################################
26
  from rwkv.utils import PIPELINE, PIPELINE_ARGS
 
 
 
 
 
 
 
27
 
28
+ title_v6 = "rwkv7-g1-0.1b-20250307-ctx4096"
29
+ model_path_v6 = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title_v6}.pth")
30
+ model_v6 = RWKV(model=model_path_v6.replace('.pth',''), strategy='cuda fp16')
31
+ pipeline_v6 = PIPELINE(model_v6, "rwkv_vocab_v20230424")
32
 
33
+ args = model_v6.args
 
 
 
34
 
35
+ penalty_decay = 0.996
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  def generate_prompt(instruction, input=""):
38
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
 
40
  if input:
41
  return f"""Instruction: {instruction}\n\nInput: {input}\n\nResponse:"""
42
  else:
43
+ return f"""User: {instruction}\n\nAssistant:"""
44
 
45
  def qa_prompt(instruction):
46
  instruction = instruction.strip().replace('\r\n','\n')
47
  instruction = re.sub(r'\n+', '\n', instruction)
48
  return f"User: {instruction}\n\nAssistant:"""
49
 
 
 
50
  def evaluate(
51
  ctx,
52
+ token_count=200,
53
  temperature=1.0,
54
+ top_p=0.7,
55
+ presencePenalty = 0.1,
56
+ countPenalty = 0.1,
57
  ):
58
  args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
59
  alpha_frequency = countPenalty,
 
67
  occurrence = {}
68
  state = None
69
  for i in range(int(token_count)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ input_ids = pipeline_v6.encode(ctx)[-ctx_limit:] if i == 0 else [token]
72
+ out, state = model_v6.forward(input_ids, state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  for n in occurrence:
74
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
75
 
76
+ token = pipeline_v6.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
77
  if token in args.token_stop:
78
  break
79
  all_tokens += [token]
80
  for xxx in occurrence:
81
  occurrence[xxx] *= penalty_decay
82
+
83
+ ttt = pipeline_v6.decode([token])
84
+ www = 1
85
+ if ttt in ' \t0123456789':
86
+ www = 0
87
+ #elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
88
+ # www = 0.5
89
  if token not in occurrence:
90
+ occurrence[token] = www
91
  else:
92
+ occurrence[token] += www
93
+
94
+ tmp = pipeline_v6.decode(all_tokens[out_last:])
95
  if '\ufffd' not in tmp:
96
  out_str += tmp
97
  yield out_str.strip()
98
  out_last = i + 1
 
 
99
 
100
  gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
101
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
102
+ print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
103
  del out
104
  del state
105
  gc.collect()
 
107
  yield out_str.strip()
108
 
109
  examples = [
110
+ ["User: simulate SpaceX mars landing using python\n\nAssistant: <think", gen_limit, 1, 0.3, 0.5, 0.5],
111
+ [generate_prompt("Please give the pros and cons of hodl versus active trading."), gen_limit, 1, 0.3, 0.5, 0.5],
112
+ ["Assistant: How can we craft an engaging story featuring vampires on Mars? Let's think step by step and provide an expert response:", gen_limit, 1, 0.3, 0.5, 0.5],
113
+ ["Assistant: How can we persuade Elon Musk to follow you on Twitter? Let's think step by step and provide an expert response:", gen_limit, 1, 0.3, 0.5, 0.5],
114
  [generate_prompt("東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。"), gen_limit, 1, 0.3, 0.5, 0.5],
115
  [generate_prompt("Write a story using the following information.", "A man named Alex chops a tree down."), gen_limit, 1, 0.3, 0.5, 0.5],
116
  ["A few light taps upon the pane made her turn to the window. It had begun to snow again.", gen_limit, 1, 0.3, 0.5, 0.5],
117
  ['''Edward: I am Edward Elric from Fullmetal Alchemist.\n\nUser: Hello Edward. What have you been up to recently?\n\nEdward:''', gen_limit, 1, 0.3, 0.5, 0.5],
118
+ [generate_prompt("Write a simple webpage. When a user clicks the button, it shows a random joke from a list of 4 jokes."), gen_limit, 1, 0.3, 0.5, 0.5],
 
119
  ["En una pequeña aldea escondida entre las montañas de Andalucía, donde las calles aún conservaban el eco de antiguas leyendas, vivía un joven llamado Alejandro.", gen_limit, 1, 0.3, 0.5, 0.5],
120
  ["Dans le cœur battant de Paris, sous le ciel teinté d'un crépuscule d'or et de pourpre, se tenait une petite librairie oubliée par le temps.", gen_limit, 1, 0.3, 0.5, 0.5],
121
  ["في تطور مذهل وغير مسبوق، أعلنت السلطات المحلية في العاصمة عن اكتشاف أثري قد يغير مجرى التاريخ كما نعرفه.", gen_limit, 1, 0.3, 0.5, 0.5],
122
  ['''“当然可以,大宇宙不会因为这五公斤就不坍缩了。”关一帆说,他还有一个没说出来的想法:也许大宇宙真的会因为相差一个原子的质量而由封闭转为开放。大��然的精巧有时超出想象,比如生命的诞生,就需要各项宇宙参数在几亿亿分之一精度上的精确配合。但程心仍然可以留下她的生态球,因为在那无数文明创造的无数小宇宙中,肯定有相当一部分不响应回归运动的号召,所以,大宇宙最终被夺走的质量至少有几亿吨,甚至可能是几亿亿亿吨。\n但愿大宇宙能够忽略这个误差。\n程心和关一帆进入了飞船,智子最后也进来了。她早就不再穿那身华丽的和服了,她现在身着迷彩服,再次成为一名轻捷精悍的战士,她的身上佩带着许多武器和生存装备,最引人注目的是那把插在背后的武士刀。\n“放心,我在,你们就在!”智子对两位人类朋友说。\n聚变发动机启动了,推进器发出幽幽的蓝光,''', gen_limit, 1, 0.3, 0.5, 0.5],
123
  ]
124
 
125
+ ##################################################################################################################
126
+ with gr.Blocks(title=title_v6) as demo:
127
+ gr.HTML(f"<div style=\"text-align: center;\">\n<h1>{title_v6}</h1>\n</div>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  with gr.Tab("=== Base Model (Raw Generation) ==="):
130
+ gr.Markdown(f'This is [RWKV7 G1](https://huggingface.co/BlinkDL/rwkv7-g1) 0.1B (!!!) L12-D768 reasoning base LM - an attention-free pure RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM). Supports 100+ world languages and code. Check [400+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). *** Can try examples (bottom of page) *** (can edit them). Demo limited to ctxlen {ctx_limit}.')
131
  with gr.Row():
132
  with gr.Column():
133
+ prompt = gr.Textbox(lines=6, label="Prompt", value=generate_prompt("User: simulate SpaceX mars landing using python\n\nAssistant: <think"))
134
  token_count = gr.Slider(10, gen_limit, label="Max Tokens", step=10, value=gen_limit)
135
  temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
136
  top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.3)
137
  presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.5)
138
  count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  with gr.Column():
140
  with gr.Row():
141
  submit = gr.Button("Submit", variant="primary")
142
  clear = gr.Button("Clear", variant="secondary")
143
  output = gr.Textbox(label="Output", lines=20, max_lines=100)
144
+ data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples, samples_per_page=50, label="Example Instructions", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
145
+ submit.click(evaluate, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
146
  clear.click(lambda: None, [], [output])
147
  data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  demo.queue(concurrency_count=1, max_size=10)
150
+ demo.launch(share=False)