mrfakename's picture
Super-squash branch 'main' using huggingface_hub
0102e16 verified
raw
history blame
6.37 kB
from typing import Tuple, Optional
import numpy as np
import torch
from torch.nn import functional as F
import torch.nn as nn
class ProfileAug(nn.Module):
"""
Implement the augmentation for profiles including:
- Split aug: split one profile into two profiles, i.e., main and inaccurate, labels assigned to main
- Merge aug: merge two profiles into one, labels are also merged into one, the other set to zero
- Disturb aug: disturb some profile with others to simulate the inaccurate clustering centroids.
"""
def __init__(
self,
apply_split_aug: bool = True,
split_aug_prob: float = 0.05,
apply_merge_aug: bool = True,
merge_aug_prob: float = 0.2,
apply_disturb_aug: bool = True,
disturb_aug_prob: float = 0.4,
disturb_alpha: float = 0.2,
) -> None:
super().__init__()
self.apply_split_aug = apply_split_aug
self.split_aug_prob = split_aug_prob
self.apply_merge_aug = apply_merge_aug
self.merge_aug_prob = merge_aug_prob
self.apply_disturb_aug = apply_disturb_aug
self.disturb_aug_prob = disturb_aug_prob
self.disturb_alpha = disturb_alpha
def split_aug(
self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor
):
# B, N
bsz, dim = profile.shape[0], profile.shape[-1]
profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False)
spk_count = binary_labels.sum(dim=1)
prob = np.random.rand(bsz)
batch_indices = np.nonzero(prob < self.split_aug_prob)[0]
for idx in batch_indices:
valid_spk_idx = torch.nonzero(spk_count[idx] * mask[idx])
pad_spk_idx = torch.nonzero((spk_count[idx] == 0) * mask[idx])
if len(valid_spk_idx) == 0 or len(pad_spk_idx) == 0:
continue
split_spk_idx = valid_spk_idx[torch.randint(len(valid_spk_idx), ())]
to_cover_idx = pad_spk_idx[torch.randint(len(pad_spk_idx), ())]
disturb_vec = torch.randn((dim,)).to(profile)
disturb_vec = F.normalize(disturb_vec, dim=-1)
profile[idx, to_cover_idx] = F.normalize(
profile[idx, split_spk_idx] + self.disturb_alpha * disturb_vec
)
mask[idx, split_spk_idx] = 0
mask[idx, to_cover_idx] = 0
return profile, binary_labels, mask
def merge_aug(
self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor
):
bsz, dim = profile.shape[0], profile.shape[-1]
profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False)
spk_count = binary_labels.sum(dim=1)
prob = np.random.rand(bsz)
batch_indices = np.nonzero(prob < self.merge_aug_prob)[0]
for idx in batch_indices:
valid_spk_idx = torch.nonzero(profile_norm[idx] * mask[idx])
if len(valid_spk_idx) == 0:
continue
to_merge = torch.randint(len(valid_spk_idx), (2,))
spk_idx_1, spk_idx_2 = (
valid_spk_idx[to_merge[0]],
valid_spk_idx[to_merge[1]],
)
# merge profile
profile[idx, spk_idx_1] = profile[idx, spk_idx_1] + profile[idx, spk_idx_2]
profile[idx, spk_idx_1] = F.normalize(profile[idx, spk_idx_1], dim=-1)
profile[idx, spk_idx_2] = 0
# merge binary labels
binary_labels[idx, :, spk_idx_1] = (
binary_labels[idx, :, spk_idx_1] + binary_labels[idx, :, spk_idx_2]
)
binary_labels[idx, :, spk_idx_1] = (
binary_labels[idx, :, spk_idx_1] > 0
).to(binary_labels)
binary_labels[idx, :, spk_idx_2] = 0
mask[idx, spk_idx_1] = 0
mask[idx, spk_idx_2] = 0
return profile, binary_labels, mask
def disturb_aug(
self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor
):
bsz, dim = profile.shape[0], profile.shape[-1]
profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False)
spk_count = binary_labels.sum(dim=1)
prob = np.random.rand(bsz)
batch_indices = np.nonzero(prob < self.disturb_aug_prob)[0]
for idx in batch_indices:
pos_spk_idx = torch.nonzero(spk_count[idx] * mask[idx])
valid_spk_idx = torch.nonzero(profile_norm[idx] * mask[idx])
if len(pos_spk_idx) == 0 or len(valid_spk_idx) == 0:
continue
to_disturb_idx = pos_spk_idx[torch.randint(len(pos_spk_idx), ())]
disturb_idx = valid_spk_idx[torch.randint(len(valid_spk_idx), ())]
alpha = self.disturb_alpha * torch.rand(()).item()
profile[idx, to_disturb_idx] = (1 - alpha) * profile[
idx, to_disturb_idx
] + alpha * profile[idx, disturb_idx]
profile[idx, to_disturb_idx] = F.normalize(
profile[idx, to_disturb_idx], dim=-1
)
mask[idx, to_disturb_idx] = 0
return profile, binary_labels, mask
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor = None,
profile: torch.Tensor = None,
profile_lengths: torch.Tensor = None,
binary_labels: torch.Tensor = None,
labels_length: torch.Tensor = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
# copy inputs to avoid inplace-operation
speech, profile, binary_labels = (
torch.clone(speech),
torch.clone(profile),
torch.clone(binary_labels),
)
profile = F.normalize(profile, dim=-1)
profile_mask = torch.ones(profile.shape[:2]).to(profile)
if self.apply_disturb_aug:
profile, binary_labels, profile_mask = self.disturb_aug(
profile, binary_labels, profile_mask
)
if self.apply_split_aug:
profile, binary_labels, profile_mask = self.split_aug(
profile, binary_labels, profile_mask
)
if self.apply_merge_aug:
profile, binary_labels, profile_mask = self.merge_aug(
profile, binary_labels, profile_mask
)
return speech, profile, binary_labels