File size: 1,787 Bytes
0102e16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import logging

import kaldiio
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset


def custom_collate(batch):
    keys, speech, speaker_labels, orders = zip(*batch)
    speech = [torch.from_numpy(np.copy(sph)).to(torch.float32) for sph in speech]
    speaker_labels = [
        torch.from_numpy(np.copy(spk)).to(torch.float32) for spk in speaker_labels
    ]
    orders = [torch.from_numpy(np.copy(o)).to(torch.int64) for o in orders]
    batch = dict(speech=speech, speaker_labels=speaker_labels, orders=orders)

    return keys, batch


class EENDOLADataset(Dataset):
    def __init__(
        self,
        data_file,
    ):
        self.data_file = data_file
        with open(data_file) as f:
            lines = f.readlines()
        self.samples = [line.strip().split() for line in lines]
        logging.info("total samples: {}".format(len(self.samples)))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        key, speech_path, speaker_label_path = self.samples[idx]
        speech = kaldiio.load_mat(speech_path)
        speaker_label = kaldiio.load_mat(speaker_label_path).reshape(
            speech.shape[0], -1
        )

        order = np.arange(speech.shape[0])
        np.random.shuffle(order)

        return key, speech, speaker_label, order


class EENDOLADataLoader:
    def __init__(self, data_file, batch_size, shuffle=True, num_workers=8):
        dataset = EENDOLADataset(data_file)
        self.data_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            collate_fn=custom_collate,
            shuffle=shuffle,
            num_workers=num_workers,
        )

    def build_iter(self, epoch):
        return self.data_loader