Spaces:
Running
on
Zero
Running
on
Zero
Upload ./vocos/dataset.py with huggingface_hub
Browse files- vocos/dataset.py +73 -0
vocos/dataset.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
from pytorch_lightning import LightningDataModule
|
7 |
+
from torch.utils.data import Dataset, DataLoader
|
8 |
+
|
9 |
+
torch.set_num_threads(1)
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class DataConfig:
|
14 |
+
filelist_path: str
|
15 |
+
sampling_rate: int
|
16 |
+
num_samples: int
|
17 |
+
batch_size: int
|
18 |
+
num_workers: int
|
19 |
+
|
20 |
+
|
21 |
+
class VocosDataModule(LightningDataModule):
|
22 |
+
def __init__(self, train_params: DataConfig, val_params: DataConfig):
|
23 |
+
super().__init__()
|
24 |
+
self.train_config = train_params
|
25 |
+
self.val_config = val_params
|
26 |
+
|
27 |
+
def _get_dataloder(self, cfg: DataConfig, train: bool):
|
28 |
+
dataset = VocosDataset(cfg, train=train)
|
29 |
+
dataloader = DataLoader(
|
30 |
+
dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=train, pin_memory=True,
|
31 |
+
)
|
32 |
+
return dataloader
|
33 |
+
|
34 |
+
def train_dataloader(self) -> DataLoader:
|
35 |
+
return self._get_dataloder(self.train_config, train=True)
|
36 |
+
|
37 |
+
def val_dataloader(self) -> DataLoader:
|
38 |
+
return self._get_dataloder(self.val_config, train=False)
|
39 |
+
|
40 |
+
|
41 |
+
class VocosDataset(Dataset):
|
42 |
+
def __init__(self, cfg: DataConfig, train: bool):
|
43 |
+
with open(cfg.filelist_path) as f:
|
44 |
+
self.filelist = f.read().splitlines()
|
45 |
+
self.sampling_rate = cfg.sampling_rate
|
46 |
+
self.num_samples = cfg.num_samples
|
47 |
+
self.train = train
|
48 |
+
|
49 |
+
def __len__(self) -> int:
|
50 |
+
return len(self.filelist)
|
51 |
+
|
52 |
+
def __getitem__(self, index: int) -> torch.Tensor:
|
53 |
+
audio_path = self.filelist[index]
|
54 |
+
y, sr = torchaudio.load(audio_path)
|
55 |
+
if y.size(0) > 1:
|
56 |
+
# mix to mono
|
57 |
+
y = y.mean(dim=0, keepdim=True)
|
58 |
+
gain = np.random.uniform(-1, -6) if self.train else -3
|
59 |
+
y, _ = torchaudio.sox_effects.apply_effects_tensor(y, sr, [["norm", f"{gain:.2f}"]])
|
60 |
+
if sr != self.sampling_rate:
|
61 |
+
y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate)
|
62 |
+
if y.size(-1) < self.num_samples:
|
63 |
+
pad_length = self.num_samples - y.size(-1)
|
64 |
+
padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
|
65 |
+
y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
|
66 |
+
elif self.train:
|
67 |
+
start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
|
68 |
+
y = y[:, start : start + self.num_samples]
|
69 |
+
else:
|
70 |
+
# During validation, take always the first segment for determinism
|
71 |
+
y = y[:, : self.num_samples]
|
72 |
+
|
73 |
+
return y[0]
|