File size: 11,546 Bytes
87ae0b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
import copy
import note_seq
from PIL import Image
import tempfile
import os
import colorama
from omegaconf import DictConfig, OmegaConf
import torch
from typing import List, Tuple, Dict
from dacite import from_dict
from collections.abc import MutableMapping
import sys


# NOTE: Imported from helibrunna.
def display_logo():
    """
    Display the logo by printing it line by line with a cyberpunk color scheme.

    Raises:
        FileNotFoundError: If the logo file is missing.
    """

    # Get the path of this script and use it to find the logo.
    script_path = os.path.dirname(os.path.realpath(__file__))
    search_path = os.path.dirname(script_path)

    # Load the logo.
    logo_path = os.path.join(search_path, "assets", "asciilogo.txt")
    if not os.path.exists(logo_path):
        raise FileNotFoundError("The logo file is missing.")
    with open(logo_path, "r") as f:
        logo = f.read()

    # Print the logo line by line. Use colorama to colorize the output. Use a cyberpunk color scheme.
    for line_index, line in enumerate(logo.split("\n")):
        color = colorama.Fore.GREEN
        style = colorama.Style.BRIGHT if line_index % 2 == 0 else colorama.Style.NORMAL
        print(color + style + line)
    print(colorama.Style.RESET_ALL)


# NOTE: Imported from helibrunna.
def model_from_config(model_config: DictConfig, device:str) -> torch.nn.Module:
    """
    Create a model based on the provided model configuration.

    Args:
        model_config (DictConfig): The configuration for the model.

    Returns:
        The created model.

    Raises:
        ValueError: If the model type is unknown.
    """
    
    # Get the model type from the configuration.
    model_type = model_config.get("type", "xLSTMLMModel")
    
    # Create the xLSTMLMModel.
    if model_type == "xLSTMLMModel":
        print("Creating xLSTMLMModel...")
        from xlstm.xlstm_lm_model import xLSTMLMModel, xLSTMLMModelConfig
        
        # If there is no GPU, use the vanilla backend.
        if not torch.cuda.is_available():
            #model_config.backend = "vanilla"
            model_config.slstm_block.slstm.backend = "vanilla"
            model_config.mlstm_block.mlstm.backend = "vanilla"
        model_config_object = from_dict(xLSTMLMModelConfig, OmegaConf.to_container(model_config))
        
        # Create the model.
        model = xLSTMLMModel(model_config_object)
        model.reset_parameters()
    
    # Create the GPT2LMModel.
    elif model_type == "gpt2":
        print("Creating GPT2LMModel...")
        from .models.gpttwo import GPT2LMModel, GPT2LMModelConfig
        model_config_object = from_dict(GPT2LMModelConfig, OmegaConf.to_container(model_config))
        model = GPT2LMModel(model_config_object)
    
    # Create the MambaLM.
    elif model_type == "mamba":
        print("Creating Mamba LM...")
        from mambapy.lm import LM, MambaConfig
        model_config_object = from_dict(MambaConfig, OmegaConf.to_container(model_config))
        model = LM(model_config_object, model_config.vocab_size)
    
    # Create the Transformer.
    elif model_type == "transformer":
        from .models.transformer import TransformerConfig, Transformer
        model_config_object = from_dict(TransformerConfig, OmegaConf.to_container(model_config))
        model = Transformer(model_config_object)
    
    # Create a Pharia instance.
    elif model_type == "pharia":
        from .models.pharia import PhariaConfig, PhariaModel
        model_config_object = from_dict(PhariaConfig, OmegaConf.to_container(model_config))
        model = PhariaModel(model_config_object)
    
    # Create a TransformerXL instance.
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    # Move the model to the device.
    model.to(device)
    return model


def convert_tokens_to_songdata(tokens):

    if isinstance(tokens, str):
        tokens = tokens.split()

    song_data = {}

    song_data["tracks"] = []

    current_track_index = 0
    current_timestep = 0
    for token in tokens:
        if token == "GARLAND_START":
            pass
        elif token == "BAR_START":
            if current_track_index == len(song_data["tracks"]):
                song_data["tracks"] += [{"bars": [], "instrument": "0"}]
            bar_data = {"notes": []}
            song_data["tracks"][current_track_index]["bars"] += [bar_data]
            current_timestep = 0
        elif token.startswith("INST="):
            instrument = token.split("=")[1]
            song_data["tracks"][current_track_index]["instrument"] = instrument
        elif token.startswith("DENSITY="):
            pass
        elif token.startswith("NOTE_ON="):
            note_pitch = int(token.split("=")[1])
            note_data = {
                "note": note_pitch,
                "start": current_timestep,
                "end": current_timestep,
                "veloctiy": 80
            }
            song_data["tracks"][current_track_index]["bars"][-1]["notes"] += [note_data]
            pass
        elif token.startswith("TIME_DELTA="):
            current_timestep += int(token.split("=")[1])
        elif token.startswith("NOTE_OFF="):
            note_pitch = int(token.split("=")[1])
            for note_data in song_data["tracks"][current_track_index]["bars"][-1]["notes"]:
                if note_data["note"] == note_pitch and note_data["start"] == note_data["end"]:
                    note_data["end"] = current_timestep
                    break
            pass
        elif token == "BAR_END":
            current_track_index += 1
        elif token == "NEXT":
            current_track_index = 0
        elif token == "GARLAND_END":
            pass
        elif token == "[PAD]":
            pass
        elif token == "[EOS]":
            pass
        else:
            raise Exception(f"Unknown token: {token}")
    
    assert isinstance(song_data, dict)
    return song_data


