File size: 3,782 Bytes
3888ab7
 
 
 
 
 
2278710
3888ab7
 
 
 
 
 
 
 
c04453c
 
0ef2079
58881e0
c04453c
58881e0
819a468
d85cb0d
8c8ea80
3888ab7
8c8ea80
4b3e8cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eba33c8
 
41327e9
3888ab7
1801257
 
eba33c8
3888ab7
 
 
 
 
 
eec4853
3888ab7
eec4853
1aa6fb6
eec4853
 
3888ab7
 
 
1aa6fb6
3888ab7
e469266
 
 
 
 
3888ab7
 
1aa6fb6
 
 
 
3888ab7
1aa6fb6
3888ab7
 
1aa6fb6
3888ab7
 
56ab42f
 
 
3888ab7
 
 
 
56ab42f
3888ab7
56ab42f
 
c04453c
3888ab7
 
4b3e8cb
 
c04453c
8c8ea80
4b3e8cb
eba33c8
d85cb0d
7dd6e93
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import argparse
import glob
import os.path

import gradio as gr

import pickle
import tqdm
import json

import MIDI
from midi_synthesizer import synthesis

in_space = os.getenv("SYSTEM") == "spaces"

def run(search_prompt, mid=None, progress=gr.Progress()):
    
    if mid == None:
        
        for m in progress.tqdm(meta_data):
            mid_seq = m[1][17:-1]
            mid_seq_ticks = m[1][16][1]
            mdata = m[1][:16]
       
    elif mid is not None:
        mid_seq = MIDI.midi2score(mid)


    x = []
    y = []
    c = []
    
    colors = ['red', 'yellow', 'green', 'cyan',
            'blue', 'pink', 'orange', 'purple',
            'gray', 'white', 'gold', 'silver',
            'lightgreen', 'indigo', 'maroon', 'turquoise']
    
    for s in [m for m in mid_seq if m[0] == 'note']:
        x.append(s[1] / mid_seq_ticks)
        y.append(s[4])
        c.append(colors[s[3]])

    plot = gr.ScatterPlot(x=x, y=y, color=c)
 
    with open(f"output.mid", 'wb') as f:
        f.write(MIDI.score2midi([mid_seq_ticks, mid_seq]))
    audio = synthesis(MIDI.score2opus([mid_seq_ticks, mid_seq]), soundfont_path)
    yield mdata, "output.mid", (44100, audio), plot

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
    parser.add_argument("--port", type=int, default=7860, help="gradio server port")
    parser.add_argument("--max-gen", type=int, default=1024, help="max")
    
    opt = parser.parse_args()
    
    soundfont_path = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
    meta_data_path = "meta-data/LAMD_META_10000.pickle"
    
    models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
                   "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
                   "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"]}


    print('Loading meta-data...')
    with open(meta_data_path, 'rb') as f:
        meta_data = pickle.load(f)
    print('Done!')
    
    app = gr.Blocks()
    with app:
        gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Search</h1>")
        gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.MIDI-Search&style=flat)\n\n"
                    "MIDI Search and Explore\n\n"
                    "Demo for [MIDI Search](https://github.com/asigalov61)\n\n"
                    "[Open In Colab]"
                    "(https://colab.research.google.com/github/asigalov61/MIDI-Search/blob/main/demo.ipynb)"
                    " for faster running and longer generation"
                    )
        
        with gr.Tabs():
            with gr.TabItem("instrument prompt") as tab1:
                
                search_prompt = gr.Textbox(label="search prompt")
                
            with gr.TabItem("midi prompt") as tab2:
                input_midi = gr.File(label="input midi", file_types=[".midi", ".mid"], type="binary")

        with gr.Accordion("options", open=False):
 
            input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
            
        search_btn = gr.Button("search", variant="primary")
        
        output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
        output_midi = gr.File(label="output midi", file_types=[".mid"])
        output_midi_seq = gr.Textbox(label="output midi metadata")
        output_plot = gr.ScatterPlot(label="output midi plot")
        
        run_event = search_btn.click(run, [search_prompt],
                                  [output_midi_seq, output_midi, output_audio, output_plot])
        gr.show(output_plot)
        
    app.queue(1).launch(server_port=opt.port, share=opt.share, inbrowser=True)