ZeyuXie commited on
Commit
d1abaff
1 Parent(s): 594eb84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -129
app.py CHANGED
@@ -1,140 +1,112 @@
 
 
1
 
 
 
 
 
2
  import os
3
  import json
4
- import numpy as np
5
- import torch
6
- import soundfile as sf
7
- import gradio as gr
8
- from diffusers import DDPMScheduler
9
- from pico_model import PicoDiffusion
10
- from audioldm.variational_autoencoder.autoencoder import AutoencoderKL
11
- from llm_preprocess import get_event, preprocess_gemini, preprocess_gpt
12
- class dotdict(dict):
13
- """dot.notation access to dictionary attributes"""
14
- __getattr__ = dict.get
15
- __setattr__ = dict.__setitem__
16
- __delattr__ = dict.__delitem__
17
-
18
- class InferRunner:
19
- def __init__(self, device):
20
- vae_config = json.load(open("ckpts/ldm/vae_config.json"))
21
- self.vae = AutoencoderKL(**vae_config).to(device)
22
- vae_weights = torch.load("ckpts/ldm/pytorch_model_vae.bin", map_location=device)
23
- self.vae.load_state_dict(vae_weights)
24
 
25
- train_args = dotdict(json.loads(open("ckpts/pico_model/summary.jsonl").readlines()[0]))
26
- self.pico_model = PicoDiffusion(
27
- scheduler_name=train_args.scheduler_name,
28
- unet_model_config_path=train_args.unet_model_config,
29
- snr_gamma=train_args.snr_gamma,
30
- freeze_text_encoder_ckpt="ckpts/laion_clap/630k-audioset-best.pt",
31
- diffusion_pt="ckpts/pico_model/diffusion.pt",
32
- ).eval().to(device)
33
- self.scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")
34
 
35
- device = "cuda" if torch.cuda.is_available() else "cpu"
36
- runner = InferRunner(device)
37
- event_list = get_event()
38
- def infer(caption, num_steps=200, guidance_scale=3.0, audio_len=16000*10):
39
- with torch.no_grad():
40
- latents = runner.pico_model.demo_inference(caption, runner.scheduler, num_steps=num_steps, guidance_scale=guidance_scale, num_samples_per_prompt=1, disable_progress=True)
41
- mel = runner.vae.decode_first_stage(latents)
42
- wave = runner.vae.decode_to_waveform(mel)[0][:audio_len]
43
- outpath = f"output.wav"
44
- sf.write(outpath, wave, samplerate=16000, subtype='PCM_16')
45
- return outpath
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- def preprocess(caption):
48
- output = preprocess_gemini(caption)
49
- return output, output
50
-
51
- with gr.Blocks() as demo:
52
- with gr.Row():
53
- gr.Markdown("## PicoAudio")
54
- with gr.Row():
55
- description_text = f"Support 18 events: {', '.join(event_list)}"
56
- gr.Markdown(description_text)
 
 
 
 
 
 
 
 
 
57
 
58
- with gr.Row():
59
- gr.Markdown("## Step1")
60
- with gr.Row():
61
- preprocess_description_text = f"Preprocess: transfer free-text into timestamp caption via LLM. "+\
62
- "This demo uses Gemini as the preprocessor. If any errors occur, please try a few more times. "+\
63
- "We also provide the GPT version consistent with the paper in the file 'Files/llm_reprocessing.py'. You can use your own api_key to modify and run 'Files/inference.py' for local inference."
64
- gr.Markdown(preprocess_description_text)
65
- with gr.Row():
66
- with gr.Column():
67
- freetext_prompt = gr.Textbox(label="Free-text prompt: Input your free-text caption here. (e.g. a dog barks three times.)",
68
- value="a dog barks three times.",)
69
- preprocess_run_button = gr.Button()
70
- prompt = None
71
- with gr.Column():
72
- freetext_prompt_out = gr.Textbox(label="Timestamp Caption: Preprocess output")
73
- with gr.Row():
74
- with gr.Column():
75
- gr.Examples(
76
- examples = [["spraying two times then gunshot three times."],
77
- ["a dog barks three times."],
78
- ["cow mooing two times."],],
79
- inputs = [freetext_prompt],
80
- outputs = [prompt]
81
- )
82
- with gr.Column():
83
- pass
84
 
