skytnt commited on
Commit
c51a1c9
1 Parent(s): a9b0cf6
Files changed (3) hide show
  1. app.py +48 -27
  2. javascript/app.js +4 -3
  3. midi_tokenizer.py +146 -35
app.py CHANGED
@@ -111,16 +111,19 @@ def create_msg(name, data):
111
  return {"name": name, "data": data, "uuid": uuid.uuid4().hex}
112
 
113
 
114
- def send_msgs(msgs, msgs_history):
 
 
115
  msgs_history.append(msgs)
116
- if len(msgs_history) > 50:
117
- msgs_history.pop(0)
118
  return json.dumps(msgs_history)
119
 
120
 
121
- def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
122
  msgs_history = []
123
  mid_seq = []
 
124
  gen_events = int(gen_events)
125
  max_len = gen_events
126
 
@@ -129,6 +132,8 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te
129
  if tab == 0:
130
  i = 0
131
  mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
 
 
132
  patches = {}
133
  if instruments is None:
134
  instruments = []
@@ -151,10 +156,10 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te
151
  max_len += len(mid)
152
  for token_seq in mid:
153
  mid_seq.append(token_seq.tolist())
154
- init_msgs = [create_msg("visualizer_clear", None)]
155
  for tokens in mid_seq:
156
  init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
157
- yield mid_seq, None, None, send_msgs(init_msgs, msgs_history), msgs_history
158
  model = models[model_name]
159
  generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
160
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
@@ -163,22 +168,31 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te
163
  token_seq = token_seq.tolist()
164
  mid_seq.append(token_seq)
165
  event = tokenizer.tokens2event(token_seq)
166
- yield mid_seq, None, None, send_msgs([create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])], msgs_history), msgs_history
167
  mid = tokenizer.detokenize(mid_seq)
168
  with open(f"output.mid", 'wb') as f:
169
  f.write(MIDI.score2midi(mid))
170
  audio = synthesis(MIDI.score2opus(mid), soundfont_path)
171
- yield mid_seq, "output.mid", (44100, audio), send_msgs([create_msg("visualizer_end", None)], msgs_history), msgs_history
 
 
 
 
 
172
 
173
 
174
- def cancel_run(mid_seq, msgs_history):
175
  if mid_seq is None:
176
  return None, None, []
177
  mid = tokenizer.detokenize(mid_seq)
178
  with open(f"output.mid", 'wb') as f:
179
  f.write(MIDI.score2midi(mid))
180
  audio = synthesis(MIDI.score2opus(mid), soundfont_path)
181
- return "output.mid", (44100, audio), send_msgs([create_msg("visualizer_end", None)], msgs_history)
 
 
 
 
182
 
183
 
184
  def load_javascript(dir="javascript"):
@@ -200,6 +214,7 @@ def load_javascript(dir="javascript"):
200
 
201
 
202
  def hf_hub_download_retry(repo_id, filename):
 
203
  retry = 0
204
  err = None
205
  while retry < 30:
@@ -246,9 +261,9 @@ if __name__ == "__main__":
246
  "Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
247
  "[Open In Colab]"
248
  "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
249
- " for faster running and longer generation"
 
250
  )
251
- js_msg_history_state = gr.State(value=[])
252
  js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
