skytnt commited on
Commit
bb51a02
1 Parent(s): c0d7f61
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import spaces
 
2
  import argparse
3
  import glob
4
  import json
@@ -131,7 +132,7 @@ def run(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm
131
  gen_events = int(gen_events)
132
  max_len = gen_events
133
  if seed_rand:
134
- seed = np.random.randint(0, MAX_SEED)
135
  generator = torch.Generator(opt.device).manual_seed(seed)
136
  disable_patch_change = False
137
  disable_channels = None
@@ -400,7 +401,7 @@ if __name__ == "__main__":
400
  with gr.Accordion("options", open=False):
401
  input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
402
  input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.98)
403
- input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=12)
404
  input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
405
  example3 = gr.Examples([[1, 0.95, 128], [1, 0.98, 20], [1, 0.98, 12]],
406
  [input_temp, input_top_p, input_top_k])
 
1
  import spaces
2
+ import random
3
  import argparse
4
  import glob
5
  import json
 
132
  gen_events = int(gen_events)
133
  max_len = gen_events
134
  if seed_rand:
135
+ seed = random.randint(0, MAX_SEED)
136
  generator = torch.Generator(opt.device).manual_seed(seed)
137
  disable_patch_change = False
138
  disable_channels = None
 
401
  with gr.Accordion("options", open=False):
402
  input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
403
  input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.98)
404
+ input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=30)
405
  input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
406
  example3 = gr.Examples([[1, 0.95, 128], [1, 0.98, 20], [1, 0.98, 12]],
407
  [input_temp, input_top_p, input_top_k])