File size: 14,006 Bytes
3df612c
60a79e9
3df612c
60a79e9
 
 
3df612c
8ef6a7b
 
60a79e9
484ad84
 
60a79e9
484ad84
 
3df612c
 
 
 
 
 
 
60a79e9
 
3df612c
 
 
 
 
 
 
 
 
894f870
3df612c
 
 
 
 
 
60a79e9
3df612c
a83e2a0
3df612c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b9a581
 
 
 
60a79e9
3225df0
 
 
60a79e9
3225df0
60a79e9
 
3df612c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dac4eb
3df612c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3225df0
 
3df612c
2b9a581
 
 
 
3df612c
3225df0
 
3df612c
3225df0
3df612c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dac4eb
3df612c
1e1529d
3df612c
1e1529d
3df612c
1e1529d
3df612c
 
 
 
 
3225df0
 
 
 
 
3df612c
60a79e9
8ef6a7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9028f64
3df612c
 
d2a8042
9028f64
51bcba3
3df612c
 
 
 
0a78870
 
3df612c
 
 
 
0a78870
 
3df612c
9028f64
d2a8042
3df612c
0a78870
60a79e9
1e1529d
c0884a0
 
8ef6a7b
 
3dac4eb
1e1529d
60a79e9
3df612c
 
 
 
60a79e9
8ef6a7b
 
3df612c
3dac4eb
3df612c
 
 
 
894f870
3225df0
3df612c
 
 
 
 
 
 
eca366c
3dac4eb
894f870
9028f64
 
 
 
 
 
 
 
 
 
 
60a79e9
eca366c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
from diffusers import StableDiffusionImg2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline
import torch
from PIL import Image, ImageDraw
import os
import numpy as np
from scipy.io.wavfile import read

from share_btn import community_icon_html, loading_icon_html, share_js


os.system('pip install gradio==3.15.0')
import gradio as gr


os.system('git clone https://github.com/hmartiro/riffusion-inference.git riffusion')
from riffusion.riffusion.riffusion_pipeline import RiffusionPipeline
from riffusion.riffusion.datatypes import PromptInput, InferenceInput
from riffusion.riffusion.audio import wav_bytes_from_spectrogram_image
from PIL import Image
import struct
import random

repo_id = "riffusion/riffusion-model-v1"

model = RiffusionPipeline.from_pretrained(
      repo_id,
      revision="main",
      torch_dtype=torch.float16,
      safety_checker=lambda images, **kwargs: (images, False),
  )

if torch.cuda.is_available():
  model.to("cuda")
  model.enable_xformers_memory_efficient_attention()

pipe_inpaint = StableDiffusionInpaintPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, safety_checker=lambda images, **kwargs: (images, False),)
pipe_inpaint.scheduler = DPMSolverMultistepScheduler.from_config(pipe_inpaint.scheduler.config)

# pipe_inpaint.enable_xformers_memory_efficient_attention()

if torch.cuda.is_available():
    pipe_inpaint = pipe_inpaint.to("cuda")
    pipe_inpaint.enable_xformers_memory_efficient_attention()


def get_init_image(image, overlap, feel):

    width, height = image.size
    init_image = Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB")
    # Crop the right side of the original image with `overlap_width`
    cropped_img = image.crop((width - int(width*overlap), 0, width, height))
    init_image.paste(cropped_img, (0, 0))

    return init_image

def get_mask(image, overlap):

    width, height = image.size

    mask = Image.new("RGB", (width, height), color="white")
    draw = ImageDraw.Draw(mask)
    draw.rectangle((0, 0, int(overlap * width), height), fill="black")
    return mask

def i2i(prompt, steps, feel, seed):
#   return pipe_i2i(
#       prompt,
#       num_inference_steps=steps,
#       image=Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB"),
#       ).images[0]

    prompt_input_start = PromptInput(prompt=prompt, seed=seed)
    prompt_input_end = PromptInput(prompt=prompt, seed=seed)

    return model.riffuse(
        inputs=InferenceInput(
            start=prompt_input_start,
            end=prompt_input_end,
            alpha=1.0,
            num_inference_steps=steps),
        init_image=Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB")
    )

def outpaint(prompt, init_image, mask, steps):
  return pipe_inpaint(
      prompt,
      num_inference_steps=steps,
      image=init_image,
      mask_image=mask,
      ).images[0]


