File size: 7,653 Bytes
a89c362
 
 
629d1bf
 
 
 
 
 
 
 
 
 
 
 
 
a89c362
629d1bf
 
 
 
 
 
a89c362
629d1bf
 
 
 
 
 
 
 
 
a89c362
629d1bf
101c1cd
629d1bf
 
 
 
 
 
 
 
 
a89c362
629d1bf
 
 
d1abaff
629d1bf
 
 
 
 
 
d1abaff
629d1bf
 
 
 
101c1cd
629d1bf
d1abaff
629d1bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101c1cd
629d1bf
101c1cd
629d1bf
 
 
 
 
101c1cd
629d1bf
101c1cd
 
 
629d1bf
 
101c1cd
629d1bf
 
 
 
 
 
 
 
 
 
 
 
d1abaff
629d1bf
101c1cd
629d1bf
 
 
 
 
101c1cd
629d1bf
101c1cd
 
 
629d1bf
 
 
 
 
 
 
 
 
 
 
 
d1abaff
629d1bf
 
 
 
 
 
 
 
 
 
 
a89c362
d1abaff
629d1bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

import os
import json
import numpy as np
import torch
import soundfile as sf
import gradio as gr
from diffusers import DDPMScheduler
from pico_model import PicoDiffusion
from audioldm.variational_autoencoder.autoencoder import AutoencoderKL
from llm_preprocess import get_event, preprocess_gemini, preprocess_gpt
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

class InferRunner:
    def __init__(self, device):
        vae_config = json.load(open("ckpts/ldm/vae_config.json"))
        self.vae = AutoencoderKL(**vae_config).to(device)
        vae_weights = torch.load("ckpts/ldm/pytorch_model_vae.bin", map_location=device)
        self.vae.load_state_dict(vae_weights)

        train_args = dotdict(json.loads(open("ckpts/pico_model/summary.jsonl").readlines()[0]))
        self.pico_model = PicoDiffusion(
            scheduler_name=train_args.scheduler_name, 
            unet_model_config_path=train_args.unet_model_config, 
            snr_gamma=train_args.snr_gamma,
            freeze_text_encoder_ckpt="ckpts/laion_clap/630k-audioset-best.pt",
            diffusion_pt="ckpts/pico_model/diffusion.pt",
        ).eval().to(device)
        self.scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")

device = "cuda" if torch.cuda.is_available() else "cpu"
runner = InferRunner(device)
event_list = get_event()
def infer(caption, num_steps=200, guidance_scale=3.0, audio_len=16000*10):
    with torch.no_grad():
        latents = runner.pico_model.demo_inference(caption, runner.scheduler, num_steps=num_steps, guidance_scale=guidance_scale, num_samples_per_prompt=1, disable_progress=True)
        mel = runner.vae.decode_first_stage(latents)
        wave = runner.vae.decode_to_waveform(mel)[0][:audio_len]
    outpath = f"output.wav"
    sf.write(outpath, wave, samplerate=16000, subtype='PCM_16')
    return outpath

def preprocess(caption):
    output = preprocess_gemini(caption)
    return output, output

def update_textbox(event_name, current_text):
    event = event_name + ' two times.'
    if current_text:
        return current_text.strip('.') + ' then ' + event
    else:
        return event

