#========================================================================= # https://huggingface.co./spaces/asigalov61/Score-2-Performance-Transformer #========================================================================= import os import time as reqtime import datetime from pytz import timezone import copy from itertools import groupby import tqdm import spaces import gradio as gr import torch from x_transformer_1_23_2 import * import random import TMIDIX from midi_to_colab_audio import midi_to_colab_audio from huggingface_hub import hf_hub_download # ================================================================================================= print('Loading model...') SEQ_LEN = 1802 PAD_IDX = 771 DEVICE = 'cuda' # 'cpu' # instantiate the model model = TransformerWrapper( num_tokens = PAD_IDX+1, max_seq_len = SEQ_LEN, attn_layers = Decoder(dim = 1024, depth = 8, heads = 8, rotary_pos_emb=True, attn_flash = True ) ) model = AutoregressiveWrapper(model, ignore_index = PAD_IDX) print('=' * 70) print('Loading model checkpoint...') model_checkpoint = hf_hub_download(repo_id='asigalov61/Score-2-Performance-Transformer', filename='Score_2_Performance_Transformer_Final_Small_Trained_Model_4496_steps_1.5185_loss_0.5589_acc.pth' ) model.load_state_dict(torch.load(model_checkpoint, map_location='cpu', weights_only=True)) model = torch.compile(model, mode='max-autotune') dtype = torch.bfloat16 ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype) print('=' * 70) print('Done!') print('=' * 70) # ================================================================================================= def load_midi(midi_file): print('Loading MIDI...') raw_score = TMIDIX.midi2single_track_ms_score(midi_file) escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True) if escore_notes[0]: escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes[0], timings_divider=16) pe = escore_notes[0] melody_chords = [] seen = [] for e in escore_notes: if e[3] != 9: #======================================================= dtime = max(0, min(255, e[1]-pe[1])) if dtime != 0: seen = [] # Durations dur = max(1, min(255, e[2])) # Pitches ptc = max(1, min(127, e[4])) vel = max(1, min(127, e[5])) if ptc not in seen: melody_chords.append([dtime, dur, ptc, vel]) seen.append(ptc) pe = e print('=' * 70) print('Number of notes in a composition:', len(melody_chords)) print('=' * 70) src_melody_chords_f = [] for i in range(0, len(melody_chords), 150): chunk = melody_chords[i:i+300] src = [] for mm in chunk: src.append([mm[0], mm[2]+256, mm[1]+384, mm[3]+640]) clen = len(src) if clen < 300: chunk_mult = (300 // clen) + 1 src += src * chunk_mult src_melody_chords_f.append([clen, src[:300]]) print('Done!') print('=' * 70) print('Number of composition chunks:', len(src_melody_chords_f)) print('=' * 70) return src_melody_chords_f # ================================================================================================= @spaces.GPU def Convert_Score_to_Performance(input_midi, input_midi_type, input_conv_type, input_number_prime_notes, input_number_conv_notes, input_model_dur_top_k, input_model_dur_temperature, input_model_vel_temperature ): #=============================================================================== print('=' * 70) print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) start_time = reqtime.time() print('=' * 70) fn = os.path.basename(input_midi) fn1 = fn.split('.')[0] print('=' * 70) print('Requested settings:') print('=' * 70) print('Input MIDI file name:', fn) print('Input MIDI type:', input_midi_type) print('Conversion type:', input_conv_type) print('Number of prime notes:', input_number_prime_notes) print('Number of notes to convert:', input_number_conv_notes) print('Model durations sampling top value:', input_model_dur_top_k) print('Model durations temperature:', input_model_dur_temperature) print('Model velocities temperature:', input_model_vel_temperature) print('=' * 70) #================================================================== src_melody_chords_f = load_midi(input_midi.name) #================================================================== print('Sample output events', src_melody_chords_f[0][1][:3]) print('=' * 70) print('Generating...') model.to(DEVICE) model.eval() #================================================================== num_prime_notes = input_number_prime_notes # Priming improves the results but it is not necessary and you can set it to zero dur_top_k = input_model_dur_top_k # Use k == 1 if src composition is score and k > 1 if src composition is performance dur_temperature = input_model_dur_temperature # For best results, durations temperature should be more than 1.0 but less than velocities temperature vel_temperature = input_model_vel_temperature # For best results, velocities temperature must be larger than 1.3 and larger than durations temperature #================================================================== if input_midi_type == 'Score': dur_top_k = 1 dur_temperature = 1.1 vel_temperature = 1.5 elif input_midi_type == 'Performance': dur_top_k = 100 dur_temperature = 1.5 vel_temperature = 1.9 else: dur_top_k = input_model_dur_top_k # Use k == 1 if src composition is score and k > 1 if src composition is performance dur_temperature = input_model_dur_temperature # For best results, durations temperature should be more than 1.0 but less than velocities temperature vel_temperature = input_model_vel_temperature final_song = [] for cc, (song_chunk_len, song_chunk) in enumerate(src_melody_chords_f): print('=' * 70) print('Rendering song chunk #', cc) print('=' * 70) #======================================================================== song = [768] if cc == 0: for m in song_chunk: song.extend(m[:2]) song.append(769) sidx = 0 eidx = 300 else: for m in song_chunk[:150]: psrc.extend(m[:2]) psrc.append(769) song = copy.deepcopy(psrc + ptrg) sidx = 150 eidx = 300 #======================================================================== for i in tqdm.tqdm(range(sidx, eidx)): song.extend(song_chunk[i][:2]) if 'Durations' in input_conv_type: if i < num_prime_notes and cc == 0: song.append(song_chunk[i][2]) else: # Durations x = torch.LongTensor(song).cuda() y = 0 while not 384 < y < 640: with ctx: out = model.generate(x, 1, temperature=dur_temperature, filter_logits_fn=top_k, filter_kwargs={'k': dur_top_k}, return_prime=False, verbose=False) y = out.tolist()[0][0] song.append(y) else: song.append(song_chunk[i][2]) #======================================================================== if 'Velocities' in input_conv_type: if i < num_prime_notes and cc == 0: song.append(song_chunk[i][3]) else: # Velocities x = torch.LongTensor(song).cuda() y = 0 while not 640 < y < 768: with ctx: out = model.generate(x, 1, temperature=vel_temperature, return_prime=False, verbose=False) y = out.tolist()[0][0] song.append(y) else: song.append(song_chunk[i][3]) #======================================================================== if cc == 0: final_song.extend(song[602:][:(song_chunk_len * 4)]) else: final_song.extend(song[602:][600:(song_chunk_len * 4)]) psrc = copy.deepcopy(song[1:301]) ptrg = copy.deepcopy(song[602:][:600]) #======================================================================== if len(final_song) >= input_number_conv_notes * 4: break #======================================================================== print('=' * 70) print('Done!') print('=' * 70) #=============================================================================== print('Rendering results...') print('=' * 70) print('Sample INTs', final_song[:15]) print('=' * 70) song_f = [] if len(final_song) != 0: time = 0 dur = 0 vel = 90 pitch = 60 channel = 0 patch = 0 patches = [0] * 16 for ss in final_song: if 0 <= ss < 256: time += ss * 16 if 256 <= ss < 384: pitch = ss-256 if 384 <= ss < 640: dur = (ss-384) * 16 if 640 <= ss < 768: vel = (ss-640) song_f.append(['note', time, dur, channel, pitch, vel, patch]) fn1 = "Score-2-Performance-Transformer-Composition" detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f, output_signature = 'Score 2 Performance Transformer', output_file_name = fn1, track_name='Project Los Angeles', list_of_MIDI_patches=patches ) new_fn = fn1+'.mid' audio = midi_to_colab_audio(new_fn, soundfont_path=soundfont, sample_rate=16000, volume_scale=10, output_for_gradio=True ) print('Done!') print('=' * 70) #======================================================== output_midi_title = str(fn1) output_midi_summary = str(song_f[:3]) output_midi = str(new_fn) output_audio = (16000, audio) output_plot = TMIDIX.plot_ms_SONG(song_f, plot_title=output_midi, return_plt=True) print('Output MIDI file name:', output_midi) print('Output MIDI title:', output_midi_title) print('Output MIDI summary:', output_midi_summary) print('=' * 70) #======================================================== print('-' * 70) print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('-' * 70) print('Req execution time:', (reqtime.time() - start_time), 'sec') return output_midi_title, output_midi_summary, output_midi, output_audio, output_plot # ================================================================================================= if __name__ == "__main__": PDT = timezone('US/Pacific') print('=' * 70) print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('=' * 70) soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2" app = gr.Blocks() with app: gr.Markdown("

