Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,231 Bytes
133ccd4 96007f4 8453f63 133ccd4 de46ee3 133ccd4 df03c6b 8453f63 3dd0e68 8453f63 5b360d1 8453f63 5b360d1 e229c68 3dd0e68 5b360d1 3dd0e68 8453f63 c80330f 8453f63 2ffbebd c80330f 2ffbebd 3dd0e68 2ffbebd c80330f 8453f63 e229c68 186f625 94706c2 e229c68 2ffbebd df03c6b e229c68 df03c6b 23c274d df03c6b a8c784b df03c6b 3dd0e68 aa942ba 3dd0e68 f4019fd de46ee3 23c274d de46ee3 23c274d de46ee3 186f625 2ffbebd 186f625 01094e0 186f625 de46ee3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import argparse
import glob
import os.path
import torch
import torch.nn.functional as F
import gradio as gr
import numpy as np
import onnxruntime as rt
import tqdm
import json
from midi_synthesizer import synthesis
import TMIDIX
in_space = os.getenv("SYSTEM") == "spaces"
#=================================================================================================
def generate(
start_tokens,
seq_len,
max_seq_len = 2048,
temperature = 0.9,
verbose=False,
return_prime=False,
progress=gr.Progress()):
out = torch.LongTensor([start_tokens])
st = len(start_tokens)
if verbose:
print("Generating sequence of max length:", seq_len)
progress(0, desc="Starting...")
for i in progress.tqdm.tqdm(range(seq_len)):
x = out[:, -max_seq_len:]
torch_in = x.tolist()[0]
logits = torch.FloatTensor(session.run(None, {'input': [torch_in]})[0])[:, -1]
filtered_logits = logits
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
if return_prime:
return out[:, :]
else:
return out[:, st:]
#=================================================================================================
def create_msg(name, data):
return {"name": name, "data": data}
def GenerateMIDI():
melody_chords_f = generate([3087, 3073+1, 3075+1], 512)
melody_chords_f = melody_chords_f.tolist()[0]
print('=' * 70)
print('Sample INTs', melody_chords_f[:12])
print('=' * 70)
if len(melody_chords_f) != 0:
song = melody_chords_f
song_f = []
time = 0
dur = 0
vel = 0
pitch = 0
channel = 0
for ss in song:
if ss > 0 and ss < 256:
time += ss * 8
if ss >= 256 and ss < 1280:
dur = ((ss-256) // 8) * 32
vel = (((ss-256) % 8)+1) * 15
if ss >= 1280 and ss < 2816:
channel = (ss-1280) // 128
pitch = (ss-1280) % 128
song_f.append(['note', time, dur, channel, pitch, vel ])
output_signature = 'Allegro Music Transformer'
output_file_name = 'Allegro-Music-Transformer-Music-Composition'
track_name='Project Los Angeles'
list_of_MIDI_patches=[0, 24, 32, 40, 42, 46, 56, 71, 73, 0, 53, 19, 0, 0, 0, 0]
number_of_ticks_per_quarter=500
text_encoding='ISO-8859-1'
output_header = [number_of_ticks_per_quarter,
[['track_name', 0, bytes(output_signature, text_encoding)]]]
patch_list = [['patch_change', 0, 0, list_of_MIDI_patches[0]],
['patch_change', 0, 1, list_of_MIDI_patches[1]],
['patch_change', 0, 2, list_of_MIDI_patches[2]],
['patch_change', 0, 3, list_of_MIDI_patches[3]],
['patch_change', 0, 4, list_of_MIDI_patches[4]],
['patch_change', 0, 5, list_of_MIDI_patches[5]],
['patch_change', 0, 6, list_of_MIDI_patches[6]],
['patch_change', 0, 7, list_of_MIDI_patches[7]],
['patch_change', 0, 8, list_of_MIDI_patches[8]],
['patch_change', 0, 9, list_of_MIDI_patches[9]],
['patch_change', 0, 10, list_of_MIDI_patches[10]],
['patch_change', 0, 11, list_of_MIDI_patches[11]],
['patch_change', 0, 12, list_of_MIDI_patches[12]],
['patch_change', 0, 13, list_of_MIDI_patches[13]],
['patch_change', 0, 14, list_of_MIDI_patches[14]],
['patch_change', 0, 15, list_of_MIDI_patches[15]],
['track_name', 0, bytes(track_name, text_encoding)]]
output = output_header + [patch_list + song_f]
midi_data = TMIDIX.score2midi(output, text_encoding)
with open(f"Allegro-Music-Transformer-Music-Composition.mid", 'wb') as f:
f.write(midi_data)
audio = synthesis(TMIDIX.score2opus(output), 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2')
yield output, "Allegro-Music-Transformer-Music-Composition.mid", (44100, audio), [create_msg("visualizer_end", None)]
#=================================================================================================
def cancel_run(output_midi_seq):
if output_midi_seq is None:
return None, None
with open(f"Allegro-Music-Transformer-Music-Composition.mid", 'wb') as f:
f.write(TMIDIX.score2midi(output_midi_seq))
audio = synthesis(TMIDIX.score2opus(output_midi_seq), 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2')
return "Allegro-Music-Transformer-Music-Composition.mid", (44100, audio), [create_msg("visualizer_end", None)]
def load_javascript(dir="javascript"):
scripts_list = glob.glob(f"app.js")
javascript = ""
for path in scripts_list:
with open(path, "r", encoding="utf8") as jsfile:
javascript += f"\n<!-- {path} --><script>{jsfile.read()}</script>"
template_response_ori = gr.routes.templates.TemplateResponse
def template_response(*args, **kwargs):
res = template_response_ori(*args, **kwargs)
res.body = res.body.replace(
b'</head>', f'{javascript}</head>'.encode("utf8"))
res.init_headers()
return res
gr.routes.templates.TemplateResponse = template_response
class JSMsgReceiver(gr.HTML):
def __init__(self, **kwargs):
super().__init__(elem_id="msg_receiver", visible=False, **kwargs)
def postprocess(self, y):
if y:
y = f"<p>{json.dumps(y)}</p>"
return super().postprocess(y)
def get_block_name(self) -> str:
return "html"
#=================================================================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
opt = parser.parse_args()
print('Loading model...')
session = rt.InferenceSession('Allegro_Music_Transformer_Small_Trained_Model_56000_steps_0.9399_loss_0.7374_acc.onnx', providers=['CUDAExecutionProvider'])
print('Done!')
load_javascript()
app = gr.Blocks()
with app:
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Allegro Music Transformer</h1>")
gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Allegro-Music-Transformer&style=flat)\n\n"
"Full-attention multi-instrumental music transformer featuring asymmetrical encoding with octo-velocity, and chords counters tokens, optimized for speed and performance\n\n"
"Check out [Allegro Music Transformer](https://github.com/asigalov61/Allegro-Music-Transformer) on GitHub!\n\n"
"[Open In Colab]"
"(https://colab.research.google.com/github/asigalov61/Allegro-Music-Transformer/blob/main/Allegro_Music_Transformer_Composer.ipynb)"
" for faster execution and endless generation"
)
js_msg = JSMsgReceiver()
run_btn = gr.Button("generate", variant="primary")
stop_btn = gr.Button("stop and output")
output_midi_seq = gr.Variable()
output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
output_midi = gr.File(label="output midi", file_types=[".mid"])
run_event = run_btn.click(GenerateMIDI, [], [output_midi_seq, output_midi, output_audio, js_msg])
stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
app.queue(2).launch(server_port=opt.port, share=opt.share, inbrowser=True) |