IlayMalinyak commited on
Commit
b3fb4dd
·
1 Parent(s): 192ac3b

first commit

Browse files
tasks/Modules/ResNet18.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ # https://github.com/samcw/ResNet18-Pytorch
5
+ class ResBlock(nn.Module):
6
+ def __init__(self, inchannel, outchannel, stride=1):
7
+ super(ResBlock, self).__init__()
8
+ self.left = nn.Sequential(
9
+ nn.Conv1d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
10
+ nn.BatchNorm1d(outchannel),
11
+ nn.ReLU(inplace=True),
12
+ nn.Conv1d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
13
+ nn.BatchNorm1d(outchannel)
14
+ )
15
+ self.shortcut = nn.Sequential()
16
+ if stride != 1 or inchannel != outchannel:
17
+ self.shortcut = nn.Sequential(
18
+ nn.Conv1d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
19
+ nn.BatchNorm1d(outchannel)
20
+ )
21
+
22
+ def forward(self, x):
23
+ out = self.left(x)
24
+ out = out + self.shortcut(x)
25
+ out = F.relu(out)
26
+
27
+ return out
28
+
29
+ class ResNet18(nn.Module):
30
+ def __init__(self, args):
31
+ super(ResNet18, self).__init__()
32
+ self.inchannel = 64
33
+ self.conv1 = nn.Sequential(
34
+ nn.Conv1d(1, 64, kernel_size=3, stride=1, padding=1, bias=False),
35
+ nn.BatchNorm1d(64),
36
+ nn.ReLU()
37
+ )
38
+ self.layer1 = self.make_layer(ResBlock, 64, 2, stride=1)
39
+ self.layer2 = self.make_layer(ResBlock, 128, 2, stride=2)
40
+ self.layer3 = self.make_layer(ResBlock, 256, 2, stride=2)
41
+ self.layer4 = self.make_layer(ResBlock, 512, 2, stride=2)
42
+ self.pred_layer = nn.Sequential(
43
+ nn.Linear(512, 512),
44
+ nn.SiLU(),
45
+ nn.Dropout(p=0.3),
46
+ nn.Linear(512, 1),
47
+ )
48
+ if getattr(args, 'mean_label', False):
49
+ self.pred_layer[3].bias.data.fill_(args.mean_label)
50
+
51
+ def make_layer(self, block, channels, num_blocks, stride):
52
+ strides = [stride] + [1] * (num_blocks - 1)
53
+ layers = []
54
+ for stride in strides:
55
+ layers.append(block(self.inchannel, channels, stride))
56
+ self.inchannel = channels
57
+ return nn.Sequential(*layers)
58
+
59
+ def forward(self, x):
60
+ x = x.unsqueeze(1)
61
+ out = self.conv1(x)
62
+ out = F.max_pool1d(out, 3, 2, 1)
63
+ out = self.layer1(out)
64
+ out = self.layer2(out)
65
+ out = self.layer3(out)
66
+ out = self.layer4(out)
67
+ out = out.mean(-1)
68
+ out = self.pred_layer(out)
69
+ return out
tasks/Modules/__init__.py ADDED
File without changes
tasks/Modules/cnn.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class ConvBlock(nn.Module):
5
+ def __init__(self, args) -> None:
6
+ super().__init__()
7
+ self.layers = nn.Sequential(
8
+ nn.Conv1d(in_channels=args.encoder_dim,
9
+ out_channels=args.encoder_dim,
10
+ kernel_size=args.kernel_size,
11
+ stride=1, padding='same', bias=False),
12
+ nn.BatchNorm1d(num_features=args.encoder_dim),
13
+ nn.SiLU(),
14
+ )
15
+
16
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
17
+ x = x.transpose(1, 2)
18
+ return self.layers(x).transpose(1, 2)
19
+
20
+ class ConvBlockDecoder(nn.Module):
21
+ def __init__(self, args) -> None:
22
+ super().__init__()
23
+ self.layers = nn.Sequential(
24
+ nn.Conv1d(in_channels=args.decoder_dim,
25
+ out_channels=args.decoder_dim,
26
+ kernel_size=args.kernel_size,
27
+ stride=1, padding='same', bias=False),
28
+ nn.BatchNorm1d(num_features=args.decoder_dim),
29
+ nn.SiLU(),
30
+ )
31
+
32
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
33
+ x = x.transpose(1, 2)
34
+ return self.layers(x).transpose(1, 2)
35
+
36
+ class ResNetLayer(nn.Module):
37
+ def __init__(self, args) -> None:
38
+ super().__init__()
39
+ self.conv_layer = nn.Sequential(
40
+ nn.Conv1d(in_channels=args.encoder_dim,
41
+ out_channels=args.encoder_dim,
42
+ kernel_size=3,
43
+ stride=1, padding='same', bias=False),
44
+ nn.BatchNorm1d(num_features=args.encoder_dim),
45
+ nn.SiLU(),
46
+ )
47
+
48
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
49
+ return self.conv_layer(x)+x
50
+
51
+
52
+ class ResNetBlock(nn.Module):
53
+ def __init__(self, args) -> None:
54
+ super().__init__()
55
+ self.layers = nn.Sequential(*[ResNetLayer(args) for _ in range(3)])
56
+
57
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
58
+ return self.layers(x)
tasks/Modules/conformer.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import Tensor
5
+ import torch.nn.init as init
6
+ import math
7
+
8
+ from .mhsa_pro import MHA_rotary, MHA_decoder
9
+ from .cnn import ConvBlock, ConvBlockDecoder
10
+
11
+ from typing import Optional,Tuple
12
+
13
+ class ResidualConnectionModule(nn.Module):
14
+ """
15
+ Residual Connection Module.
16
+ outputs = (module(inputs) x module_factor + inputs x input_factor)
17
+ """
18
+ def __init__(self, module: nn.Module, dims, args):
19
+ super(ResidualConnectionModule, self).__init__()
20
+ self.module = module
21
+ self.module_factor = 1
22
+ self.input_factor = 1
23
+
24
+ def forward(self, inputs: Tensor, **kwargs) -> Tensor:
25
+ return (self.module(inputs, **kwargs) * self.module_factor) + (inputs * self.input_factor)
26
+
27
+ class PostNorm(nn.Module):
28
+ """
29
+ Residual Connection Module.
30
+ outputs = (module(inputs) x module_factor + inputs x input_factor)
31
+ """
32
+ def __init__(self, module: nn.Module, dims, args):
33
+ super(PostNorm, self).__init__()
34
+ self.module = module
35
+ input_factor = torch.FloatTensor(args.alpha) if getattr(args, 'alpha', None) else torch.tensor(1.)
36
+ self.register_buffer('input_factor', input_factor)
37
+ self.norm = nn.LayerNorm(dims)
38
+
39
+ def forward(self, inputs: Tensor, **kwargs) -> Tensor:
40
+ return self.norm(self.module(inputs, **kwargs) + (inputs * self.input_factor))
41
+
42
+ class Linear(nn.Module):
43
+ """
44
+ Wrapper class of torch.nn.Linear
45
+ Weight initialize by xavier initialization and bias initialize to zeros.
46
+ """
47
+ def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
48
+ super(Linear, self).__init__()
49
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
50
+ init.xavier_uniform_(self.linear.weight)
51
+ if bias:
52
+ init.zeros_(self.linear.bias)
53
+
54
+ def forward(self, x: Tensor) -> Tensor:
55
+ return self.linear(x)
56
+
57
+
58
+ class View(nn.Module):
59
+ """ Wrapper class of torch.view() for Sequential module. """
60
+ def __init__(self, shape: tuple, contiguous: bool = False):
61
+ super(View, self).__init__()
62
+ self.shape = shape
63
+ self.contiguous = contiguous
64
+
65
+ def forward(self, x: Tensor) -> Tensor:
66
+ if self.contiguous:
67
+ x = x.contiguous()
68
+
69
+ return x.view(*self.shape)
70
+
71
+
72
+ class Transpose(nn.Module):
73
+ """ Wrapper class of torch.transpose() for Sequential module. """
74
+ def __init__(self, shape: tuple):
75
+ super(Transpose, self).__init__()
76
+ self.shape = shape
77
+
78
+ def forward(self, x: Tensor) -> Tensor:
79
+ return x.transpose(*self.shape)
80
+
81
+ class FeedForwardModule(nn.Module):
82
+ """
83
+ Conformer Feed Forward Module follow pre-norm residual units and apply layer normalization within the residual unit
84
+ and on the input before the first linear layer. This module also apply Swish activation and dropout, which helps
85
+ regularizing the network.
86
+ Args:
87
+ encoder_dim (int): Dimension of conformer encoder
88
+ expansion_factor (int): Expansion factor of feed forward module.
89
+ dropout_p (float): Ratio of dropout
90
+ device (torch.device): torch device (cuda or cpu)
91
+ Inputs: inputs
92
+ - **inputs** (batch, time, dim): Tensor contains input sequences
93
+ Outputs: outputs
94
+ - **outputs** (batch, time, dim): Tensor produces by feed forward module.
95
+ """
96
+ def __init__(
97
+ self,
98
+ args,
99
+
100
+ ) -> None:
101
+ super(FeedForwardModule, self).__init__()
102
+ expansion_factor = 4
103
+ self.sequential = nn.Sequential(
104
+ nn.LayerNorm(args.encoder_dim),
105
+ Linear(args.encoder_dim, args.encoder_dim * expansion_factor, bias=True),
106
+ nn.SiLU(),
107
+ nn.Dropout(p=args.dropout_p),
108
+ Linear(args.encoder_dim * expansion_factor, args.encoder_dim, bias=True),
109
+ nn.Dropout(p=args.dropout_p),
110
+ )
111
+
112
+ def forward(self, inputs: Tensor) -> Tensor:
113
+ return self.sequential(inputs)
114
+
115
+ class DepthwiseConv1d(nn.Module):
116
+ """
117
+ When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
118
+ this operation is termed in literature as depthwise convolution.
119
+ Args:
120
+ in_channels (int): Number of channels in the input
121
+ out_channels (int): Number of channels produced by the convolution
122
+ kernel_size (int or tuple): Size of the convolving kernel
123
+ stride (int, optional): Stride of the convolution. Default: 1
124
+ padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
125
+ bias (bool, optional): If True, adds a learnable bias to the output. Default: True
126
+ Inputs: inputs
127
+ - **inputs** (batch, in_channels, time): Tensor containing input vector
128
+ Returns: outputs
129
+ - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
130
+ """
131
+ def __init__(
132
+ self,
133
+ in_channels: int,
134
+ out_channels: int,
135
+ kernel_size: int,
136
+ stride: int = 1,
137
+ padding: int = 0,
138
+ bias: bool = False,
139
+ ) -> None:
140
+ super(DepthwiseConv1d, self).__init__()
141
+ assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
142
+ self.conv = nn.Conv1d(
143
+ in_channels=in_channels,
144
+ out_channels=out_channels,
145
+ kernel_size=kernel_size,
146
+ groups=in_channels,
147
+ stride=stride,
148
+ padding=padding,
149
+ bias=bias,
150
+ )
151
+
152
+ def forward(self, inputs: Tensor) -> Tensor:
153
+ return self.conv(inputs)
154
+
155
+
156
+ class PointwiseConv1d(nn.Module):
157
+ """
158
+ When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution.
159
+ This operation often used to match dimensions.
160
+ Args:
161
+ in_channels (int): Number of channels in the input
162
+ out_channels (int): Number of channels produced by the convolution
163
+ stride (int, optional): Stride of the convolution. Default: 1
164
+ padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
165
+ bias (bool, optional): If True, adds a learnable bias to the output. Default: True
166
+ Inputs: inputs
167
+ - **inputs** (batch, in_channels, time): Tensor containing input vector
168
+ Returns: outputs
169
+ - **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution.
170
+ """
171
+ def __init__(
172
+ self,
173
+ in_channels: int,
174
+ out_channels: int,
175
+ stride: int = 1,
176
+ padding: int = 0,
177
+ bias: bool = True,
178
+ ) -> None:
179
+ super(PointwiseConv1d, self).__init__()
180
+ self.conv = nn.Conv1d(
181
+ in_channels=in_channels,
182
+ out_channels=out_channels,
183
+ kernel_size=1,
184
+ stride=stride,
185
+ padding=padding,
186
+ bias=bias,
187
+ )
188
+
189
+ def forward(self, inputs: Tensor) -> Tensor:
190
+ return self.conv(inputs)
191
+
192
+
193
+ class ConformerConvModule(nn.Module):
194
+ """
195
+ Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
196
+ This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
197
+ to aid training deep models.
198
+ Args:
199
+ in_channels (int): Number of channels in the input
200
+ kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
201
+ dropout_p (float, optional): probability of dropout
202
+ Inputs: inputs
203
+ inputs (batch, time, dim): Tensor contains input sequences
204
+ Outputs: outputs
205
+ outputs (batch, time, dim): Tensor produces by conformer convolution module.
206
+ """
207
+ def __init__(
208
+ self,
209
+ args,
210
+ ) -> None:
211
+ super(ConformerConvModule, self).__init__()
212
+ assert (args.kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
213
+ expansion_factor = 2
214
+ dropout_p = 0.1
215
+
216
+ self.sequential = nn.Sequential(
217
+ nn.LayerNorm(args.encoder_dim),
218
+ Transpose(shape=(1, 2)),
219
+ PointwiseConv1d(args.encoder_dim, args.encoder_dim * expansion_factor, stride=1, padding=0, bias=True),
220
+ nn.GLU(dim=1),
221
+ DepthwiseConv1d(args.encoder_dim, args.encoder_dim, args.kernel_size, stride=1, padding=(args.kernel_size - 1) // 2),
222
+ nn.BatchNorm1d(args.encoder_dim),
223
+ nn.SiLU(),
224
+ PointwiseConv1d(args.encoder_dim, args.encoder_dim, stride=1, padding=0, bias=True),
225
+ nn.Dropout(p=dropout_p),
226
+ )
227
+
228
+ def forward(self, inputs: Tensor) -> Tensor:
229
+ return self.sequential(inputs).transpose(1, 2)
230
+
231
+ class PositionalEncoding(nn.Module):
232
+ """
233
+ Positional Encoding proposed in "Attention Is All You Need".
234
+ Since transformer contains no recurrence and no convolution, in order for the model to make
235
+ use of the order of the sequence, we must add some positional information.
236
+ "Attention Is All You Need" use sine and cosine functions of different frequencies:
237
+ PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model))
238
+ PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model))
239
+ """
240
+ def __init__(self, d_model: int = 128, max_len: int = 10000) -> None:
241
+ super(PositionalEncoding, self).__init__()
242
+ pe = torch.zeros(max_len, d_model, requires_grad=False)
243
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
244
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
245
+ pe[:, 0::2] = torch.sin(position * div_term)
246
+ pe[:, 1::2] = torch.cos(position * div_term)
247
+ pe = pe.unsqueeze(0)
248
+ self.register_buffer('pe', pe)
249
+
250
+ def forward(self, length: int) -> Tensor:
251
+ return self.pe[:, :length]
252
+
253
+ class RelativeMultiHeadAttention(nn.Module):
254
+ """
255
+ Multi-head attention with relative positional encoding.
256
+ This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
257
+ Args:
258
+ d_model (int): The dimension of model
259
+ num_heads (int): The number of attention heads.
260
+ dropout_p (float): probability of dropout
261
+ Inputs: query, key, value, pos_embedding, mask
262
+ - **query** (batch, time, dim): Tensor containing query vector
263
+ - **key** (batch, time, dim): Tensor containing key vector
264
+ - **value** (batch, time, dim): Tensor containing value vector
265
+ - **pos_embedding** (batch, time, dim): Positional embedding tensor
266
+ - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
267
+ Returns:
268
+ - **outputs**: Tensor produces by relative multi head attention module.
269
+ """
270
+ def __init__(
271
+ self,
272
+ encoder_dim: int = 128,
273
+ num_heads: int = 8,
274
+ dropout_p: float = 0.1
275
+ ):
276
+ super(RelativeMultiHeadAttention, self).__init__()
277
+ assert encoder_dim % num_heads == 0, "d_model % num_heads should be zero."
278
+ self.d_model = encoder_dim
279
+ self.d_head = int(encoder_dim / num_heads)
280
+ self.num_heads = num_heads
281
+ self.sqrt_dim = math.sqrt(encoder_dim)
282
+
283
+ self.query_proj = Linear(encoder_dim, encoder_dim)
284
+ self.key_proj = Linear(encoder_dim, encoder_dim)
285
+ self.value_proj = Linear(encoder_dim, encoder_dim)
286
+ self.pos_proj = Linear(encoder_dim, encoder_dim, bias=False)
287
+
288
+ self.dropout = nn.Dropout(p=dropout_p)
289
+ self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
290
+ self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
291
+ torch.nn.init.xavier_uniform_(self.u_bias)
292
+ torch.nn.init.xavier_uniform_(self.v_bias)
293
+
294
+ self.out_proj = Linear(encoder_dim, encoder_dim)
295
+
296
+ def forward(
297
+ self,
298
+ query: Tensor,
299
+ key: Tensor,
300
+ value: Tensor,
301
+ pos_embedding: Tensor,
302
+ mask: Optional[Tensor] = None,
303
+ ) -> Tensor:
304
+ batch_size = value.size(0)
305
+
306
+ query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
307
+ query = query.view(batch_size, -1, self.num_heads, self.d_head)
308
+ key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
309
+ value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
310
+ pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
311
+
312
+ content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
313
+ pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
314
+ # content_score = torch.matmul((query).transpose(1, 2), key.transpose(2, 3))
315
+ # pos_score = torch.matmul((query).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
316
+ #Q(B,numheads,length,d_head)*PE(B,numheads,d_heads,length) = posscore(B,num_heads,length,length)
317
+ pos_score = self._relative_shift(pos_score)
318
+ score = (content_score + pos_score) / self.sqrt_dim
319
+
320
+ if mask is not None:
321
+ mask = mask.unsqueeze(1)
322
+ score.masked_fill_(mask, -1e9)
323
+
324
+ score = F.softmax(score, -1)
325
+ attn = self.dropout(score)
326
+
327
+ context = torch.matmul(attn, value).transpose(1, 2)
328
+ context = context.contiguous().view(batch_size, -1, self.d_model)
329
+
330
+ return self.out_proj(context)
331
+
332
+ def _relative_shift(self, pos_score: Tensor) -> Tensor:
333
+ batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
334
+ zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
335
+ padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
336
+
337
+ padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
338
+ pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
339
+ #shift position score a unit along length axis and leave a blank row.
340
+ return pos_score
341
+
342
+
343
+ class MultiHeadedSelfAttentionModule(nn.Module):
344
+ """
345
+ Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL,
346
+ the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention
347
+ module to generalize better on different input length and the resulting encoder is more robust to the variance of
348
+ the utterance length. Conformer use prenorm residual units with dropout which helps training
349
+ and regularizing deeper models.
350
+ Args:
351
+ d_model (int): The dimension of model
352
+ num_heads (int): The number of attention heads.
353
+ dropout_p (float): probability of dropout
354
+ device (torch.device): torch device (cuda or cpu)
355
+ Inputs: inputs, mask
356
+ - **inputs** (batch, time, dim): Tensor containing input vector
357
+ - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
358
+ Returns:
359
+ - **outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module.
360
+ """
361
+ def __init__(self, args):
362
+ super(MultiHeadedSelfAttentionModule, self).__init__()
363
+ dropout_p = 0.1
364
+ self.positional_encoding = PositionalEncoding(args.encoder_dim)
365
+ self.layer_norm = nn.LayerNorm(args.encoder_dim)
366
+ self.attention = RelativeMultiHeadAttention(args.encoder_dim, args.num_heads, args.dropout_p)
367
+ self.dropout = nn.Dropout(p=dropout_p)
368
+
369
+ def forward(self, inputs: Tensor, mask: Optional[Tensor] = None):
370
+ batch_size, seq_length, _ = inputs.size()
371
+ pos_embedding = self.positional_encoding(seq_length)
372
+ pos_embedding = pos_embedding.repeat(batch_size, 1, 1)
373
+
374
+ inputs = self.layer_norm(inputs)
375
+ outputs = self.attention(inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask)
376
+
377
+ return self.dropout(outputs)
378
+
379
+ class ConformerBlock(nn.Module):
380
+ """
381
+ Conformer block contains two Feed Forward modules sandwiching the Multi-Headed Self-Attention module
382
+ and the Convolution module. This sandwich structure is inspired by Macaron-Net, which proposes replacing
383
+ the original feed-forward layer in the Transformer block into two half-step feed-forward layers,
384
+ one before the attention layer and one after.
385
+ Args:
386
+ encoder_dim (int, optional): Dimension of conformer encoder
387
+ num_attention_heads (int, optional): Number of attention heads
388
+ feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
389
+ conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
390
+ feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
391
+ attention_dropout_p (float, optional): Probability of attention module dropout
392
+ conv_dropout_p (float, optional): Probability of conformer convolution module dropout
393
+ conv_kernel_size (int or tuple, optional): Size of the convolving kernel
394
+ half_step_residual (bool): Flag indication whether to use half step residual or not
395
+ device (torch.device): torch device (cuda or cpu)
396
+ Inputs: inputs
397
+ - **inputs** (batch, time, dim): Tensor containing input vector
398
+ Returns: outputs
399
+ - **outputs** (batch, time, dim): Tensor produces by conformer block.
400
+ """
401
+ def __init__(
402
+ self,
403
+ args
404
+ ):
405
+ super(ConformerBlock, self).__init__()
406
+
407
+ norm_dict = {
408
+ 'shortcut': ResidualConnectionModule,
409
+ 'postnorm': PostNorm
410
+ }
411
+ block_dict = {
412
+ 'ffn': FeedForwardModule,
413
+ 'mhsa': MultiHeadedSelfAttentionModule,
414
+ 'mhsa_pro': MHA_rotary,
415
+ 'conv': ConvBlock,
416
+ 'conformerconv': ConformerConvModule
417
+ }
418
+
419
+ self.modlist = nn.ModuleList([norm_dict[args.norm](block_dict[block](args), args.encoder_dim, args) for block in args.encoder]\
420
+ )
421
+
422
+ def forward(self, x: Tensor, RoPE, key_padding_mask=None) -> Tensor:
423
+ for m in self.modlist:
424
+ if isinstance(m.module, MHA_rotary):
425
+ x = m(x, RoPE=RoPE, key_padding_mask=key_padding_mask)
426
+ else:
427
+ x = m(x)
428
+ return x
429
+
430
+
431
+ class DecoderBlock(nn.Module):
432
+ """
433
+ Decoder block contains two Feed Forward modules sandwiching the Multi-Headed Self-Attention module
434
+ and the Convolution module. This sandwich structure is inspired by Macaron-Net, which proposes replacing
435
+ the original feed-forward layer in the Transformer block into two half-step feed-forward layers,
436
+ one before the attention layer and one after.
437
+ Args:
438
+ encoder_dim (int, optional): Dimension of conformer encoder
439
+ num_attention_heads (int, optional): Number of attention heads
440
+ feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
441
+ conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
442
+ feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
443
+ attention_dropout_p (float, optional): Probability of attention module dropout
444
+ conv_dropout_p (float, optional): Probability of conformer convolution module dropout
445
+ conv_kernel_size (int or tuple, optional): Size of the convolving kernel
446
+ half_step_residual (bool): Flag indication whether to use half step residual or not
447
+ device (torch.device): torch device (cuda or cpu)
448
+ Inputs: inputs
449
+ - **inputs** (batch, time, dim): Tensor containing input vector
450
+ Returns: outputs
451
+ - **outputs** (batch, time, dim): Tensor produces by conformer block.
452
+ """
453
+ def __init__(
454
+ self,
455
+ args
456
+ ):
457
+ super(DecoderBlock, self).__init__()
458
+
459
+ norm_dict = {
460
+ 'shortcut': ResidualConnectionModule,
461
+ 'postnorm': PostNorm
462
+ }
463
+ block_dict = {
464
+ 'ffn': FeedForwardModule,
465
+ 'mhsa': MultiHeadedSelfAttentionModule,
466
+ 'mhsa_pro': MHA_rotary,
467
+ 'mhsa_decoder': MHA_decoder,
468
+ 'conv': ConvBlockDecoder,
469
+ 'conformerconv': ConformerConvModule
470
+ }
471
+
472
+ self.modlist = nn.ModuleList([norm_dict[args.norm](block_dict[block](args),args.decoder_dim, args) for block in args.decoder]\
473
+ )
474
+
475
+ def forward(self, x: Tensor, memory:Tensor, RoPE, key_padding_mask=None) -> Tensor:
476
+ for m in self.modlist:
477
+ if isinstance(m.module, MHA_decoder):
478
+ x = m(x, memory=memory, RoPE=RoPE, key_padding_mask=key_padding_mask)
479
+ elif isinstance(m.module, MHA_rotary):
480
+ x = m(x, RoPE=RoPE, key_padding_mask=key_padding_mask).transpose(0,1)
481
+ else:
482
+ x = m(x)
483
+ return x
484
+
485
+
486
+ class ConformerEncoder(nn.Module):
487
+ """
488
+ Conformer encoder first processes the input with a convolution subsampling layer and then
489
+ with a number of conformer blocks.
490
+ Args:
491
+ input_dim (int, optional): Dimension of input vector
492
+ encoder_dim (int, optional): Dimension of conformer encoder
493
+ num_layers (int, optional): Number of conformer blocks
494
+ num_attention_heads (int, optional): Number of attention heads
495
+ feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
496
+ conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
497
+ feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
498
+ attention_dropout_p (float, optional): Probability of attention module dropout
499
+ conv_dropout_p (float, optional): Probability of conformer convolution module dropout
500
+ conv_kernel_size (int or tuple, optional): Size of the convolving kernel
501
+ half_step_residual (bool): Flag indication whether to use half step residual or not
502
+ device (torch.device): torch device (cuda or cpu)
503
+ Inputs: inputs, input_lengths
504
+ - **inputs** (batch, time, dim): Tensor containing input vector
505
+ - **input_lengths** (batch): list of sequence input lengths
506
+ Returns: outputs, output_lengths
507
+ - **outputs** (batch, out_channels, time): Tensor produces by conformer encoder.
508
+ - **output_lengths** (batch): list of sequence output lengths
509
+ """
510
+ def __init__(
511
+ self,
512
+ args,
513
+ ):
514
+ super(ConformerEncoder, self).__init__()
515
+ self.blocks = nn.ModuleList([ConformerBlock(
516
+ args) for _ in range(args.num_layers)])
517
+
518
+ def forward(self, x: Tensor, RoPE=None, key_padding_mask=None) -> Tuple[Tensor, Tensor]:
519
+ """
520
+ Forward propagate a `inputs` for encoder training.
521
+ Args:
522
+ inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
523
+ `FloatTensor` of size ``(batch, seq_length, dimension)``.
524
+ input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
525
+ Returns:
526
+ (Tensor, Tensor)
527
+ * outputs (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size
528
+ ``(batch, seq_length, dimension)``
529
+ * output_lengths (torch.LongTensor): The length of output tensor. ``(batch)``
530
+ """
531
+ for block in self.blocks:
532
+ x = block(x, RoPE=RoPE, key_padding_mask=key_padding_mask)
533
+
534
+ return x
535
+
536
+ class ConformerDecoder(nn.Module):
537
+ """
538
+ Conformer encoder first processes the input with a convolution subsampling layer and then
539
+ with a number of conformer blocks.
540
+ Args:
541
+ input_dim (int, optional): Dimension of input vector
542
+ encoder_dim (int, optional): Dimension of conformer encoder
543
+ num_layers (int, optional): Number of conformer blocks
544
+ num_attention_heads (int, optional): Number of attention heads
545
+ feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
546
+ conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
547
+ feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
548
+ attention_dropout_p (float, optional): Probability of attention module dropout
549
+ conv_dropout_p (float, optional): Probability of conformer convolution module dropout
550
+ conv_kernel_size (int or tuple, optional): Size of the convolving kernel
551
+ half_step_residual (bool): Flag indication whether to use half step residual or not
552
+ device (torch.device): torch device (cuda or cpu)
553
+ Inputs: inputs, input_lengths
554
+ - **inputs** (batch, time, dim): Tensor containing input vector
555
+ - **input_lengths** (batch): list of sequence input lengths
556
+ Returns: outputs, output_lengths
557
+ - **outputs** (batch, out_channels, time): Tensor produces by conformer encoder.
558
+ - **output_lengths** (batch): list of sequence output lengths
559
+ """
560
+ def __init__(
561
+ self,
562
+ args,
563
+ ):
564
+ super(ConformerDecoder, self).__init__()
565
+ self.blocks = nn.ModuleList([DecoderBlock(
566
+ args) for _ in range(args.num_decoder_layers)])
567
+
568
+ def forward(self, x: Tensor, memory: Tensor, RoPE=None, key_padding_mask=None) -> Tuple[Tensor, Tensor]:
569
+ """
570
+ Forward propagate a `inputs` for encoder training.
571
+ Args:
572
+ inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
573
+ `FloatTensor` of size ``(batch, seq_length, dimension)``.
574
+ input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
575
+ Returns:
576
+ (Tensor, Tensor)
577
+ * outputs (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size
578
+ ``(batch, seq_length, dimension)``
579
+ * output_lengths (torch.LongTensor): The length of output tensor. ``(batch)``
580
+ """
581
+ for block in self.blocks:
582
+ x = block(x, memory, RoPE=RoPE, key_padding_mask=key_padding_mask)
583
+
584
+ return x
tasks/Modules/mhsa_pro.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.nn.init as init
5
+
6
+ from typing import Optional,Tuple
7
+ import math
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ rwkv_emb_scale = 0.4 # try 0.4 for char-level english. try 1.0 for chinese.
14
+ rwkv_layer_decay = 1.0 # decay weights in higher layers. try 0.5 ~ 1.0.
15
+
16
+ class AttentionConfig:
17
+ def __init__(self, ctx_len=100, **kwargs):
18
+ self.ctx_len = ctx_len
19
+ for k,v in kwargs.items():
20
+ setattr(self, k, v)
21
+
22
+
23
+ ########################################################################################################
24
+ # MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN
25
+ ########################################################################################################
26
+
27
+ class RotaryEmbedding(torch.nn.Module):
28
+ def __init__(self, dim, base=10000):
29
+ super().__init__()
30
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
31
+ self.register_buffer('inv_freq', inv_freq)
32
+ self.seq_len_cached = None
33
+ self.cos_cached = None
34
+ self.sin_cached = None
35
+
36
+ def forward(self, x, seq_len=None):
37
+ if seq_len != self.seq_len_cached:
38
+ self.seq_len_cached = seq_len
39
+ t = torch.arange(seq_len, device=x.device)
40
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
41
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
42
+ self.cos_cached = emb.cos()
43
+ self.sin_cached = emb.sin()
44
+ return torch.stack([self.cos_cached, self.sin_cached])
45
+
46
+ class ContinuousRotaryEmbedding(torch.nn.Module):
47
+ '''Continuous rotary position embedding'''
48
+ def __init__(self, dim, sequence_scale):
49
+ super().__init__()
50
+ base=10000
51
+ self.sequence_scale = sequence_scale
52
+ self.register_buffer('inv_freq', 1. / (base ** (torch.arange(0, dim, 2))))
53
+
54
+ def forward(self, t):
55
+ t = (t + 0.5)* self.sequence_scale
56
+ freqs = torch.einsum('ij,k->ijk', t, self.inv_freq) # freqs: [B, L, dim//2]
57
+ emb = torch.cat((freqs, freqs), dim=-1).unsqueeze(1) # emb: [B, 1, L, dim], 1 for broadcast in head_num dim
58
+ return torch.stack([emb.cos(), emb.sin()])
59
+
60
+ def rotate_half(x):
61
+ x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
62
+ return torch.cat((-x2, x1), -1)
63
+
64
+ @torch.jit.script
65
+ def apply_rotary_pos_emb(q, k, cos, sin):
66
+ cos, sin = cos[...,:q.shape[2],:], sin[...,:q.shape[2],:]
67
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
68
+
69
+ class MHA_rotary(nn.Module):
70
+ def __init__(self, args):
71
+ super().__init__()
72
+ self.collect_attention_map = False
73
+ self.attention_map = None
74
+ assert args.encoder_dim % args.num_heads == 0
75
+ self.num_heads = args.num_heads
76
+ self.head_size = args.encoder_dim // args.num_heads
77
+
78
+ if args.timeshift:
79
+ self.time_shift = nn.ZeroPad2d((0,0,1,0))
80
+
81
+ self.query = nn.Linear(args.encoder_dim, args.encoder_dim)
82
+ self.key = nn.Linear(args.encoder_dim, args.encoder_dim)
83
+ self.value = nn.Linear(args.encoder_dim, args.encoder_dim)
84
+
85
+ # self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
86
+
87
+ self.rotary_ndims = int(self.head_size * 0.5)
88
+
89
+ self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
90
+
91
+ self.output = nn.Linear(args.encoder_dim, args.encoder_dim)
92
+
93
+ def forward(self, x, RoPE, key_padding_mask=None):
94
+ B, T, C = x.size()
95
+
96
+ if hasattr(self, 'time_shift'):
97
+ x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
98
+
99
+ q = self.query(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
100
+ k = self.key(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
101
+ v = self.value(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
102
+
103
+ q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
104
+ k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
105
+
106
+ # cos, sin = self.rotary_emb(q, seq_len=T)
107
+ cos, sin = RoPE
108
+ q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding
109
+ q = torch.cat((q, query_pass), dim=-1)
110
+ k = torch.cat((k, key_pass), dim=-1)
111
+
112
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
113
+ if key_padding_mask is not None:
114
+ key_padding_mask = key_padding_mask[:, None, None, :] # (B, T) -> (B, 1, 1, T)
115
+ att = att.masked_fill(key_padding_mask == 0, float('-inf'))
116
+ att = F.softmax(att, dim = -1) # softmax
117
+
118
+ x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
119
+ x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
120
+
121
+ x = self.output(x)
122
+
123
+ if self.collect_attention_map:
124
+ self.attention_map = att
125
+
126
+ return x
127
+
128
+ class MHA_decoder(nn.Module):
129
+ def __init__(self, args):
130
+ super().__init__()
131
+ self.collect_attention_map = False
132
+ self.attention_map = None
133
+ assert args.encoder_dim % args.num_heads == 0
134
+ self.num_heads = args.num_heads
135
+ self.head_size = args.decoder_dim // args.num_heads
136
+
137
+ if args.timeshift:
138
+ self.time_shift = nn.ZeroPad2d((0,0,1,0))
139
+
140
+ self.query = nn.Linear(args.decoder_dim, args.decoder_dim)
141
+ self.key = nn.Linear(args.decoder_dim, args.decoder_dim)
142
+ self.value = nn.Linear(args.decoder_dim, args.decoder_dim)
143
+
144
+ # self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
145
+
146
+ self.rotary_ndims = int(self.head_size * 0.5)
147
+
148
+ self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
149
+
150
+ self.output = nn.Linear(args.decoder_dim, args.decoder_dim)
151
+
152
+ def forward(self, x, memory,RoPE, key_padding_mask=None):
153
+ B, T, C = x.size()
154
+ _, L, M = memory.size()
155
+
156
+ # print("x size: ", x.size(), 'memory size: ', memory.size())
157
+ # print('B, T, C: ', B, T, C, 'L: ', L)
158
+
159
+ q = self.query(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
160
+ k = self.key(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
161
+ v = self.value(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
162
+
163
+ q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
164
+ k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
165
+
166
+ # cos, sin = self.rotary_emb(q, seq_len=T)
167
+ cos, sin = RoPE
168
+ q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding
169
+ q = torch.cat((q, query_pass), dim=-1)
170
+ k = torch.cat((k, key_pass), dim=-1)
171
+
172
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
173
+ if key_padding_mask is not None:
174
+ key_padding_mask = key_padding_mask[:, None, None, :] # (B, T) -> (B, 1, 1, T)
175
+ att = att.masked_fill(key_padding_mask == 0, float('-inf'))
176
+ att = F.softmax(att, dim = -1) # softmax
177
+
178
+ x = att @ v
179
+ # print("after attention vals: ", x.shape) # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
180
+ x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
181
+
182
+ # x = self.output(x)
183
+
184
+ # print("after linear: ", x.shape) # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
185
+
186
+
187
+ # cross attention:
188
+ q = self.query(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
189
+ k = self.key(memory).view(B, L, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
190
+ v = self.value(memory).view(B, L, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
191
+
192
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
193
+ # print("att size: ", att.size())
194
+ if key_padding_mask is not None:
195
+ key_padding_mask = key_padding_mask[:, None, None, :] # (B, T) -> (B, 1, 1, T)
196
+ att = att.masked_fill(key_padding_mask == 0, float('-inf'))
197
+ att = F.softmax(att, dim = -1) # softmax
198
+
199
+ x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
200
+ # print("x deocder size: ", x.size())
201
+ x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
202
+ # print("x deocder size transposed: ", x.size())
203
+ x = self.output(x)
204
+
205
+ if self.collect_attention_map:
206
+ self.attention_map = att
207
+
208
+ return x
209
+
210
+ class GeGLU(torch.nn.Module):
211
+ def __init__(self, config, layer_id, time_shift = False):
212
+ super().__init__()
213
+ self.layer_id = layer_id
214
+
215
+ if time_shift:
216
+ self.time_shift = nn.ZeroPad2d((0,0,1,0))
217
+
218
+ hidden_sz = 3 * config.n_ffn
219
+ self.key = nn.Linear(config.n_embd, hidden_sz)
220
+ self.value = nn.Linear(config.n_embd, hidden_sz)
221
+ self.weight = nn.Linear(hidden_sz, config.n_embd)
222
+
223
+ def forward(self, x):
224
+ B, T, C = x.size()
225
+ if hasattr(self, 'time_shift'):
226
+ x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
227
+
228
+ k = self.key(x)
229
+ v = self.value(x)
230
+ y = self.weight(F.gelu(k) * v)
231
+ return y
tasks/audio.py CHANGED
@@ -2,11 +2,19 @@ from fastapi import APIRouter
2
  from datetime import datetime
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
 
5
  import random
6
  import os
 
 
7
 
8
  from .utils.evaluation import AudioEvaluationRequest
9
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
 
 
 
 
 
10
 
11
  from dotenv import load_dotenv
12
  load_dotenv()
@@ -43,20 +51,43 @@ async def evaluate_audio(request: AudioEvaluationRequest):
43
  # Split dataset
44
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
45
  test_dataset = train_test["test"]
46
-
47
  # Start tracking emissions
48
  tracker.start()
49
  tracker.start_task("inference")
50
-
51
  #--------------------------------------------------------------------------------------------
52
  # YOUR MODEL INFERENCE CODE HERE
53
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
54
- #--------------------------------------------------------------------------------------------
55
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # Make random predictions (placeholder for actual model inference)
57
  true_labels = test_dataset["label"]
58
- predictions = [random.randint(0, 1) for _ in range(len(true_labels))]
59
-
60
  #--------------------------------------------------------------------------------------------
61
  # YOUR MODEL INFERENCE STOPS HERE
62
  #--------------------------------------------------------------------------------------------
 
2
  from datetime import datetime
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
+ import numpy as np
6
  import random
7
  import os
8
+ import torch
9
+ from torch.utils.data import DataLoader
10
 
11
  from .utils.evaluation import AudioEvaluationRequest
12
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
13
+ from data import FFTDataset
14
+ from models import DualEncoder
15
+ from train import Trainer
16
+ from data_utils import collate_fn, Container
17
+ import yaml
18
 
19
  from dotenv import load_dotenv
20
  load_dotenv()
 
51
  # Split dataset
52
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
53
  test_dataset = train_test["test"]
54
+
55
  # Start tracking emissions
56
  tracker.start()
57
  tracker.start_task("inference")
58
+
59
  #--------------------------------------------------------------------------------------------
60
  # YOUR MODEL INFERENCE CODE HERE
61
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
62
+ #--------------------------------------------------------------------------------------------
63
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
+ args_path = 'config.yaml'
65
+ data_args = Container(**yaml.safe_load(open(args_path, 'r'))['Data'])
66
+ model_args = Container(**yaml.safe_load(open(args_path, 'r'))['CNNEncoder'])
67
+ model_args_f = Container(**yaml.safe_load(open(args_path, 'r'))['CNNEncoder_f'])
68
+ conformer_args = Container(**yaml.safe_load(open(args_path, 'r'))['Conformer'])
69
+
70
+ test_dataset = FFTDataset(test_dataset)
71
+ test_dl = DataLoader(test_dataset, batch_size=data_args.batch_size, collate_fn=collate_fn)
72
+
73
+ model = DualEncoder(model_args, model_args_f, conformer_args)
74
+ model = model.to(device)
75
+ missing, unexpected = model.load_state_dict(torch.load(model_args.checkpoint_path))
76
+
77
+ loss_fn = torch.nn.BCEWithLogitsLoss()
78
+ optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
79
+ trainer = Trainer(model=model, optimizer=optimizer,
80
+ criterion=loss_fn, output_dim=model_args.output_dim, scaler=None,
81
+ scheduler=None, train_dataloader=None,
82
+ val_dataloader=None, device=device,
83
+ exp_num='test', log_path=None,
84
+ range_update=None,
85
+ accumulation_step=1, max_iter=np.inf,
86
+ exp_name=f"frugal_cnnencoder_inference")
87
+ predictions, acc = trainer.predict(test_dl, device=device)
88
  # Make random predictions (placeholder for actual model inference)
89
  true_labels = test_dataset["label"]
90
+
 
91
  #--------------------------------------------------------------------------------------------
92
  # YOUR MODEL INFERENCE STOPS HERE
93
  #--------------------------------------------------------------------------------------------
tasks/config.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Data:
2
+ # Basics
3
+ log_dir: '/data/frugal/logs'
4
+ # Data
5
+ dataset: "KeplerDataset"
6
+ data_dir: '/data/lightPred/data'
7
+ model_name: "CNNEncoder"
8
+ batch_size: 16
9
+ num_epochs: 1000
10
+ exp_num: 2
11
+ max_len_spectra: 4096
12
+ max_days_lc: 270
13
+ lc_freq: 0.0208
14
+ create_umap: True
15
+
16
+ CNNEncoder:
17
+ # Model
18
+ in_channels: 1
19
+ num_layers: 4
20
+ stride: 1
21
+ encoder_dims: [32,64,128,256]
22
+ kernel_size: 3
23
+ dropout_p: 0.3
24
+ output_dim: 2
25
+ beta: 1
26
+ load_checkpoint: True
27
+ checkpoint_num: 1
28
+ activation: "silu"
29
+ sine_w0: 1.0
30
+ avg_output: True
31
+ checkpoint_path: 'logs/frugal_2025-01-10/frugal_cnnencoder_2.pth'
32
+
33
+ CNNEncoder_f:
34
+ # Model
35
+ in_channels: 1
36
+ num_layers: 4
37
+ stride: 1
38
+ encoder_dims: [32,64,128]
39
+ kernel_size: 3
40
+ dropout_p: 0.3
41
+ output_dim: 2
42
+ beta: 1
43
+ load_checkpoint: True
44
+ checkpoint_num: 1
45
+ activation: "silu"
46
+ sine_w0: 1.0
47
+ avg_output: True
48
+
49
+
50
+ Conformer:
51
+ encoder: ["mhsa_pro", "conv"]
52
+ timeshift: false
53
+ num_layers: 8
54
+ encoder_dim: 128
55
+ num_heads: 8
56
+ kernel_size: 3
57
+ dropout_p: 0.2
58
+ norm: "postnorm"
59
+
60
+
61
+ Optimization:
62
+ # Optimization
63
+ max_lr: 1e-5
64
+ weight_decay: 5e-6
65
+ warmup_pct: 0.3
66
+ steps_per_epoch: 3500
tasks/data.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import IterableDataset
3
+ from torch.fft import fft
4
+ from itertools import tee
5
+ import random
6
+ import torchaudio.transforms as T
7
+
8
+
9
+ class SplitDataset(IterableDataset):
10
+ def __init__(self, dataset, is_train=True, train_ratio=0.8):
11
+ self.dataset = dataset
12
+ self.is_train = is_train
13
+ self.train_ratio = train_ratio
14
+
15
+ def __iter__(self):
16
+ count = 0
17
+ for item in self.dataset:
18
+ # For first train_ratio portion of items, yield to train
19
+ # For remaining items, yield to validation
20
+ is_train_item = count < int(self.train_ratio * 100)
21
+ if is_train_item == self.is_train:
22
+ yield item
23
+ count = (count + 1) % 100
24
+
25
+
26
+ class FFTDataset(IterableDataset):
27
+ def __init__(self, original_dataset, orig_sample_rate=12000, target_sample_rate=6000):
28
+ self.dataset = original_dataset
29
+ self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate)
30
+
31
+ def __iter__(self):
32
+ for item in self.dataset:
33
+ # Assuming your audio data is in item['audio']
34
+ # Modify this based on your actual data structure
35
+ audio_data = torch.tensor(item['audio']['array']).float()
36
+ if len(audio_data) == 0:
37
+ continue
38
+ resampled_audio = self.resampler(audio_data)
39
+ fft_data = fft(resampled_audio)
40
+
41
+ # Update the item with FFT data
42
+ item['audio']['fft'] = fft_data
43
+ yield item
tasks/data_utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.distributed as dist
4
+ from torch.nn.utils.rnn import pad_sequence
5
+
6
+ def collate_fn(batch):
7
+ # Extract audio arrays and FFT data from the batch of dictionaries
8
+ audio_arrays = [torch.tensor(item['audio']['array']) for item in batch]
9
+ fft_arrays = [torch.tensor(item['audio']['fft']) for item in batch]
10
+ labels = [torch.tensor(item['label']) for item in batch]
11
+
12
+ # Pad both sequences
13
+ padded_audio = pad_sequence(audio_arrays, batch_first=True, padding_value=0)
14
+ padded_fft = pad_sequence(fft_arrays, batch_first=True, padding_value=0)
15
+
16
+ # Return as dictionary with the same structure
17
+ return {
18
+ 'audio': {
19
+ 'array': padded_audio,
20
+ 'fft': padded_fft
21
+ },
22
+ 'label': torch.stack(labels)
23
+
24
+ }
25
+
26
+ class Container(object):
27
+ '''A container class that can be used to store any attributes.'''
28
+ def __init__(self, **kwargs):
29
+ self.__dict__.update(kwargs)
30
+
31
+ def load_dict(self, dict):
32
+ for key, value in dict.items():
33
+ if getattr(self, key, None) is None:
34
+ setattr(self, key, value)
35
+
36
+ def print_attributes(self):
37
+ for key, value in vars(self).items():
38
+ print(f"{key}: {value}")
39
+
40
+ def get_dict(self):
41
+ return self.__dict__
42
+
43
+ def setup():
44
+ """
45
+ Setup the distributed training environment.
46
+ """
47
+ world_size = int(os.environ["WORLD_SIZE"])
48
+ rank = int(os.environ["SLURM_PROCID"])
49
+ jobid = int(os.environ["SLURM_JOBID"])
50
+ gpus_per_node = torch.cuda.device_count()
51
+ print('jobid ', jobid)
52
+ print('gpus per node ', gpus_per_node)
53
+ print(f"Hello from rank {rank} of {world_size} where there are" \
54
+ f" {gpus_per_node} allocated GPUs per node. ", flush=True)
55
+
56
+ # initialize the process group
57
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
58
+
59
+ if rank == 0: print(f"Group initialized? {dist.is_initialized()}", flush=True)
60
+ local_rank = rank - gpus_per_node * (rank // gpus_per_node)
61
+ torch.cuda.set_device(local_rank)
62
+ print(f"rank: {rank}, local_rank: {local_rank}")
63
+ return local_rank, world_size, gpus_per_node
tasks/models.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from Modules.conformer import ConformerEncoder, ConformerDecoder
4
+ from Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding
5
+
6
+ class ConvBlock(nn.Module):
7
+ def __init__(self, args, num_layer) -> None:
8
+ super().__init__()
9
+ if args.activation == 'silu':
10
+ self.activation = nn.SiLU()
11
+ else:
12
+ self.activation = nn.ReLU()
13
+ in_channels = args.encoder_dims[num_layer-1] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
14
+ out_channels = args.encoder_dims[num_layer] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
15
+ self.layers = nn.Sequential(
16
+ nn.Conv1d(in_channels=in_channels,
17
+ out_channels=out_channels,
18
+ kernel_size=args.kernel_size,
19
+ stride=1, padding='same', bias=False),
20
+ nn.BatchNorm1d(num_features=out_channels),
21
+ self.activation,
22
+ )
23
+
24
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
25
+ return self.layers(x)
26
+
27
+ class CNNEncoder(nn.Module):
28
+ def __init__(self, args) -> None:
29
+ super().__init__()
30
+ print("Using CNN encoder wit activation: ", args.activation, 'args avg_output: ', args.avg_output)
31
+ if args.activation == 'silu':
32
+ self.activation = nn.SiLU()
33
+ else:
34
+ self.activation = nn.ReLU()
35
+ self.embedding = nn.Sequential(nn.Conv1d(in_channels = args.in_channels,
36
+ kernel_size=3, out_channels = args.encoder_dims[0], stride=1, padding = 'same', bias = False),
37
+ nn.BatchNorm1d(args.encoder_dims[0]),
38
+ self.activation,
39
+ )
40
+
41
+ self.layers = nn.ModuleList([ConvBlock(args, i+1)
42
+ for i in range(args.num_layers)])
43
+ self.pool = nn.MaxPool1d(2)
44
+ self.output_dim = args.encoder_dims[-1]
45
+ self.min_seq_len = 2
46
+ self.avg_output = args.avg_output
47
+
48
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
49
+ if len(x.shape)==2:
50
+ x = x.unsqueeze(1)
51
+ if len(x.shape)==3 and x.shape[-1]==1:
52
+ x = x.permute(0,2,1)
53
+ x = self.embedding(x)
54
+ for m in self.layers:
55
+ x = m(x)
56
+ if x.shape[-1] > self.min_seq_len:
57
+ x = self.pool(x)
58
+ if self.avg_output:
59
+ x = x.mean(dim=-1)
60
+ return x
61
+
62
+
63
+ class MultiEncoder(nn.Module):
64
+ def __init__(self, args, conformer_args):
65
+ super().__init__()
66
+ self.backbone = CNNEncoder(args)
67
+ self.backbone.avg_output = False
68
+ self.head_size = conformer_args.encoder_dim // conformer_args.num_heads
69
+ self.rotary_ndims = int(self.head_size * 0.5)
70
+ self.pe = RotaryEmbedding(self.rotary_ndims)
71
+ self.encoder = ConformerEncoder(conformer_args)
72
+ self.output_dim = conformer_args.encoder_dim
73
+ self.avg_output = args.avg_output
74
+
75
+ def forward(self, x):
76
+ # Store backbone output in a separate tensor
77
+ backbone_out = self.backbone(x)
78
+
79
+ # Create x_enc from backbone_out
80
+ if len(backbone_out.shape) == 2:
81
+ x_enc = backbone_out.unsqueeze(1).clone()
82
+ else:
83
+ x_enc = backbone_out.permute(0,2,1).clone()
84
+
85
+ RoPE = self.pe(x_enc, x_enc.shape[1])
86
+ x_enc = self.encoder(x_enc, RoPE)
87
+
88
+ if len(x_enc.shape) == 3:
89
+ if self.avg_output:
90
+ x_enc = x_enc.sum(dim=1)
91
+ else:
92
+ x_enc = x_enc.permute(0,2,1)
93
+
94
+ # Return x_enc and the original backbone output
95
+ return x_enc, backbone_out
96
+
97
+ class DualEncoder(nn.Module):
98
+ def __init__(self, args_x, args_f, conformer_args) -> None:
99
+ super().__init__()
100
+ self.encoder_x = CNNEncoder(args_x)
101
+ self.encoder_f = MultiEncoder(args_f, conformer_args)
102
+ total_output_dim = args_x.encoder_dims[-1] + args_f.encoder_dims[-1]
103
+ self.regressor = nn.Sequential(
104
+ nn.Linear(total_output_dim, total_output_dim//2),
105
+ nn.BatchNorm1d(total_output_dim//2),
106
+ nn.SiLU(),
107
+ nn.Linear(total_output_dim//2, 1)
108
+ )
109
+
110
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
111
+ x1 = self.encoder_x(x)
112
+ x2, _ = self.encoder_f(x)
113
+ logits = torch.cat([x1, x2], dim=-1)
114
+ return self.regressor(logits).squeeze()
tasks/train.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.cuda.amp import autocast
3
+ import numpy as np
4
+ import time
5
+ import os
6
+ import yaml
7
+ from matplotlib import pyplot as plt
8
+ import glob
9
+ from collections import OrderedDict
10
+ from tqdm import tqdm
11
+ import torch.distributed as dist
12
+ import umap
13
+
14
+ class Trainer(object):
15
+ """
16
+ A class that encapsulates the training loop for a PyTorch model.
17
+ """
18
+ def __init__(self, model, optimizer, criterion, train_dataloader, device, world_size=1, output_dim=2,
19
+ scheduler=None, val_dataloader=None, max_iter=np.inf, scaler=None,
20
+ grad_clip=False, exp_num=None, log_path=None, exp_name=None, plot_every=None,
21
+ cos_inc=False, range_update=None, accumulation_step=1, wandb_log=False, num_quantiles=1,
22
+ update_func=lambda x: x):
23
+ self.model = model
24
+ self.optimizer = optimizer
25
+ self.criterion = criterion
26
+ self.scaler = scaler
27
+ self.grad_clip = grad_clip
28
+ self.cos_inc = cos_inc
29
+ self.output_dim = output_dim
30
+ self.scheduler = scheduler
31
+ self.train_dl = train_dataloader
32
+ self.val_dl = val_dataloader
33
+ self.train_sampler = self.get_sampler_from_dataloader(train_dataloader)
34
+ self.val_sampler = self.get_sampler_from_dataloader(val_dataloader)
35
+ self.max_iter = max_iter
36
+ self.device = device
37
+ self.world_size = world_size
38
+ self.exp_num = exp_num
39
+ self.exp_name = exp_name
40
+ self.log_path = log_path
41
+ self.best_state_dict = None
42
+ self.plot_every = plot_every
43
+ self.logger = None
44
+ self.range_update = range_update
45
+ self.accumulation_step = accumulation_step
46
+ self.wandb = wandb_log
47
+ self.num_quantiles = num_quantiles
48
+ self.update_func = update_func
49
+ # if log_path is not None:
50
+ # self.logger =SummaryWriter(f'{self.log_path}/exp{self.exp_num}')
51
+ # # print(f"logger path: {self.log_path}/exp{self.exp_num}")
52
+
53
+ # print("logger is: ", self.logger)
54
+
55
+ def get_sampler_from_dataloader(self, dataloader):
56
+ if hasattr(dataloader, 'sampler'):
57
+ if isinstance(dataloader.sampler, torch.utils.data.DistributedSampler):
58
+ return dataloader.sampler
59
+ elif hasattr(dataloader.sampler, 'sampler'):
60
+ return dataloader.sampler.sampler
61
+
62
+ if hasattr(dataloader, 'batch_sampler') and hasattr(dataloader.batch_sampler, 'sampler'):
63
+ return dataloader.batch_sampler.sampler
64
+
65
+ return None
66
+
67
+ def fit(self, num_epochs, device, early_stopping=None, only_p=False, best='loss', conf=False):
68
+ """
69
+ Fits the model for the given number of epochs.
70
+ """
71
+ min_loss = np.inf
72
+ best_acc = 0
73
+ train_loss, val_loss, = [], []
74
+ train_acc, val_acc = [], []
75
+ lrs = []
76
+ # self.optim_params['lr_history'] = []
77
+ epochs_without_improvement = 0
78
+ main_proccess = (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0) or self.device == 'cpu'
79
+
80
+ print(f"Starting training for {num_epochs} epochs")
81
+ print("is main process: ", main_proccess, flush=True)
82
+ global_time = time.time()
83
+ self.epoch = 0
84
+ for epoch in range(num_epochs):
85
+ self.epoch = epoch
86
+ start_time = time.time()
87
+ plot = (self.plot_every is not None) and (epoch % self.plot_every == 0)
88
+ t_loss, t_acc = self.train_epoch(device, epoch=epoch)
89
+ t_loss_mean = np.nanmean(t_loss)
90
+ train_loss.extend(t_loss)
91
+ global_train_accuracy, global_train_loss = self.process_loss(t_acc, t_loss_mean)
92
+ if main_proccess: # Only perform this on the master GPU
93
+ train_acc.append(global_train_accuracy.mean().item())
94
+
95
+ v_loss, v_acc = self.eval_epoch(device, epoch=epoch)
96
+ v_loss_mean = np.nanmean(v_loss)
97
+ val_loss.extend(v_loss)
98
+ global_val_accuracy, global_val_loss = self.process_loss(v_acc, v_loss_mean)
99
+ if main_proccess: # Only perform this on the master GPU
100
+ val_acc.append(global_val_accuracy.mean().item())
101
+
102
+ current_objective = global_val_loss if best == 'loss' else global_val_accuracy.mean()
103
+ improved = False
104
+
105
+ if best == 'loss':
106
+ if current_objective < min_loss:
107
+ min_loss = current_objective
108
+ improved = True
109
+ else:
110
+ if current_objective > best_acc:
111
+ best_acc = current_objective
112
+ improved = True
113
+
114
+ if improved:
115
+ model_name = f'{self.log_path}/{self.exp_num}/{self.exp_name}.pth'
116
+ print(f"saving model at {model_name}...")
117
+ torch.save(self.model.state_dict(), model_name)
118
+ self.best_state_dict = self.model.state_dict()
119
+ epochs_without_improvement = 0
120
+ else:
121
+ epochs_without_improvement += 1
122
+
123
+ current_lr = self.optimizer.param_groups[0]['lr'] if self.scheduler is None \
124
+ else self.scheduler.get_last_lr()[0]
125
+
126
+ lrs.append(current_lr)
127
+
128
+ print(f'Epoch {epoch}, lr {current_lr}, Train Loss: {global_train_loss:.6f}, Val Loss:'\
129
+ f'{global_val_loss:.6f}, Train Acc: {global_train_accuracy.round(decimals=4).tolist()}, '\
130
+ f'Val Acc: {global_val_accuracy.round(decimals=4).tolist()},'\
131
+ f'Time: {time.time() - start_time:.2f}s, Total Time: {(time.time() - global_time)/3600} hr', flush=True)
132
+ if epoch % 10 == 0:
133
+ print(os.system('nvidia-smi'))
134
+
135
+ if epochs_without_improvement == early_stopping:
136
+ print('early stopping!', flush=True)
137
+ break
138
+ if time.time() - global_time > (23.83 * 3600):
139
+ print("time limit reached")
140
+ break
141
+
142
+ return {"num_epochs":num_epochs, "train_loss": train_loss,
143
+ "val_loss": val_loss, "train_acc": train_acc, "val_acc": val_acc, "lrs": lrs}
144
+
145
+ def process_loss(self, acc, loss_mean):
146
+ if torch.cuda.is_available() and torch.distributed.is_initialized():
147
+ global_accuracy = torch.tensor(acc).cuda() # Convert accuracy to a tensor on the GPU
148
+ torch.distributed.reduce(global_accuracy, dst=0, op=torch.distributed.ReduceOp.SUM)
149
+ global_loss = torch.tensor(loss_mean).cuda() # Convert loss to a tensor on the GPU
150
+ torch.distributed.reduce(global_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
151
+
152
+ # Divide both loss and accuracy by world size
153
+ world_size = torch.distributed.get_world_size()
154
+ global_loss /= world_size
155
+ global_accuracy /= world_size
156
+ else:
157
+ global_loss = torch.tensor(loss_mean)
158
+ global_accuracy = torch.tensor(acc)
159
+ return global_accuracy, global_loss
160
+
161
+ def load_best_model(self, to_ddp=True, from_ddp=True):
162
+ data_dir = f'{self.log_path}/exp{self.exp_num}'
163
+ # data_dir = f'{self.log_path}/exp29' # for debugging
164
+
165
+ state_dict_files = glob.glob(data_dir + '/*.pth')
166
+ print("loading model from ", state_dict_files[-1])
167
+
168
+ state_dict = torch.load(state_dict_files[-1]) if to_ddp else torch.load(state_dict_files[0],map_location=self.device)
169
+
170
+ if from_ddp:
171
+ print("loading distributed model")
172
+ # Remove "module." from keys
173
+ new_state_dict = OrderedDict()
174
+ for key, value in state_dict.items():
175
+ if key.startswith('module.'):
176
+ while key.startswith('module.'):
177
+ key = key[7:]
178
+ new_state_dict[key] = value
179
+ state_dict = new_state_dict
180
+ # print("state_dict: ", state_dict.keys())
181
+ # print("model: ", self.model.state_dict().keys())
182
+
183
+ self.model.load_state_dict(state_dict, strict=False)
184
+
185
+ def check_gradients(self):
186
+ for name, param in self.model.named_parameters():
187
+ if param.grad is not None:
188
+ grad_norm = param.grad.norm().item()
189
+ if grad_norm > 10:
190
+ print(f"Large gradient in {name}: {grad_norm}")
191
+
192
+ def train_epoch(self, device, epoch):
193
+ """
194
+ Trains the model for one epoch.
195
+ """
196
+ if self.train_sampler is not None:
197
+ try:
198
+ self.train_sampler.set_epoch(epoch)
199
+ except AttributeError:
200
+ pass
201
+ self.model.train()
202
+ train_loss = []
203
+ train_acc = 0
204
+ total = 0
205
+ all_accs = torch.zeros(self.output_dim, device=device)
206
+ pbar = tqdm(self.train_dl)
207
+ for i, batch in enumerate(pbar):
208
+ if self.optimizer is not None:
209
+ self.optimizer.zero_grad()
210
+ loss, acc , y = self.train_batch(batch, i, device)
211
+ train_loss.append(loss.item())
212
+ all_accs = all_accs + acc
213
+ total += len(y)
214
+ pbar.set_description(f"train_acc: {acc}, train_loss: {loss.item()}")
215
+ if i > self.max_iter:
216
+ break
217
+ print("number of train_accs: ", train_acc)
218
+ return train_loss, all_accs/total
219
+
220
+ def train_batch(self, batch, batch_idx, device):
221
+ x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
222
+ x = x.to(device).float()
223
+ fft = fft.to(device).float()
224
+ y = y.to(device).float()
225
+ y_pred = self.model(fft)
226
+ loss = self.criterion(y_pred, y)
227
+ loss.backward()
228
+ self.optimizer.step()
229
+ if self.scheduler is not None:
230
+ self.scheduler.step()
231
+ # get predicted classes
232
+ probs = torch.sigmoid(y_pred)
233
+ cls_pred = (probs > 0.5).float()
234
+ acc = (cls_pred == y).sum()
235
+ return loss, acc, y
236
+
237
+ def eval_epoch(self, device, epoch):
238
+ """
239
+ Evaluates the model for one epoch.
240
+ """
241
+ self.model.eval()
242
+ val_loss = []
243
+ val_acc = 0
244
+ total = 0
245
+ all_accs = torch.zeros(self.output_dim, device=device)
246
+ pbar = tqdm(self.val_dl)
247
+ for i,batch in enumerate(pbar):
248
+ loss, acc, y = self.eval_batch(batch, i, device)
249
+ val_loss.append(loss.item())
250
+ all_accs = all_accs + acc
251
+ total += len(y)
252
+ pbar.set_description(f"val_acc: {acc}, val_loss: {loss.item()}")
253
+ if i > self.max_iter:
254
+ break
255
+ return val_loss, all_accs/total
256
+
257
+ def eval_batch(self, batch, batch_idx, device):
258
+ x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
259
+ x = x.to(device).float()
260
+ fft = fft.to(device).float()
261
+ y = y.to(device).float()
262
+ with torch.no_grad():
263
+ y_pred = self.model(fft)
264
+ loss = self.criterion(y_pred, y)
265
+ probs = torch.sigmoid(y_pred)
266
+ cls_pred = (probs > 0.5).float()
267
+ acc = (cls_pred == y).sum()
268
+ return loss, acc, y
269
+
270
+ def predict(self, test_dataloader, device):
271
+ """
272
+ Returns the predictions of the model on the given dataset.
273
+ """
274
+ self.model.eval()
275
+ total = 0
276
+ all_accs = 0
277
+ predictions = []
278
+ pbar = tqdm(self.val_dl)
279
+ for i,batch in enumerate(pbar):
280
+ x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
281
+ x = x.to(device).float()
282
+ fft = fft.to(device).float()
283
+ y = y.to(device).float()
284
+ with torch.no_grad():
285
+ y_pred = self.model(fft)
286
+ loss = self.criterion(y_pred, y)
287
+ probs = torch.sigmoid(y_pred)
288
+ cls_pred = (probs > 0.5).float()
289
+ acc = (cls_pred == y).sum()
290
+ predictions.append(cls_pred)
291
+ all_accs += acc
292
+ total += len(y)
293
+ return predictions, all_accs/total