Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ke Chen
|
2 | |
3 |
+
# Zero-shot Audio Source Separation via Query-based Learning from Weakly-labeled Data
|
4 |
+
# The Main Script
|
5 |
+
|
6 |
+
import os
|
7 |
+
# this is to avoid the sdr calculation from occupying all cpus
|
8 |
+
os.environ["OMP_NUM_THREADS"] = "4"
|
9 |
+
os.environ["OPENBLAS_NUM_THREADS"] = "4"
|
10 |
+
os.environ["MKL_NUM_THREADS"] = "6"
|
11 |
+
os.environ["VECLIB_MAXIMUM_THREADS"] = "4"
|
12 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "6"
|
13 |
+
|
14 |
+
import sys
|
15 |
+
import librosa
|
16 |
+
import numpy as np
|
17 |
+
import argparse
|
18 |
+
import logging
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torch.utils.data import DataLoader
|
22 |
+
from torch.utils.data.distributed import DistributedSampler
|
23 |
+
|
24 |
+
from utils import collect_fn, dump_config, create_folder, prepprocess_audio
|
25 |
+
import musdb
|
26 |
+
|
27 |
+
from models.asp_model import ZeroShotASP, SeparatorModel, AutoTaggingWarpper, WhitingWarpper
|
28 |
+
from data_processor import LGSPDataset, MusdbDataset
|
29 |
+
import config
|
30 |
+
import htsat_config
|
31 |
+
from models.htsat import HTSAT_Swin_Transformer
|
32 |
+
from sed_model import SEDWrapper
|
33 |
+
|
34 |
+
import pytorch_lightning as pl
|
35 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
36 |
+
|
37 |
+
from htsat_utils import process_idc
|
38 |
+
|
39 |
+
import warnings
|
40 |
+
warnings.filterwarnings("ignore")
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
class data_prep(pl.LightningDataModule):
|
45 |
+
def __init__(self, train_dataset, eval_dataset, device_num, config):
|
46 |
+
super().__init__()
|
47 |
+
self.train_dataset = train_dataset
|
48 |
+
self.eval_dataset = eval_dataset
|
49 |
+
self.device_num = device_num
|
50 |
+
self.config = config
|
51 |
+
|
52 |
+
def train_dataloader(self):
|
53 |
+
train_sampler = DistributedSampler(self.train_dataset, shuffle = False) if self.device_num > 1 else None
|
54 |
+
train_loader = DataLoader(
|
55 |
+
dataset = self.train_dataset,
|
56 |
+
num_workers = config.num_workers,
|
57 |
+
batch_size = config.batch_size // self.device_num,
|
58 |
+
shuffle = False,
|
59 |
+
sampler = train_sampler,
|
60 |
+
collate_fn = collect_fn
|
61 |
+
)
|
62 |
+
return train_loader
|
63 |
+
def val_dataloader(self):
|
64 |
+
eval_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None
|
65 |
+
eval_loader = DataLoader(
|
66 |
+
dataset = self.eval_dataset,
|
67 |
+
num_workers = config.num_workers,
|
68 |
+
batch_size = config.batch_size // self.device_num,
|
69 |
+
shuffle = False,
|
70 |
+
sampler = eval_sampler,
|
71 |
+
collate_fn = collect_fn
|
72 |
+
)
|
73 |
+
return eval_loader
|
74 |
+
def test_dataloader(self):
|
75 |
+
test_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None
|
76 |
+
test_loader = DataLoader(
|
77 |
+
dataset = self.eval_dataset,
|
78 |
+
num_workers = config.num_workers,
|
79 |
+
batch_size = config.batch_size // self.device_num,
|
80 |
+
shuffle = False,
|
81 |
+
sampler = test_sampler,
|
82 |
+
collate_fn = collect_fn
|
83 |
+
)
|
84 |
+
return test_loader
|
85 |
+
|
86 |
+
def save_idc():
|
87 |
+
train_index_path = os.path.join(config.dataset_path, "hdf5s", "indexes", config.index_type + ".h5")
|
88 |
+
eval_index_path = os.path.join(config.dataset_path,"hdf5s", "indexes", "eval.h5")
|
89 |
+
process_idc(train_index_path, config.classes_num, config.index_type + "_idc.npy")
|
90 |
+
process_idc(eval_index_path, config.classes_num, "eval_idc.npy")
|
91 |
+
|
92 |
+
# Process the musdb tracks into the sample rate of 32000 Hz sample rate, the original is 44100 Hz
|
93 |
+
def process_musdb():
|
94 |
+
# use musdb as testset
|
95 |
+
test_data = musdb.DB(
|
96 |
+
root = config.musdb_path,
|
97 |
+
download = False,
|
98 |
+
subsets = "test",
|
99 |
+
is_wav = True
|
100 |
+
)
|
101 |
+
print(len(test_data.tracks))
|
102 |
+
mus_tracks = []
|
103 |
+
# in musdb, all fs is the same (44100)
|
104 |
+
orig_fs = test_data.tracks[0].rate
|
105 |
+
print(orig_fs)
|
106 |
+
for track in test_data.tracks:
|
107 |
+
temp = {}
|
108 |
+
mixture = prepprocess_audio(
|
109 |
+
track.audio,
|
110 |
+
orig_fs, config.sample_rate,
|
111 |
+
config.test_type
|
112 |
+
)
|
113 |
+
temp["mixture" ]= mixture
|
114 |
+
for dickey in config.test_key:
|
115 |
+
source = prepprocess_audio(
|
116 |
+
track.targets[dickey].audio,
|
117 |
+
orig_fs, config.sample_rate,
|
118 |
+
config.test_type
|
119 |
+
)
|
120 |
+
temp[dickey] = source
|
121 |
+
print(track.audio.shape, len(temp.keys()), temp["mixture"].shape)
|
122 |
+
mus_tracks.append(temp)
|
123 |
+
print(len(mus_tracks))
|
124 |
+
# save the file to npy
|
125 |
+
np.save("musdb-32000fs.npy", mus_tracks)
|
126 |
+
|
127 |
+
# weight average will perform in the given folder
|
128 |
+
# It will output one model checkpoint, which avergas the weight of all models in the folder
|
129 |
+
def weight_average():
|
130 |
+
model_ckpt = []
|
131 |
+
model_files = os.listdir(config.wa_model_folder)
|
132 |
+
wa_ckpt = {
|
133 |
+
"state_dict": {}
|
134 |
+
}
|
135 |
+
|
136 |
+
for model_file in model_files:
|
137 |
+
model_file = os.path.join(config.esm_model_folder, model_file)
|
138 |
+
model_ckpt.append(torch.load(model_file, map_location="cpu")["state_dict"])
|
139 |
+
keys = model_ckpt[0].keys()
|
140 |
+
for key in keys:
|
141 |
+
model_ckpt_key = torch.cat([d[key].float().unsqueeze(0) for d in model_ckpt])
|
142 |
+
model_ckpt_key = torch.mean(model_ckpt_key, dim = 0)
|
143 |
+
assert model_ckpt_key.shape == model_ckpt[0][key].shape, "the shape is unmatched " + model_ckpt_key.shape + " " + model_ckpt[0][key].shape
|
144 |
+
wa_ckpt["state_dict"][key] = model_ckpt_key
|
145 |
+
torch.save(wa_ckpt, config.wa_model_path)
|
146 |
+
|
147 |
+
|
148 |
+
# use the model to quickly separate a track given a query
|
149 |
+
# it requires four variables in config.py:
|
150 |
+
# inference_file: the track you want to separate
|
151 |
+
# inference_query: a **folder** containing all samples from the same source
|
152 |
+
# test_key: ["name"] indicate the source name (just a name for final output, no other functions)
|
153 |
+
# wave_output_path: the output folder
|
154 |
+
|
155 |
+
# make sure the query folder contain the samples from the same source
|
156 |
+
# each time, the model is able to separate one source from the track
|
157 |
+
# if you want to separate multiple sources, you need to change the query folder or write a script to help you do that
|
158 |
+
def inference():
|
159 |
+
# set exp settings
|
160 |
+
device_name = "cuda" if torch.cuda.is_available() else "cpu"
|
161 |
+
device = torch.device("cuda")
|
162 |
+
assert config.test_key is not None, "there should be a separate key"
|
163 |
+
create_folder(config.wave_output_path)
|
164 |
+
test_track, fs = librosa.load(config.inference_file, sr = None)
|
165 |
+
test_track = test_track[:,None]
|
166 |
+
print(test_track.shape)
|
167 |
+
print(fs)
|
168 |
+
# convert the track into 32000 Hz sample rate
|
169 |
+
test_track = prepprocess_audio(
|
170 |
+
test_track,
|
171 |
+
fs, config.sample_rate,
|
172 |
+
config.test_type
|
173 |
+
)
|
174 |
+
test_tracks = []
|
175 |
+
temp = [test_track]
|
176 |
+
for dickey in config.test_key:
|
177 |
+
temp.append(test_track)
|
178 |
+
temp = np.array(temp)
|
179 |
+
test_tracks.append(temp)
|
180 |
+
dataset = MusdbDataset(tracks = test_tracks) # the action is similar to musdbdataset, reuse it
|
181 |
+
loader = DataLoader(
|
182 |
+
dataset = dataset,
|
183 |
+
num_workers = 1,
|
184 |
+
batch_size = 1,
|
185 |
+
shuffle = False
|
186 |
+
)
|
187 |
+
# obtain the samples for query
|
188 |
+
queries = []
|
189 |
+
for query_file in os.listdir(config.inference_query):
|
190 |
+
f_path = os.path.join(config.inference_query, query_file)
|
191 |
+
if query_file.endswith(".wav"):
|
192 |
+
temp_q, fs = librosa.load(f_path, sr = None)
|
193 |
+
temp_q = temp_q[:, None]
|
194 |
+
temp_q = prepprocess_audio(
|
195 |
+
temp_q,
|
196 |
+
fs, config.sample_rate,
|
197 |
+
config.test_type
|
198 |
+
)
|
199 |
+
temp = [temp_q]
|
200 |
+
for dickey in config.test_key:
|
201 |
+
temp.append(temp_q)
|
202 |
+
temp = np.array(temp)
|
203 |
+
queries.append(temp)
|
204 |
+
|
205 |
+
assert config.resume_checkpoint is not None, "there should be a saved model when inferring"
|
206 |
+
|
207 |
+
sed_model = HTSAT_Swin_Transformer(
|
208 |
+
spec_size=htsat_config.htsat_spec_size,
|
209 |
+
patch_size=htsat_config.htsat_patch_size,
|
210 |
+
in_chans=1,
|
211 |
+
num_classes=htsat_config.classes_num,
|
212 |
+
window_size=htsat_config.htsat_window_size,
|
213 |
+
config = htsat_config,
|
214 |
+
depths = htsat_config.htsat_depth,
|
215 |
+
embed_dim = htsat_config.htsat_dim,
|
216 |
+
patch_stride=htsat_config.htsat_stride,
|
217 |
+
num_heads=htsat_config.htsat_num_head
|
218 |
+
)
|
219 |
+
at_model = SEDWrapper(
|
220 |
+
sed_model = sed_model,
|
221 |
+
config = htsat_config,
|
222 |
+
dataset = None
|
223 |
+
)
|
224 |
+
ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu")
|
225 |
+
at_model.load_state_dict(ckpt["state_dict"])
|
226 |
+
|
227 |
+
trainer = pl.Trainer(
|
228 |
+
gpus = 1
|
229 |
+
)
|
230 |
+
avg_at = None
|
231 |
+
# obtain the latent embedding as query
|
232 |
+
if config.infer_type == "mean":
|
233 |
+
avg_dataset = MusdbDataset(tracks = queries)
|
234 |
+
avg_loader = DataLoader(
|
235 |
+
dataset = avg_dataset,
|
236 |
+
num_workers = 1,
|
237 |
+
batch_size = 1,
|
238 |
+
shuffle = False
|
239 |
+
)
|
240 |
+
at_wrapper = AutoTaggingWarpper(
|
241 |
+
at_model = at_model,
|
242 |
+
config = config,
|
243 |
+
target_keys = config.test_key
|
244 |
+
)
|
245 |
+
trainer.test(at_wrapper, test_dataloaders = avg_loader)
|
246 |
+
avg_at = at_wrapper.avg_at
|
247 |
+
|
248 |
+
# import seapration model
|
249 |
+
model = ZeroShotASP(
|
250 |
+
channels = 1, config = config,
|
251 |
+
at_model = at_model,
|
252 |
+
dataset = dataset
|
253 |
+
)
|
254 |
+
# resume checkpoint
|
255 |
+
ckpt = torch.load(config.resume_checkpoint, map_location="cpu")
|
256 |
+
model.load_state_dict(ckpt["state_dict"], strict= False)
|
257 |
+
exp_model = SeparatorModel(
|
258 |
+
model = model,
|
259 |
+
config = config,
|
260 |
+
target_keys = config.test_key,
|
261 |
+
avg_at = avg_at,
|
262 |
+
using_wiener = False,
|
263 |
+
calc_sdr = False,
|
264 |
+
output_wav = True
|
265 |
+
)
|
266 |
+
trainer.test(exp_model, test_dataloaders = loader)
|
267 |
+
|
268 |
+
# test the separation model, mainly in musdb
|
269 |
+
def test():
|
270 |
+
# set exp settings
|
271 |
+
device_name = "cuda" if torch.cuda.is_available() else "cpu"
|
272 |
+
device = torch.device("cuda")
|
273 |
+
assert config.test_key is not None, "there should be a separate key"
|
274 |
+
create_folder(config.wave_output_path)
|
275 |
+
# use musdb as testset
|
276 |
+
test_data = np.load(config.testset_path, allow_pickle = True)
|
277 |
+
print(len(test_data))
|
278 |
+
mus_tracks = []
|
279 |
+
# in musdb, all fs is the same (44100)
|
280 |
+
# load the dataset
|
281 |
+
for track in test_data:
|
282 |
+
temp = []
|
283 |
+
mixture = track["mixture"]
|
284 |
+
temp.append(mixture)
|
285 |
+
for dickey in config.test_key:
|
286 |
+
source = track[dickey]
|
287 |
+
temp.append(source)
|
288 |
+
temp = np.array(temp)
|
289 |
+
print(temp.shape)
|
290 |
+
mus_tracks.append(temp)
|
291 |
+
print(len(mus_tracks))
|
292 |
+
dataset = MusdbDataset(tracks = mus_tracks)
|
293 |
+
loader = DataLoader(
|
294 |
+
dataset = dataset,
|
295 |
+
num_workers = 1,
|
296 |
+
batch_size = 1,
|
297 |
+
shuffle = False
|
298 |
+
)
|
299 |
+
assert config.resume_checkpoint is not None, "there should be a saved model when inferring"
|
300 |
+
|
301 |
+
sed_model = HTSAT_Swin_Transformer(
|
302 |
+
spec_size=htsat_config.htsat_spec_size,
|
303 |
+
patch_size=htsat_config.htsat_patch_size,
|
304 |
+
in_chans=1,
|
305 |
+
num_classes=htsat_config.classes_num,
|
306 |
+
window_size=htsat_config.htsat_window_size,
|
307 |
+
config = htsat_config,
|
308 |
+
depths = htsat_config.htsat_depth,
|
309 |
+
embed_dim = htsat_config.htsat_dim,
|
310 |
+
patch_stride=htsat_config.htsat_stride,
|
311 |
+
num_heads=htsat_config.htsat_num_head
|
312 |
+
)
|
313 |
+
at_model = SEDWrapper(
|
314 |
+
sed_model = sed_model,
|
315 |
+
config = htsat_config,
|
316 |
+
dataset = None
|
317 |
+
)
|
318 |
+
ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu")
|
319 |
+
at_model.load_state_dict(ckpt["state_dict"])
|
320 |
+
trainer = pl.Trainer(
|
321 |
+
gpus = 1
|
322 |
+
)
|
323 |
+
avg_at = None
|
324 |
+
# obtain the query of four stems from the training set
|
325 |
+
if config.infer_type == "mean":
|
326 |
+
avg_data = np.load(config.testavg_path, allow_pickle = True)[:90]
|
327 |
+
print(len(avg_data))
|
328 |
+
avgmus_tracks = []
|
329 |
+
# in musdb, all fs is the same (44100)
|
330 |
+
# load the dataset
|
331 |
+
for track in avg_data:
|
332 |
+
temp = []
|
333 |
+
mixture = track["mixture"]
|
334 |
+
temp.append(mixture)
|
335 |
+
for dickey in config.test_key:
|
336 |
+
source = track[dickey]
|
337 |
+
temp.append(source)
|
338 |
+
temp = np.array(temp)
|
339 |
+
print(temp.shape)
|
340 |
+
avgmus_tracks.append(temp)
|
341 |
+
print(len(avgmus_tracks))
|
342 |
+
avg_dataset = MusdbDataset(tracks = avgmus_tracks)
|
343 |
+
avg_loader = DataLoader(
|
344 |
+
dataset = avg_dataset,
|
345 |
+
num_workers = 1,
|
346 |
+
batch_size = 1,
|
347 |
+
shuffle = False
|
348 |
+
)
|
349 |
+
at_wrapper = AutoTaggingWarpper(
|
350 |
+
at_model = at_model,
|
351 |
+
config = config,
|
352 |
+
target_keys = config.test_key
|
353 |
+
)
|
354 |
+
trainer.test(at_wrapper, test_dataloaders = avg_loader)
|
355 |
+
avg_at = at_wrapper.avg_at
|
356 |
+
|
357 |
+
model = ZeroShotASP(
|
358 |
+
channels = 1, config = config,
|
359 |
+
at_model = at_model,
|
360 |
+
dataset = dataset
|
361 |
+
)
|
362 |
+
ckpt = torch.load(config.resume_checkpoint, map_location="cpu")
|
363 |
+
model.load_state_dict(ckpt["state_dict"], strict= False)
|
364 |
+
exp_model = SeparatorModel(
|
365 |
+
model = model,
|
366 |
+
config = config,
|
367 |
+
target_keys = config.test_key,
|
368 |
+
avg_at = avg_at,
|
369 |
+
using_wiener = config.using_wiener
|
370 |
+
)
|
371 |
+
trainer.test(exp_model, test_dataloaders = loader)
|
372 |
+
|
373 |
+
def train():
|
374 |
+
# set exp settings
|
375 |
+
# device_name = "cuda" if torch.cuda.is_available() else "cpu"
|
376 |
+
# device = torch.device("cuda")
|
377 |
+
|
378 |
+
device_num = torch.cuda.device_count()
|
379 |
+
print("each batch size:", config.batch_size // device_num)
|
380 |
+
|
381 |
+
train_index_path = os.path.join(config.dataset_path, "hdf5s","indexes", config.index_type + ".h5")
|
382 |
+
train_idc = np.load(os.path.join(config.idc_path, config.index_type + "_idc.npy"), allow_pickle = True)
|
383 |
+
|
384 |
+
eval_index_path = os.path.join(config.dataset_path,"hdf5s", "indexes", "eval.h5")
|
385 |
+
eval_idc = np.load(os.path.join(config.idc_path, "eval_idc.npy"), allow_pickle = True)
|
386 |
+
|
387 |
+
# set exp folder
|
388 |
+
exp_dir = os.path.join(config.workspace, "results", config.exp_name)
|
389 |
+
checkpoint_dir = os.path.join(config.workspace, "results", config.exp_name, "checkpoint")
|
390 |
+
|
391 |
+
if not config.debug:
|
392 |
+
create_folder(os.path.join(config.workspace, "results"))
|
393 |
+
create_folder(exp_dir)
|
394 |
+
create_folder(checkpoint_dir)
|
395 |
+
dump_config(config, os.path.join(exp_dir, config.exp_name), False)
|
396 |
+
|
397 |
+
# load data
|
398 |
+
# import dataset LGSPDataset (latent general source separation) and sampler
|
399 |
+
dataset = LGSPDataset(
|
400 |
+
index_path = train_index_path,
|
401 |
+
idc = train_idc,
|
402 |
+
config = config,
|
403 |
+
factor = 0.05,
|
404 |
+
eval_mode = False
|
405 |
+
)
|
406 |
+
eval_dataset = LGSPDataset(
|
407 |
+
index_path = eval_index_path,
|
408 |
+
idc = eval_idc,
|
409 |
+
config = config,
|
410 |
+
factor = 0.05,
|
411 |
+
eval_mode = True
|
412 |
+
)
|
413 |
+
|
414 |
+
audioset_data = data_prep(train_dataset=dataset,eval_dataset=eval_dataset,device_num=device_num, config=config)
|
415 |
+
checkpoint_callback = ModelCheckpoint(
|
416 |
+
monitor = "mixture_sdr",
|
417 |
+
filename='l-{epoch:d}-{mixture_sdr:.3f}-{clean_sdr:.3f}-{silence_sdr:.3f}',
|
418 |
+
save_top_k = 10,
|
419 |
+
mode = "max"
|
420 |
+
)
|
421 |
+
# infer at model
|
422 |
+
sed_model = HTSAT_Swin_Transformer(
|
423 |
+
spec_size=htsat_config.htsat_spec_size,
|
424 |
+
patch_size=htsat_config.htsat_patch_size,
|
425 |
+
in_chans=1,
|
426 |
+
num_classes=htsat_config.classes_num,
|
427 |
+
window_size=htsat_config.htsat_window_size,
|
428 |
+
config = htsat_config,
|
429 |
+
depths = htsat_config.htsat_depth,
|
430 |
+
embed_dim = htsat_config.htsat_dim,
|
431 |
+
patch_stride=htsat_config.htsat_stride,
|
432 |
+
num_heads=htsat_config.htsat_num_head
|
433 |
+
)
|
434 |
+
at_model = SEDWrapper(
|
435 |
+
sed_model = sed_model,
|
436 |
+
config = htsat_config,
|
437 |
+
dataset = None
|
438 |
+
)
|
439 |
+
# load the checkpoint
|
440 |
+
ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu")
|
441 |
+
at_model.load_state_dict(ckpt["state_dict"])
|
442 |
+
|
443 |
+
trainer = pl.Trainer(
|
444 |
+
deterministic=True,
|
445 |
+
default_root_dir = checkpoint_dir,
|
446 |
+
gpus = device_num,
|
447 |
+
val_check_interval = 0.2,
|
448 |
+
# check_val_every_n_epoch = 1,
|
449 |
+
max_epochs = config.max_epoch,
|
450 |
+
auto_lr_find = True,
|
451 |
+
sync_batchnorm = True,
|
452 |
+
callbacks = [checkpoint_callback],
|
453 |
+
accelerator = "ddp" if device_num > 1 else None,
|
454 |
+
resume_from_checkpoint = None, #config.resume_checkpoint,
|
455 |
+
replace_sampler_ddp = False,
|
456 |
+
gradient_clip_val=1.0,
|
457 |
+
num_sanity_val_steps = 0,
|
458 |
+
)
|
459 |
+
model = ZeroShotASP(
|
460 |
+
channels = 1, config = config,
|
461 |
+
at_model = at_model,
|
462 |
+
dataset = dataset
|
463 |
+
)
|
464 |
+
if config.resume_checkpoint is not None:
|
465 |
+
ckpt = torch.load(config.resume_checkpoint, map_location="cpu")
|
466 |
+
model.load_state_dict(ckpt["state_dict"])
|
467 |
+
# trainer.test(model, datamodule = audioset_data)
|
468 |
+
trainer.fit(model, audioset_data)
|
469 |
+
|
470 |
+
def main():
|
471 |
+
parser = argparse.ArgumentParser(description="latent genreal source separation parser")
|
472 |
+
subparsers = parser.add_subparsers(dest = "mode")
|
473 |
+
parser_train = subparsers.add_parser("train")
|
474 |
+
parser_test = subparsers.add_parser("test")
|
475 |
+
parser_musdb = subparsers.add_parser("musdb_process")
|
476 |
+
parser_saveidc = subparsers.add_parser("save_idc")
|
477 |
+
parser_wa = subparsers.add_parser("weight_average")
|
478 |
+
parser_infer = subparsers.add_parser("inference")
|
479 |
+
args = parser.parse_args()
|
480 |
+
# default settings
|
481 |
+
logging.basicConfig(level=logging.INFO)
|
482 |
+
pl.utilities.seed.seed_everything(seed = config.random_seed)
|
483 |
+
|
484 |
+
if args.mode == "train":
|
485 |
+
train()
|
486 |
+
elif args.mode == "test":
|
487 |
+
test()
|
488 |
+
elif args.mode == "musdb_process":
|
489 |
+
process_musdb()
|
490 |
+
elif args.mode == "weight_average":
|
491 |
+
weight_average()
|
492 |
+
elif args.mode == "save_idc":
|
493 |
+
save_idc()
|
494 |
+
elif args.mode == "inference":
|
495 |
+
inference()
|
496 |
+
else:
|
497 |
+
raise Exception("Error Mode!")
|
498 |
+
|
499 |
+
|
500 |
+
if __name__ == '__main__':
|
501 |
+
main()
|