Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -167,17 +167,67 @@ def Generate_Rock_Song(input_midi, input_melody_seed_number):
|
|
167 |
print('=' * 70)
|
168 |
print('Generating...')
|
169 |
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
|
|
178 |
|
179 |
-
|
|
|
|
|
|
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
print('=' * 70)
|
182 |
print('Done!')
|
183 |
print('=' * 70)
|
|
|
167 |
print('=' * 70)
|
168 |
print('Generating...')
|
169 |
|
170 |
+
#==================================================================
|
171 |
+
|
172 |
+
def generate_tokens(seq, max_num_ptcs=10):
|
173 |
+
|
174 |
+
input = copy.deepcopy(seq)
|
175 |
+
|
176 |
+
pcount = 0
|
177 |
+
y = 545
|
178 |
+
|
179 |
+
gen_tokens = []
|
180 |
+
|
181 |
+
while pcount < max_num_ptcs and y > 255:
|
182 |
+
|
183 |
+
x = torch.tensor(input, dtype=torch.long, device='cuda')
|
184 |
+
|
185 |
+
with ctx:
|
186 |
+
out = model.generate(x,
|
187 |
+
1,
|
188 |
+
filter_logits_fn=top_k,
|
189 |
+
filter_kwargs={'k': 10},
|
190 |
+
temperature=0.9,
|
191 |
+
return_prime=False,
|
192 |
+
verbose=False)
|
193 |
+
|
194 |
+
y = out[0].tolist()[0]
|
195 |
+
|
196 |
+
if pcount < max_num_ptcs and y > 255:
|
197 |
+
input.append(y)
|
198 |
+
gen_tokens.append(y)
|
199 |
+
if y > 544:
|
200 |
+
pcount += 1
|
201 |
+
|
202 |
+
return gen_tokens
|
203 |
|
204 |
+
#==================================================================
|
205 |
+
|
206 |
+
num_prime_chords = 128
|
207 |
+
pass_chan_dur_tok = False
|
208 |
+
match_ptcs_counts = False
|
209 |
+
|
210 |
+
song = []
|
211 |
|
212 |
+
for i in range(num_prime_chords):
|
213 |
+
song.extend(prime_toks[i])
|
214 |
+
|
215 |
+
for i in tqdm.tqdm(range(num_prime_chords, len(score_toks))):
|
216 |
|
217 |
+
song.extend(score_toks[i])
|
218 |
+
|
219 |
+
if control_toks[i]:
|
220 |
+
for ct in control_toks[i]:
|
221 |
+
if pass_chan_dur_tok:
|
222 |
+
song.append(ct[0])
|
223 |
+
if match_ptcs_counts:
|
224 |
+
out_seq = generate_tokens(song, ct[1])
|
225 |
+
else:
|
226 |
+
out_seq = generate_tokens(song)
|
227 |
+
song.extend(out_seq)
|
228 |
+
|
229 |
+
#==================================================================
|
230 |
+
|
231 |
print('=' * 70)
|
232 |
print('Done!')
|
233 |
print('=' * 70)
|