YUNSUN7 commited on
Commit
9971dc1
·
verified ·
1 Parent(s): a27ff76

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +501 -0
app.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
3
+ # Zero-shot Audio Source Separation via Query-based Learning from Weakly-labeled Data
4
+ # The Main Script
5
+
6
+ import os
7
+ # this is to avoid the sdr calculation from occupying all cpus
8
+ os.environ["OMP_NUM_THREADS"] = "4"
9
+ os.environ["OPENBLAS_NUM_THREADS"] = "4"
10
+ os.environ["MKL_NUM_THREADS"] = "6"
11
+ os.environ["VECLIB_MAXIMUM_THREADS"] = "4"
12
+ os.environ["NUMEXPR_NUM_THREADS"] = "6"
13
+
14
+ import sys
15
+ import librosa
16
+ import numpy as np
17
+ import argparse
18
+ import logging
19
+
20
+ import torch
21
+ from torch.utils.data import DataLoader
22
+ from torch.utils.data.distributed import DistributedSampler
23
+
24
+ from utils import collect_fn, dump_config, create_folder, prepprocess_audio
25
+ import musdb
26
+
27
+ from models.asp_model import ZeroShotASP, SeparatorModel, AutoTaggingWarpper, WhitingWarpper
28
+ from data_processor import LGSPDataset, MusdbDataset
29
+ import config
30
+ import htsat_config
31
+ from models.htsat import HTSAT_Swin_Transformer
32
+ from sed_model import SEDWrapper
33
+
34
+ import pytorch_lightning as pl
35
+ from pytorch_lightning.callbacks import ModelCheckpoint
36
+
37
+ from htsat_utils import process_idc
38
+
39
+ import warnings
40
+ warnings.filterwarnings("ignore")
41
+
42
+
43
+
44
+ class data_prep(pl.LightningDataModule):
45
+ def __init__(self, train_dataset, eval_dataset, device_num, config):
46
+ super().__init__()
47
+ self.train_dataset = train_dataset
48
+ self.eval_dataset = eval_dataset
49
+ self.device_num = device_num
50
+ self.config = config
51
+
52
+ def train_dataloader(self):
53
+ train_sampler = DistributedSampler(self.train_dataset, shuffle = False) if self.device_num > 1 else None
54
+ train_loader = DataLoader(
55
+ dataset = self.train_dataset,
56
+ num_workers = config.num_workers,
57
+ batch_size = config.batch_size // self.device_num,
58
+ shuffle = False,
59
+ sampler = train_sampler,
60
+ collate_fn = collect_fn
61
+ )
62
+ return train_loader
63
+ def val_dataloader(self):
64
+ eval_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None
65
+ eval_loader = DataLoader(
66
+ dataset = self.eval_dataset,
67
+ num_workers = config.num_workers,
68
+ batch_size = config.batch_size // self.device_num,
69
+ shuffle = False,
70
+ sampler = eval_sampler,
71
+ collate_fn = collect_fn
72
+ )
73
+ return eval_loader
74
+ def test_dataloader(self):
75
+ test_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None
76
+ test_loader = DataLoader(
77
+ dataset = self.eval_dataset,
78
+ num_workers = config.num_workers,
79
+ batch_size = config.batch_size // self.device_num,
80
+ shuffle = False,
81
+ sampler = test_sampler,
82
+ collate_fn = collect_fn
83
+ )
84
+ return test_loader
85
+
86
+ def save_idc():
87
+ train_index_path = os.path.join(config.dataset_path, "hdf5s", "indexes", config.index_type + ".h5")
88
+ eval_index_path = os.path.join(config.dataset_path,"hdf5s", "indexes", "eval.h5")
89
+ process_idc(train_index_path, config.classes_num, config.index_type + "_idc.npy")
90
+ process_idc(eval_index_path, config.classes_num, "eval_idc.npy")
91
+
92
+ # Process the musdb tracks into the sample rate of 32000 Hz sample rate, the original is 44100 Hz
93
+ def process_musdb():
94
+ # use musdb as testset
95
+ test_data = musdb.DB(
96
+ root = config.musdb_path,
97
+ download = False,
98
+ subsets = "test",
99
+ is_wav = True
100
+ )
101
+ print(len(test_data.tracks))
102
+ mus_tracks = []
103
+ # in musdb, all fs is the same (44100)
104
+ orig_fs = test_data.tracks[0].rate
105
+ print(orig_fs)
106
+ for track in test_data.tracks:
107
+ temp = {}
108
+ mixture = prepprocess_audio(
109
+ track.audio,
110
+ orig_fs, config.sample_rate,
111
+ config.test_type
112
+ )
113
+ temp["mixture" ]= mixture
114
+ for dickey in config.test_key:
115
+ source = prepprocess_audio(
116
+ track.targets[dickey].audio,
117
+ orig_fs, config.sample_rate,
118
+ config.test_type
119
+ )
120
+ temp[dickey] = source
121
+ print(track.audio.shape, len(temp.keys()), temp["mixture"].shape)
122
+ mus_tracks.append(temp)
123
+ print(len(mus_tracks))
124
+ # save the file to npy
125
+ np.save("musdb-32000fs.npy", mus_tracks)
126
+
127
+ # weight average will perform in the given folder
128
+ # It will output one model checkpoint, which avergas the weight of all models in the folder
129
+ def weight_average():
130
+ model_ckpt = []
131
+ model_files = os.listdir(config.wa_model_folder)
132
+ wa_ckpt = {
133
+ "state_dict": {}
134
+ }
135
+
136
+ for model_file in model_files:
137
+ model_file = os.path.join(config.esm_model_folder, model_file)
138
+ model_ckpt.append(torch.load(model_file, map_location="cpu")["state_dict"])
139
+ keys = model_ckpt[0].keys()
140
+ for key in keys:
141
+ model_ckpt_key = torch.cat([d[key].float().unsqueeze(0) for d in model_ckpt])
142
+ model_ckpt_key = torch.mean(model_ckpt_key, dim = 0)
143
+ assert model_ckpt_key.shape == model_ckpt[0][key].shape, "the shape is unmatched " + model_ckpt_key.shape + " " + model_ckpt[0][key].shape
144
+ wa_ckpt["state_dict"][key] = model_ckpt_key
145
+ torch.save(wa_ckpt, config.wa_model_path)
146
+
147
+
148
+ # use the model to quickly separate a track given a query
149
+ # it requires four variables in config.py:
150
+ # inference_file: the track you want to separate
151
+ # inference_query: a **folder** containing all samples from the same source
152
+ # test_key: ["name"] indicate the source name (just a name for final output, no other functions)
153
+ # wave_output_path: the output folder
154
+
155
+ # make sure the query folder contain the samples from the same source
156
+ # each time, the model is able to separate one source from the track
157
+ # if you want to separate multiple sources, you need to change the query folder or write a script to help you do that
158
+ def inference():
159
+ # set exp settings
160
+ device_name = "cuda" if torch.cuda.is_available() else "cpu"
161
+ device = torch.device("cuda")
162
+ assert config.test_key is not None, "there should be a separate key"
163
+ create_folder(config.wave_output_path)
164
+ test_track, fs = librosa.load(config.inference_file, sr = None)
165
+ test_track = test_track[:,None]
166
+ print(test_track.shape)
167
+ print(fs)
168
+ # convert the track into 32000 Hz sample rate
169
+ test_track = prepprocess_audio(
170
+ test_track,
171
+ fs, config.sample_rate,
172
+ config.test_type
173
+ )
174
+ test_tracks = []
175
+ temp = [test_track]
176
+ for dickey in config.test_key:
177
+ temp.append(test_track)
178
+ temp = np.array(temp)
179
+ test_tracks.append(temp)
180
+ dataset = MusdbDataset(tracks = test_tracks) # the action is similar to musdbdataset, reuse it
181
+ loader = DataLoader(
182
+ dataset = dataset,
183
+ num_workers = 1,
184
+ batch_size = 1,
185
+ shuffle = False
186
+ )
187
+ # obtain the samples for query
188
+ queries = []
189
+ for query_file in os.listdir(config.inference_query):
190
+ f_path = os.path.join(config.inference_query, query_file)
191
+ if query_file.endswith(".wav"):
192
+ temp_q, fs = librosa.load(f_path, sr = None)
193
+ temp_q = temp_q[:, None]
194
+ temp_q = prepprocess_audio(
195
+ temp_q,
196
+ fs, config.sample_rate,
197
+ config.test_type
198
+ )
199
+ temp = [temp_q]
200
+ for dickey in config.test_key:
201
+ temp.append(temp_q)
202
+ temp = np.array(temp)
203
+ queries.append(temp)
204
+
205
+ assert config.resume_checkpoint is not None, "there should be a saved model when inferring"
206
+
207
+ sed_model = HTSAT_Swin_Transformer(
208
+ spec_size=htsat_config.htsat_spec_size,
209
+ patch_size=htsat_config.htsat_patch_size,
210
+ in_chans=1,
211
+ num_classes=htsat_config.classes_num,
212
+ window_size=htsat_config.htsat_window_size,
213
+ config = htsat_config,
214
+ depths = htsat_config.htsat_depth,
215
+ embed_dim = htsat_config.htsat_dim,
216
+ patch_stride=htsat_config.htsat_stride,
217
+ num_heads=htsat_config.htsat_num_head
218
+ )
219
+ at_model = SEDWrapper(
220
+ sed_model = sed_model,
221
+ config = htsat_config,
222
+ dataset = None
223
+ )
224
+ ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu")
225
+ at_model.load_state_dict(ckpt["state_dict"])
226
+
227
+ trainer = pl.Trainer(
228
+ gpus = 1
229
+ )
230
+ avg_at = None
231
+ # obtain the latent embedding as query
232
+ if config.infer_type == "mean":
233
+ avg_dataset = MusdbDataset(tracks = queries)
234
+ avg_loader = DataLoader(
235
+ dataset = avg_dataset,
236
+ num_workers = 1,
237
+ batch_size = 1,
238
+ shuffle = False
239
+ )
240
+ at_wrapper = AutoTaggingWarpper(
241
+ at_model = at_model,
242
+ config = config,
243
+ target_keys = config.test_key
244
+ )
245
+ trainer.test(at_wrapper, test_dataloaders = avg_loader)
246
+ avg_at = at_wrapper.avg_at
247
+
248
+ # import seapration model
249
+ model = ZeroShotASP(
250
+ channels = 1, config = config,
251
+ at_model = at_model,
252
+ dataset = dataset
253
+ )
254
+ # resume checkpoint
255
+ ckpt = torch.load(config.resume_checkpoint, map_location="cpu")
256
+ model.load_state_dict(ckpt["state_dict"], strict= False)
257
+ exp_model = SeparatorModel(
258
+ model = model,
259
+ config = config,
260
+ target_keys = config.test_key,
261
+ avg_at = avg_at,
262
+ using_wiener = False,
263
+ calc_sdr = False,
264
+ output_wav = True
265
+ )
266
+ trainer.test(exp_model, test_dataloaders = loader)
267
+
268
+ # test the separation model, mainly in musdb
269
+ def test():
270
+ # set exp settings
271
+ device_name = "cuda" if torch.cuda.is_available() else "cpu"
272
+ device = torch.device("cuda")
273
+ assert config.test_key is not None, "there should be a separate key"
274
+ create_folder(config.wave_output_path)
275
+ # use musdb as testset
276
+ test_data = np.load(config.testset_path, allow_pickle = True)
277
+ print(len(test_data))
278
+ mus_tracks = []
279
+ # in musdb, all fs is the same (44100)
280
+ # load the dataset
281
+ for track in test_data:
282
+ temp = []
283
+ mixture = track["mixture"]
284
+ temp.append(mixture)
285
+ for dickey in config.test_key:
286
+ source = track[dickey]
287
+ temp.append(source)
288
+ temp = np.array(temp)
289
+ print(temp.shape)
290
+ mus_tracks.append(temp)
291
+ print(len(mus_tracks))
292
+ dataset = MusdbDataset(tracks = mus_tracks)
293
+ loader = DataLoader(
294
+ dataset = dataset,
295
+ num_workers = 1,
296
+ batch_size = 1,
297
+ shuffle = False
298
+ )
299
+ assert config.resume_checkpoint is not None, "there should be a saved model when inferring"
300
+
301
+ sed_model = HTSAT_Swin_Transformer(
302
+ spec_size=htsat_config.htsat_spec_size,
303
+ patch_size=htsat_config.htsat_patch_size,
304
+ in_chans=1,
305
+ num_classes=htsat_config.classes_num,
306
+ window_size=htsat_config.htsat_window_size,
307
+ config = htsat_config,
308
+ depths = htsat_config.htsat_depth,
309
+ embed_dim = htsat_config.htsat_dim,
310
+ patch_stride=htsat_config.htsat_stride,
311
+ num_heads=htsat_config.htsat_num_head
312
+ )
313
+ at_model = SEDWrapper(
314
+ sed_model = sed_model,
315
+ config = htsat_config,
316
+ dataset = None
317
+ )
318
+ ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu")
319
+ at_model.load_state_dict(ckpt["state_dict"])
320
+ trainer = pl.Trainer(
321
+ gpus = 1
322
+ )
323
+ avg_at = None
324
+ # obtain the query of four stems from the training set
325
+ if config.infer_type == "mean":
326
+ avg_data = np.load(config.testavg_path, allow_pickle = True)[:90]
327
+ print(len(avg_data))
328
+ avgmus_tracks = []
329
+ # in musdb, all fs is the same (44100)
330
+ # load the dataset
331
+ for track in avg_data:
332
+ temp = []
333
+ mixture = track["mixture"]
334
+ temp.append(mixture)
335
+ for dickey in config.test_key:
336
+ source = track[dickey]
337
+ temp.append(source)
338
+ temp = np.array(temp)
339
+ print(temp.shape)
340
+ avgmus_tracks.append(temp)
341
+ print(len(avgmus_tracks))
342
+ avg_dataset = MusdbDataset(tracks = avgmus_tracks)
343
+ avg_loader = DataLoader(
344
+ dataset = avg_dataset,
345
+ num_workers = 1,
346
+ batch_size = 1,
347
+ shuffle = False
348
+ )
349
+ at_wrapper = AutoTaggingWarpper(
350
+ at_model = at_model,
351
+ config = config,
352
+ target_keys = config.test_key
353
+ )
354
+ trainer.test(at_wrapper, test_dataloaders = avg_loader)
355
+ avg_at = at_wrapper.avg_at
356
+
357
+ model = ZeroShotASP(
358
+ channels = 1, config = config,
359
+ at_model = at_model,
360
+ dataset = dataset
361
+ )
362
+ ckpt = torch.load(config.resume_checkpoint, map_location="cpu")
363
+ model.load_state_dict(ckpt["state_dict"], strict= False)
364
+ exp_model = SeparatorModel(
365
+ model = model,
366
+ config = config,
367
+ target_keys = config.test_key,
368
+ avg_at = avg_at,
369
+ using_wiener = config.using_wiener
370
+ )
371
+ trainer.test(exp_model, test_dataloaders = loader)
372
+
373
+ def train():
374
+ # set exp settings
375
+ # device_name = "cuda" if torch.cuda.is_available() else "cpu"
376
+ # device = torch.device("cuda")
377
+
378
+ device_num = torch.cuda.device_count()
379
+ print("each batch size:", config.batch_size // device_num)
380
+
381
+ train_index_path = os.path.join(config.dataset_path, "hdf5s","indexes", config.index_type + ".h5")
382
+ train_idc = np.load(os.path.join(config.idc_path, config.index_type + "_idc.npy"), allow_pickle = True)
383
+
384
+ eval_index_path = os.path.join(config.dataset_path,"hdf5s", "indexes", "eval.h5")
385
+ eval_idc = np.load(os.path.join(config.idc_path, "eval_idc.npy"), allow_pickle = True)
386
+
387
+ # set exp folder
388
+ exp_dir = os.path.join(config.workspace, "results", config.exp_name)
389
+ checkpoint_dir = os.path.join(config.workspace, "results", config.exp_name, "checkpoint")
390
+
391
+ if not config.debug:
392
+ create_folder(os.path.join(config.workspace, "results"))
393
+ create_folder(exp_dir)
394
+ create_folder(checkpoint_dir)
395
+ dump_config(config, os.path.join(exp_dir, config.exp_name), False)
396
+
397
+ # load data
398
+ # import dataset LGSPDataset (latent general source separation) and sampler
399
+ dataset = LGSPDataset(
400
+ index_path = train_index_path,
401
+ idc = train_idc,
402
+ config = config,
403
+ factor = 0.05,
404
+ eval_mode = False
405
+ )
406
+ eval_dataset = LGSPDataset(
407
+ index_path = eval_index_path,
408
+ idc = eval_idc,
409
+ config = config,
410
+ factor = 0.05,
411
+ eval_mode = True
412
+ )
413
+
414
+ audioset_data = data_prep(train_dataset=dataset,eval_dataset=eval_dataset,device_num=device_num, config=config)
415
+ checkpoint_callback = ModelCheckpoint(
416
+ monitor = "mixture_sdr",
417
+ filename='l-{epoch:d}-{mixture_sdr:.3f}-{clean_sdr:.3f}-{silence_sdr:.3f}',
418
+ save_top_k = 10,
419
+ mode = "max"
420
+ )
421
+ # infer at model
422
+ sed_model = HTSAT_Swin_Transformer(
423
+ spec_size=htsat_config.htsat_spec_size,
424
+ patch_size=htsat_config.htsat_patch_size,
425
+ in_chans=1,
426
+ num_classes=htsat_config.classes_num,
427
+ window_size=htsat_config.htsat_window_size,
428
+ config = htsat_config,
429
+ depths = htsat_config.htsat_depth,
430
+ embed_dim = htsat_config.htsat_dim,
431
+ patch_stride=htsat_config.htsat_stride,
432
+ num_heads=htsat_config.htsat_num_head
433
+ )
434
+ at_model = SEDWrapper(
435
+ sed_model = sed_model,
436
+ config = htsat_config,
437
+ dataset = None
438
+ )
439
+ # load the checkpoint
440
+ ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu")
441
+ at_model.load_state_dict(ckpt["state_dict"])
442
+
443
+ trainer = pl.Trainer(
444
+ deterministic=True,
445
+ default_root_dir = checkpoint_dir,
446
+ gpus = device_num,
447
+ val_check_interval = 0.2,
448
+ # check_val_every_n_epoch = 1,
449
+ max_epochs = config.max_epoch,
450
+ auto_lr_find = True,
451
+ sync_batchnorm = True,
452
+ callbacks = [checkpoint_callback],
453
+ accelerator = "ddp" if device_num > 1 else None,
454
+ resume_from_checkpoint = None, #config.resume_checkpoint,
455
+ replace_sampler_ddp = False,
456
+ gradient_clip_val=1.0,
457
+ num_sanity_val_steps = 0,
458
+ )
459
+ model = ZeroShotASP(
460
+ channels = 1, config = config,
461
+ at_model = at_model,
462
+ dataset = dataset
463
+ )
464
+ if config.resume_checkpoint is not None:
465
+ ckpt = torch.load(config.resume_checkpoint, map_location="cpu")
466
+ model.load_state_dict(ckpt["state_dict"])
467
+ # trainer.test(model, datamodule = audioset_data)
468
+ trainer.fit(model, audioset_data)
469
+
470
+ def main():
471
+ parser = argparse.ArgumentParser(description="latent genreal source separation parser")
472
+ subparsers = parser.add_subparsers(dest = "mode")
473
+ parser_train = subparsers.add_parser("train")
474
+ parser_test = subparsers.add_parser("test")
475
+ parser_musdb = subparsers.add_parser("musdb_process")
476
+ parser_saveidc = subparsers.add_parser("save_idc")
477
+ parser_wa = subparsers.add_parser("weight_average")
478
+ parser_infer = subparsers.add_parser("inference")
479
+ args = parser.parse_args()
480
+ # default settings
481
+ logging.basicConfig(level=logging.INFO)
482
+ pl.utilities.seed.seed_everything(seed = config.random_seed)
483
+
484
+ if args.mode == "train":
485
+ train()
486
+ elif args.mode == "test":
487
+ test()
488
+ elif args.mode == "musdb_process":
489
+ process_musdb()
490
+ elif args.mode == "weight_average":
491
+ weight_average()
492
+ elif args.mode == "save_idc":
493
+ save_idc()
494
+ elif args.mode == "inference":
495
+ inference()
496
+ else:
497
+ raise Exception("Error Mode!")
498
+
499
+
500
+ if __name__ == '__main__':
501
+ main()