Spaces:
Runtime error
Runtime error
import logging | |
import kaldiio | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader | |
from torch.utils.data import Dataset | |
def custom_collate(batch): | |
keys, speech, speaker_labels, orders = zip(*batch) | |
speech = [torch.from_numpy(np.copy(sph)).to(torch.float32) for sph in speech] | |
speaker_labels = [ | |
torch.from_numpy(np.copy(spk)).to(torch.float32) for spk in speaker_labels | |
] | |
orders = [torch.from_numpy(np.copy(o)).to(torch.int64) for o in orders] | |
batch = dict(speech=speech, speaker_labels=speaker_labels, orders=orders) | |
return keys, batch | |
class EENDOLADataset(Dataset): | |
def __init__( | |
self, | |
data_file, | |
): | |
self.data_file = data_file | |
with open(data_file) as f: | |
lines = f.readlines() | |
self.samples = [line.strip().split() for line in lines] | |
logging.info("total samples: {}".format(len(self.samples))) | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, idx): | |
key, speech_path, speaker_label_path = self.samples[idx] | |
speech = kaldiio.load_mat(speech_path) | |
speaker_label = kaldiio.load_mat(speaker_label_path).reshape( | |
speech.shape[0], -1 | |
) | |
order = np.arange(speech.shape[0]) | |
np.random.shuffle(order) | |
return key, speech, speaker_label, order | |
class EENDOLADataLoader: | |
def __init__(self, data_file, batch_size, shuffle=True, num_workers=8): | |
dataset = EENDOLADataset(data_file) | |
self.data_loader = DataLoader( | |
dataset, | |
batch_size=batch_size, | |
collate_fn=custom_collate, | |
shuffle=shuffle, | |
num_workers=num_workers, | |
) | |
def build_iter(self, epoch): | |
return self.data_loader | |