Spaces:
Runtime error
Runtime error
File size: 6,371 Bytes
0102e16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
from typing import Tuple, Optional
import numpy as np
import torch
from torch.nn import functional as F
import torch.nn as nn
class ProfileAug(nn.Module):
"""
Implement the augmentation for profiles including:
- Split aug: split one profile into two profiles, i.e., main and inaccurate, labels assigned to main
- Merge aug: merge two profiles into one, labels are also merged into one, the other set to zero
- Disturb aug: disturb some profile with others to simulate the inaccurate clustering centroids.
"""
def __init__(
self,
apply_split_aug: bool = True,
split_aug_prob: float = 0.05,
apply_merge_aug: bool = True,
merge_aug_prob: float = 0.2,
apply_disturb_aug: bool = True,
disturb_aug_prob: float = 0.4,
disturb_alpha: float = 0.2,
) -> None:
super().__init__()
self.apply_split_aug = apply_split_aug
self.split_aug_prob = split_aug_prob
self.apply_merge_aug = apply_merge_aug
self.merge_aug_prob = merge_aug_prob
self.apply_disturb_aug = apply_disturb_aug
self.disturb_aug_prob = disturb_aug_prob
self.disturb_alpha = disturb_alpha
def split_aug(
self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor
):
# B, N
bsz, dim = profile.shape[0], profile.shape[-1]
profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False)
spk_count = binary_labels.sum(dim=1)
prob = np.random.rand(bsz)
batch_indices = np.nonzero(prob < self.split_aug_prob)[0]
for idx in batch_indices:
valid_spk_idx = torch.nonzero(spk_count[idx] * mask[idx])
pad_spk_idx = torch.nonzero((spk_count[idx] == 0) * mask[idx])
if len(valid_spk_idx) == 0 or len(pad_spk_idx) == 0:
continue
split_spk_idx = valid_spk_idx[torch.randint(len(valid_spk_idx), ())]
to_cover_idx = pad_spk_idx[torch.randint(len(pad_spk_idx), ())]
disturb_vec = torch.randn((dim,)).to(profile)
disturb_vec = F.normalize(disturb_vec, dim=-1)
profile[idx, to_cover_idx] = F.normalize(
profile[idx, split_spk_idx] + self.disturb_alpha * disturb_vec
)
mask[idx, split_spk_idx] = 0
mask[idx, to_cover_idx] = 0
return profile, binary_labels, mask
def merge_aug(
self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor
):
bsz, dim = profile.shape[0], profile.shape[-1]
profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False)
spk_count = binary_labels.sum(dim=1)
prob = np.random.rand(bsz)
batch_indices = np.nonzero(prob < self.merge_aug_prob)[0]
for idx in batch_indices:
valid_spk_idx = torch.nonzero(profile_norm[idx] * mask[idx])
if len(valid_spk_idx) == 0:
continue
to_merge = torch.randint(len(valid_spk_idx), (2,))
spk_idx_1, spk_idx_2 = (
valid_spk_idx[to_merge[0]],
valid_spk_idx[to_merge[1]],
)
# merge profile
profile[idx, spk_idx_1] = profile[idx, spk_idx_1] + profile[idx, spk_idx_2]
profile[idx, spk_idx_1] = F.normalize(profile[idx, spk_idx_1], dim=-1)
profile[idx, spk_idx_2] = 0
# merge binary labels
binary_labels[idx, :, spk_idx_1] = (
binary_labels[idx, :, spk_idx_1] + binary_labels[idx, :, spk_idx_2]
)
binary_labels[idx, :, spk_idx_1] = (
binary_labels[idx, :, spk_idx_1] > 0
).to(binary_labels)
binary_labels[idx, :, spk_idx_2] = 0
mask[idx, spk_idx_1] = 0
mask[idx, spk_idx_2] = 0
return profile, binary_labels, mask
def disturb_aug(
self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor
):
bsz, dim = profile.shape[0], profile.shape[-1]
profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False)
spk_count = binary_labels.sum(dim=1)
prob = np.random.rand(bsz)
batch_indices = np.nonzero(prob < self.disturb_aug_prob)[0]
for idx in batch_indices:
pos_spk_idx = torch.nonzero(spk_count[idx] * mask[idx])
valid_spk_idx = torch.nonzero(profile_norm[idx] * mask[idx])
if len(pos_spk_idx) == 0 or len(valid_spk_idx) == 0:
continue
to_disturb_idx = pos_spk_idx[torch.randint(len(pos_spk_idx), ())]
disturb_idx = valid_spk_idx[torch.randint(len(valid_spk_idx), ())]
alpha = self.disturb_alpha * torch.rand(()).item()
profile[idx, to_disturb_idx] = (1 - alpha) * profile[
idx, to_disturb_idx
] + alpha * profile[idx, disturb_idx]
profile[idx, to_disturb_idx] = F.normalize(
profile[idx, to_disturb_idx], dim=-1
)
mask[idx, to_disturb_idx] = 0
return profile, binary_labels, mask
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor = None,
profile: torch.Tensor = None,
profile_lengths: torch.Tensor = None,
binary_labels: torch.Tensor = None,
labels_length: torch.Tensor = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
# copy inputs to avoid inplace-operation
speech, profile, binary_labels = (
torch.clone(speech),
torch.clone(profile),
torch.clone(binary_labels),
)
profile = F.normalize(profile, dim=-1)
profile_mask = torch.ones(profile.shape[:2]).to(profile)
if self.apply_disturb_aug:
profile, binary_labels, profile_mask = self.disturb_aug(
profile, binary_labels, profile_mask
)
if self.apply_split_aug:
profile, binary_labels, profile_mask = self.split_aug(
profile, binary_labels, profile_mask
)
if self.apply_merge_aug:
profile, binary_labels, profile_mask = self.merge_aug(
profile, binary_labels, profile_mask
)
return speech, profile, binary_labels
|