def convert_songdata_to_notesequence(song_data:dict, quantize_steps_per_quarter=8, remove_disabled_tracks=True):

    assert isinstance(song_data, dict), f"Invalid song data type: {type(song_data)}"

    # Clone the song data.
    song_data = copy.deepcopy(song_data)

    # Sort the tracks by instrument.
    assert "tracks" in song_data, f"Invalid song data: {song_data.keys()}"
    tracks = sorted(song_data["tracks"], key=lambda t: t["instrument"])
    song_data["tracks"] = tracks

    # Remove tracks that are not enabled.
    if remove_disabled_tracks:
        song_data["tracks"] = [t for t in song_data["tracks"] if t.get("enabled", True)]

    # Create an empy note sequence.
    note_sequence = note_seq.protobuf.music_pb2.NoteSequence()

    # Add the tempo.
    bpm = song_data["bpm"] if "bpm" in song_data else 120
    note_sequence.tempos.add().qpm = bpm

    # Compute some lengths.
    step_length_seconds = 60.0 / bpm / quantize_steps_per_quarter
    bar_length_seconds = 4 * step_length_seconds * quantize_steps_per_quarter

    # Get the instruments.
    instruments = list(set([t["instrument"] for t in song_data["tracks"]]))

    # Add the tracks.
    for track_index, track_data in enumerate(song_data["tracks"]):
        instrument = track_data["instrument"]
        for bar_index, bar_data in enumerate(track_data["bars"]):
            bar_start_time = bar_index * bar_length_seconds
            for note_data in bar_data["notes"]:
                assert "note" in note_data
                assert "start" in note_data
                assert "end" in note_data
                note = note_sequence.notes.add()
                #note.instrument = instrument TODO
                note.pitch = note_data["note"]
                note.start_time = note_data["start"] * step_length_seconds + bar_start_time
                note.end_time = note_data["end"] * step_length_seconds + bar_start_time
                if "velocity" in note_data:
                    note.velocity = note_data["velocity"]
                else:
                    note.velocity = 80
                note.instrument = track_index
                if instrument == "drums":
                    note.is_drum = True
                else:
                    note.is_drum = False
                    note.program = int(instrument)

    return note_sequence


def convert_songdata_to_pianoroll(song_data):

    # The bars are 4/4 and the quantization is 8 steps per quarter, aka 32 steps per bar.
    # We will render a grid. The height is 64 pixels. The width is 32 pixels per bar

    # Create a new image.
    lengths = [len(track["bars"]) for track in song_data["tracks"]]
    if lengths == []:
        return None
    assert len(set(lengths)) == 1, f"Unequal number of bars: {lengths}"
    num_bars = lengths[0]

    # Get the note extremes.
    min_note = 128
    max_note = 0
    for track_data in song_data["tracks"]:
        for bar_data in track_data["bars"]:
            for note_data in bar_data["notes"]:
                min_note = min(min_note, note_data["note"])
                max_note = max(max_note, note_data["note"])

    # The width depends on the bars.
    width = 32 * num_bars
    
    # The width depends on the notes.
    height = 1 + max_note - min_note

    # Create the image.
    image = Image.new("RGB", (width, height), (14, 17, 23))

    # Define some colors.
    base_color = (255, 75, 75)
    adjustments = [1.2, 1.0, 0.8, 0.6]
    colors = []
    for adjustment in adjustments:
        import colorsys
        rgb = base_color
        rgb = [float(c) / 255.0 for c in rgb]
        hsv = colorsys.rgb_to_hsv(*rgb)
        # Rotate the hue.
        offset = (adjustment - 1.0) * 0.1
        hsv = (hsv[0] + offset, hsv[1], hsv[2])
        rgb = colorsys.hsv_to_rgb(*hsv)
        rgb = tuple([int(255.0 * c) for c in rgb])
        colors += [rgb]
        print("")

    for color in colors:
        print(color)
        


    # Draw the grid.
    for track_index, track_data in enumerate(song_data["tracks"]):
        color = colors[track_index % len(colors)]
        for bar_index, bar_data in enumerate(track_data["bars"]):
            x = bar_index * 32
            
            for note_data in bar_data["notes"]:
                y = max_note - note_data["note"]
                assert y >= 0 and y < height, f"Invalid y: {y}, note {note_data['note']} min_note: {min_note}, max_note: {max_note}, difference: {max_note - min_note}, height: {height}"
                for i in range(note_data["start"], note_data["end"]):
                    image.putpixel((x + i, y), color)

    # Resize the image. Use nearest neighbor for pixel art.
    factor = 4
    image = image.resize((width * factor, height * factor), Image.NEAREST)

    return image


def convert_notesequence_to_wave(note_sequence):

    if len(note_sequence.notes) == 0:
        return None

    try:
        synthesizer = note_seq.fluidsynth
        wave = synthesizer(note_sequence, sample_rate=44100)
        return wave
    except Exception as e:
        synthesizer = note_seq.synthesize
        wave = synthesizer(note_sequence)
        return wave
    

def convert_notesequence_to_midi(note_sequence, filename="output.mid"):

    if len(note_sequence.notes) == 0:
        return None

    # Returns the file content of the midi file.
    with tempfile.NamedTemporaryFile(delete=False) as temp_file:
        filename = temp_file.name
        note_seq.sequence_proto_to_midi_file(note_sequence, filename)
        with open(filename, "rb") as file:
            content = file.read()
    return content