def generate(prompt, steps, num_iterations, feel, seed):

    if seed == 0:
        seed = random.randint(0,4294967295)

    num_images = num_iterations
    overlap = 0.5
    image_width, image_height = 512, 512  # dimensions of each output image
    total_width = num_images * image_width - (num_images - 1) * int(overlap * image_width)  # total width of the stitched image

    # Create a blank image with the desired dimensions
    stitched_image = Image.new("RGB", (total_width, image_height), color="white")

    # Initialize the x position for pasting the next image
    x_pos = 0

    image = i2i(prompt, steps, feel, seed)

    for i in range(num_images):
        # Generate the prompt, initial image, and mask for this iteration
        init_image = get_init_image(image, overlap, feel)
        mask = get_mask(init_image, overlap)
        
        # Run the outpaint function to generate the output image
        steps = 25
        image = outpaint(prompt, init_image, mask, steps)

        # Paste the output image onto the stitched image
        stitched_image.paste(image, (x_pos, 0))
        
        # Update the x position for the next iteration
        x_pos += int((1 - overlap) * image_width)

    wav_bytes, duration_s = wav_bytes_from_spectrogram_image(stitched_image)

    # mask = Image.new("RGB", (512, 512), color="white")
    # bg_image = outpaint(prompt, init_image, mask, steps)
    # bg_image.save("bg_image.png")
    init_image.save("bg_image.png")

    # return read(wav_bytes)
    with open("output.wav", "wb") as f:
        f.write(wav_bytes.read())

    return gr.make_waveform("output.wav", bg_image="bg_image.png", bar_count=int(duration_s*25))


###############################################

def riffuse(steps, feel, init_image, prompt_start, seed_start, denoising_start=0.75, guidance_start=7.0, prompt_end=None, seed_end=None, denoising_end=0.75, guidance_end=7.0, alpha=0.5):

  prompt_input_start = PromptInput(prompt=prompt_start, seed=seed_start, denoising=denoising_start, guidance=guidance_start)
    
  prompt_input_end = PromptInput(prompt=prompt_end, seed=seed_end, denoising=denoising_end, guidance=guidance_end)

  input = InferenceInput(
      start=prompt_input_start,
      end=prompt_input_end,
      alpha=alpha,
      num_inference_steps=steps,
      seed_image_id=feel,
      # mask_image_id="mask_beat_lines_80.png"
  )

  image = model.riffuse(inputs=input, init_image=init_image)

  wav_bytes, duration_s = wav_bytes_from_spectrogram_image(image)

  return wav_bytes, image

def generate_riffuse(prompt_start, steps, num_iterations, feel, prompt_end=None, seed_start=None, seed_end=None, denoising_start=0.75, denoising_end=0.75, guidance_start=7.0, guidance_end=7.0):
    """Generate a WAV file of length seconds using the Riffusion model.

    Args:
        length (int): Length of the WAV file in seconds, must be divisible by 5.
        prompt_start (str): Prompt to start with.
        prompt_end (str, optional): Prompt to end with. Defaults to prompt_start.
        overlap (float, optional): Overlap between audio clips as a fraction of the image size. Defaults to 0.2.
        """

    # open the initial image and convert it to RGB
    init_image = Image.open(f"riffusion/seed_images/{feel}.png").convert("RGB")

    if prompt_end is None:
      prompt_end = prompt_start
    if seed_start == 0:
      seed_start = random.randint(0,4294967295)
    if seed_end is None:
      seed_end = seed_start

    # one riffuse() generates 5 seconds of audio
    wav_list = []

    for i in range(int(num_iterations)):

        alpha = i / (num_iterations - 1)
        print(alpha)
        wav_bytes, image = riffuse(steps, feel, init_image, prompt_start, seed_start, denoising_start, guidance_start, prompt_end, seed_end, denoising_end, guidance_end, alpha=alpha)
        wav_list.append(wav_bytes)

        init_image = image

        seed_start = seed_end
        seed_end = seed_start + 1

    # return read(wav_bytes)
    # return wav_list_to_wav(wav_list)

    # mask = Image.new("RGB", (512, 512), color="white")
    # bg_image = outpaint(f"{prompt_start} and {prompt_end}", init_image, mask, steps)
    # bg_image.save("bg_image.png")
    init_image.save("bg_image.png")

    with open("output.wav", "wb") as f:
        f.write(wav_list_to_wav(wav_list))

    return gr.make_waveform("output.wav", bg_image="bg_image.png")


def wav_list_to_wav(wav_list):

  # remove headers from the WAV files
  data = [wav.read()[44:] for wav in wav_list]

  # concatenate the data
  concatenated_data = b"".join(data)

  # create a new RIFF header
  channels = 1
  sample_rate = 44100
  bytes_per_second = channels * sample_rate
  new_header = struct.pack("<4sI4s4sIHHIIHH4sI", b"RIFF", len(concatenated_data) + 44 - 8, b"WAVE", b"fmt ", 16, 1, channels, sample_rate, bytes_per_second, 2, 16, b"data", len(concatenated_data))

  # combine the header and data to create the final WAV file
  final_wav = new_header + concatenated_data
  return final_wav

###############################################

def on_submit(prompt_1, prompt_2, feel, num_iterations, steps=25, seed=0):
    if prompt_1 == "":
        return None, gr.update(value="First prompt is required."), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
    if prompt_2 == "":
        return generate(prompt_1, steps, num_iterations, feel, seed), None, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
    else:
        return generate_riffuse(prompt_1, steps, num_iterations, feel, prompt_end=prompt_2, seed_start=seed), None, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)