85
 
86
- with gr.Row():
87
- gr.Markdown("## Step2")
88
- with gr.Row():
89
- generate_description_text = f"Generate audio based on timestamp caption."
90
- gr.Markdown(generate_description_text)
91
- with gr.Row():
92
- with gr.Column():
93
- prompt = gr.Textbox(label="Timestamp Caption: Input your caption formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'.",
94
- value="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",)
95
- generate_run_button = gr.Button()
96
- with gr.Accordion("Advanced options", open=False):
97
- num_steps = gr.Slider(label="num_steps", minimum=1, maximum=300, value=200, step=1)
98
- guidance_scale = gr.Slider(label="guidance_scale", minimum=0.1, maximum=8.0, value=3.0, step=0.1)
99
- with gr.Column():
100
- outaudio = gr.Audio()
101
- preprocess_run_button.click(fn=preprocess, inputs=[freetext_prompt], outputs=[prompt, freetext_prompt_out])
102
- generate_run_button.click(fn=infer, inputs=[prompt, num_steps, guidance_scale], outputs=[outaudio])
103
-
104
- with gr.Row():
105
- with gr.Column():
106
- gr.Examples(
107
- examples = [["spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031."],
108
- ["dog_barking at 0.562-2.562_4.25-6.25."],
109
- ["cow_mooing at 0.958-3.582_5.272-7.896."],],
110
- inputs = [prompt, num_steps, guidance_scale],
111
- outputs = [outaudio]
112
- )
113
- with gr.Column():
114
- pass
115
-
116
 
117
- demo.launch()
118
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
- # description_text = f"18 events: {', '.join(event_list)}"
121
- # prompt = gr.Textbox(label="Prompt: Input your caption formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'.",
122
- # value="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",)
123
- # outaudio = gr.Audio()
124
- # num_steps = gr.Slider(label="num_steps", minimum=1, maximum=300, value=200, step=1)
125
- # guidance_scale = gr.Slider(label="guidance_scale", minimum=0.1, maximum=8.0, value=3.0, step=0.1)
126
- # gr_interface = gr.Interface(
127
- # fn=infer,
128
- # inputs=[prompt, num_steps, guidance_scale],
129
- # outputs=[outaudio],
130
- # title="PicoAudio",
131
- # description=description_text,
132
- # allow_flagging=False,
133
- # examples=[
134
- # ["spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031."],
135
- # ["dog_barking at 0.562-2.562_4.25-6.25."],
136
- # ["cow_mooing at 0.958-3.582_5.272-7.896."],
137
- # ],
138
- # cache_examples="lazy", # Turn on to cache.
139
- # )
140
- # gr_interface.queue(10).launch()
 
1
+ """
2
+ At the command line, only need to run once to install the package via pip:
3
 
4
+ $ pip install google-generativeai
5
+ """
6
+
7
+ from pathlib import Path
8
  import os
9
  import json
10
+ import re
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ os.environ['HTTP_PROXY'] = 'http://127.0.0.1:58591'
13
+ os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:58591'
 
 
 
 
 
 
 
14
 
15
+ def get_event():
16
+ event_list = [
17
+ "burping_belching", # 0
18
+ "car_horn_honking", #
19
+ "cat_meowing", #
20
+ "cow_mooing", #
21
+ "dog_barking", #
22
+ "door_knocking", #
23
+ "door_slamming", #
24
+ "explosion", #
25
+ "gunshot", # 8
26
+ "sheep_goat_bleating", #
27
+ "sneeze", #
28
+ "spraying", #
29
+ "thump_thud", #
30
+ "train_horn", #
31
+ "tapping_clicking_clanking", #
32
+ "woman_laughing", #
33
+ "duck_quacking", # 16
34
+ "whistling", #
35
+ ]
36
+ return event_list
37
 
