Uhhy's picture
Update app.py
d1805ac verified
from queue import Queue
from threading import Thread
from typing import Optional
import numpy as np
import torch
from transformers import MusicgenForConditionalGeneration, MusicgenProcessor, set_seed
from transformers.generation.streamers import BaseStreamer
import gradio as gr
import io
global model
model = None
def load_model():
global model
if model is None:
model_bytes = io.BytesIO()
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
model.save_pretrained(model_bytes)
model = MusicgenForConditionalGeneration.from_pretrained(model_bytes)
processor_bytes = io.BytesIO()
processor = MusicgenProcessor.from_pretrained("facebook/musicgen-small")
# Get the vocabulary from the tokenizer and write it directly to the BytesIO object
vocabulary = processor.tokenizer.get_vocab()
with io.open(processor_bytes, "w", encoding="utf-8") as f:
for word, index in vocabulary.items():
f.write(f"{word} {index}\n")
processor_bytes.seek(0) # Reset the stream position
processor.feature_extractor.save_pretrained(processor_bytes)
processor = MusicgenProcessor.from_pretrained(processor_bytes)
title = "MusicGen Streaming"
class MusicgenStreamer(BaseStreamer):
def __init__(
self,
model: MusicgenForConditionalGeneration,
play_steps: Optional[int] = 10,
stride: Optional[int] = None,
timeout: Optional[float] = None,
):
self.decoder = model.decoder
self.audio_encoder = model.audio_encoder
self.generation_config = model.generation_config
self.play_steps = play_steps
if stride is not None:
self.stride = stride
else:
hop_length = np.prod(self.audio_encoder.config.upsampling_ratios)
self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6
self.token_cache = None
self.to_yield = 0
self.audio_queue = Queue()
self.stop_signal = None
self.timeout = timeout
def apply_delay_pattern_mask(self, input_ids):
_, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
input_ids[:, :1],
pad_token_id=self.generation_config.decoder_start_token_id,
max_length=input_ids.shape[-1],
)
input_ids = self.decoder.apply_delay_pattern_mask(input_ids, decoder_delay_pattern_mask)
input_ids = input_ids[input_ids != self.generation_config.pad_token_id].reshape(
1, self.decoder.num_codebooks, -1
)
input_ids = input_ids[None, ...]
input_ids = input_ids.to(self.audio_encoder.device)
output_values = self.audio_encoder.decode(
input_ids,
audio_scales=[None],
)
audio_values = output_values.audio_values[0, 0]
return audio_values.cpu().float().numpy()
def put(self, value):
batch_size = value.shape[0] // self.decoder.num_codebooks
if batch_size > 1:
raise ValueError("MusicgenStreamer only supports batch size 1")
if self.token_cache is None:
self.token_cache = value
else:
self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)
if self.token_cache.shape[-1] % self.play_steps == 0:
audio_values = self.apply_delay_pattern_mask(self.token_cache)
self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
self.to_yield += len(audio_values) - self.to_yield - self.stride
def end(self):
if self.token_cache is not None:
audio_values = self.apply_delay_pattern_mask(self.token_cache)
else:
audio_values = np.zeros(self.to_yield)
self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True)
def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
self.audio_queue.put(audio, timeout=self.timeout)
if stream_end:
self.audio_queue.put(self.stop_signal, timeout=self.timeout)
def __iter__(self):
return self
def __next__(self):
value = self.audio_queue.get(timeout=self.timeout)
if not isinstance(value, np.ndarray) and value == self.stop_signal:
raise StopIteration()
else:
return value
sampling_rate = model.audio_encoder.config.sampling_rate
frame_rate = model.audio_encoder.config.frame_rate
def generate_audio(text_prompt, audio_length_in_s=10.0, play_steps_in_s=2.0, seed=0):
max_new_tokens = int(frame_rate * audio_length_in_s)
play_steps = int(frame_rate * play_steps_in_s)
inputs = processor(
text=text_prompt,
padding=True,
return_tensors="pt",
)
streamer = MusicgenStreamer(model, play_steps=play_steps)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
set_seed(seed)
try:
for new_audio in streamer:
yield sampling_rate, new_audio
except Exception as e:
print(f"Error during generation: {e}")
yield sampling_rate, np.zeros(sampling_rate)
demo = gr.Interface(
fn=generate_audio,
inputs=[
gr.Text(label="Prompt", value="80s pop track with synth and instrumentals"),
gr.Slider(10, 600, value=15, step=5, label="Audio length in seconds"),
gr.Slider(0.5, 2.5, value=1.5, step=0.5, label="Streaming interval in seconds", info="Lower = shorter chunks, lower latency, more codec steps"),
gr.Slider(0, 10, value=5, step=1, label="Seed for random generations"),
],
outputs=[
gr.Audio(label="Generated Music", streaming=True, autoplay=True)
],
title=title,
cache_examples=False,
)
load_model()
demo.queue(concurrency_count=5).launch()