Spaces:
Running
Running
fixs
Browse files- midi_synthesizer.py +1 -1
- midi_tokenizer.py +77 -6
midi_synthesizer.py
CHANGED
@@ -14,7 +14,7 @@ def synthesis(midi_opus, soundfont_path, sample_rate=44100):
|
|
14 |
event_list.append(event_new)
|
15 |
event_list = sorted(event_list, key=lambda e: e[1])
|
16 |
|
17 |
-
tempo = int((60 /
|
18 |
ss = np.empty((0, 2), dtype=np.int16)
|
19 |
fl = fluidsynth.Synth(samplerate=float(sample_rate))
|
20 |
sfid = fl.sfload(soundfont_path)
|
|
|
14 |
event_list.append(event_new)
|
15 |
event_list = sorted(event_list, key=lambda e: e[1])
|
16 |
|
17 |
+
tempo = int((60 / 120) * 10 ** 6) # default 120 bpm
|
18 |
ss = np.empty((0, 2), dtype=np.int16)
|
19 |
fl = fluidsynth.Synth(samplerate=float(sample_rate))
|
20 |
sfid = fl.sfload(soundfont_path)
|
midi_tokenizer.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import PIL
|
2 |
import numpy as np
|
3 |
|
@@ -43,19 +45,31 @@ class MIDITokenizer:
|
|
43 |
def tokenize(self, midi_score, add_bos_eos=True):
|
44 |
ticks_per_beat = midi_score[0]
|
45 |
event_list = {}
|
46 |
-
track_num = len(midi_score[1:])
|
47 |
for track_idx, track in enumerate(midi_score[1:129]):
|
|
|
48 |
for event in track:
|
49 |
-
t = round(16 * event[1] / ticks_per_beat)
|
50 |
new_event = [event[0], t // 16, t % 16, track_idx] + event[2:]
|
51 |
if event[0] == "note":
|
52 |
new_event[4] = max(1, round(16 * new_event[4] / ticks_per_beat))
|
53 |
elif event[0] == "set_tempo":
|
54 |
new_event[4] = int(self.tempo2bpm(new_event[4]))
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
event_list[key] = new_event
|
57 |
event_list = list(event_list.values())
|
58 |
-
event_list = sorted(event_list, key=lambda e:
|
59 |
midi_seq = []
|
60 |
|
61 |
last_t1 = 0
|
@@ -113,18 +127,24 @@ class MIDITokenizer:
|
|
113 |
tracks_dict[track_idx] = []
|
114 |
tracks_dict[track_idx].append([event[0], t] + event[4:])
|
115 |
tracks = list(tracks_dict.values())
|
116 |
-
|
|
|
117 |
track = tracks[i]
|
118 |
track = sorted(track, key=lambda e: e[1])
|
119 |
last_note_t = {}
|
|
|
120 |
for e in reversed(track):
|
121 |
if e[0] == "note":
|
122 |
t, d, c, p = e[1:5]
|
123 |
key = (c, p)
|
124 |
if key in last_note_t:
|
125 |
-
d = min(d, max(last_note_t[key] - t, 0))
|
126 |
last_note_t[key] = t
|
127 |
e[2] = d
|
|
|
|
|
|
|
|
|
128 |
tracks[i] = track
|
129 |
return [ticks_per_beat, *tracks]
|
130 |
|
@@ -148,3 +168,54 @@ class MIDITokenizer:
|
|
148 |
img[p, t: t + d] = colors[(tr, c)]
|
149 |
img = PIL.Image.fromarray(np.flip(img, 0))
|
150 |
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
import PIL
|
4 |
import numpy as np
|
5 |
|
|
|
45 |
def tokenize(self, midi_score, add_bos_eos=True):
|
46 |
ticks_per_beat = midi_score[0]
|
47 |
event_list = {}
|
|
|
48 |
for track_idx, track in enumerate(midi_score[1:129]):
|
49 |
+
last_notes = {}
|
50 |
for event in track:
|
51 |
+
t = round(16 * event[1] / ticks_per_beat) # quantization
|
52 |
new_event = [event[0], t // 16, t % 16, track_idx] + event[2:]
|
53 |
if event[0] == "note":
|
54 |
new_event[4] = max(1, round(16 * new_event[4] / ticks_per_beat))
|
55 |
elif event[0] == "set_tempo":
|
56 |
new_event[4] = int(self.tempo2bpm(new_event[4]))
|
57 |
+
if event[0] == "note":
|
58 |
+
key = tuple(new_event[:4] + new_event[5:-1])
|
59 |
+
else:
|
60 |
+
key = tuple(new_event[:-1])
|
61 |
+
if event[0] == "note": # to eliminate note overlap due to quantization
|
62 |
+
cp = tuple(new_event[5:7])
|
63 |
+
if cp in last_notes:
|
64 |
+
last_note_key, last_note = last_notes[cp]
|
65 |
+
last_t = last_note[1] * 16 + last_note[2]
|
66 |
+
last_note[4] = max(0, min(last_note[4], t - last_t))
|
67 |
+
if last_note[4] == 0:
|
68 |
+
event_list.pop(last_note_key)
|
69 |
+
last_notes[cp] = (key, new_event)
|
70 |
event_list[key] = new_event
|
71 |
event_list = list(event_list.values())
|
72 |
+
event_list = sorted(event_list, key=lambda e: e[1:4])
|
73 |
midi_seq = []
|
74 |
|
75 |
last_t1 = 0
|
|
|
127 |
tracks_dict[track_idx] = []
|
128 |
tracks_dict[track_idx].append([event[0], t] + event[4:])
|
129 |
tracks = list(tracks_dict.values())
|
130 |
+
|
131 |
+
for i in range(len(tracks)): # to eliminate note overlap
|
132 |
track = tracks[i]
|
133 |
track = sorted(track, key=lambda e: e[1])
|
134 |
last_note_t = {}
|
135 |
+
zero_len_notes = []
|
136 |
for e in reversed(track):
|
137 |
if e[0] == "note":
|
138 |
t, d, c, p = e[1:5]
|
139 |
key = (c, p)
|
140 |
if key in last_note_t:
|
141 |
+
d = min(d, max(last_note_t[key] - t, 0))
|
142 |
last_note_t[key] = t
|
143 |
e[2] = d
|
144 |
+
if d == 0:
|
145 |
+
zero_len_notes.append(e)
|
146 |
+
for e in zero_len_notes:
|
147 |
+
track.remove(e)
|
148 |
tracks[i] = track
|
149 |
return [ticks_per_beat, *tracks]
|
150 |
|
|
|
168 |
img[p, t: t + d] = colors[(tr, c)]
|
169 |
img = PIL.Image.fromarray(np.flip(img, 0))
|
170 |
return img
|
171 |
+
|
172 |
+
def augment(self, midi_seq, max_pitch_shift=4, max_vel_shift=10, max_cc_val_shift=10, max_bpm_shift=10):
|
173 |
+
pitch_shift = random.randint(-max_pitch_shift, max_pitch_shift)
|
174 |
+
vel_shift = random.randint(-max_vel_shift, max_vel_shift)
|
175 |
+
cc_val_shift = random.randint(-max_cc_val_shift, max_cc_val_shift)
|
176 |
+
bpm_shift = random.randint(-max_bpm_shift, max_bpm_shift)
|
177 |
+
midi_seq_new = []
|
178 |
+
for tokens in midi_seq:
|
179 |
+
tokens_new = [*tokens]
|
180 |
+
if tokens[0] in self.id_events:
|
181 |
+
name = self.id_events[tokens[0]]
|
182 |
+
if name == "note":
|
183 |
+
c = tokens[5] - self.parameter_ids["channel"][0]
|
184 |
+
p = tokens[6] - self.parameter_ids["pitch"][0]
|
185 |
+
v = tokens[7] - self.parameter_ids["velocity"][0]
|
186 |
+
if c != 9: # no shift for drums
|
187 |
+
p += pitch_shift
|
188 |
+
if not 0 <= p < 128:
|
189 |
+
return midi_seq
|
190 |
+
v += vel_shift
|
191 |
+
v = max(1, min(127, v))
|
192 |
+
tokens_new[6] = self.parameter_ids["pitch"][p]
|
193 |
+
tokens_new[7] = self.parameter_ids["velocity"][v]
|
194 |
+
elif name == "control_change":
|
195 |
+
cc = tokens[5] - self.parameter_ids["controller"][0]
|
196 |
+
val = tokens[6] - self.parameter_ids["value"][0]
|
197 |
+
if cc in [1, 2, 7, 11]:
|
198 |
+
val += cc_val_shift
|
199 |
+
val = max(1, min(127, val))
|
200 |
+
tokens_new[6] = self.parameter_ids["value"][val]
|
201 |
+
elif name == "set_tempo":
|
202 |
+
bpm = tokens[4] - self.parameter_ids["bpm"][0]
|
203 |
+
bpm += bpm_shift
|
204 |
+
bpm = max(1, min(255, bpm))
|
205 |
+
tokens_new[4] = self.parameter_ids["bpm"][bpm]
|
206 |
+
midi_seq_new.append(tokens_new)
|
207 |
+
return midi_seq_new
|
208 |
+
|
209 |
+
def check_alignment(self, midi_seq, threshold=0.4):
|
210 |
+
total = 0
|
211 |
+
hist = [0] * 16
|
212 |
+
for tokens in midi_seq:
|
213 |
+
if tokens[0] in self.id_events and self.id_events[tokens[0]] == "note":
|
214 |
+
t2 = tokens[2] - self.parameter_ids["time2"][0]
|
215 |
+
total += 1
|
216 |
+
hist[t2] += 1
|
217 |
+
if total == 0:
|
218 |
+
return False
|
219 |
+
hist = sorted(hist, reverse=True)
|
220 |
+
p = sum(hist[:2]) / total
|
221 |
+
return p > threshold
|