riffusion-demo / app.py
anzorq's picture
Update app.py
acf85f1
raw
history blame
No virus
1.73 kB
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import torch
import gradio as gr
import os
import numpy as np
from scipy.io.wavfile import read
os.system('git clone https://github.com/hmartiro/riffusion-inference.git riffusion')
repo_id = "riffusion/riffusion-model-v1"
pipe = DiffusionPipeline.from_pretrained(repo_id)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
if torch.cuda.is_available():
pipe = pipe.to("cuda")
def infer(prompt, steps):
from riffusion.riffusion import audio
mel_spectr = pipe(prompt, num_inference_steps=steps).images[0]
wav_bytes, duration_s = audio.wav_bytes_from_spectrogram_image(mel_spectr)
return read(wav_bytes)
with gr.Blocks() as app:
gr.Markdown("## Riffusion Demo")
gr.Markdown("Generate audio clips from text prompts using the [Riffusion model](https://huggingface.co./riffusion/riffusion-model-v1).")
with gr.Group(elem_id='gr-interface'):
prompt = gr.Textbox(lines=1, label="Prompt")
steps = gr.Slider(minimum=1, maximum=100, value=25, label="Steps")
audio = gr.Audio(label="Audio")
btn_generate = gr.Button(value="Generate")
inputs = [prompt, steps]
outputs = [audio]
prompt.submit(infer, inputs, outputs)
btn_generate.click(infer, inputs, outputs)
examples = gr.Examples(
examples=[["rap battle freestyle"], ["techno club banger"], ["acoustic folk ballad"], ["blues guitar riff"], ["jazzy trumpet solo"], ["classical symphony orchestra"], ["rock and roll power chord"], ["soulful R&B love song"], ["reggae dub beat"], ["country western twangy guitar"], ["all 25 steps"]],
inputs=[prompt])
app.launch(debug=True, share=True)