Spaces:
Runtime error
Runtime error
import gradio as gr | |
import json, urllib | |
from model import GPT, GPTConfig | |
from utils import sample | |
import torch | |
import pickle | |
device = torch.device('cpu') | |
# Create the model | |
vocab_size=147 | |
block_size=128 | |
mconf = GPTConfig(vocab_size, block_size, | |
n_layer=6, n_head=8, n_embd=256) | |
model = GPT(mconf) | |
# Load checkpoint | |
model.load_state_dict(torch.load('another_epoch_1.75total.ckpt', map_location=device)) | |
# Vocab | |
stoi = pickle.load(open('stoi.pkl', 'rb')) | |
itos = pickle.load(open('itos.pkl', 'rb')) | |
# Post-process generation | |
# Completion | |
def completion_to_song(c): | |
lines = c.split('\n') | |
kept_lines = [] | |
notes = False | |
for l in lines: | |
# Record if we've hit music | |
if '|' in l: | |
notes = True | |
# Stop if we then go back to the start of another song | |
if 'T' in l and notes: | |
break | |
if 'T' in l and notes: | |
break | |
# Stop on an empty line | |
if len(l.strip()) < 2 and notes: | |
break | |
# Otherwise keep the line | |
kept_lines.append(l) | |
return '\n'.join(kept_lines) | |
# Generate function | |
def generate_song(randomize, title, nu, ks, key): | |
# Start sequence | |
context = b"""T:""" | |
if not randomize: | |
context += bytes(title+'\n', 'utf-8') | |
context += bytes('M:'+ks+'\n', 'utf-8') | |
context += bytes('K:'+key+'\n', 'utf-8') | |
context += bytes('L:'+nu+'\n', 'utf-8') | |
# Model inputs | |
x = torch.tensor([stoi[s] for s in context], dtype=torch.long)[None,...].to(device) | |
# Completion | |
y = sample(model, x, 400, temperature=1.0, sample=True, top_k=10)[0] | |
completion = ''.join([chr(itos[int(i)]) for i in y]) | |
# Return the first song | |
song = completion_to_song(completion) | |
html_song = song.replace('\n', '<br>') | |
url_song = urllib.parse.quote(song, safe='~@#$&()*!+=:;,?/\'') | |
html_text = '<p><a href="https://editor.drawthedots.com?t='+url_song+'" target="_blank"><b>EDIT LINK - click to open abcjs editor (allows download and playback)</b></a></p>'+"<p>"+html_song+'</p>' | |
return html_text | |
# Gradio demo | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("Quick demo for [WhistleGen v2](https://wandb.ai/johnowhitaker/whistlegen_v2/reports/WhistleGen-v2--VmlldzoyMTAwNjAz) which lets you generate folk music using a transformer model. I can't get the javascript needed for rendering and playback working with gradio, so this shows the raw ABC notation from the model and a link to view it properly in an external editor.") | |
with gr.Row(): | |
title = gr.Text(label='Title', value='The March of AI') | |
with gr.Column(): | |
nu = gr.Text(label='Note unit', value='1/8') | |
with gr.Row(): | |
key_signature = gr.Dropdown(['3/4', '4/4', '6/8', 'Random'], value='4/4', label='Time Signature') | |
with gr.Column(): | |
key = gr.Text(label='Key', value='D') | |
with gr.Row(): | |
randomize = gr.Checkbox(label='Randomize (ignores settings above)', value=True) | |
with gr.Row(): | |
out = gr.HTML(label="Output", value='Output should appear here (takes ~30s)') | |
btn = gr.Button("Run") | |
btn.click(fn=generate_song, inputs=[randomize, title, nu, key_signature, key], outputs=out) | |
with gr.Row(): | |
gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=whistlegen_v2_space)") | |
gr.Markdown("This is currently using an early model. See the [report](https://wandb.ai/johnowhitaker/whistlegen_v2/reports/WhistleGen-v2--VmlldzoyMTAwNjAz) for training info and updates.") | |
demo.launch(enable_queue=True) |