Spaces:
Runtime error
Runtime error
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import random | |
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
from utils.data_utils import * | |
from tqdm import tqdm | |
from g2p_en import G2p | |
import librosa | |
from torch.utils.data import Dataset | |
import pandas as pd | |
import time | |
import io | |
SAMPLE_RATE = 16000 | |
# g2p | |
from .g2p_processor import G2pProcessor | |
phonemizer_g2p = G2pProcessor() | |
class VALLEDataset(Dataset): | |
def __init__(self, args): | |
print(f"Initializing VALLEDataset") | |
self.dataset_list = args.dataset_list | |
print(f"using sampling rate {SAMPLE_RATE}") | |
# set dataframe clumn name | |
book_col_name = [ | |
"ID", | |
"Original_text", | |
"Normalized_text", | |
"Aligned_or_not", | |
"Start_time", | |
"End_time", | |
"Signal_to_noise_ratio", | |
] | |
trans_col_name = [ | |
"ID", | |
"Original_text", | |
"Normalized_text", | |
"Dir_path", | |
"Duration", | |
] | |
self.metadata_cache = pd.DataFrame(columns=book_col_name) | |
self.trans_cache = pd.DataFrame(columns=trans_col_name) | |
# dataset_cache_dir = args.cache_dir # cache_dir | |
# print(f"args.cache_dir = ", args.cache_dir) | |
# os.makedirs(dataset_cache_dir, exist_ok=True) | |
######## add data dir to dataset2dir ########## | |
self.dataset2dir = { | |
"dev-clean": f"{args.data_dir}/dev-clean", | |
"dev-other": f"{args.data_dir}/dev-other", | |
"test-clean": f"{args.data_dir}/test-clean", | |
"test-other": f"{args.data_dir}/test-other", | |
"train-clean-100": f"{args.data_dir}/train-clean-100", | |
"train-clean-360": f"{args.data_dir}/train-clean-360", | |
"train-other-500": f"{args.data_dir}/train-other-500", | |
} | |
###### load metadata and transcripts ##### | |
for dataset_name in self.dataset_list: | |
print("Initializing dataset: ", dataset_name) | |
# get [book,transcripts,audio] files list | |
self.book_files_list = self.get_metadata_files( | |
self.dataset2dir[dataset_name] | |
) | |
self.trans_files_list = self.get_trans_files(self.dataset2dir[dataset_name]) | |
## create metadata_cache (book.tsv file is not filtered, some file is not exist, but contain Duration and Signal_to_noise_ratio) | |
print("reading paths for dataset...") | |
for book_path in tqdm(self.book_files_list): | |
tmp_cache = pd.read_csv( | |
book_path, sep="\t", names=book_col_name, quoting=3 | |
) | |
self.metadata_cache = pd.concat( | |
[self.metadata_cache, tmp_cache], ignore_index=True | |
) | |
self.metadata_cache.set_index("ID", inplace=True) | |
## create transcripts (the trans.tsv file) | |
print("creating transcripts for dataset...") | |
for trans_path in tqdm(self.trans_files_list): | |
tmp_cache = pd.read_csv( | |
trans_path, sep="\t", names=trans_col_name, quoting=3 | |
) | |
tmp_cache["Dir_path"] = os.path.dirname(trans_path) | |
self.trans_cache = pd.concat( | |
[self.trans_cache, tmp_cache], ignore_index=True | |
) | |
self.trans_cache.set_index("ID", inplace=True) | |
## calc duration | |
self.trans_cache["Duration"] = ( | |
self.metadata_cache.End_time[self.trans_cache.index] | |
- self.metadata_cache.Start_time[self.trans_cache.index] | |
) | |
## add fullpath | |
# self.trans_cache['Full_path'] = os.path.join(self.dataset2dir[dataset_name],self.trans_cache['ID']) | |
# filter_by_duration: filter_out files with duration < 3.0 or > 15.0 | |
print(f"Filtering files with duration between 3.0 and 15.0 seconds") | |
print(f"Before filtering: {len(self.trans_cache)}") | |
self.trans_cache = self.trans_cache[ | |
(self.trans_cache["Duration"] >= 3.0) | |
& (self.trans_cache["Duration"] <= 15.0) | |
] | |
print(f"After filtering: {len(self.trans_cache)}") | |
def get_metadata_files(self, directory): | |
book_files = [] | |
for root, _, files in os.walk(directory): | |
for file in files: | |
if file.endswith(".book.tsv") and file[0] != ".": | |
rel_path = os.path.join(root, file) | |
book_files.append(rel_path) | |
return book_files | |
def get_trans_files(self, directory): | |
trans_files = [] | |
for root, _, files in os.walk(directory): | |
for file in files: | |
if file.endswith(".trans.tsv") and file[0] != ".": | |
rel_path = os.path.join(root, file) | |
trans_files.append(rel_path) | |
return trans_files | |
def get_audio_files(self, directory): | |
audio_files = [] | |
for root, _, files in os.walk(directory): | |
for file in files: | |
if file.endswith((".flac", ".wav", ".opus")): | |
rel_path = os.path.relpath(os.path.join(root, file), directory) | |
audio_files.append(rel_path) | |
return audio_files | |
def get_num_frames(self, index): | |
# get_num_frames(durations) by index | |
duration = self.meta_data_cache["Duration"][index] | |
# num_frames = duration * SAMPLE_RATE | |
num_frames = int(duration * 75) | |
# file_rel_path = self.meta_data_cache['relpath'][index] | |
# uid = file_rel_path.rstrip('.flac').split('/')[-1] | |
# num_frames += len(self.transcripts[uid]) | |
return num_frames | |
def __len__(self): | |
return len(self.trans_cache) | |
def __getitem__(self, idx): | |
# Get the file rel path | |
file_dir_path = self.trans_cache["Dir_path"].iloc[idx] | |
# Get uid | |
uid = self.trans_cache.index[idx] | |
# Get the file name from cache uid | |
file_name = uid + ".wav" | |
# Get the full file path | |
full_file_path = os.path.join(file_dir_path, file_name) | |
# get phone | |
phone = self.trans_cache["Normalized_text"][uid] | |
phone = phonemizer_g2p(phone, "en")[1] | |
# load speech | |
speech, _ = librosa.load(full_file_path, sr=SAMPLE_RATE) | |
# if self.resample_to_24k: | |
# speech = librosa.resample(speech, orig_sr=SAMPLE_RATE, target_sr=24000) | |
# speech = torch.tensor(speech, dtype=torch.float32) | |
# pad speech to multiples of 200 | |
# remainder = speech.size(0) % 200 | |
# if remainder > 0: | |
# pad = 200 - remainder | |
# speech = torch.cat([speech, torch.zeros(pad, dtype=torch.float32)], dim=0) | |
# inputs = self._get_reference_vc(speech, hop_length=200) | |
inputs = {} | |
# Get the speaker id | |
# speaker = self.meta_data_cache['speaker'][idx] | |
# speaker_id = self.speaker2id[speaker] | |
# inputs["speaker_id"] = speaker_id | |
inputs["speech"] = speech # 24khz speech, [T] | |
inputs["phone"] = phone # [T] | |
return inputs | |
def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): | |
if len(batch) == 0: | |
return 0 | |
if len(batch) == max_sentences: | |
return 1 | |
if num_tokens > max_tokens: | |
return 1 | |
return 0 | |
def batch_by_size( | |
indices, | |
num_tokens_fn, | |
max_tokens=None, | |
max_sentences=None, | |
required_batch_size_multiple=1, | |
): | |
""" | |
Yield mini-batches of indices bucketed by size. Batches may contain | |
sequences of different lengths. | |
Args: | |
indices (List[int]): ordered list of dataset indices | |
num_tokens_fn (callable): function that returns the number of tokens at | |
a given index | |
max_tokens (int, optional): max number of tokens in each batch | |
(default: None). | |
max_sentences (int, optional): max number of sentences in each | |
batch (default: None). | |
required_batch_size_multiple (int, optional): require batch size to | |
be a multiple of N (default: 1). | |
""" | |
bsz_mult = required_batch_size_multiple | |
sample_len = 0 | |
sample_lens = [] | |
batch = [] | |
batches = [] | |
for i in range(len(indices)): | |
idx = indices[i] | |
num_tokens = num_tokens_fn(idx) | |
sample_lens.append(num_tokens) | |
sample_len = max(sample_len, num_tokens) | |
assert ( | |
sample_len <= max_tokens | |
), "sentence at index {} of size {} exceeds max_tokens " "limit of {}!".format( | |
idx, sample_len, max_tokens | |
) | |
num_tokens = (len(batch) + 1) * sample_len | |
if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): | |
mod_len = max( | |
bsz_mult * (len(batch) // bsz_mult), | |
len(batch) % bsz_mult, | |
) | |
batches.append(batch[:mod_len]) | |
batch = batch[mod_len:] | |
sample_lens = sample_lens[mod_len:] | |
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 | |
batch.append(idx) | |
if len(batch) > 0: | |
batches.append(batch) | |
return batches | |
def test(): | |
from utils.util import load_config | |
cfg = load_config("./egs/tts/VALLE_V2/exp_ar_libritts.json") | |
dataset = VALLEDataset(cfg.dataset) | |
metadata_cache = dataset.metadata_cache | |
trans_cache = dataset.trans_cache | |
print(trans_cache.head(10)) | |
# print(dataset.book_files_list) | |
breakpoint() | |
if __name__ == "__main__": | |
test() | |