Update dataset.py
Browse files- dataset.py +3 -6
dataset.py
CHANGED
@@ -209,19 +209,16 @@ class TrainDataset(Dataset):
|
|
209 |
|
210 |
sig = sig.reshape(-1).astype(np.float32)
|
211 |
|
212 |
-
sig = sig.reshape((1, -1))
|
213 |
target = torch.tensor(sig.copy())
|
214 |
p_size = random.choice(self.p_sizes)
|
215 |
|
216 |
sig = np.reshape(sig, (-1, p_size))
|
217 |
mask = self.mask_generator.gen_mask(len(sig), seed=index)[:, np.newaxis]
|
218 |
sig *= mask
|
219 |
-
sig = torch.tensor(sig.copy())
|
220 |
|
221 |
-
|
222 |
-
|
223 |
-
target = torch.stft(target.squeeze(0), self.chunk_len, self.stride, window=self.hann,
|
224 |
return_complex=False).permute(2, 0, 1).float()
|
225 |
-
sig = torch.stft(sig
|
226 |
sig = sig.permute(2, 0, 1).float()
|
227 |
return sig, target
|
|
|
209 |
|
210 |
sig = sig.reshape(-1).astype(np.float32)
|
211 |
|
|
|
212 |
target = torch.tensor(sig.copy())
|
213 |
p_size = random.choice(self.p_sizes)
|
214 |
|
215 |
sig = np.reshape(sig, (-1, p_size))
|
216 |
mask = self.mask_generator.gen_mask(len(sig), seed=index)[:, np.newaxis]
|
217 |
sig *= mask
|
218 |
+
sig = torch.tensor(sig.copy()).reshape(-1)
|
219 |
|
220 |
+
target = torch.stft(target, self.chunk_len, self.stride, window=self.hann,
|
|
|
|
|
221 |
return_complex=False).permute(2, 0, 1).float()
|
222 |
+
sig = torch.stft(sig, self.chunk_len, self.stride, window=self.hann, return_complex=False)
|
223 |
sig = sig.permute(2, 0, 1).float()
|
224 |
return sig, target
|