Score 2 Performance Transformer

") gr.Markdown("

Convert any MIDI score to a nice performance

") gr.Markdown("## Upload your MIDI or select a sample example MIDI below") input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"]) gr.Markdown("## Select MIDI type") input_midi_type = gr.Radio(["Score", "Performance", "Custom"], value="Score", label="Input MIDI type", info="Select 'Custom' option to enable model top_k and temperature settings below" ) gr.Markdown("## Select conversion type") input_conv_type = gr.Radio(["Durations and Velocities", "Durations", "Velocities"], value="Durations and Velocities", label="Conversion type" ) gr.Markdown("## Conversion options") input_number_prime_notes = gr.Slider(0, 512, value=0, step=8, label="Number of prime notes") input_number_conv_notes = gr.Slider(8, 2048, value=512, step=8, label="Number of notes to convert") gr.Markdown("## Custom MIDI type model options") input_model_dur_top_k = gr.Slider(1, 100, value=1, step=1, label="Model sampling top k value for durations") input_model_dur_temperature = gr.Slider(0.5, 1.5, value=1.1, step=0.05, label="Model temperature for durations") input_model_vel_temperature = gr.Slider(0.5, 1.5, value=1.5, step=0.05, label="Model temperature for velocities") run_btn = gr.Button("convert", variant="primary") gr.Markdown("## Generation results") output_midi_title = gr.Textbox(label="Output MIDI title") output_midi_summary = gr.Textbox(label="Output MIDI summary") output_audio = gr.Audio(label="Output MIDI audio", format="wav", elem_id="midi_audio") output_plot = gr.Plot(label="Output MIDI score plot") output_midi = gr.File(label="Output MIDI file", file_types=[".mid"]) run_event = run_btn.click(Convert_Score_to_Performance, [input_midi, input_midi_type, input_conv_type, input_number_prime_notes, input_number_conv_notes, input_model_dur_top_k, input_model_dur_temperature, input_model_vel_temperature ], [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot]) gr.Examples( [["asap_midi_score_21.mid", "Score", "Durations and Velocities", 8, 600, 1, 1.1, 1.5], ["asap_midi_score_45.mid", "Score", "Durations and Velocities", 8, 600, 1, 1.1, 1.5], ["asap_midi_score_69.mid", "Score", "Durations and Velocities", 8, 600, 1, 1.1, 1.5], ["asap_midi_score_118.mid", "Score", "Durations and Velocities", 8, 600, 1, 1.1, 1.5], ["asap_midi_score_167.mid", "Score", "Durations and Velocities", 8, 600, 1, 1.1, 1.5], ], [input_midi, input_midi_type, input_conv_type, input_number_prime_notes, input_number_conv_notes, input_model_dur_top_k, input_model_dur_temperature, input_model_vel_temperature ], [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot], Convert_Score_to_Performance ) app.queue().launch()