253
  js_msg.change(None, [js_msg], [], js="""
254
  (msg_json) =>{
@@ -262,19 +277,25 @@ if __name__ == "__main__":
262
  tab_select = gr.State(value=0)
263
  with gr.Tabs():
264
  with gr.TabItem("instrument prompt") as tab1:
265
- input_instruments = gr.Dropdown(label="instruments (auto if empty)", choices=list(patch2number.keys()),
266
  multiselect=True, max_choices=15, type="value")
267
- input_drum_kit = gr.Dropdown(label="drum kit", choices=list(drum_kits2number.keys()), type="value",
268
  value="None")
 
 
 
269
  example1 = gr.Examples([
270
  [[], "None"],
271
  [["Acoustic Grand"], "None"],
272
- [["Acoustic Grand", "Violin", "Viola", "Cello", "Contrabass"], "Orchestra"],
273
- [["Flute", "Cello", "Bassoon", "Tuba"], "None"],
274
- [["Violin", "Viola", "Cello", "Contrabass", "Trumpet", "French Horn", "Brass Section",
275
- "Flute", "Piccolo", "Tuba", "Trombone", "Timpani"], "Orchestra"],
276
- [["Acoustic Guitar(nylon)", "Acoustic Guitar(steel)", "Electric Guitar(jazz)",
277
- "Electric Guitar(clean)", "Electric Guitar(muted)", "Overdriven Guitar", "Distortion Guitar",
 
 
 
278
  "Electric Bass(finger)"], "Standard"]
279
  ], [input_instruments, input_drum_kit])
280
  with gr.TabItem("midi prompt") as tab2:
@@ -292,19 +313,19 @@ if __name__ == "__main__":
292
  with gr.Accordion("options", open=False):
293
  input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
294
  input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.98)
295
- input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=48)
296
  input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
297
- example3 = gr.Examples([[1, 0.98, 12], [1.2, 0.95, 8]], [input_temp, input_top_p, input_top_k])
298
  run_btn = gr.Button("generate", variant="primary")
299
  stop_btn = gr.Button("stop and output")
300
  output_midi_seq = gr.State()
301
  output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
302
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
303
  output_midi = gr.File(label="output midi", file_types=[".mid"])
304
- run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_midi,
305
- input_midi_events, input_gen_events, input_temp, input_top_p, input_top_k,
306
- input_allow_cc],
307
- [output_midi_seq, output_midi, output_audio, js_msg, js_msg_history_state],
308
  concurrency_limit=3)
309
- stop_btn.click(cancel_run, [output_midi_seq, js_msg_history_state], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
310
  app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
 
111
  return {"name": name, "data": data, "uuid": uuid.uuid4().hex}
112
 
113
 
114
+ def send_msgs(msgs, msgs_history=None):
115
+ if msgs_history is None:
116
+ msgs_history = []
117
  msgs_history.append(msgs)
118
+ if len(msgs_history) > 25:
119
+ msgs_history= msgs_history[1:]
120
  return json.dumps(msgs_history)
121
 
122
 
123
+ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
124
  msgs_history = []
125
  mid_seq = []
126
+ bpm = int(bpm)
127
  gen_events = int(gen_events)
128
  max_len = gen_events
129
 
 
132
  if tab == 0:
133
  i = 0
134
  mid = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
135
+ if bpm != 0:
136
+ mid.append(tokenizer.event2tokens(["set_tempo",0,0,0, bpm]))
137
  patches = {}
138
  if instruments is None:
139
  instruments = []
 
156
  max_len += len(mid)
157
  for token_seq in mid:
158
  mid_seq.append(token_seq.tolist())
159
+ init_msgs = [create_msg("visualizer_clear", False)]
160
  for tokens in mid_seq:
161
  init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
162
+ yield mid_seq, None, None, send_msgs(init_msgs, msgs_history)
163
  model = models[model_name]
164
  generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
165
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
 
168
  token_seq = token_seq.tolist()
169
  mid_seq.append(token_seq)
170
  event = tokenizer.tokens2event(token_seq)
171
+ yield mid_seq, None, None, send_msgs([create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])], msgs_history)
172
  mid = tokenizer.detokenize(mid_seq)
173
  with open(f"output.mid", 'wb') as f:
174
  f.write(MIDI.score2midi(mid))
175
  audio = synthesis(MIDI.score2opus(mid), soundfont_path)
176
+ # resend all msgs
177
+ msgs = [create_msg("visualizer_end", None), create_msg("visualizer_clear", True)]
178
+ for tokens in mid_seq:
179
+ msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
180
+ msgs.append(create_msg("visualizer_end", None))
181
+ yield mid_seq, "output.mid", (44100, audio), send_msgs(msgs)
182
 
183
 
184
+ def cancel_run(mid_seq):
185
  if mid_seq is None:
186
  return None, None, []
187
  mid = tokenizer.detokenize(mid_seq)
188
  with open(f"output.mid", 'wb') as f:
189
  f.write(MIDI.score2midi(mid))
190
  audio = synthesis(MIDI.score2opus(mid), soundfont_path)
191
+ msgs = [create_msg("visualizer_end", None), create_msg("visualizer_clear", True)]
192
+ for tokens in mid_seq:
193
+ msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
194
+ msgs.append(create_msg("visualizer_end", None))
195
+ return "output.mid", (44100, audio), send_msgs(msgs)
196
 
197
 
198
  def load_javascript(dir="javascript"):
 
214
 
215
 
216
  def hf_hub_download_retry(repo_id, filename):
217
+ print(f"downloading {repo_id} {filename}")
218
  retry = 0
219
  err = None
220
  while retry < 30:
 
261
  "Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
262
  "[Open In Colab]"
263
  "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
264
+ " for faster running and longer generation\n\n"
265
+ "**Update v1.2**: Optimise the tokenizer and dataset"
266
  )
 
267
  js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
268
  js_msg.change(None, [js_msg], [], js="""
269
  (msg_json) =>{
 
277
  tab_select = gr.State(value=0)
278
  with gr.Tabs():
279
  with gr.TabItem("instrument prompt") as tab1:
280
+ input_instruments = gr.Dropdown(label="🪗instruments (auto if empty)", choices=list(patch2number.keys()),
281
  multiselect=True, max_choices=15, type="value")
282
+ input_drum_kit = gr.Dropdown(label="🥁drum kit", choices=list(drum_kits2number.keys()), type="value",
283
  value="None")
284
+ input_bpm = gr.Slider(label="BPM (beats per minute, auto if 0)", minimum=0, maximum=255,
285
+ step=1,
286
+ value=0)
287
  example1 = gr.Examples([
288
  [[], "None"],
289
  [["Acoustic Grand"], "None"],
290
+ [['Acoustic Grand', 'SynthStrings 2', 'SynthStrings 1', 'Pizzicato Strings',
291
+ 'Pad 2 (warm)', 'Tremolo Strings', 'String Ensemble 1'], "Orchestra"],
292
+ [['Trumpet', 'Oboe', 'Trombone', 'String Ensemble 1', 'Clarinet',
293
+ 'French Horn', 'Pad 4 (choir)', 'Bassoon', 'Flute'], "None"],
294
+ [['Flute', 'French Horn', 'Clarinet', 'String Ensemble 2', 'English Horn', 'Bassoon',
295
+ 'Oboe', 'Pizzicato Strings'], "Orchestra"],
296
+ [['Electric Piano 2', 'Lead 5 (charang)', 'Electric Bass(pick)', 'Lead 2 (sawtooth)',
297
+ 'Pad 1 (new age)', 'Orchestra Hit', 'Cello', 'Electric Guitar(clean)'], "Standard"],
298
+ [["Electric Guitar(clean)", "Electric Guitar(muted)", "Overdriven Guitar", "Distortion Guitar",
299
  "Electric Bass(finger)"], "Standard"]
300
  ], [input_instruments, input_drum_kit])
301
  with gr.TabItem("midi prompt") as tab2:
 
313
  with gr.Accordion("options", open=False):
314
  input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
315
  input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.98)
316
+ input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=20)
317
  input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
