KingNish commited on
Commit
f0bddf6
·
verified ·
1 Parent(s): c09ee2a

Upload ./vocos/dataset.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vocos/dataset.py +73 -0
vocos/dataset.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchaudio
6
+ from pytorch_lightning import LightningDataModule
7
+ from torch.utils.data import Dataset, DataLoader
8
+
9
+ torch.set_num_threads(1)
10
+
11
+
12
+ @dataclass
13
+ class DataConfig:
14
+ filelist_path: str
15
+ sampling_rate: int
16
+ num_samples: int
17
+ batch_size: int
18
+ num_workers: int
19
+
20
+
21
+ class VocosDataModule(LightningDataModule):
22
+ def __init__(self, train_params: DataConfig, val_params: DataConfig):
23
+ super().__init__()
24
+ self.train_config = train_params
25
+ self.val_config = val_params
26
+
27
+ def _get_dataloder(self, cfg: DataConfig, train: bool):
28
+ dataset = VocosDataset(cfg, train=train)
29
+ dataloader = DataLoader(
30
+ dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=train, pin_memory=True,
31
+ )
32
+ return dataloader
33
+
34
+ def train_dataloader(self) -> DataLoader:
35
+ return self._get_dataloder(self.train_config, train=True)
36
+
37
+ def val_dataloader(self) -> DataLoader:
38
+ return self._get_dataloder(self.val_config, train=False)
39
+
40
+
41
+ class VocosDataset(Dataset):
42
+ def __init__(self, cfg: DataConfig, train: bool):
43
+ with open(cfg.filelist_path) as f:
44
+ self.filelist = f.read().splitlines()
45
+ self.sampling_rate = cfg.sampling_rate
46
+ self.num_samples = cfg.num_samples
47
+ self.train = train
48
+
49
+ def __len__(self) -> int:
50
+ return len(self.filelist)
51
+
52
+ def __getitem__(self, index: int) -> torch.Tensor:
53
+ audio_path = self.filelist[index]
54
+ y, sr = torchaudio.load(audio_path)
55
+ if y.size(0) > 1:
56
+ # mix to mono
57
+ y = y.mean(dim=0, keepdim=True)
58
+ gain = np.random.uniform(-1, -6) if self.train else -3
59
+ y, _ = torchaudio.sox_effects.apply_effects_tensor(y, sr, [["norm", f"{gain:.2f}"]])
60
+ if sr != self.sampling_rate:
61
+ y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate)
62
+ if y.size(-1) < self.num_samples:
63
+ pad_length = self.num_samples - y.size(-1)
64
+ padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
65
+ y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
66
+ elif self.train:
67
+ start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
68
+ y = y[:, start : start + self.num_samples]
69
+ else:
70
+ # During validation, take always the first segment for determinism
71
+ y = y[:, : self.num_samples]
72
+
73
+ return y[0]