Spaces:
Sleeping
Sleeping
import copy | |
import note_seq | |
from PIL import Image | |
import tempfile | |
import os | |
import colorama | |
from omegaconf import DictConfig, OmegaConf | |
import torch | |
from typing import List, Tuple, Dict | |
from dacite import from_dict | |
from collections.abc import MutableMapping | |
import sys | |
# NOTE: Imported from helibrunna. | |
def display_logo(): | |
""" | |
Display the logo by printing it line by line with a cyberpunk color scheme. | |
Raises: | |
FileNotFoundError: If the logo file is missing. | |
""" | |
# Get the path of this script and use it to find the logo. | |
script_path = os.path.dirname(os.path.realpath(__file__)) | |
search_path = os.path.dirname(script_path) | |
# Load the logo. | |
logo_path = os.path.join(search_path, "assets", "asciilogo.txt") | |
if not os.path.exists(logo_path): | |
raise FileNotFoundError("The logo file is missing.") | |
with open(logo_path, "r") as f: | |
logo = f.read() | |
# Print the logo line by line. Use colorama to colorize the output. Use a cyberpunk color scheme. | |
for line_index, line in enumerate(logo.split("\n")): | |
color = colorama.Fore.GREEN | |
style = colorama.Style.BRIGHT if line_index % 2 == 0 else colorama.Style.NORMAL | |
print(color + style + line) | |
print(colorama.Style.RESET_ALL) | |
# NOTE: Imported from helibrunna. | |
def model_from_config(model_config: DictConfig, device:str) -> torch.nn.Module: | |
""" | |
Create a model based on the provided model configuration. | |
Args: | |
model_config (DictConfig): The configuration for the model. | |
Returns: | |
The created model. | |
Raises: | |
ValueError: If the model type is unknown. | |
""" | |
# Get the model type from the configuration. | |
model_type = model_config.get("type", "xLSTMLMModel") | |
# Create the xLSTMLMModel. | |
if model_type == "xLSTMLMModel": | |
print("Creating xLSTMLMModel...") | |
from xlstm.xlstm_lm_model import xLSTMLMModel, xLSTMLMModelConfig | |
# If there is no GPU, use the vanilla backend. | |
if not torch.cuda.is_available(): | |
#model_config.backend = "vanilla" | |
model_config.slstm_block.slstm.backend = "vanilla" | |
model_config.mlstm_block.mlstm.backend = "vanilla" | |
model_config_object = from_dict(xLSTMLMModelConfig, OmegaConf.to_container(model_config)) | |
# Create the model. | |
model = xLSTMLMModel(model_config_object) | |
model.reset_parameters() | |
# Create the GPT2LMModel. | |
elif model_type == "gpt2": | |
print("Creating GPT2LMModel...") | |
from .models.gpttwo import GPT2LMModel, GPT2LMModelConfig | |
model_config_object = from_dict(GPT2LMModelConfig, OmegaConf.to_container(model_config)) | |
model = GPT2LMModel(model_config_object) | |
# Create the MambaLM. | |
elif model_type == "mamba": | |
print("Creating Mamba LM...") | |
from mambapy.lm import LM, MambaConfig | |
model_config_object = from_dict(MambaConfig, OmegaConf.to_container(model_config)) | |
model = LM(model_config_object, model_config.vocab_size) | |
# Create the Transformer. | |
elif model_type == "transformer": | |
from .models.transformer import TransformerConfig, Transformer | |
model_config_object = from_dict(TransformerConfig, OmegaConf.to_container(model_config)) | |
model = Transformer(model_config_object) | |
# Create a Pharia instance. | |
elif model_type == "pharia": | |
from .models.pharia import PhariaConfig, PhariaModel | |
model_config_object = from_dict(PhariaConfig, OmegaConf.to_container(model_config)) | |
model = PhariaModel(model_config_object) | |
# Create a TransformerXL instance. | |
else: | |
raise ValueError(f"Unknown model type: {model_type}") | |
# Move the model to the device. | |
model.to(device) | |
return model | |
def convert_tokens_to_songdata(tokens): | |
if isinstance(tokens, str): | |
tokens = tokens.split() | |
song_data = {} | |
song_data["tracks"] = [] | |
current_track_index = 0 | |
current_timestep = 0 | |
for token in tokens: | |
if token == "GARLAND_START": | |
pass | |
elif token == "BAR_START": | |
if current_track_index == len(song_data["tracks"]): | |
song_data["tracks"] += [{"bars": [], "instrument": "0"}] | |
bar_data = {"notes": []} | |
song_data["tracks"][current_track_index]["bars"] += [bar_data] | |
current_timestep = 0 | |
elif token.startswith("INST="): | |
instrument = token.split("=")[1] | |
song_data["tracks"][current_track_index]["instrument"] = instrument | |
elif token.startswith("DENSITY="): | |
pass | |
elif token.startswith("NOTE_ON="): | |
note_pitch = int(token.split("=")[1]) | |
note_data = { | |
"note": note_pitch, | |
"start": current_timestep, | |
"end": current_timestep, | |
"veloctiy": 80 | |
} | |
song_data["tracks"][current_track_index]["bars"][-1]["notes"] += [note_data] | |
pass | |
elif token.startswith("TIME_DELTA="): | |
current_timestep += int(token.split("=")[1]) | |
elif token.startswith("NOTE_OFF="): | |
note_pitch = int(token.split("=")[1]) | |
for note_data in song_data["tracks"][current_track_index]["bars"][-1]["notes"]: | |
if note_data["note"] == note_pitch and note_data["start"] == note_data["end"]: | |
note_data["end"] = current_timestep | |
break | |
pass | |
elif token == "BAR_END": | |
current_track_index += 1 | |
elif token == "NEXT": | |
current_track_index = 0 | |
elif token == "GARLAND_END": | |
pass | |
elif token == "[PAD]": | |
pass | |
elif token == "[EOS]": | |
pass | |
else: | |
raise Exception(f"Unknown token: {token}") | |
assert isinstance(song_data, dict) | |
return song_data | |
def convert_songdata_to_notesequence(song_data:dict, quantize_steps_per_quarter=8, remove_disabled_tracks=True): | |
assert isinstance(song_data, dict), f"Invalid song data type: {type(song_data)}" | |
# Clone the song data. | |
song_data = copy.deepcopy(song_data) | |
# Sort the tracks by instrument. | |
assert "tracks" in song_data, f"Invalid song data: {song_data.keys()}" | |
tracks = sorted(song_data["tracks"], key=lambda t: t["instrument"]) | |
song_data["tracks"] = tracks | |
# Remove tracks that are not enabled. | |
if remove_disabled_tracks: | |
song_data["tracks"] = [t for t in song_data["tracks"] if t.get("enabled", True)] | |
# Create an empy note sequence. | |
note_sequence = note_seq.protobuf.music_pb2.NoteSequence() | |
# Add the tempo. | |
bpm = song_data["bpm"] if "bpm" in song_data else 120 | |
note_sequence.tempos.add().qpm = bpm | |
# Compute some lengths. | |
step_length_seconds = 60.0 / bpm / quantize_steps_per_quarter | |
bar_length_seconds = 4 * step_length_seconds * quantize_steps_per_quarter | |
# Get the instruments. | |
instruments = list(set([t["instrument"] for t in song_data["tracks"]])) | |
# Add the tracks. | |
for track_index, track_data in enumerate(song_data["tracks"]): | |
instrument = track_data["instrument"] | |
for bar_index, bar_data in enumerate(track_data["bars"]): | |
bar_start_time = bar_index * bar_length_seconds | |
for note_data in bar_data["notes"]: | |
assert "note" in note_data | |
assert "start" in note_data | |
assert "end" in note_data | |
note = note_sequence.notes.add() | |
#note.instrument = instrument TODO | |
note.pitch = note_data["note"] | |
note.start_time = note_data["start"] * step_length_seconds + bar_start_time | |
note.end_time = note_data["end"] * step_length_seconds + bar_start_time | |
if "velocity" in note_data: | |
note.velocity = note_data["velocity"] | |
else: | |
note.velocity = 80 | |
note.instrument = track_index | |
if instrument == "drums": | |
note.is_drum = True | |
else: | |
note.is_drum = False | |
note.program = int(instrument) | |
return note_sequence | |
def convert_songdata_to_pianoroll(song_data): | |
# The bars are 4/4 and the quantization is 8 steps per quarter, aka 32 steps per bar. | |
# We will render a grid. The height is 64 pixels. The width is 32 pixels per bar | |
# Create a new image. | |
lengths = [len(track["bars"]) for track in song_data["tracks"]] | |
if lengths == []: | |
return None | |
assert len(set(lengths)) == 1, f"Unequal number of bars: {lengths}" | |
num_bars = lengths[0] | |
# Get the note extremes. | |
min_note = 128 | |
max_note = 0 | |
for track_data in song_data["tracks"]: | |
for bar_data in track_data["bars"]: | |
for note_data in bar_data["notes"]: | |
min_note = min(min_note, note_data["note"]) | |
max_note = max(max_note, note_data["note"]) | |
# The width depends on the bars. | |
width = 32 * num_bars | |
# The width depends on the notes. | |
height = 1 + max_note - min_note | |
# Create the image. | |
image = Image.new("RGB", (width, height), (14, 17, 23)) | |
# Define some colors. | |
base_color = (255, 75, 75) | |
adjustments = [1.2, 1.0, 0.8, 0.6] | |
colors = [] | |
for adjustment in adjustments: | |
import colorsys | |
rgb = base_color | |
rgb = [float(c) / 255.0 for c in rgb] | |
hsv = colorsys.rgb_to_hsv(*rgb) | |
# Rotate the hue. | |
offset = (adjustment - 1.0) * 0.1 | |
hsv = (hsv[0] + offset, hsv[1], hsv[2]) | |
rgb = colorsys.hsv_to_rgb(*hsv) | |
rgb = tuple([int(255.0 * c) for c in rgb]) | |
colors += [rgb] | |
print("") | |
for color in colors: | |
print(color) | |
# Draw the grid. | |
for track_index, track_data in enumerate(song_data["tracks"]): | |
color = colors[track_index % len(colors)] | |
for bar_index, bar_data in enumerate(track_data["bars"]): | |
x = bar_index * 32 | |
for note_data in bar_data["notes"]: | |
y = max_note - note_data["note"] | |
assert y >= 0 and y < height, f"Invalid y: {y}, note {note_data['note']} min_note: {min_note}, max_note: {max_note}, difference: {max_note - min_note}, height: {height}" | |
for i in range(note_data["start"], note_data["end"]): | |
image.putpixel((x + i, y), color) | |
# Resize the image. Use nearest neighbor for pixel art. | |
factor = 4 | |
image = image.resize((width * factor, height * factor), Image.NEAREST) | |
return image | |
def convert_notesequence_to_wave(note_sequence): | |
if len(note_sequence.notes) == 0: | |
return None | |
try: | |
synthesizer = note_seq.fluidsynth | |
wave = synthesizer(note_sequence, sample_rate=44100) | |
return wave | |
except Exception as e: | |
synthesizer = note_seq.synthesize | |
wave = synthesizer(note_sequence) | |
return wave | |
def convert_notesequence_to_midi(note_sequence, filename="output.mid"): | |
if len(note_sequence.notes) == 0: | |
return None | |
# Returns the file content of the midi file. | |
with tempfile.NamedTemporaryFile(delete=False) as temp_file: | |
filename = temp_file.name | |
note_seq.sequence_proto_to_midi_file(note_sequence, filename) | |
with open(filename, "rb") as file: | |
content = file.read() | |
return content | |