318
+ example3 = gr.Examples([[1, 0.98, 20], [1, 0.98, 12]], [input_temp, input_top_p, input_top_k])
319
  run_btn = gr.Button("generate", variant="primary")
320
  stop_btn = gr.Button("stop and output")
321
  output_midi_seq = gr.State()
322
  output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
323
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
324
  output_midi = gr.File(label="output midi", file_types=[".mid"])
325
+ run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_bpm,
326
+ input_midi, input_midi_events, input_gen_events, input_temp,
327
+ input_top_p, input_top_k, input_allow_cc],
328
+ [output_midi_seq, output_midi, output_audio, js_msg],
329
  concurrency_limit=3)
330
+ stop_btn.click(cancel_run, [output_midi_seq], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
331
  app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
javascript/app.js CHANGED
@@ -146,13 +146,14 @@ class MidiVisualizer extends HTMLElement{
146
  this.setPlayTime(0);
147
  }
148
 
149
- clearMidiEvents(){
150
  this.pause()
151
  this.midiEvents = [];
152
  this.activeNotes = [];
153
  this.midiTimes = [];
154
  this.t1 = 0
155
- this.colorMap.clear()
 
156
  this.setPlayTime(0);
157
  this.totalTimeMs = 0;
158
  this.playTimeMs = 0
@@ -426,7 +427,7 @@ customElements.define('midi-visualizer', MidiVisualizer);
426
  handled_msgs.push(msg.uuid);
427
  switch (msg.name) {
428
  case "visualizer_clear":
429
- midi_visualizer.clearMidiEvents();
430
  createProgressBar(midi_visualizer_container_inited)
431
  break;
432
  case "visualizer_append":
 
146
  this.setPlayTime(0);
147
  }
148
 
149
+ clearMidiEvents(keepColor=false){
150
  this.pause()
151
  this.midiEvents = [];
152
  this.activeNotes = [];
153
  this.midiTimes = [];
154
  this.t1 = 0
155
+ if (!keepColor)
156
+ this.colorMap.clear()
157
  this.setPlayTime(0);
158
  this.totalTimeMs = 0;
159
  this.playTimeMs = 0
 
427
  handled_msgs.push(msg.uuid);
428
  switch (msg.name) {
429
  case "visualizer_clear":
430
+ midi_visualizer.clearMidiEvents(msg.data);
431
  createProgressBar(midi_visualizer_container_inited)
432
  break;
433
  case "visualizer_append":
midi_tokenizer.py CHANGED
@@ -42,22 +42,48 @@ class MIDITokenizer:
42
  tempo = int((60 / bpm) * 10 ** 6)
43
  return tempo
44
 
45
- def tokenize(self, midi_score, add_bos_eos=True):
46
  ticks_per_beat = midi_score[0]
47
  event_list = {}
48
  for track_idx, track in enumerate(midi_score[1:129]):
49
  last_notes = {}
 
 
 
50
  for event in track:
 
 
51
  t = round(16 * event[1] / ticks_per_beat) # quantization
52
  new_event = [event[0], t // 16, t % 16, track_idx] + event[2:]
53
  if event[0] == "note":
54
  new_event[4] = max(1, round(16 * new_event[4] / ticks_per_beat))
55
  elif event[0] == "set_tempo":
56
- new_event[4] = int(self.tempo2bpm(new_event[4]))
 
 
 
57
  if event[0] == "note":
58
  key = tuple(new_event[:4] + new_event[5:-1])
59
  else:
60
  key = tuple(new_event[:-1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  if event[0] == "note": # to eliminate note overlap due to quantization
62
  cp = tuple(new_event[5:7])
63
  if cp in last_notes:
@@ -71,21 +97,39 @@ class MIDITokenizer:
71
  event_list = list(event_list.values())
72
  event_list = sorted(event_list, key=lambda e: e[1:4])
73
  midi_seq = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  last_t1 = 0
76
  for event in event_list:
77
- name = event[0]
78
- if name in self.event_ids:
79
- params = event[1:]
80
- cur_t1 = params[0]
81
- params[0] = params[0] - last_t1
82
- if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
83
- continue
84
- tokens = [self.event_ids[name]] + [self.parameter_ids[p][params[i]]
85
- for i, p in enumerate(self.events[name])]
86
- tokens += [self.pad_id] * (self.max_token_seq - len(tokens))
87
- midi_seq.append(tokens)
88
- last_t1 = cur_t1
89
 
90
  if add_bos_eos:
91
  bos = [self.bos_id] + [self.pad_id] * (self.max_token_seq - 1)
@@ -96,6 +140,8 @@ class MIDITokenizer:
96
  def event2tokens(self, event):
97
  name = event[0]
98
  params = event[1:]
 
 
99
  tokens = [self.event_ids[name]] + [self.parameter_ids[p][params[i]]
100
  for i, p in enumerate(self.events[name])]
101
  tokens += [self.pad_id] * (self.max_token_seq - len(tokens))
@@ -120,14 +166,10 @@ class MIDITokenizer:
120
  t1 = 0
121
  for tokens in midi_seq:
122
  if tokens[0] in self.id_events:
123
- name = self.id_events[tokens[0]]
124
- if len(tokens) <= len(self.events[name]):
125
  continue
126
- params = tokens[1:]
127
- params = [params[i] - self.parameter_ids[p][0] for i, p in enumerate(self.events[name])]
128
- if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
129
- continue
130
- event = [name] + params
131
  if name == "set_tempo":
132
  event[4] = self.bpm2tempo(event[4])
133
  if event[0] == "note":
@@ -183,7 +225,7 @@ class MIDITokenizer:
183
  return img
184
 
185
  def augment(self, midi_seq, max_pitch_shift=4, max_vel_shift=10, max_cc_val_shift=10, max_bpm_shift=10,
186
- max_track_shift=128, max_channel_shift=16):
187
  pitch_shift = random.randint(-max_pitch_shift, max_pitch_shift)
188
  vel_shift = random.randint(-max_vel_shift, max_vel_shift)
189
  cc_val_shift = random.randint(-max_cc_val_shift, max_cc_val_shift)
@@ -239,16 +281,85 @@ class MIDITokenizer:
239
  midi_seq_new.append(tokens_new)
240
  return midi_seq_new
241
 
242
- def check_alignment(self, midi_seq, threshold=0.3):
243
- total = 0
244
- hist = [0] * 16
245
- for tokens in midi_seq:
246
- if tokens[0] in self.id_events and self.id_events[tokens[0]] == "note":
247
- t2 = tokens[2] - self.parameter_ids["time2"][0]
248
- total += 1
249
- hist[t2] += 1
250
- if total == 0:
251
- return False
252
- hist = sorted(hist, reverse=True)
253
- p = sum(hist[:2]) / total
254
- return p > threshold
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  tempo = int((60 / bpm) * 10 ** 6)
43
  return tempo
44
 
45
+ def tokenize(self, midi_score, add_bos_eos=True, cc_eps=4, tempo_eps=4):
46
  ticks_per_beat = midi_score[0]
47
  event_list = {}
48
  for track_idx, track in enumerate(midi_score[1:129]):
49
  last_notes = {}
50
+ patch_dict = {}
51
+ control_dict = {}
52
+ last_tempo = 0
53
  for event in track:
54
+ if event[0] not in self.events:
55
+ continue
56
  t = round(16 * event[1] / ticks_per_beat) # quantization
57
  new_event = [event[0], t // 16, t % 16, track_idx] + event[2:]
58
  if event[0] == "note":
59
  new_event[4] = max(1, round(16 * new_event[4] / ticks_per_beat))
60
  elif event[0] == "set_tempo":
61
+ if new_event[4] == 0: # invalid tempo
62
+ continue
63
+ bpm = int(self.tempo2bpm(new_event[4]))
64
+ new_event[4] = min(bpm, 255)
65
  if event[0] == "note":
66
  key = tuple(new_event[:4] + new_event[5:-1])
67
  else:
68
  key = tuple(new_event[:-1])
69
+ if event[0] == "patch_change":
70
+ c, p = event[2:]
71
+ last_p = patch_dict.setdefault(c, None)
72
+ if last_p == p:
73
+ continue
74
+ patch_dict[c] = p
75
+ elif event[0] == "control_change":
76
+ c, cc, v = event[2:]
77
+ last_v = control_dict.setdefault((c, cc), 0)
78
+ if abs(last_v - v) < cc_eps:
79
+ continue
80
+ control_dict[(c, cc)] = v
81
+ elif event[0] == "set_tempo":
82
+ tempo = new_event[-1]
83
+ if abs(last_tempo - tempo) < tempo_eps:
84
+ continue
85
+ last_tempo = tempo
86
+
87
  if event[0] == "note": # to eliminate note overlap due to quantization
88
  cp = tuple(new_event[5:7])
89
  if cp in last_notes:
 
97
  event_list = list(event_list.values())
98
  event_list = sorted(event_list, key=lambda e: e[1:4])
99
  midi_seq = []
100
+ setup_events = {}
101
+ notes_in_setup = False
102
+ for i, event in enumerate(event_list): # optimise setup
103
+ new_event = [*event]
104
+ if event[0] != "note":
105
+ new_event[1] = 0
106
+ new_event[2] = 0
107
+ has_next = False
108
+ has_pre = False
109
+ if i < len(event_list) - 1:
110
+ next_event = event_list[i + 1]
111
+ has_next = event[1] + event[2] == next_event[1] + next_event[2]
112
+ if notes_in_setup and i > 0:
113
+ pre_event = event_list[i - 1]
114
+ has_pre = event[1] + event[2] == pre_event[1] + pre_event[2]
115
+ if (event[0] == "note" and not has_next) or (notes_in_setup and not has_pre) :
116
+ event_list = sorted(setup_events.values(), key=lambda e: 1 if e[0] == "note" else 0) + event_list[i:]
117
+ break
118
+ else:
119
+ if event[0] == "note":
120
+ notes_in_setup = True
121
+ key = tuple(event[3:-1])
122
+ setup_events[key] = new_event
123
 
124
  last_t1 = 0
125
  for event in event_list:
126
+ cur_t1 = event[1]
127
+ event[1] = event[1] - last_t1
128
+ tokens = self.event2tokens(event)
129
+ if not tokens:
130
+ continue
131
+ midi_seq.append(tokens)
132
+ last_t1 = cur_t1
 
 
 
 
 
133
 
134
  if add_bos_eos:
135
  bos = [self.bos_id] + [self.pad_id] * (self.max_token_seq - 1)
 
140
  def event2tokens(self, event):
141
  name = event[0]
142
  params = event[1:]
143
+ if not all([0 <= params[i] < self.event_parameters[p] for i, p in enumerate(self.events[name])]):
144
+ return []
145
  tokens = [self.event_ids[name]] + [self.parameter_ids[p][params[i]]
146
  for i, p in enumerate(self.events[name])]
147
  tokens += [self.pad_id] * (self.max_token_seq - len(tokens))
 
166
  t1 = 0
167
  for tokens in midi_seq:
168
  if tokens[0] in self.id_events:
169
+ event = self.tokens2event(tokens)
170
+ if not event:
171
  continue
172
+ name = event[0]
 
 
 
 
173
  if name == "set_tempo":
174
  event[4] = self.bpm2tempo(event[4])
175
  if event[0] == "note":
 
225
  return img
226
 
227
  def augment(self, midi_seq, max_pitch_shift=4, max_vel_shift=10, max_cc_val_shift=10, max_bpm_shift=10,
228
+ max_track_shift=0, max_channel_shift=16):
229
  pitch_shift = random.randint(-max_pitch_shift, max_pitch_shift)
230
  vel_shift = random.randint(-max_vel_shift, max_vel_shift)
231
  cc_val_shift = random.randint(-max_cc_val_shift, max_cc_val_shift)
 
281
  midi_seq_new.append(tokens_new)
282
  return midi_seq_new
283
 
284
+ def check_quality(self, midi_seq, alignment_min=0.4, tonality_min=0.8, piano_max=0.7, notes_bandwidth_min=3, notes_density_max=30, notes_density_min=2.5, total_notes_max=10000, total_notes_min=500, note_window_size=16):
285
+ total_notes = 0
286
+ channels = []
287
+ time_hist = [0] * 16
288
+ note_windows = {}
289
+ notes_sametime = []
290
+ notes_density_list = []
291
+ tonality_list = []
292
+ notes_bandwidth_list = []
293
+ instruments = {}
294
+ piano_channels = []
295
+ undef_instrument = False
296
+ abs_t1 = 0
297
+ last_t = 0
298
+ for tsi, tokens in enumerate(midi_seq):
299
+ event = self.tokens2event(tokens)
300
+ if not event:
301
+ continue
302
+ t1, t2, tr = event[1:4]
303
+ abs_t1 += t1
304
+ t = abs_t1 * 16 + t2
305
+ c = None
306
+ if event[0] == "note":
307
+ d, c, p, v = event[4:]
308
+ total_notes += 1
309
+ time_hist[t2] += 1
310
+ if c != 9: # ignore drum channel
311
+ if c not in instruments:
312
+ undef_instrument = True
313
+ note_windows.setdefault(abs_t1 // note_window_size, []).append(p)
314
+ if last_t != t:
315
+ notes_sametime = [(et, p_) for et, p_ in notes_sametime if et > last_t]
316
+ notes_sametime_p = [p_ for _, p_ in notes_sametime]
317
+ if len(notes_sametime) > 0:
318
+ notes_bandwidth_list.append(max(notes_sametime_p) - min(notes_sametime_p))
319
+ notes_sametime.append((t + d - 1, p))
320
+ elif event[0] == "patch_change":
321
+ c, p = event[4:]
322
+ instruments[c] = p
323
+ if p == 0 and c not in piano_channels:
324
+ piano_channels.append(c)
325
+ if c is not None and c not in channels:
326
+ channels.append(c)
327
+ last_t = t
328
+ reasons = []
329
+ if total_notes < total_notes_min:
330
+ reasons.append("total_min")
331
+ if total_notes > total_notes_max:
332
+ reasons.append("total_max")
333
+ if undef_instrument:
334
+ reasons.append("undef_instr")
335
+ if len(note_windows) == 0 and total_notes > 0:
336
+ reasons.append("drum_only")
337
+ if reasons:
338
+ return False, reasons
339
+ time_hist = sorted(time_hist, reverse=True)
340
+ alignment = sum(time_hist[:2]) / total_notes
341
+ for notes in note_windows.values():
342
+ key_hist = [0] * 12
343
+ for p in notes:
344
+ key_hist[p % 12] += 1
345
+ key_hist = sorted(key_hist, reverse=True)
346
+ tonality_list.append(sum(key_hist[:7]) / len(notes))
347
+ notes_density_list.append(len(notes) / note_window_size)
348
+ tonality_list = sorted(tonality_list)
349
+ tonality = sum(tonality_list)/len(tonality_list)
350
+ notes_bandwidth = sum(notes_bandwidth_list)/len(notes_bandwidth_list) if notes_bandwidth_list else 0
351
+ notes_density = max(notes_density_list) if notes_density_list else 0
352
+ piano_ratio = len(piano_channels) / len(channels)
353
+ if len(channels) <=3: # ignore piano threshold if it is a piano solo midi
354
+ piano_max = 1
355
+ if alignment < alignment_min: # check weather the notes align to the bars (because some midi files are recorded)
356
+ reasons.append("alignment")
357
+ if tonality < tonality_min: # check whether the music is tonal
358
+ reasons.append("tonality")
359
+ if notes_bandwidth < notes_bandwidth_min: # check whether music is melodic line only
360
+ reasons.append("bandwidth")
361
+ if not notes_density_min < notes_density < notes_density_max:
362
+ reasons.append("density")
363
+ if piano_ratio > piano_max: # check whether most instruments is piano (because some midi files don't have instruments assigned correctly)
364
+ reasons.append("piano")
365
+ return not reasons, reasons