cocktailpeanut commited on
Commit
2354f08
Β·
1 Parent(s): 187c1c3
Files changed (3) hide show
  1. README.md +10 -9
  2. app.py +249 -0
  3. requirements.txt +22 -0
README.md CHANGED
@@ -1,13 +1,14 @@
1
- ---
2
  title: AudioGen
3
- emoji: πŸš€
4
- colorFrom: blue
 
 
 
 
 
 
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 3.39.0
8
- app_file: app.py
9
- pinned: false
10
  license: cc-by-nc-4.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
  title: AudioGen
2
+ python_version: '3.9'
3
+ tags:
4
+ - audio generation
5
+ - language models
6
+ - LLMs
7
+ app_file: app.py
8
+ emoji: πŸ”Š
9
+ colorFrom: white
10
  colorTo: blue
11
  sdk: gradio
12
+ sdk_version: 3.34.0
13
+ pinned: true
 
14
  license: cc-by-nc-4.0
 
 
 
app.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py
8
+ # also released under the MIT license.
9
+
10
+ import argparse
11
+ from concurrent.futures import ProcessPoolExecutor
12
+ import os
13
+ from pathlib import Path
14
+ import subprocess as sp
15
+ from tempfile import NamedTemporaryFile
16
+ import time
17
+ import typing as tp
18
+ import warnings
19
+
20
+ import torch
21
+ import gradio as gr
22
+
23
+ from audiocraft.data.audio_utils import convert_audio
24
+ from audiocraft.data.audio import audio_write
25
+ from audiocraft.models import AudioGen, MultiBandDiffusion
26
+
27
+
28
+ MODEL = None # Last used model
29
+ INTERRUPTING = False
30
+ # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
31
+ _old_call = sp.call
32
+
33
+
34
+ def _call_nostderr(*args, **kwargs):
35
+ # Avoid ffmpeg vomiting on the logs.
36
+ kwargs['stderr'] = sp.DEVNULL
37
+ kwargs['stdout'] = sp.DEVNULL
38
+ _old_call(*args, **kwargs)
39
+
40
+
41
+ sp.call = _call_nostderr
42
+ # Preallocating the pool of processes.
43
+ pool = ProcessPoolExecutor(4)
44
+ pool.__enter__()
45
+
46
+
47
+ def interrupt():
48
+ global INTERRUPTING
49
+ INTERRUPTING = True
50
+
51
+
52
+ class FileCleaner:
53
+ def __init__(self, file_lifetime: float = 3600):
54
+ self.file_lifetime = file_lifetime
55
+ self.files = []
56
+
57
+ def add(self, path: tp.Union[str, Path]):
58
+ self._cleanup()
59
+ self.files.append((time.time(), Path(path)))
60
+
61
+ def _cleanup(self):
62
+ now = time.time()
63
+ for time_added, path in list(self.files):
64
+ if now - time_added > self.file_lifetime:
65
+ if path.exists():
66
+ path.unlink()
67
+ self.files.pop(0)
68
+ else:
69
+ break
70
+
71
+
72
+ file_cleaner = FileCleaner()
73
+
74
+
75
+ def make_waveform(*args, **kwargs):
76
+ # Further remove some warnings.
77
+ be = time.time()
78
+ with warnings.catch_warnings():
79
+ warnings.simplefilter('ignore')
80
+ out = gr.make_waveform(*args, **kwargs)
81
+ print("Make a video took", time.time() - be)
82
+ return out
83
+
84
+
85
+ def load_model(version='facebook/audiogen-medium'):
86
+ global MODEL
87
+ print("Loading model", version)
88
+ if MODEL is None or MODEL.name != version:
89
+ MODEL = AudioGen.get_pretrained(version)
90
+
91
+
92
+ def load_diffusion():
93
+ global MBD
94
+ print("loading MBD")
95
+ MBD = MultiBandDiffusion.get_mbd_musicgen()
96
+
97
+
98
+ def _do_predictions(texts, duration, progress=False, **gen_kwargs):
99
+ MODEL.set_generation_params(duration=duration, **gen_kwargs)
100
+ be = time.time()
101
+ target_sr = 32000
102
+ target_ac = 1
103
+
104
+ outputs = MODEL.generate(texts, progress=progress)
105
+ if USE_DIFFUSION:
106
+ outputs_diffusion = MBD.tokens_to_wav(outputs[1])
107
+ outputs = torch.cat([outputs[0], outputs_diffusion], dim=0)
108
+ outputs = outputs.detach().cpu().float()
109
+ pending_videos = []
110
+ out_wavs = []
111
+ for output in outputs:
112
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
113
+ audio_write(
114
+ file.name, output, MODEL.sample_rate, strategy="loudness",
115
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
116
+ pending_videos.append(pool.submit(make_waveform, file.name))
117
+ out_wavs.append(file.name)
118
+ file_cleaner.add(file.name)
119
+ out_videos = [pending_video.result() for pending_video in pending_videos]
120
+ for video in out_videos:
121
+ file_cleaner.add(video)
122
+ print("batch finished", len(texts), time.time() - be)
123
+ print("Tempfiles currently stored: ", len(file_cleaner.files))
124
+ return out_videos, out_wavs
125
+
126
+
127
+
128
+ def predict_full(model, decoder, text, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
129
+ global INTERRUPTING
130
+ global USE_DIFFUSION
131
+ INTERRUPTING = False
132
+ if temperature < 0:
133
+ raise gr.Error("Temperature must be >= 0.")
134
+ if topk < 0:
135
+ raise gr.Error("Topk must be non-negative.")
136
+ if topp < 0:
137
+ raise gr.Error("Topp must be non-negative.")
138
+
139
+ topk = int(topk)
140
+ if decoder == "MultiBand_Diffusion":
141
+ USE_DIFFUSION = True
142
+ load_diffusion()
143
+ else:
144
+ USE_DIFFUSION = False
145
+ load_model(model)
146
+
147
+ def _progress(generated, to_generate):
148
+ progress((min(generated, to_generate), to_generate))
149
+ if INTERRUPTING:
150
+ raise gr.Error("Interrupted.")
151
+ MODEL.set_custom_progress_callback(_progress)
152
+
153
+ videos, wavs = _do_predictions(
154
+ [text], duration, progress=True,
155
+ top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
156
+ if USE_DIFFUSION:
157
+ return videos[0], wavs[0], videos[1], wavs[1]
158
+ return videos[0], wavs[0], None, None
159
+ return videos[0], wavs[0]
160
+
161
+
162
+
163
+ def toggle_diffusion(choice):
164
+ if choice == "MultiBand_Diffusion":
165
+ return [gr.update(visible=True)] * 2
166
+ else:
167
+ return [gr.update(visible=False)] * 2
168
+
169
+
170
+ def ui_full(launch_kwargs):
171
+ with gr.Blocks() as interface:
172
+ gr.Markdown(
173
+ """
174
+ # AudioGen
175
+ This is your private demo for [AudioGen](https://github.com/facebookresearch/audiocraft/blob/main/docs/AUDIOGEN.md),
176
+ a simple and controllable model for audio generation
177
+ """
178
+ )
179
+ with gr.Row():
180
+ with gr.Column():
181
+ with gr.Row():
182
+ text = gr.Text(label="Input Text", interactive=True)
183
+ with gr.Row():
184
+ submit = gr.Button("Submit")
185
+ # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
186
+ _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
187
+ with gr.Row():
188
+ model = gr.Radio(["facebook/audiogen-medium"], label="Model", value="facebook/audiogen-medium", interactive=True)
189
+ with gr.Row():
190
+ decoder = gr.Radio(["Default"], label="Decoder", value="Default", interactive=False)
191
+ with gr.Row():
192
+ duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
193
+ with gr.Row():
194
+ topk = gr.Number(label="Top-k", value=250, interactive=True)
195
+ topp = gr.Number(label="Top-p", value=0, interactive=True)
196
+ temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
197
+ cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
198
+ with gr.Column():
199
+ output = gr.Video(label="Generated Audio")
200
+ audio_output = gr.Audio(label="Generated Audio (wav)", type='filepath')
201
+ submit.click(predict_full, inputs=[model, decoder, text, duration, topk, topp, temperature, cfg_coef], outputs=[output, audio_output])
202
+
203
+ interface.queue().launch(**launch_kwargs)
204
+
205
+
206
+
207
+ if __name__ == "__main__":
208
+ parser = argparse.ArgumentParser()
209
+ parser.add_argument(
210
+ '--listen',
211
+ type=str,
212
+ default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
213
+ help='IP to listen on for connections to Gradio',
214
+ )
215
+ parser.add_argument(
216
+ '--username', type=str, default='', help='Username for authentication'
217
+ )
218
+ parser.add_argument(
219
+ '--password', type=str, default='', help='Password for authentication'
220
+ )
221
+ parser.add_argument(
222
+ '--server_port',
223
+ type=int,
224
+ default=0,
225
+ help='Port to run the server listener on',
226
+ )
227
+ parser.add_argument(
228
+ '--inbrowser', action='store_true', help='Open in browser'
229
+ )
230
+ parser.add_argument(
231
+ '--share', action='store_true', help='Share the gradio UI'
232
+ )
233
+
234
+ args = parser.parse_args()
235
+
236
+ launch_kwargs = {}
237
+ launch_kwargs['server_name'] = args.listen
238
+
239
+ if args.username and args.password:
240
+ launch_kwargs['auth'] = (args.username, args.password)
241
+ if args.server_port:
242
+ launch_kwargs['server_port'] = args.server_port
243
+ if args.inbrowser:
244
+ launch_kwargs['inbrowser'] = args.inbrowser
245
+ if args.share:
246
+ launch_kwargs['share'] = args.share
247
+
248
+ # Show the interface
249
+ ui_full(launch_kwargs)
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # please make sure you have already a pytorch install that is cuda enabled!
2
+ av
3
+ einops
4
+ flashy>=0.0.1
5
+ hydra-core>=1.1
6
+ hydra_colorlog
7
+ julius
8
+ num2words
9
+ numpy
10
+ sentencepiece
11
+ spacy==3.5.2
12
+ torch>=2.0.0
13
+ torchaudio>=2.0.0
14
+ huggingface_hub
15
+ tqdm
16
+ transformers>=4.31.0 # need Encodec there.
17
+ xformers
18
+ demucs
19
+ librosa
20
+ gradio
21
+ torchmetrics
22
+ encodec