asigalov61 commited on
Commit
78b066f
·
verified ·
1 Parent(s): fa9d7e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -8
app.py CHANGED
@@ -167,17 +167,67 @@ def Generate_Rock_Song(input_midi, input_melody_seed_number):
167
  print('=' * 70)
168
  print('Generating...')
169
 
170
- x = (torch.tensor(seed_melody, dtype=torch.long, device='cuda')[None, ...])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
- with ctx:
173
- out = model.generate(x,
174
- 1536,
175
- temperature=0.9,
176
- return_prime=False,
177
- verbose=False)
 
178
 
179
- output = out[0].tolist()
 
 
 
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  print('=' * 70)
182
  print('Done!')
183
  print('=' * 70)
 
167
  print('=' * 70)
168
  print('Generating...')
169
 
170
+ #==================================================================
171
+
172
+ def generate_tokens(seq, max_num_ptcs=10):
173
+
174
+ input = copy.deepcopy(seq)
175
+
176
+ pcount = 0
177
+ y = 545
178
+
179
+ gen_tokens = []
180
+
181
+ while pcount < max_num_ptcs and y > 255:
182
+
183
+ x = torch.tensor(input, dtype=torch.long, device='cuda')
184
+
185
+ with ctx:
186
+ out = model.generate(x,
187
+ 1,
188
+ filter_logits_fn=top_k,
189
+ filter_kwargs={'k': 10},
190
+ temperature=0.9,
191
+ return_prime=False,
192
+ verbose=False)
193
+
194
+ y = out[0].tolist()[0]
195
+
196
+ if pcount < max_num_ptcs and y > 255:
197
+ input.append(y)
198
+ gen_tokens.append(y)
199
+ if y > 544:
200
+ pcount += 1
201
+
202
+ return gen_tokens
203
 
204
+ #==================================================================
205
+
206
+ num_prime_chords = 128
207
+ pass_chan_dur_tok = False
208
+ match_ptcs_counts = False
209
+
210
+ song = []
211
 
212
+ for i in range(num_prime_chords):
213
+ song.extend(prime_toks[i])
214
+
215
+ for i in tqdm.tqdm(range(num_prime_chords, len(score_toks))):
216
 
217
+ song.extend(score_toks[i])
218
+
219
+ if control_toks[i]:
220
+ for ct in control_toks[i]:
221
+ if pass_chan_dur_tok:
222
+ song.append(ct[0])
223
+ if match_ptcs_counts:
224
+ out_seq = generate_tokens(song, ct[1])
225
+ else:
226
+ out_seq = generate_tokens(song)
227
+ song.extend(out_seq)
228
+
229
+ #==================================================================
230
+
231
  print('=' * 70)
232
  print('Done!')
233
  print('=' * 70)