Spaces:
Runtime error
Runtime error
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import random | |
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
from utils.data_utils import * | |
from models.base.base_dataset import ( | |
BaseOfflineCollator, | |
BaseOfflineDataset, | |
BaseTestDataset, | |
BaseTestCollator, | |
) | |
import librosa | |
class AutoencoderKLDataset(BaseOfflineDataset): | |
def __init__(self, cfg, dataset, is_valid=False): | |
BaseOfflineDataset.__init__(self, cfg, dataset, is_valid=is_valid) | |
cfg = self.cfg | |
# utt2melspec | |
if cfg.preprocess.use_melspec: | |
self.utt2melspec_path = {} | |
for utt_info in self.metadata: | |
dataset = utt_info["Dataset"] | |
uid = utt_info["Uid"] | |
utt = "{}_{}".format(dataset, uid) | |
self.utt2melspec_path[utt] = os.path.join( | |
cfg.preprocess.processed_dir, | |
dataset, | |
cfg.preprocess.melspec_dir, | |
uid + ".npy", | |
) | |
# utt2wav | |
if cfg.preprocess.use_wav: | |
self.utt2wav_path = {} | |
for utt_info in self.metadata: | |
dataset = utt_info["Dataset"] | |
uid = utt_info["Uid"] | |
utt = "{}_{}".format(dataset, uid) | |
self.utt2wav_path[utt] = os.path.join( | |
cfg.preprocess.processed_dir, | |
dataset, | |
cfg.preprocess.wav_dir, | |
uid + ".wav", | |
) | |
def __getitem__(self, index): | |
# melspec: (n_mels, T) | |
# wav: (T,) | |
single_feature = BaseOfflineDataset.__getitem__(self, index) | |
utt_info = self.metadata[index] | |
dataset = utt_info["Dataset"] | |
uid = utt_info["Uid"] | |
utt = "{}_{}".format(dataset, uid) | |
if self.cfg.preprocess.use_melspec: | |
single_feature["melspec"] = np.load(self.utt2melspec_path[utt]) | |
if self.cfg.preprocess.use_wav: | |
wav, sr = librosa.load( | |
self.utt2wav_path[utt], sr=16000 | |
) # hard coding for 16KHz... | |
single_feature["wav"] = wav | |
return single_feature | |
def __len__(self): | |
return len(self.metadata) | |
def __len__(self): | |
return len(self.metadata) | |
class AutoencoderKLCollator(BaseOfflineCollator): | |
def __init__(self, cfg): | |
BaseOfflineCollator.__init__(self, cfg) | |
def __call__(self, batch): | |
# mel: (B, n_mels, T) | |
# wav (option): (B, T) | |
packed_batch_features = dict() | |
for key in batch[0].keys(): | |
if key == "melspec": | |
packed_batch_features["melspec"] = torch.from_numpy( | |
np.array([b["melspec"][:, :624] for b in batch]) | |
) | |
if key == "wav": | |
values = [torch.from_numpy(b[key]) for b in batch] | |
packed_batch_features[key] = pad_sequence( | |
values, batch_first=True, padding_value=0 | |
) | |
return packed_batch_features | |
class AutoencoderKLTestDataset(BaseTestDataset): ... | |
class AutoencoderKLTestCollator(BaseTestCollator): ... | |