skytnt commited on
Commit
573d12d
1 Parent(s): 15942da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -13
app.py CHANGED
@@ -41,7 +41,7 @@ def sample_top_p_k(probs, p, k):
41
  return next_token
42
 
43
 
44
- def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
45
  disable_patch_change=False, disable_control_change=False, disable_channels=None):
46
  if disable_channels is not None:
47
  disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
@@ -63,7 +63,7 @@ def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
63
  with bar:
64
  while cur_len < max_len:
65
  end = False
66
- hidden = model_base.run(None, {'x': input_tensor})[0][:, -1]
67
  next_token_seq = np.empty((1, 0), dtype=np.int64)
68
  event_name = ""
69
  for i in range(max_token_seq):
@@ -81,7 +81,7 @@ def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
81
  if param_name == "channel":
82
  mask_ids = [i for i in mask_ids if i not in disable_channels]
83
  mask[mask_ids] = 1
84
- logits = model_token.run(None, {'x': next_token_seq, "hidden": hidden})[0][:, -1:]
85
  scores = softmax(logits / temp, -1) * mask
86
  sample = sample_top_p_k(scores, top_p, top_k)
87
  if i == 0:
@@ -107,7 +107,7 @@ def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
107
  break
108
 
109
 
110
- def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
111
  mid_seq = []
112
  max_len = int(gen_events)
113
  img_len = 1024
@@ -172,7 +172,8 @@ def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, t
172
  for token_seq in mid:
173
  mid_seq.append(token_seq)
174
  draw_event(token_seq)
175
- generator = generate(mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
 
176
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
177
  disable_channels=disable_channels)
178
  for token_seq in generator:
@@ -208,13 +209,18 @@ if __name__ == "__main__":
208
  parser.add_argument("--max-gen", type=int, default=1024, help="max")
209
  opt = parser.parse_args()
210
  soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
211
- model_base_path = hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx")
212
- model_token_path = hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx")
213
-
 
214
  tokenizer = MIDITokenizer()
215
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
216
- model_base = rt.InferenceSession(model_base_path, providers=providers)
217
- model_token = rt.InferenceSession(model_token_path, providers=providers)
 
 
 
 
218
 
219
  app = gr.Blocks()
220
  with app:
@@ -229,6 +235,8 @@ if __name__ == "__main__":
229
 
230
  tab_select = gr.Variable(value=0)
231
  with gr.Tabs():
 
 
232
  with gr.TabItem("instrument prompt") as tab1:
233
  input_instruments = gr.Dropdown(label="instruments (auto if empty)", choices=list(patch2number.keys()),
234
  multiselect=True, max_choices=15, type="value")
@@ -260,7 +268,7 @@ if __name__ == "__main__":
260
  with gr.Accordion("options", open=False):
261
  input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
262
  input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.98)
263
- input_top_k = gr.Slider(label="top k", minimum=1, maximum=20, step=1, value=12)
264
  input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
265
  example3 = gr.Examples([[1, 0.98, 12], [1.2, 0.95, 8]], [input_temp, input_top_p, input_top_k])
266
  run_btn = gr.Button("generate", variant="primary")
@@ -269,8 +277,8 @@ if __name__ == "__main__":
269
  output_midi_img = gr.Image(label="output image")
270
  output_midi = gr.File(label="output midi", file_types=[".mid"])
271
  output_audio = gr.Audio(label="output audio", format="mp3")
272
- run_event = run_btn.click(run, [tab_select, input_instruments, input_drum_kit, input_midi, input_midi_events,
273
- input_gen_events, input_temp, input_top_p, input_top_k,
274
  input_allow_cc],
275
  [output_midi_seq, output_midi_img, output_midi, output_audio])
276
  stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio], cancels=run_event, queue=False)
 
41
  return next_token
42
 
43
 
44
+ def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
45
  disable_patch_change=False, disable_control_change=False, disable_channels=None):
46
  if disable_channels is not None:
47
  disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
 
63
  with bar:
64
  while cur_len < max_len:
65
  end = False
66
+ hidden = model[0].run(None, {'x': input_tensor})[0][:, -1]
67
  next_token_seq = np.empty((1, 0), dtype=np.int64)
68
  event_name = ""
69
  for i in range(max_token_seq):
 
81
  if param_name == "channel":
82
  mask_ids = [i for i in mask_ids if i not in disable_channels]
83
  mask[mask_ids] = 1
84
+ logits = model[1].run(None, {'x': next_token_seq, "hidden": hidden})[0][:, -1:]
85
  scores = softmax(logits / temp, -1) * mask
86
  sample = sample_top_p_k(scores, top_p, top_k)
87
  if i == 0:
 
107
  break
108
 
109
 
110
+ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
111
  mid_seq = []
112
  max_len = int(gen_events)
113
  img_len = 1024
 
172
  for token_seq in mid:
173
  mid_seq.append(token_seq)
174
  draw_event(token_seq)
175
+ model = models[model_name]
176
+ generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
177
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
178
  disable_channels=disable_channels)
179
  for token_seq in generator:
 
209
  parser.add_argument("--max-gen", type=int, default=1024, help="max")
210
  opt = parser.parse_args()
211
  soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
212
+ models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
213
+ "symphony finetune model": ["skytnt/midi-model-ft", "symphony/"],
214
+ "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"]}
215
+ models = {}
216
  tokenizer = MIDITokenizer()
217
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
218
+ for name, (repo_id, path) in models_info.items():
219
+ model_base_path = hf_hub_download(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
220
+ model_token_path = hf_hub_download(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
221
+ model_base = rt.InferenceSession(model_base_path, providers=providers)
222
+ model_token = rt.InferenceSession(model_token_path, providers=providers)
223
+ models[name] = [model_base, model_token]
224
 
225
  app = gr.Blocks()
226
  with app:
 
235
 
236
  tab_select = gr.Variable(value=0)
237
  with gr.Tabs():
238
+ input_model = gr.Dropdown(label="select model", choices=list(models.keys()),
239
+ type="value", value=list(models.keys())[0])
240
  with gr.TabItem("instrument prompt") as tab1:
241
  input_instruments = gr.Dropdown(label="instruments (auto if empty)", choices=list(patch2number.keys()),
242
  multiselect=True, max_choices=15, type="value")
 
268
  with gr.Accordion("options", open=False):
269
  input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
270
  input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.98)
271
+ input_top_k = gr.Slider(label="top k", minimum=1, maximum=20, step=1, value=20)
272
  input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
273
  example3 = gr.Examples([[1, 0.98, 12], [1.2, 0.95, 8]], [input_temp, input_top_p, input_top_k])
274
  run_btn = gr.Button("generate", variant="primary")
 
277
  output_midi_img = gr.Image(label="output image")
278
  output_midi = gr.File(label="output midi", file_types=[".mid"])
279
  output_audio = gr.Audio(label="output audio", format="mp3")
280
+ run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_midi,
281
+ input_midi_events, input_gen_events, input_temp, input_top_p, input_top_k,
282
  input_allow_cc],
283
  [output_midi_seq, output_midi_img, output_midi, output_audio])
284
  stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio], cancels=run_event, queue=False)