38
+ def get_prompt():
39
+
40
+ train_json_list = ["data/train_multi-event_v3.json",
41
+ f"data/train_single-event_multi_v3.json",
42
+ f"data/train_single-event_single_v3.json"]
43
+ learn_pair = ""
44
+ for train_json in train_json_list:
45
+ with open(train_json, 'r') as train_file:
46
+ for idx, line in enumerate(train_file):
47
+ if idx >= 100: break
48
+ data = json.loads(line.strip())
49
+ learn_pair += f"{str(idx)}:{data['captions']}~{data['onset']}. "
50
+ preffix_prompt = "I'm doing an audio event generation, which is a harmless job that will contain some sound events. For example, a gunshot is a sound that is harmless." +\
51
+ "You need to convert the input sentence into the following standard timing format: 'event1--event2-- ... --eventN', " +\
52
+ "where the 'eventN' format is 'eventN__onset1-offset1_onset2-offset2_ ... _onsetK-offsetK'. " +\
53
+ "The 'onset-offset' inside needs to be determined based on common sense and the examples I provide, with a duration not less than 1 and not greater than 4. All format 'onsetk-offsetk' should replaced by number. " +\
54
+ "The very strict constraints are that the total duration is less than 10 seconds, meaning all times are less than 10. It is preferred that events do not overlap as much as possible. " +\
55
+ "Now, I will provide you with 300 examples in training set for your learning, each example in the format 'index: input~output'. " +\
56
+ learn_pair
57
 
58
+ print(len(preffix_prompt))
59
+ return preffix_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
 
62
+ def postprocess(caption):
63
+ caption = caption.strip('\n').strip(' ').strip('.')
64
+ caption = caption.replace('__', ' at ').replace('--', ' and ')
65
+ return caption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ def preprocess_gemini(free_text_caption):
68
+ preffix_prompt = get_prompt()
69
+ import google.generativeai as genai
70
+ genai.configure(api_key="AIzaSyDfGKPQtS9qExCfl3bnfxC1rLPzvORz3E4")
71
+ print(free_text_caption)
72
+ # Set up the model
73
+ generation_config = {
74
+ "temperature": 1,
75
+ "top_p": 0.95,
76
+ "top_k": 64,
77
+ "max_output_tokens": 8192,
78
+ }
79
+
80
+ model = genai.GenerativeModel(model_name="gemini-1.5-flash",
81
+ generation_config=generation_config,)
82
+
83
+ prompt_parts = [
84
+ preffix_prompt +\
85
+ f"Please convert the following inputs into the standard timing format:{free_text_caption}. You should only output results in the standard timing format. Do not output anything other than format and do not add symbols.",
86
+ ]
87
+
88
+ timestampCaption = model.generate_content(prompt_parts).text
89
+ print(timestampCaption)
90
+ return postprocess(timestampCaption)
91
+
92
+ def preprocess_gpt(free_text_caption):
93
+ preffix_prompt = get_prompt()
94
+ from openai import OpenAI
95
+ client = OpenAI(api_key="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")
96
+ completion_start = client.chat.completions.create(
97
+ model="gpt-4-1106-preview",
98
+ messages=[{
99
+ "role": "user",
100
+ "content":
101
+ preffix_prompt +\
102
+ f"Please convert the following inputs into the standard timing format:{free_text_caption}. You should only output results in the standard timing format. Do not output anything other than format and do not add symbols."
103
+ }]
104
+ )
105
+
106
+ timestampCaption = completion_start.choices[0].message.content
107
 
108
+ return postprocess(timestampCaption)
109
+
110
+ if __name__=="__main__":
111
+ caption = preprocess_gemini("spraying two times then gunshot three times.")
112
+ print(caption)