def on_num_iterations_change(n, prompt_2):
    if n is None:
        return gr.update(value="")

    if prompt_2 != "":
        total_length = 5 * n
    else:
        total_length = 2.5 + 2.5 * n
    return gr.update(value=f"Total length: {total_length:.2f} seconds")


css = '''
    #share-btn-container {
        display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
    }
    #share-btn {
        all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
    }
    #share-btn * {
        all: unset;
    }
    #share-btn-container div:nth-child(-n+2){
        width: auto !important;
        min-height: 0px !important;
    }
    #share-btn-container .wrap {
        display: none !important;
    }
'''

with gr.Blocks(css=css) as app:
    gr.Markdown("## Riffusion Demo")
    gr.Markdown("""Generate audio using the [Riffusion](https://huggingface.co./riffusion/riffusion-model-v1) model.<br>
                In single prompt mode you can generate up to ~1 minute of audio with smooth transitions between sections. (beta)<br>
                Bi-prompt mode interpolates between two prompts. It can generate up to ~2 minutes of audio, but transitions between sections are more abrupt.""")
    gr.Markdown(f"""Running on {"**GPU 🔥**" if torch.cuda.is_available() else f"**CPU 🥶**. For faster inference it is recommended to **upgrade to GPU in space's Settings**"}<br>
                [![Duplicate Space](https://bit.ly/3gLdBN6)](https://huggingface.co./spaces/anzorq/riffusion-demo?duplicate=true)""")
    
    with gr.Row():
        with gr.Group():
            with gr.Row():
                prompt_1 = gr.Textbox(lines=1, label="Start from", placeholder="Starting prompt", elem_id="riff-prompt_1")
                prompt_2 = gr.Textbox(lines=1, label="End with (optional)", placeholder="Prompt to shift towards at the end", elem_id="riff-prompt_2")
            with gr.Row():
                steps = gr.Slider(minimum=1, maximum=100, value=25, label="Steps per section")
                num_iterations = gr.Slider(minimum=2, maximum=25, value=2, step=1, label="Number of sections")
            with gr.Row():
                feel = gr.Dropdown(["og_beat", "agile", "vibes", "motorway", "marim"], value="og_beat", label="Feel", elem_id="riff-feel")
                seed = gr.Slider(minimum=0, maximum=4294967295, value=0, step=1, label="Seed (0 for random)", elem_id="riff-seed")

            btn_generate = gr.Button(value="Generate").style(full_width=True)
            info = gr.Markdown()
        with gr.Column():
            video = gr.Video(elem_id="riff-video")

            with gr.Group(elem_id="share-btn-container"):
                community_icon = gr.HTML(community_icon_html, elem_id="share-btn-share-icon", visible=False)
                loading_icon = gr.HTML(loading_icon_html, elem_id="share-btn-loading-icon", visible=False)
                share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)

    inputs = [prompt_1, prompt_2, feel, num_iterations, steps, seed]
    outputs = [video, info, community_icon, loading_icon, share_button]

    num_iterations.change(on_num_iterations_change, [num_iterations, prompt_2], [info])
    prompt_1.submit(on_submit, inputs, outputs)
    prompt_2.submit(on_submit, inputs, outputs)
    btn_generate.click(on_submit, inputs, outputs)

    share_button.click(None, [], [], _js=share_js)

    examples = gr.Examples(
        fn=on_submit,
        examples=[
            ["typing", "dance beat", "og_beat", 10],
            ["synthwave", "jazz", "agile", 10],
            ["rap battle freestyle", "", "og_beat", 10],
            # ["techno club banger", "", "og_beat", 10],
            ["reggae dub beat", "sunset chill", "og_beat", 10],
            ["acoustic folk ballad", "", "agile", 10],
            ["blues guitar riff", "", "agile", 5],
            ["jazzy trumpet solo", "", "og_beat", 5],
            ["classical symphony orchestra", "", "vibes", 10],
            ["rock and roll power chord", "", "motorway", 5],
            ["soulful R&B love song", "", "marim", 10],
            ["country western twangy guitar", "", "agile", 10]],
        inputs=[prompt_1, prompt_2, feel, num_iterations],
        outputs=outputs,
        cache_examples=True)
    
    gr.HTML("""
    <div style="border-top: 1px solid #303030;">
      <br>
      <p>Space by:<br>
      <a href="https://twitter.com/hahahahohohe"><img src="https://img.shields.io/twitter/follow/hahahahohohe?label=%40anzorq&style=social" alt="Twitter Follow"></a><br>
      <a href="https://github.com/qunash"><img alt="GitHub followers" src="https://img.shields.io/github/followers/qunash?style=social" alt="Github Follow"></a></p><br>
      <a href="https://www.buymeacoffee.com/anzorq" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 24px !important;width: 81px !important;" ></a><br><br>
      <p><img src="https://visitor-badge.glitch.me/badge?page_id=anzorq.riffusion-demo" alt="visitors"></p>
    </div>
    """)

app.queue(max_size=250, concurrency_count=6).launch()