Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import pickle | |
from typing import NoReturn | |
import h5py | |
from bytesep.utils import read_yaml | |
def create_indexes(args) -> NoReturn: | |
r"""Create and write out training indexes into disk. The indexes may contain | |
information from multiple datasets. During training, training indexes will | |
be shuffled and iterated for selecting segments to be mixed. E.g., the | |
training indexes_dict looks like: { | |
'vocals': [ | |
{'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300} | |
{'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 136710} | |
... | |
] | |
'accompaniment': [ | |
{'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 0, 'end_sample': 132300} | |
{'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 4410, 'end_sample': 136710} | |
... | |
] | |
} | |
""" | |
# Arugments & parameters | |
workspace = args.workspace | |
config_yaml = args.config_yaml | |
# Only create indexes for training, because evalution is on entire pieces. | |
split = "train" | |
# Read config file. | |
configs = read_yaml(config_yaml) | |
sample_rate = configs["sample_rate"] | |
segment_samples = int(configs["segment_seconds"] * sample_rate) | |
# Path to write out index. | |
indexes_path = os.path.join(workspace, configs[split]["indexes"]) | |
os.makedirs(os.path.dirname(indexes_path), exist_ok=True) | |
source_types = configs[split]["source_types"].keys() | |
# E.g., ['vocals', 'accompaniment'] | |
indexes_dict = {source_type: [] for source_type in source_types} | |
# E.g., indexes_dict will looks like: { | |
# 'vocals': [ | |
# {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300} | |
# {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 136710} | |
# ... | |
# ] | |
# 'accompaniment': [ | |
# {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 0, 'end_sample': 132300} | |
# {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 4410, 'end_sample': 136710} | |
# ... | |
# ] | |
# } | |
# Get training indexes for each source type. | |
for source_type in source_types: | |
# E.g., ['vocals', 'bass', ...] | |
print("--- {} ---".format(source_type)) | |
dataset_types = configs[split]["source_types"][source_type] | |
# E.g., ['musdb18', ...] | |
# Each source can come from mulitple datasets. | |
for dataset_type in dataset_types: | |
hdf5s_dir = os.path.join( | |
workspace, dataset_types[dataset_type]["hdf5s_directory"] | |
) | |
hop_samples = int(dataset_types[dataset_type]["hop_seconds"] * sample_rate) | |
key_in_hdf5 = dataset_types[dataset_type]["key_in_hdf5"] | |
# E.g., 'vocals' | |
hdf5_names = sorted(os.listdir(hdf5s_dir)) | |
print("Hdf5 files num: {}".format(len(hdf5_names))) | |
# Traverse all packed hdf5 files of a dataset. | |
for n, hdf5_name in enumerate(hdf5_names): | |
print(n, hdf5_name) | |
hdf5_path = os.path.join(hdf5s_dir, hdf5_name) | |
with h5py.File(hdf5_path, "r") as hf: | |
bgn_sample = 0 | |
while bgn_sample + segment_samples < hf[key_in_hdf5].shape[-1]: | |
meta = { | |
'hdf5_path': hdf5_path, | |
'key_in_hdf5': key_in_hdf5, | |
'begin_sample': bgn_sample, | |
'end_sample': bgn_sample + segment_samples, | |
} | |
indexes_dict[source_type].append(meta) | |
bgn_sample += hop_samples | |
# If the audio length is shorter than the segment length, | |
# then use the entire audio as a segment. | |
if bgn_sample == 0: | |
meta = { | |
'hdf5_path': hdf5_path, | |
'key_in_hdf5': key_in_hdf5, | |
'begin_sample': 0, | |
'end_sample': segment_samples, | |
} | |
indexes_dict[source_type].append(meta) | |
print( | |
"Total indexes for {}: {}".format( | |
source_type, len(indexes_dict[source_type]) | |
) | |
) | |
pickle.dump(indexes_dict, open(indexes_path, "wb")) | |
print("Write index dict to {}".format(indexes_path)) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--workspace", type=str, required=True, help="Directory of workspace." | |
) | |
parser.add_argument( | |
"--config_yaml", type=str, required=True, help="User defined config file." | |
) | |
# Parse arguments. | |
args = parser.parse_args() | |
# Create training indexes. | |
create_indexes(args) | |