Spaces:
Runtime error
Runtime error
from typing import Tuple, Optional | |
import numpy as np | |
import torch | |
from torch.nn import functional as F | |
import torch.nn as nn | |
class ProfileAug(nn.Module): | |
""" | |
Implement the augmentation for profiles including: | |
- Split aug: split one profile into two profiles, i.e., main and inaccurate, labels assigned to main | |
- Merge aug: merge two profiles into one, labels are also merged into one, the other set to zero | |
- Disturb aug: disturb some profile with others to simulate the inaccurate clustering centroids. | |
""" | |
def __init__( | |
self, | |
apply_split_aug: bool = True, | |
split_aug_prob: float = 0.05, | |
apply_merge_aug: bool = True, | |
merge_aug_prob: float = 0.2, | |
apply_disturb_aug: bool = True, | |
disturb_aug_prob: float = 0.4, | |
disturb_alpha: float = 0.2, | |
) -> None: | |
super().__init__() | |
self.apply_split_aug = apply_split_aug | |
self.split_aug_prob = split_aug_prob | |
self.apply_merge_aug = apply_merge_aug | |
self.merge_aug_prob = merge_aug_prob | |
self.apply_disturb_aug = apply_disturb_aug | |
self.disturb_aug_prob = disturb_aug_prob | |
self.disturb_alpha = disturb_alpha | |
def split_aug( | |
self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor | |
): | |
# B, N | |
bsz, dim = profile.shape[0], profile.shape[-1] | |
profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False) | |
spk_count = binary_labels.sum(dim=1) | |
prob = np.random.rand(bsz) | |
batch_indices = np.nonzero(prob < self.split_aug_prob)[0] | |
for idx in batch_indices: | |
valid_spk_idx = torch.nonzero(spk_count[idx] * mask[idx]) | |
pad_spk_idx = torch.nonzero((spk_count[idx] == 0) * mask[idx]) | |
if len(valid_spk_idx) == 0 or len(pad_spk_idx) == 0: | |
continue | |
split_spk_idx = valid_spk_idx[torch.randint(len(valid_spk_idx), ())] | |
to_cover_idx = pad_spk_idx[torch.randint(len(pad_spk_idx), ())] | |
disturb_vec = torch.randn((dim,)).to(profile) | |
disturb_vec = F.normalize(disturb_vec, dim=-1) | |
profile[idx, to_cover_idx] = F.normalize( | |
profile[idx, split_spk_idx] + self.disturb_alpha * disturb_vec | |
) | |
mask[idx, split_spk_idx] = 0 | |
mask[idx, to_cover_idx] = 0 | |
return profile, binary_labels, mask | |
def merge_aug( | |
self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor | |
): | |
bsz, dim = profile.shape[0], profile.shape[-1] | |
profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False) | |
spk_count = binary_labels.sum(dim=1) | |
prob = np.random.rand(bsz) | |
batch_indices = np.nonzero(prob < self.merge_aug_prob)[0] | |
for idx in batch_indices: | |
valid_spk_idx = torch.nonzero(profile_norm[idx] * mask[idx]) | |
if len(valid_spk_idx) == 0: | |
continue | |
to_merge = torch.randint(len(valid_spk_idx), (2,)) | |
spk_idx_1, spk_idx_2 = ( | |
valid_spk_idx[to_merge[0]], | |
valid_spk_idx[to_merge[1]], | |
) | |
# merge profile | |
profile[idx, spk_idx_1] = profile[idx, spk_idx_1] + profile[idx, spk_idx_2] | |
profile[idx, spk_idx_1] = F.normalize(profile[idx, spk_idx_1], dim=-1) | |
profile[idx, spk_idx_2] = 0 | |
# merge binary labels | |
binary_labels[idx, :, spk_idx_1] = ( | |
binary_labels[idx, :, spk_idx_1] + binary_labels[idx, :, spk_idx_2] | |
) | |
binary_labels[idx, :, spk_idx_1] = ( | |
binary_labels[idx, :, spk_idx_1] > 0 | |
).to(binary_labels) | |
binary_labels[idx, :, spk_idx_2] = 0 | |
mask[idx, spk_idx_1] = 0 | |
mask[idx, spk_idx_2] = 0 | |
return profile, binary_labels, mask | |
def disturb_aug( | |
self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor | |
): | |
bsz, dim = profile.shape[0], profile.shape[-1] | |
profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False) | |
spk_count = binary_labels.sum(dim=1) | |
prob = np.random.rand(bsz) | |
batch_indices = np.nonzero(prob < self.disturb_aug_prob)[0] | |
for idx in batch_indices: | |
pos_spk_idx = torch.nonzero(spk_count[idx] * mask[idx]) | |
valid_spk_idx = torch.nonzero(profile_norm[idx] * mask[idx]) | |
if len(pos_spk_idx) == 0 or len(valid_spk_idx) == 0: | |
continue | |
to_disturb_idx = pos_spk_idx[torch.randint(len(pos_spk_idx), ())] | |
disturb_idx = valid_spk_idx[torch.randint(len(valid_spk_idx), ())] | |
alpha = self.disturb_alpha * torch.rand(()).item() | |
profile[idx, to_disturb_idx] = (1 - alpha) * profile[ | |
idx, to_disturb_idx | |
] + alpha * profile[idx, disturb_idx] | |
profile[idx, to_disturb_idx] = F.normalize( | |
profile[idx, to_disturb_idx], dim=-1 | |
) | |
mask[idx, to_disturb_idx] = 0 | |
return profile, binary_labels, mask | |
def forward( | |
self, | |
speech: torch.Tensor, | |
speech_lengths: torch.Tensor = None, | |
profile: torch.Tensor = None, | |
profile_lengths: torch.Tensor = None, | |
binary_labels: torch.Tensor = None, | |
labels_length: torch.Tensor = None, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: | |
# copy inputs to avoid inplace-operation | |
speech, profile, binary_labels = ( | |
torch.clone(speech), | |
torch.clone(profile), | |
torch.clone(binary_labels), | |
) | |
profile = F.normalize(profile, dim=-1) | |
profile_mask = torch.ones(profile.shape[:2]).to(profile) | |
if self.apply_disturb_aug: | |
profile, binary_labels, profile_mask = self.disturb_aug( | |
profile, binary_labels, profile_mask | |
) | |
if self.apply_split_aug: | |
profile, binary_labels, profile_mask = self.split_aug( | |
profile, binary_labels, profile_mask | |
) | |
if self.apply_merge_aug: | |
profile, binary_labels, profile_mask = self.merge_aug( | |
profile, binary_labels, profile_mask | |
) | |
return speech, profile, binary_labels | |