maskgct-audio-lab / models /tta /ldm /audioldm_dataset.py
Hecheng0625's picture
Upload 409 files
c968fc3 verified
raw
history blame
4.71 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import random
import torch
from torch.nn.utils.rnn import pad_sequence
from utils.data_utils import *
from models.base.base_dataset import (
BaseOfflineCollator,
BaseOfflineDataset,
BaseTestDataset,
BaseTestCollator,
)
import librosa
from transformers import AutoTokenizer
class AudioLDMDataset(BaseOfflineDataset):
def __init__(self, cfg, dataset, is_valid=False):
BaseOfflineDataset.__init__(self, cfg, dataset, is_valid=is_valid)
self.cfg = cfg
# utt2melspec
if cfg.preprocess.use_melspec:
self.utt2melspec_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2melspec_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.melspec_dir,
uid + ".npy",
)
# utt2wav
if cfg.preprocess.use_wav:
self.utt2wav_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2wav_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.wav_dir,
uid + ".wav",
)
# utt2caption
if cfg.preprocess.use_caption:
self.utt2caption = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2caption[utt] = utt_info["Caption"]
def __getitem__(self, index):
# melspec: (n_mels, T)
# wav: (T,)
single_feature = BaseOfflineDataset.__getitem__(self, index)
utt_info = self.metadata[index]
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
if self.cfg.preprocess.use_melspec:
single_feature["melspec"] = np.load(self.utt2melspec_path[utt])
if self.cfg.preprocess.use_wav:
wav, sr = librosa.load(
self.utt2wav_path[utt], sr=16000
) # hard coding for 16KHz...
single_feature["wav"] = wav
if self.cfg.preprocess.use_caption:
cond_mask = np.random.choice(
[1, 0],
p=[
self.cfg.preprocess.cond_mask_prob,
1 - self.cfg.preprocess.cond_mask_prob,
],
) # (0.1, 0.9)
if cond_mask:
single_feature["caption"] = ""
else:
single_feature["caption"] = self.utt2caption[utt]
return single_feature
def __len__(self):
return len(self.metadata)
class AudioLDMCollator(BaseOfflineCollator):
def __init__(self, cfg):
BaseOfflineCollator.__init__(self, cfg)
self.tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
def __call__(self, batch):
# mel: (B, n_mels, T)
# wav (option): (B, T)
# text_input_ids: (B, L)
# text_attention_mask: (B, L)
packed_batch_features = dict()
for key in batch[0].keys():
if key == "melspec":
packed_batch_features["melspec"] = torch.from_numpy(
np.array([b["melspec"][:, :624] for b in batch])
)
if key == "wav":
values = [torch.from_numpy(b[key]) for b in batch]
packed_batch_features[key] = pad_sequence(
values, batch_first=True, padding_value=0
)
if key == "caption":
captions = [b[key] for b in batch]
text_input = self.tokenizer(
captions, return_tensors="pt", truncation=True, padding="longest"
)
text_input_ids = text_input["input_ids"]
text_attention_mask = text_input["attention_mask"]
packed_batch_features["text_input_ids"] = text_input_ids
packed_batch_features["text_attention_mask"] = text_attention_mask
return packed_batch_features
class AudioLDMTestDataset(BaseTestDataset): ...
class AudioLDMTestCollator(BaseTestCollator): ...