with gr.Blocks() as demo:
    with gr.Row():
        gr.Markdown("## PicoAudio")
    with gr.Row():
        description_text = f"18 events supported:"
        gr.Markdown(description_text)

    
    btn_event = []
    with gr.Row():
        for i in range(6):
            event_name = f"{event_list[i]}"
            btn_event.append(gr.Button(event_name))
    with gr.Row():
        for i in range(6, 12):
            event_name = f"{event_list[i]}"
            btn_event.append(gr.Button(event_name))
    with gr.Row():
        for i in range(12, 18):
            event_name = f"{event_list[i]}"
            btn_event.append(gr.Button(event_name))
            
        
    with gr.Row():
        gr.Markdown("## Step1-Preprocess")
    with gr.Row():
        preprocess_description_text = f"Transfer free-text into timestamp caption via LLM. "+\
            "This demo uses Gemini as the preprocessor. If any errors occur, please try a few more times. "+\
                "We also provide the GPT version consistent with the paper in the file 'Files/llm_reprocessing.py'. You can use your own api_key to modify and run 'Files/inference.py' for local inference."
        gr.Markdown(preprocess_description_text)
    with gr.Row():
        with gr.Column():
            freetext_prompt = gr.Textbox(label="Free-text Prompt: Input your free-text caption here. (e.g. a dog barks three times.)",
                value="a dog barks three times.",)
            with gr.Row():
                preprocess_run_button = gr.Button()
                preprocess_run_clear = gr.ClearButton([freetext_prompt])
            prompt = None
        with gr.Column():
            freetext_prompt_out = gr.Textbox(label="Timestamp Caption: Preprocess output")
    with gr.Row():
        with gr.Column():
            gr.Examples(
                        examples = [["spraying two times then gunshot three times."],
                                    ["a dog barks three times."],
                                    ["cow mooing two times."],],
                        inputs = [freetext_prompt],
                        outputs = [prompt]
                        )
        with gr.Column():
            pass
    

    with gr.Row():
        gr.Markdown("## Step2-Generate")
    with gr.Row():
        generate_description_text = f"Generate audio based on timestamp caption."
        gr.Markdown(generate_description_text)
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Timestamp Caption: Specify your timestamp caption formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'.",
                value="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",)
            with gr.Row():
                generate_run_button = gr.Button()
                generate_run_clear = gr.ClearButton([prompt])
            with gr.Accordion("Advanced options", open=False):
                num_steps = gr.Slider(label="num_steps", minimum=1, maximum=300, value=200, step=1)
                guidance_scale = gr.Slider(label="guidance_scale", minimum=0.1, maximum=8.0, value=3.0, step=0.1)    
        with gr.Column():
            outaudio = gr.Audio()
            
    for i in range(18):
        event_name = f"{event_list[i]}"
        btn_event[i].click(fn=update_textbox, inputs=[gr.State(event_name), freetext_prompt], outputs=freetext_prompt)
    preprocess_run_button.click(fn=preprocess, inputs=[freetext_prompt], outputs=[prompt, freetext_prompt_out])
    generate_run_button.click(fn=infer, inputs=[prompt, num_steps, guidance_scale], outputs=[outaudio])
    

    with gr.Row():
        with gr.Column():
            gr.Examples(
                        examples = [["spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031."],
                                    ["dog_barking at 0.562-2.562_4.25-6.25."],
                                    ["cow_mooing at 0.958-3.582_5.272-7.896."],],
                        inputs = [prompt, num_steps, guidance_scale],
                        outputs = [outaudio]
                        )
        with gr.Column():
            pass
    

demo.launch()
        
    
# description_text = f"18 events: {', '.join(event_list)}"
# prompt = gr.Textbox(label="Prompt: Input your caption formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'.",
#     value="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",)
# outaudio = gr.Audio()
# num_steps = gr.Slider(label="num_steps", minimum=1, maximum=300, value=200, step=1)
# guidance_scale = gr.Slider(label="guidance_scale", minimum=0.1, maximum=8.0, value=3.0, step=0.1)    
# gr_interface = gr.Interface(
        #     fn=infer,
        #     inputs=[prompt, num_steps, guidance_scale], 
        #     outputs=[outaudio],
        #     title="PicoAudio",
        #     description=description_text,
        #     allow_flagging=False,
        #     examples=[
        #         ["spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031."],
        #         ["dog_barking at 0.562-2.562_4.25-6.25."],
        #         ["cow_mooing at 0.958-3.582_5.272-7.896."],
        #     ],
        #     cache_examples="lazy", # Turn on to cache.
        # )
        # gr_interface.queue(10).launch()