File size: 6,371 Bytes
0102e16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from typing import Tuple, Optional
import numpy as np
import torch
from torch.nn import functional as F
import torch.nn as nn


class ProfileAug(nn.Module):
    """
    Implement the augmentation for profiles including:
    - Split aug: split one profile into two profiles, i.e., main and inaccurate, labels assigned to main
    - Merge aug: merge two profiles into one, labels are also merged into one, the other set to zero
    - Disturb aug: disturb some profile with others to simulate the inaccurate clustering centroids.
    """

    def __init__(
        self,
        apply_split_aug: bool = True,
        split_aug_prob: float = 0.05,
        apply_merge_aug: bool = True,
        merge_aug_prob: float = 0.2,
        apply_disturb_aug: bool = True,
        disturb_aug_prob: float = 0.4,
        disturb_alpha: float = 0.2,
    ) -> None:
        super().__init__()
        self.apply_split_aug = apply_split_aug
        self.split_aug_prob = split_aug_prob
        self.apply_merge_aug = apply_merge_aug
        self.merge_aug_prob = merge_aug_prob
        self.apply_disturb_aug = apply_disturb_aug
        self.disturb_aug_prob = disturb_aug_prob
        self.disturb_alpha = disturb_alpha

    def split_aug(
        self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor
    ):
        # B, N
        bsz, dim = profile.shape[0], profile.shape[-1]
        profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False)
        spk_count = binary_labels.sum(dim=1)
        prob = np.random.rand(bsz)
        batch_indices = np.nonzero(prob < self.split_aug_prob)[0]
        for idx in batch_indices:
            valid_spk_idx = torch.nonzero(spk_count[idx] * mask[idx])
            pad_spk_idx = torch.nonzero((spk_count[idx] == 0) * mask[idx])
            if len(valid_spk_idx) == 0 or len(pad_spk_idx) == 0:
                continue
            split_spk_idx = valid_spk_idx[torch.randint(len(valid_spk_idx), ())]
            to_cover_idx = pad_spk_idx[torch.randint(len(pad_spk_idx), ())]
            disturb_vec = torch.randn((dim,)).to(profile)
            disturb_vec = F.normalize(disturb_vec, dim=-1)
            profile[idx, to_cover_idx] = F.normalize(
                profile[idx, split_spk_idx] + self.disturb_alpha * disturb_vec
            )
            mask[idx, split_spk_idx] = 0
            mask[idx, to_cover_idx] = 0
        return profile, binary_labels, mask

    def merge_aug(
        self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor
    ):
        bsz, dim = profile.shape[0], profile.shape[-1]
        profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False)
        spk_count = binary_labels.sum(dim=1)
        prob = np.random.rand(bsz)
        batch_indices = np.nonzero(prob < self.merge_aug_prob)[0]
        for idx in batch_indices:
            valid_spk_idx = torch.nonzero(profile_norm[idx] * mask[idx])
            if len(valid_spk_idx) == 0:
                continue
            to_merge = torch.randint(len(valid_spk_idx), (2,))
            spk_idx_1, spk_idx_2 = (
                valid_spk_idx[to_merge[0]],
                valid_spk_idx[to_merge[1]],
            )
            # merge profile
            profile[idx, spk_idx_1] = profile[idx, spk_idx_1] + profile[idx, spk_idx_2]
            profile[idx, spk_idx_1] = F.normalize(profile[idx, spk_idx_1], dim=-1)
            profile[idx, spk_idx_2] = 0
            # merge binary labels
            binary_labels[idx, :, spk_idx_1] = (
                binary_labels[idx, :, spk_idx_1] + binary_labels[idx, :, spk_idx_2]
            )
            binary_labels[idx, :, spk_idx_1] = (
                binary_labels[idx, :, spk_idx_1] > 0
            ).to(binary_labels)
            binary_labels[idx, :, spk_idx_2] = 0

            mask[idx, spk_idx_1] = 0
            mask[idx, spk_idx_2] = 0

        return profile, binary_labels, mask

    def disturb_aug(
        self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor
    ):
        bsz, dim = profile.shape[0], profile.shape[-1]
        profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False)
        spk_count = binary_labels.sum(dim=1)
        prob = np.random.rand(bsz)
        batch_indices = np.nonzero(prob < self.disturb_aug_prob)[0]
        for idx in batch_indices:
            pos_spk_idx = torch.nonzero(spk_count[idx] * mask[idx])
            valid_spk_idx = torch.nonzero(profile_norm[idx] * mask[idx])
            if len(pos_spk_idx) == 0 or len(valid_spk_idx) == 0:
                continue
            to_disturb_idx = pos_spk_idx[torch.randint(len(pos_spk_idx), ())]
            disturb_idx = valid_spk_idx[torch.randint(len(valid_spk_idx), ())]
            alpha = self.disturb_alpha * torch.rand(()).item()
            profile[idx, to_disturb_idx] = (1 - alpha) * profile[
                idx, to_disturb_idx
            ] + alpha * profile[idx, disturb_idx]
            profile[idx, to_disturb_idx] = F.normalize(
                profile[idx, to_disturb_idx], dim=-1
            )
            mask[idx, to_disturb_idx] = 0

        return profile, binary_labels, mask

    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor = None,
        profile: torch.Tensor = None,
        profile_lengths: torch.Tensor = None,
        binary_labels: torch.Tensor = None,
        labels_length: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:

        # copy inputs to avoid inplace-operation
        speech, profile, binary_labels = (
            torch.clone(speech),
            torch.clone(profile),
            torch.clone(binary_labels),
        )
        profile = F.normalize(profile, dim=-1)

        profile_mask = torch.ones(profile.shape[:2]).to(profile)
        if self.apply_disturb_aug:
            profile, binary_labels, profile_mask = self.disturb_aug(
                profile, binary_labels, profile_mask
            )
        if self.apply_split_aug:
            profile, binary_labels, profile_mask = self.split_aug(
                profile, binary_labels, profile_mask
            )
        if self.apply_merge_aug:
            profile, binary_labels, profile_mask = self.merge_aug(
                profile, binary_labels, profile_mask
            )

        return speech, profile, binary_labels