Spaces:
Sleeping
Sleeping
IlayMalinyak
commited on
Commit
·
b3fb4dd
1
Parent(s):
192ac3b
first commit
Browse files- tasks/Modules/ResNet18.py +69 -0
- tasks/Modules/__init__.py +0 -0
- tasks/Modules/cnn.py +58 -0
- tasks/Modules/conformer.py +584 -0
- tasks/Modules/mhsa_pro.py +231 -0
- tasks/audio.py +37 -6
- tasks/config.yaml +66 -0
- tasks/data.py +43 -0
- tasks/data_utils.py +63 -0
- tasks/models.py +114 -0
- tasks/train.py +293 -0
tasks/Modules/ResNet18.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
# https://github.com/samcw/ResNet18-Pytorch
|
5 |
+
class ResBlock(nn.Module):
|
6 |
+
def __init__(self, inchannel, outchannel, stride=1):
|
7 |
+
super(ResBlock, self).__init__()
|
8 |
+
self.left = nn.Sequential(
|
9 |
+
nn.Conv1d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
|
10 |
+
nn.BatchNorm1d(outchannel),
|
11 |
+
nn.ReLU(inplace=True),
|
12 |
+
nn.Conv1d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
|
13 |
+
nn.BatchNorm1d(outchannel)
|
14 |
+
)
|
15 |
+
self.shortcut = nn.Sequential()
|
16 |
+
if stride != 1 or inchannel != outchannel:
|
17 |
+
self.shortcut = nn.Sequential(
|
18 |
+
nn.Conv1d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
|
19 |
+
nn.BatchNorm1d(outchannel)
|
20 |
+
)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
out = self.left(x)
|
24 |
+
out = out + self.shortcut(x)
|
25 |
+
out = F.relu(out)
|
26 |
+
|
27 |
+
return out
|
28 |
+
|
29 |
+
class ResNet18(nn.Module):
|
30 |
+
def __init__(self, args):
|
31 |
+
super(ResNet18, self).__init__()
|
32 |
+
self.inchannel = 64
|
33 |
+
self.conv1 = nn.Sequential(
|
34 |
+
nn.Conv1d(1, 64, kernel_size=3, stride=1, padding=1, bias=False),
|
35 |
+
nn.BatchNorm1d(64),
|
36 |
+
nn.ReLU()
|
37 |
+
)
|
38 |
+
self.layer1 = self.make_layer(ResBlock, 64, 2, stride=1)
|
39 |
+
self.layer2 = self.make_layer(ResBlock, 128, 2, stride=2)
|
40 |
+
self.layer3 = self.make_layer(ResBlock, 256, 2, stride=2)
|
41 |
+
self.layer4 = self.make_layer(ResBlock, 512, 2, stride=2)
|
42 |
+
self.pred_layer = nn.Sequential(
|
43 |
+
nn.Linear(512, 512),
|
44 |
+
nn.SiLU(),
|
45 |
+
nn.Dropout(p=0.3),
|
46 |
+
nn.Linear(512, 1),
|
47 |
+
)
|
48 |
+
if getattr(args, 'mean_label', False):
|
49 |
+
self.pred_layer[3].bias.data.fill_(args.mean_label)
|
50 |
+
|
51 |
+
def make_layer(self, block, channels, num_blocks, stride):
|
52 |
+
strides = [stride] + [1] * (num_blocks - 1)
|
53 |
+
layers = []
|
54 |
+
for stride in strides:
|
55 |
+
layers.append(block(self.inchannel, channels, stride))
|
56 |
+
self.inchannel = channels
|
57 |
+
return nn.Sequential(*layers)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
x = x.unsqueeze(1)
|
61 |
+
out = self.conv1(x)
|
62 |
+
out = F.max_pool1d(out, 3, 2, 1)
|
63 |
+
out = self.layer1(out)
|
64 |
+
out = self.layer2(out)
|
65 |
+
out = self.layer3(out)
|
66 |
+
out = self.layer4(out)
|
67 |
+
out = out.mean(-1)
|
68 |
+
out = self.pred_layer(out)
|
69 |
+
return out
|
tasks/Modules/__init__.py
ADDED
File without changes
|
tasks/Modules/cnn.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class ConvBlock(nn.Module):
|
5 |
+
def __init__(self, args) -> None:
|
6 |
+
super().__init__()
|
7 |
+
self.layers = nn.Sequential(
|
8 |
+
nn.Conv1d(in_channels=args.encoder_dim,
|
9 |
+
out_channels=args.encoder_dim,
|
10 |
+
kernel_size=args.kernel_size,
|
11 |
+
stride=1, padding='same', bias=False),
|
12 |
+
nn.BatchNorm1d(num_features=args.encoder_dim),
|
13 |
+
nn.SiLU(),
|
14 |
+
)
|
15 |
+
|
16 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
17 |
+
x = x.transpose(1, 2)
|
18 |
+
return self.layers(x).transpose(1, 2)
|
19 |
+
|
20 |
+
class ConvBlockDecoder(nn.Module):
|
21 |
+
def __init__(self, args) -> None:
|
22 |
+
super().__init__()
|
23 |
+
self.layers = nn.Sequential(
|
24 |
+
nn.Conv1d(in_channels=args.decoder_dim,
|
25 |
+
out_channels=args.decoder_dim,
|
26 |
+
kernel_size=args.kernel_size,
|
27 |
+
stride=1, padding='same', bias=False),
|
28 |
+
nn.BatchNorm1d(num_features=args.decoder_dim),
|
29 |
+
nn.SiLU(),
|
30 |
+
)
|
31 |
+
|
32 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
33 |
+
x = x.transpose(1, 2)
|
34 |
+
return self.layers(x).transpose(1, 2)
|
35 |
+
|
36 |
+
class ResNetLayer(nn.Module):
|
37 |
+
def __init__(self, args) -> None:
|
38 |
+
super().__init__()
|
39 |
+
self.conv_layer = nn.Sequential(
|
40 |
+
nn.Conv1d(in_channels=args.encoder_dim,
|
41 |
+
out_channels=args.encoder_dim,
|
42 |
+
kernel_size=3,
|
43 |
+
stride=1, padding='same', bias=False),
|
44 |
+
nn.BatchNorm1d(num_features=args.encoder_dim),
|
45 |
+
nn.SiLU(),
|
46 |
+
)
|
47 |
+
|
48 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
49 |
+
return self.conv_layer(x)+x
|
50 |
+
|
51 |
+
|
52 |
+
class ResNetBlock(nn.Module):
|
53 |
+
def __init__(self, args) -> None:
|
54 |
+
super().__init__()
|
55 |
+
self.layers = nn.Sequential(*[ResNetLayer(args) for _ in range(3)])
|
56 |
+
|
57 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
58 |
+
return self.layers(x)
|
tasks/Modules/conformer.py
ADDED
@@ -0,0 +1,584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import Tensor
|
5 |
+
import torch.nn.init as init
|
6 |
+
import math
|
7 |
+
|
8 |
+
from .mhsa_pro import MHA_rotary, MHA_decoder
|
9 |
+
from .cnn import ConvBlock, ConvBlockDecoder
|
10 |
+
|
11 |
+
from typing import Optional,Tuple
|
12 |
+
|
13 |
+
class ResidualConnectionModule(nn.Module):
|
14 |
+
"""
|
15 |
+
Residual Connection Module.
|
16 |
+
outputs = (module(inputs) x module_factor + inputs x input_factor)
|
17 |
+
"""
|
18 |
+
def __init__(self, module: nn.Module, dims, args):
|
19 |
+
super(ResidualConnectionModule, self).__init__()
|
20 |
+
self.module = module
|
21 |
+
self.module_factor = 1
|
22 |
+
self.input_factor = 1
|
23 |
+
|
24 |
+
def forward(self, inputs: Tensor, **kwargs) -> Tensor:
|
25 |
+
return (self.module(inputs, **kwargs) * self.module_factor) + (inputs * self.input_factor)
|
26 |
+
|
27 |
+
class PostNorm(nn.Module):
|
28 |
+
"""
|
29 |
+
Residual Connection Module.
|
30 |
+
outputs = (module(inputs) x module_factor + inputs x input_factor)
|
31 |
+
"""
|
32 |
+
def __init__(self, module: nn.Module, dims, args):
|
33 |
+
super(PostNorm, self).__init__()
|
34 |
+
self.module = module
|
35 |
+
input_factor = torch.FloatTensor(args.alpha) if getattr(args, 'alpha', None) else torch.tensor(1.)
|
36 |
+
self.register_buffer('input_factor', input_factor)
|
37 |
+
self.norm = nn.LayerNorm(dims)
|
38 |
+
|
39 |
+
def forward(self, inputs: Tensor, **kwargs) -> Tensor:
|
40 |
+
return self.norm(self.module(inputs, **kwargs) + (inputs * self.input_factor))
|
41 |
+
|
42 |
+
class Linear(nn.Module):
|
43 |
+
"""
|
44 |
+
Wrapper class of torch.nn.Linear
|
45 |
+
Weight initialize by xavier initialization and bias initialize to zeros.
|
46 |
+
"""
|
47 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
|
48 |
+
super(Linear, self).__init__()
|
49 |
+
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
50 |
+
init.xavier_uniform_(self.linear.weight)
|
51 |
+
if bias:
|
52 |
+
init.zeros_(self.linear.bias)
|
53 |
+
|
54 |
+
def forward(self, x: Tensor) -> Tensor:
|
55 |
+
return self.linear(x)
|
56 |
+
|
57 |
+
|
58 |
+
class View(nn.Module):
|
59 |
+
""" Wrapper class of torch.view() for Sequential module. """
|
60 |
+
def __init__(self, shape: tuple, contiguous: bool = False):
|
61 |
+
super(View, self).__init__()
|
62 |
+
self.shape = shape
|
63 |
+
self.contiguous = contiguous
|
64 |
+
|
65 |
+
def forward(self, x: Tensor) -> Tensor:
|
66 |
+
if self.contiguous:
|
67 |
+
x = x.contiguous()
|
68 |
+
|
69 |
+
return x.view(*self.shape)
|
70 |
+
|
71 |
+
|
72 |
+
class Transpose(nn.Module):
|
73 |
+
""" Wrapper class of torch.transpose() for Sequential module. """
|
74 |
+
def __init__(self, shape: tuple):
|
75 |
+
super(Transpose, self).__init__()
|
76 |
+
self.shape = shape
|
77 |
+
|
78 |
+
def forward(self, x: Tensor) -> Tensor:
|
79 |
+
return x.transpose(*self.shape)
|
80 |
+
|
81 |
+
class FeedForwardModule(nn.Module):
|
82 |
+
"""
|
83 |
+
Conformer Feed Forward Module follow pre-norm residual units and apply layer normalization within the residual unit
|
84 |
+
and on the input before the first linear layer. This module also apply Swish activation and dropout, which helps
|
85 |
+
regularizing the network.
|
86 |
+
Args:
|
87 |
+
encoder_dim (int): Dimension of conformer encoder
|
88 |
+
expansion_factor (int): Expansion factor of feed forward module.
|
89 |
+
dropout_p (float): Ratio of dropout
|
90 |
+
device (torch.device): torch device (cuda or cpu)
|
91 |
+
Inputs: inputs
|
92 |
+
- **inputs** (batch, time, dim): Tensor contains input sequences
|
93 |
+
Outputs: outputs
|
94 |
+
- **outputs** (batch, time, dim): Tensor produces by feed forward module.
|
95 |
+
"""
|
96 |
+
def __init__(
|
97 |
+
self,
|
98 |
+
args,
|
99 |
+
|
100 |
+
) -> None:
|
101 |
+
super(FeedForwardModule, self).__init__()
|
102 |
+
expansion_factor = 4
|
103 |
+
self.sequential = nn.Sequential(
|
104 |
+
nn.LayerNorm(args.encoder_dim),
|
105 |
+
Linear(args.encoder_dim, args.encoder_dim * expansion_factor, bias=True),
|
106 |
+
nn.SiLU(),
|
107 |
+
nn.Dropout(p=args.dropout_p),
|
108 |
+
Linear(args.encoder_dim * expansion_factor, args.encoder_dim, bias=True),
|
109 |
+
nn.Dropout(p=args.dropout_p),
|
110 |
+
)
|
111 |
+
|
112 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
113 |
+
return self.sequential(inputs)
|
114 |
+
|
115 |
+
class DepthwiseConv1d(nn.Module):
|
116 |
+
"""
|
117 |
+
When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
|
118 |
+
this operation is termed in literature as depthwise convolution.
|
119 |
+
Args:
|
120 |
+
in_channels (int): Number of channels in the input
|
121 |
+
out_channels (int): Number of channels produced by the convolution
|
122 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
123 |
+
stride (int, optional): Stride of the convolution. Default: 1
|
124 |
+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
125 |
+
bias (bool, optional): If True, adds a learnable bias to the output. Default: True
|
126 |
+
Inputs: inputs
|
127 |
+
- **inputs** (batch, in_channels, time): Tensor containing input vector
|
128 |
+
Returns: outputs
|
129 |
+
- **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
|
130 |
+
"""
|
131 |
+
def __init__(
|
132 |
+
self,
|
133 |
+
in_channels: int,
|
134 |
+
out_channels: int,
|
135 |
+
kernel_size: int,
|
136 |
+
stride: int = 1,
|
137 |
+
padding: int = 0,
|
138 |
+
bias: bool = False,
|
139 |
+
) -> None:
|
140 |
+
super(DepthwiseConv1d, self).__init__()
|
141 |
+
assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
|
142 |
+
self.conv = nn.Conv1d(
|
143 |
+
in_channels=in_channels,
|
144 |
+
out_channels=out_channels,
|
145 |
+
kernel_size=kernel_size,
|
146 |
+
groups=in_channels,
|
147 |
+
stride=stride,
|
148 |
+
padding=padding,
|
149 |
+
bias=bias,
|
150 |
+
)
|
151 |
+
|
152 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
153 |
+
return self.conv(inputs)
|
154 |
+
|
155 |
+
|
156 |
+
class PointwiseConv1d(nn.Module):
|
157 |
+
"""
|
158 |
+
When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution.
|
159 |
+
This operation often used to match dimensions.
|
160 |
+
Args:
|
161 |
+
in_channels (int): Number of channels in the input
|
162 |
+
out_channels (int): Number of channels produced by the convolution
|
163 |
+
stride (int, optional): Stride of the convolution. Default: 1
|
164 |
+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
165 |
+
bias (bool, optional): If True, adds a learnable bias to the output. Default: True
|
166 |
+
Inputs: inputs
|
167 |
+
- **inputs** (batch, in_channels, time): Tensor containing input vector
|
168 |
+
Returns: outputs
|
169 |
+
- **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution.
|
170 |
+
"""
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
in_channels: int,
|
174 |
+
out_channels: int,
|
175 |
+
stride: int = 1,
|
176 |
+
padding: int = 0,
|
177 |
+
bias: bool = True,
|
178 |
+
) -> None:
|
179 |
+
super(PointwiseConv1d, self).__init__()
|
180 |
+
self.conv = nn.Conv1d(
|
181 |
+
in_channels=in_channels,
|
182 |
+
out_channels=out_channels,
|
183 |
+
kernel_size=1,
|
184 |
+
stride=stride,
|
185 |
+
padding=padding,
|
186 |
+
bias=bias,
|
187 |
+
)
|
188 |
+
|
189 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
190 |
+
return self.conv(inputs)
|
191 |
+
|
192 |
+
|
193 |
+
class ConformerConvModule(nn.Module):
|
194 |
+
"""
|
195 |
+
Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
|
196 |
+
This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
|
197 |
+
to aid training deep models.
|
198 |
+
Args:
|
199 |
+
in_channels (int): Number of channels in the input
|
200 |
+
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
|
201 |
+
dropout_p (float, optional): probability of dropout
|
202 |
+
Inputs: inputs
|
203 |
+
inputs (batch, time, dim): Tensor contains input sequences
|
204 |
+
Outputs: outputs
|
205 |
+
outputs (batch, time, dim): Tensor produces by conformer convolution module.
|
206 |
+
"""
|
207 |
+
def __init__(
|
208 |
+
self,
|
209 |
+
args,
|
210 |
+
) -> None:
|
211 |
+
super(ConformerConvModule, self).__init__()
|
212 |
+
assert (args.kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
|
213 |
+
expansion_factor = 2
|
214 |
+
dropout_p = 0.1
|
215 |
+
|
216 |
+
self.sequential = nn.Sequential(
|
217 |
+
nn.LayerNorm(args.encoder_dim),
|
218 |
+
Transpose(shape=(1, 2)),
|
219 |
+
PointwiseConv1d(args.encoder_dim, args.encoder_dim * expansion_factor, stride=1, padding=0, bias=True),
|
220 |
+
nn.GLU(dim=1),
|
221 |
+
DepthwiseConv1d(args.encoder_dim, args.encoder_dim, args.kernel_size, stride=1, padding=(args.kernel_size - 1) // 2),
|
222 |
+
nn.BatchNorm1d(args.encoder_dim),
|
223 |
+
nn.SiLU(),
|
224 |
+
PointwiseConv1d(args.encoder_dim, args.encoder_dim, stride=1, padding=0, bias=True),
|
225 |
+
nn.Dropout(p=dropout_p),
|
226 |
+
)
|
227 |
+
|
228 |
+
def forward(self, inputs: Tensor) -> Tensor:
|
229 |
+
return self.sequential(inputs).transpose(1, 2)
|
230 |
+
|
231 |
+
class PositionalEncoding(nn.Module):
|
232 |
+
"""
|
233 |
+
Positional Encoding proposed in "Attention Is All You Need".
|
234 |
+
Since transformer contains no recurrence and no convolution, in order for the model to make
|
235 |
+
use of the order of the sequence, we must add some positional information.
|
236 |
+
"Attention Is All You Need" use sine and cosine functions of different frequencies:
|
237 |
+
PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model))
|
238 |
+
PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model))
|
239 |
+
"""
|
240 |
+
def __init__(self, d_model: int = 128, max_len: int = 10000) -> None:
|
241 |
+
super(PositionalEncoding, self).__init__()
|
242 |
+
pe = torch.zeros(max_len, d_model, requires_grad=False)
|
243 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
244 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
|
245 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
246 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
247 |
+
pe = pe.unsqueeze(0)
|
248 |
+
self.register_buffer('pe', pe)
|
249 |
+
|
250 |
+
def forward(self, length: int) -> Tensor:
|
251 |
+
return self.pe[:, :length]
|
252 |
+
|
253 |
+
class RelativeMultiHeadAttention(nn.Module):
|
254 |
+
"""
|
255 |
+
Multi-head attention with relative positional encoding.
|
256 |
+
This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
257 |
+
Args:
|
258 |
+
d_model (int): The dimension of model
|
259 |
+
num_heads (int): The number of attention heads.
|
260 |
+
dropout_p (float): probability of dropout
|
261 |
+
Inputs: query, key, value, pos_embedding, mask
|
262 |
+
- **query** (batch, time, dim): Tensor containing query vector
|
263 |
+
- **key** (batch, time, dim): Tensor containing key vector
|
264 |
+
- **value** (batch, time, dim): Tensor containing value vector
|
265 |
+
- **pos_embedding** (batch, time, dim): Positional embedding tensor
|
266 |
+
- **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
|
267 |
+
Returns:
|
268 |
+
- **outputs**: Tensor produces by relative multi head attention module.
|
269 |
+
"""
|
270 |
+
def __init__(
|
271 |
+
self,
|
272 |
+
encoder_dim: int = 128,
|
273 |
+
num_heads: int = 8,
|
274 |
+
dropout_p: float = 0.1
|
275 |
+
):
|
276 |
+
super(RelativeMultiHeadAttention, self).__init__()
|
277 |
+
assert encoder_dim % num_heads == 0, "d_model % num_heads should be zero."
|
278 |
+
self.d_model = encoder_dim
|
279 |
+
self.d_head = int(encoder_dim / num_heads)
|
280 |
+
self.num_heads = num_heads
|
281 |
+
self.sqrt_dim = math.sqrt(encoder_dim)
|
282 |
+
|
283 |
+
self.query_proj = Linear(encoder_dim, encoder_dim)
|
284 |
+
self.key_proj = Linear(encoder_dim, encoder_dim)
|
285 |
+
self.value_proj = Linear(encoder_dim, encoder_dim)
|
286 |
+
self.pos_proj = Linear(encoder_dim, encoder_dim, bias=False)
|
287 |
+
|
288 |
+
self.dropout = nn.Dropout(p=dropout_p)
|
289 |
+
self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
290 |
+
self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
|
291 |
+
torch.nn.init.xavier_uniform_(self.u_bias)
|
292 |
+
torch.nn.init.xavier_uniform_(self.v_bias)
|
293 |
+
|
294 |
+
self.out_proj = Linear(encoder_dim, encoder_dim)
|
295 |
+
|
296 |
+
def forward(
|
297 |
+
self,
|
298 |
+
query: Tensor,
|
299 |
+
key: Tensor,
|
300 |
+
value: Tensor,
|
301 |
+
pos_embedding: Tensor,
|
302 |
+
mask: Optional[Tensor] = None,
|
303 |
+
) -> Tensor:
|
304 |
+
batch_size = value.size(0)
|
305 |
+
|
306 |
+
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
|
307 |
+
query = query.view(batch_size, -1, self.num_heads, self.d_head)
|
308 |
+
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
309 |
+
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
|
310 |
+
pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
|
311 |
+
|
312 |
+
content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
|
313 |
+
pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
|
314 |
+
# content_score = torch.matmul((query).transpose(1, 2), key.transpose(2, 3))
|
315 |
+
# pos_score = torch.matmul((query).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
|
316 |
+
#Q(B,numheads,length,d_head)*PE(B,numheads,d_heads,length) = posscore(B,num_heads,length,length)
|
317 |
+
pos_score = self._relative_shift(pos_score)
|
318 |
+
score = (content_score + pos_score) / self.sqrt_dim
|
319 |
+
|
320 |
+
if mask is not None:
|
321 |
+
mask = mask.unsqueeze(1)
|
322 |
+
score.masked_fill_(mask, -1e9)
|
323 |
+
|
324 |
+
score = F.softmax(score, -1)
|
325 |
+
attn = self.dropout(score)
|
326 |
+
|
327 |
+
context = torch.matmul(attn, value).transpose(1, 2)
|
328 |
+
context = context.contiguous().view(batch_size, -1, self.d_model)
|
329 |
+
|
330 |
+
return self.out_proj(context)
|
331 |
+
|
332 |
+
def _relative_shift(self, pos_score: Tensor) -> Tensor:
|
333 |
+
batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
|
334 |
+
zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
|
335 |
+
padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
|
336 |
+
|
337 |
+
padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
|
338 |
+
pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
|
339 |
+
#shift position score a unit along length axis and leave a blank row.
|
340 |
+
return pos_score
|
341 |
+
|
342 |
+
|
343 |
+
class MultiHeadedSelfAttentionModule(nn.Module):
|
344 |
+
"""
|
345 |
+
Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL,
|
346 |
+
the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention
|
347 |
+
module to generalize better on different input length and the resulting encoder is more robust to the variance of
|
348 |
+
the utterance length. Conformer use prenorm residual units with dropout which helps training
|
349 |
+
and regularizing deeper models.
|
350 |
+
Args:
|
351 |
+
d_model (int): The dimension of model
|
352 |
+
num_heads (int): The number of attention heads.
|
353 |
+
dropout_p (float): probability of dropout
|
354 |
+
device (torch.device): torch device (cuda or cpu)
|
355 |
+
Inputs: inputs, mask
|
356 |
+
- **inputs** (batch, time, dim): Tensor containing input vector
|
357 |
+
- **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
|
358 |
+
Returns:
|
359 |
+
- **outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module.
|
360 |
+
"""
|
361 |
+
def __init__(self, args):
|
362 |
+
super(MultiHeadedSelfAttentionModule, self).__init__()
|
363 |
+
dropout_p = 0.1
|
364 |
+
self.positional_encoding = PositionalEncoding(args.encoder_dim)
|
365 |
+
self.layer_norm = nn.LayerNorm(args.encoder_dim)
|
366 |
+
self.attention = RelativeMultiHeadAttention(args.encoder_dim, args.num_heads, args.dropout_p)
|
367 |
+
self.dropout = nn.Dropout(p=dropout_p)
|
368 |
+
|
369 |
+
def forward(self, inputs: Tensor, mask: Optional[Tensor] = None):
|
370 |
+
batch_size, seq_length, _ = inputs.size()
|
371 |
+
pos_embedding = self.positional_encoding(seq_length)
|
372 |
+
pos_embedding = pos_embedding.repeat(batch_size, 1, 1)
|
373 |
+
|
374 |
+
inputs = self.layer_norm(inputs)
|
375 |
+
outputs = self.attention(inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask)
|
376 |
+
|
377 |
+
return self.dropout(outputs)
|
378 |
+
|
379 |
+
class ConformerBlock(nn.Module):
|
380 |
+
"""
|
381 |
+
Conformer block contains two Feed Forward modules sandwiching the Multi-Headed Self-Attention module
|
382 |
+
and the Convolution module. This sandwich structure is inspired by Macaron-Net, which proposes replacing
|
383 |
+
the original feed-forward layer in the Transformer block into two half-step feed-forward layers,
|
384 |
+
one before the attention layer and one after.
|
385 |
+
Args:
|
386 |
+
encoder_dim (int, optional): Dimension of conformer encoder
|
387 |
+
num_attention_heads (int, optional): Number of attention heads
|
388 |
+
feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
|
389 |
+
conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
|
390 |
+
feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
|
391 |
+
attention_dropout_p (float, optional): Probability of attention module dropout
|
392 |
+
conv_dropout_p (float, optional): Probability of conformer convolution module dropout
|
393 |
+
conv_kernel_size (int or tuple, optional): Size of the convolving kernel
|
394 |
+
half_step_residual (bool): Flag indication whether to use half step residual or not
|
395 |
+
device (torch.device): torch device (cuda or cpu)
|
396 |
+
Inputs: inputs
|
397 |
+
- **inputs** (batch, time, dim): Tensor containing input vector
|
398 |
+
Returns: outputs
|
399 |
+
- **outputs** (batch, time, dim): Tensor produces by conformer block.
|
400 |
+
"""
|
401 |
+
def __init__(
|
402 |
+
self,
|
403 |
+
args
|
404 |
+
):
|
405 |
+
super(ConformerBlock, self).__init__()
|
406 |
+
|
407 |
+
norm_dict = {
|
408 |
+
'shortcut': ResidualConnectionModule,
|
409 |
+
'postnorm': PostNorm
|
410 |
+
}
|
411 |
+
block_dict = {
|
412 |
+
'ffn': FeedForwardModule,
|
413 |
+
'mhsa': MultiHeadedSelfAttentionModule,
|
414 |
+
'mhsa_pro': MHA_rotary,
|
415 |
+
'conv': ConvBlock,
|
416 |
+
'conformerconv': ConformerConvModule
|
417 |
+
}
|
418 |
+
|
419 |
+
self.modlist = nn.ModuleList([norm_dict[args.norm](block_dict[block](args), args.encoder_dim, args) for block in args.encoder]\
|
420 |
+
)
|
421 |
+
|
422 |
+
def forward(self, x: Tensor, RoPE, key_padding_mask=None) -> Tensor:
|
423 |
+
for m in self.modlist:
|
424 |
+
if isinstance(m.module, MHA_rotary):
|
425 |
+
x = m(x, RoPE=RoPE, key_padding_mask=key_padding_mask)
|
426 |
+
else:
|
427 |
+
x = m(x)
|
428 |
+
return x
|
429 |
+
|
430 |
+
|
431 |
+
class DecoderBlock(nn.Module):
|
432 |
+
"""
|
433 |
+
Decoder block contains two Feed Forward modules sandwiching the Multi-Headed Self-Attention module
|
434 |
+
and the Convolution module. This sandwich structure is inspired by Macaron-Net, which proposes replacing
|
435 |
+
the original feed-forward layer in the Transformer block into two half-step feed-forward layers,
|
436 |
+
one before the attention layer and one after.
|
437 |
+
Args:
|
438 |
+
encoder_dim (int, optional): Dimension of conformer encoder
|
439 |
+
num_attention_heads (int, optional): Number of attention heads
|
440 |
+
feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
|
441 |
+
conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
|
442 |
+
feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
|
443 |
+
attention_dropout_p (float, optional): Probability of attention module dropout
|
444 |
+
conv_dropout_p (float, optional): Probability of conformer convolution module dropout
|
445 |
+
conv_kernel_size (int or tuple, optional): Size of the convolving kernel
|
446 |
+
half_step_residual (bool): Flag indication whether to use half step residual or not
|
447 |
+
device (torch.device): torch device (cuda or cpu)
|
448 |
+
Inputs: inputs
|
449 |
+
- **inputs** (batch, time, dim): Tensor containing input vector
|
450 |
+
Returns: outputs
|
451 |
+
- **outputs** (batch, time, dim): Tensor produces by conformer block.
|
452 |
+
"""
|
453 |
+
def __init__(
|
454 |
+
self,
|
455 |
+
args
|
456 |
+
):
|
457 |
+
super(DecoderBlock, self).__init__()
|
458 |
+
|
459 |
+
norm_dict = {
|
460 |
+
'shortcut': ResidualConnectionModule,
|
461 |
+
'postnorm': PostNorm
|
462 |
+
}
|
463 |
+
block_dict = {
|
464 |
+
'ffn': FeedForwardModule,
|
465 |
+
'mhsa': MultiHeadedSelfAttentionModule,
|
466 |
+
'mhsa_pro': MHA_rotary,
|
467 |
+
'mhsa_decoder': MHA_decoder,
|
468 |
+
'conv': ConvBlockDecoder,
|
469 |
+
'conformerconv': ConformerConvModule
|
470 |
+
}
|
471 |
+
|
472 |
+
self.modlist = nn.ModuleList([norm_dict[args.norm](block_dict[block](args),args.decoder_dim, args) for block in args.decoder]\
|
473 |
+
)
|
474 |
+
|
475 |
+
def forward(self, x: Tensor, memory:Tensor, RoPE, key_padding_mask=None) -> Tensor:
|
476 |
+
for m in self.modlist:
|
477 |
+
if isinstance(m.module, MHA_decoder):
|
478 |
+
x = m(x, memory=memory, RoPE=RoPE, key_padding_mask=key_padding_mask)
|
479 |
+
elif isinstance(m.module, MHA_rotary):
|
480 |
+
x = m(x, RoPE=RoPE, key_padding_mask=key_padding_mask).transpose(0,1)
|
481 |
+
else:
|
482 |
+
x = m(x)
|
483 |
+
return x
|
484 |
+
|
485 |
+
|
486 |
+
class ConformerEncoder(nn.Module):
|
487 |
+
"""
|
488 |
+
Conformer encoder first processes the input with a convolution subsampling layer and then
|
489 |
+
with a number of conformer blocks.
|
490 |
+
Args:
|
491 |
+
input_dim (int, optional): Dimension of input vector
|
492 |
+
encoder_dim (int, optional): Dimension of conformer encoder
|
493 |
+
num_layers (int, optional): Number of conformer blocks
|
494 |
+
num_attention_heads (int, optional): Number of attention heads
|
495 |
+
feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
|
496 |
+
conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
|
497 |
+
feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
|
498 |
+
attention_dropout_p (float, optional): Probability of attention module dropout
|
499 |
+
conv_dropout_p (float, optional): Probability of conformer convolution module dropout
|
500 |
+
conv_kernel_size (int or tuple, optional): Size of the convolving kernel
|
501 |
+
half_step_residual (bool): Flag indication whether to use half step residual or not
|
502 |
+
device (torch.device): torch device (cuda or cpu)
|
503 |
+
Inputs: inputs, input_lengths
|
504 |
+
- **inputs** (batch, time, dim): Tensor containing input vector
|
505 |
+
- **input_lengths** (batch): list of sequence input lengths
|
506 |
+
Returns: outputs, output_lengths
|
507 |
+
- **outputs** (batch, out_channels, time): Tensor produces by conformer encoder.
|
508 |
+
- **output_lengths** (batch): list of sequence output lengths
|
509 |
+
"""
|
510 |
+
def __init__(
|
511 |
+
self,
|
512 |
+
args,
|
513 |
+
):
|
514 |
+
super(ConformerEncoder, self).__init__()
|
515 |
+
self.blocks = nn.ModuleList([ConformerBlock(
|
516 |
+
args) for _ in range(args.num_layers)])
|
517 |
+
|
518 |
+
def forward(self, x: Tensor, RoPE=None, key_padding_mask=None) -> Tuple[Tensor, Tensor]:
|
519 |
+
"""
|
520 |
+
Forward propagate a `inputs` for encoder training.
|
521 |
+
Args:
|
522 |
+
inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
|
523 |
+
`FloatTensor` of size ``(batch, seq_length, dimension)``.
|
524 |
+
input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
|
525 |
+
Returns:
|
526 |
+
(Tensor, Tensor)
|
527 |
+
* outputs (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size
|
528 |
+
``(batch, seq_length, dimension)``
|
529 |
+
* output_lengths (torch.LongTensor): The length of output tensor. ``(batch)``
|
530 |
+
"""
|
531 |
+
for block in self.blocks:
|
532 |
+
x = block(x, RoPE=RoPE, key_padding_mask=key_padding_mask)
|
533 |
+
|
534 |
+
return x
|
535 |
+
|
536 |
+
class ConformerDecoder(nn.Module):
|
537 |
+
"""
|
538 |
+
Conformer encoder first processes the input with a convolution subsampling layer and then
|
539 |
+
with a number of conformer blocks.
|
540 |
+
Args:
|
541 |
+
input_dim (int, optional): Dimension of input vector
|
542 |
+
encoder_dim (int, optional): Dimension of conformer encoder
|
543 |
+
num_layers (int, optional): Number of conformer blocks
|
544 |
+
num_attention_heads (int, optional): Number of attention heads
|
545 |
+
feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
|
546 |
+
conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
|
547 |
+
feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
|
548 |
+
attention_dropout_p (float, optional): Probability of attention module dropout
|
549 |
+
conv_dropout_p (float, optional): Probability of conformer convolution module dropout
|
550 |
+
conv_kernel_size (int or tuple, optional): Size of the convolving kernel
|
551 |
+
half_step_residual (bool): Flag indication whether to use half step residual or not
|
552 |
+
device (torch.device): torch device (cuda or cpu)
|
553 |
+
Inputs: inputs, input_lengths
|
554 |
+
- **inputs** (batch, time, dim): Tensor containing input vector
|
555 |
+
- **input_lengths** (batch): list of sequence input lengths
|
556 |
+
Returns: outputs, output_lengths
|
557 |
+
- **outputs** (batch, out_channels, time): Tensor produces by conformer encoder.
|
558 |
+
- **output_lengths** (batch): list of sequence output lengths
|
559 |
+
"""
|
560 |
+
def __init__(
|
561 |
+
self,
|
562 |
+
args,
|
563 |
+
):
|
564 |
+
super(ConformerDecoder, self).__init__()
|
565 |
+
self.blocks = nn.ModuleList([DecoderBlock(
|
566 |
+
args) for _ in range(args.num_decoder_layers)])
|
567 |
+
|
568 |
+
def forward(self, x: Tensor, memory: Tensor, RoPE=None, key_padding_mask=None) -> Tuple[Tensor, Tensor]:
|
569 |
+
"""
|
570 |
+
Forward propagate a `inputs` for encoder training.
|
571 |
+
Args:
|
572 |
+
inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
|
573 |
+
`FloatTensor` of size ``(batch, seq_length, dimension)``.
|
574 |
+
input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
|
575 |
+
Returns:
|
576 |
+
(Tensor, Tensor)
|
577 |
+
* outputs (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size
|
578 |
+
``(batch, seq_length, dimension)``
|
579 |
+
* output_lengths (torch.LongTensor): The length of output tensor. ``(batch)``
|
580 |
+
"""
|
581 |
+
for block in self.blocks:
|
582 |
+
x = block(x, memory, RoPE=RoPE, key_padding_mask=key_padding_mask)
|
583 |
+
|
584 |
+
return x
|
tasks/Modules/mhsa_pro.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.nn.init as init
|
5 |
+
|
6 |
+
from typing import Optional,Tuple
|
7 |
+
import math
|
8 |
+
import logging
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
rwkv_emb_scale = 0.4 # try 0.4 for char-level english. try 1.0 for chinese.
|
14 |
+
rwkv_layer_decay = 1.0 # decay weights in higher layers. try 0.5 ~ 1.0.
|
15 |
+
|
16 |
+
class AttentionConfig:
|
17 |
+
def __init__(self, ctx_len=100, **kwargs):
|
18 |
+
self.ctx_len = ctx_len
|
19 |
+
for k,v in kwargs.items():
|
20 |
+
setattr(self, k, v)
|
21 |
+
|
22 |
+
|
23 |
+
########################################################################################################
|
24 |
+
# MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN
|
25 |
+
########################################################################################################
|
26 |
+
|
27 |
+
class RotaryEmbedding(torch.nn.Module):
|
28 |
+
def __init__(self, dim, base=10000):
|
29 |
+
super().__init__()
|
30 |
+
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
31 |
+
self.register_buffer('inv_freq', inv_freq)
|
32 |
+
self.seq_len_cached = None
|
33 |
+
self.cos_cached = None
|
34 |
+
self.sin_cached = None
|
35 |
+
|
36 |
+
def forward(self, x, seq_len=None):
|
37 |
+
if seq_len != self.seq_len_cached:
|
38 |
+
self.seq_len_cached = seq_len
|
39 |
+
t = torch.arange(seq_len, device=x.device)
|
40 |
+
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
|
41 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
42 |
+
self.cos_cached = emb.cos()
|
43 |
+
self.sin_cached = emb.sin()
|
44 |
+
return torch.stack([self.cos_cached, self.sin_cached])
|
45 |
+
|
46 |
+
class ContinuousRotaryEmbedding(torch.nn.Module):
|
47 |
+
'''Continuous rotary position embedding'''
|
48 |
+
def __init__(self, dim, sequence_scale):
|
49 |
+
super().__init__()
|
50 |
+
base=10000
|
51 |
+
self.sequence_scale = sequence_scale
|
52 |
+
self.register_buffer('inv_freq', 1. / (base ** (torch.arange(0, dim, 2))))
|
53 |
+
|
54 |
+
def forward(self, t):
|
55 |
+
t = (t + 0.5)* self.sequence_scale
|
56 |
+
freqs = torch.einsum('ij,k->ijk', t, self.inv_freq) # freqs: [B, L, dim//2]
|
57 |
+
emb = torch.cat((freqs, freqs), dim=-1).unsqueeze(1) # emb: [B, 1, L, dim], 1 for broadcast in head_num dim
|
58 |
+
return torch.stack([emb.cos(), emb.sin()])
|
59 |
+
|
60 |
+
def rotate_half(x):
|
61 |
+
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
62 |
+
return torch.cat((-x2, x1), -1)
|
63 |
+
|
64 |
+
@torch.jit.script
|
65 |
+
def apply_rotary_pos_emb(q, k, cos, sin):
|
66 |
+
cos, sin = cos[...,:q.shape[2],:], sin[...,:q.shape[2],:]
|
67 |
+
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
68 |
+
|
69 |
+
class MHA_rotary(nn.Module):
|
70 |
+
def __init__(self, args):
|
71 |
+
super().__init__()
|
72 |
+
self.collect_attention_map = False
|
73 |
+
self.attention_map = None
|
74 |
+
assert args.encoder_dim % args.num_heads == 0
|
75 |
+
self.num_heads = args.num_heads
|
76 |
+
self.head_size = args.encoder_dim // args.num_heads
|
77 |
+
|
78 |
+
if args.timeshift:
|
79 |
+
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
80 |
+
|
81 |
+
self.query = nn.Linear(args.encoder_dim, args.encoder_dim)
|
82 |
+
self.key = nn.Linear(args.encoder_dim, args.encoder_dim)
|
83 |
+
self.value = nn.Linear(args.encoder_dim, args.encoder_dim)
|
84 |
+
|
85 |
+
# self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
|
86 |
+
|
87 |
+
self.rotary_ndims = int(self.head_size * 0.5)
|
88 |
+
|
89 |
+
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
|
90 |
+
|
91 |
+
self.output = nn.Linear(args.encoder_dim, args.encoder_dim)
|
92 |
+
|
93 |
+
def forward(self, x, RoPE, key_padding_mask=None):
|
94 |
+
B, T, C = x.size()
|
95 |
+
|
96 |
+
if hasattr(self, 'time_shift'):
|
97 |
+
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
|
98 |
+
|
99 |
+
q = self.query(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
100 |
+
k = self.key(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
101 |
+
v = self.value(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
102 |
+
|
103 |
+
q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
|
104 |
+
k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
|
105 |
+
|
106 |
+
# cos, sin = self.rotary_emb(q, seq_len=T)
|
107 |
+
cos, sin = RoPE
|
108 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding
|
109 |
+
q = torch.cat((q, query_pass), dim=-1)
|
110 |
+
k = torch.cat((k, key_pass), dim=-1)
|
111 |
+
|
112 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
|
113 |
+
if key_padding_mask is not None:
|
114 |
+
key_padding_mask = key_padding_mask[:, None, None, :] # (B, T) -> (B, 1, 1, T)
|
115 |
+
att = att.masked_fill(key_padding_mask == 0, float('-inf'))
|
116 |
+
att = F.softmax(att, dim = -1) # softmax
|
117 |
+
|
118 |
+
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
|
119 |
+
x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
|
120 |
+
|
121 |
+
x = self.output(x)
|
122 |
+
|
123 |
+
if self.collect_attention_map:
|
124 |
+
self.attention_map = att
|
125 |
+
|
126 |
+
return x
|
127 |
+
|
128 |
+
class MHA_decoder(nn.Module):
|
129 |
+
def __init__(self, args):
|
130 |
+
super().__init__()
|
131 |
+
self.collect_attention_map = False
|
132 |
+
self.attention_map = None
|
133 |
+
assert args.encoder_dim % args.num_heads == 0
|
134 |
+
self.num_heads = args.num_heads
|
135 |
+
self.head_size = args.decoder_dim // args.num_heads
|
136 |
+
|
137 |
+
if args.timeshift:
|
138 |
+
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
139 |
+
|
140 |
+
self.query = nn.Linear(args.decoder_dim, args.decoder_dim)
|
141 |
+
self.key = nn.Linear(args.decoder_dim, args.decoder_dim)
|
142 |
+
self.value = nn.Linear(args.decoder_dim, args.decoder_dim)
|
143 |
+
|
144 |
+
# self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
|
145 |
+
|
146 |
+
self.rotary_ndims = int(self.head_size * 0.5)
|
147 |
+
|
148 |
+
self.rotary_emb = RotaryEmbedding(self.rotary_ndims)
|
149 |
+
|
150 |
+
self.output = nn.Linear(args.decoder_dim, args.decoder_dim)
|
151 |
+
|
152 |
+
def forward(self, x, memory,RoPE, key_padding_mask=None):
|
153 |
+
B, T, C = x.size()
|
154 |
+
_, L, M = memory.size()
|
155 |
+
|
156 |
+
# print("x size: ", x.size(), 'memory size: ', memory.size())
|
157 |
+
# print('B, T, C: ', B, T, C, 'L: ', L)
|
158 |
+
|
159 |
+
q = self.query(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
160 |
+
k = self.key(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
161 |
+
v = self.value(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
162 |
+
|
163 |
+
q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
|
164 |
+
k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
|
165 |
+
|
166 |
+
# cos, sin = self.rotary_emb(q, seq_len=T)
|
167 |
+
cos, sin = RoPE
|
168 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin) # rotary encoding
|
169 |
+
q = torch.cat((q, query_pass), dim=-1)
|
170 |
+
k = torch.cat((k, key_pass), dim=-1)
|
171 |
+
|
172 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
|
173 |
+
if key_padding_mask is not None:
|
174 |
+
key_padding_mask = key_padding_mask[:, None, None, :] # (B, T) -> (B, 1, 1, T)
|
175 |
+
att = att.masked_fill(key_padding_mask == 0, float('-inf'))
|
176 |
+
att = F.softmax(att, dim = -1) # softmax
|
177 |
+
|
178 |
+
x = att @ v
|
179 |
+
# print("after attention vals: ", x.shape) # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
|
180 |
+
x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
|
181 |
+
|
182 |
+
# x = self.output(x)
|
183 |
+
|
184 |
+
# print("after linear: ", x.shape) # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
|
185 |
+
|
186 |
+
|
187 |
+
# cross attention:
|
188 |
+
q = self.query(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
189 |
+
k = self.key(memory).view(B, L, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
190 |
+
v = self.value(memory).view(B, L, self.num_heads, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
|
191 |
+
|
192 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
|
193 |
+
# print("att size: ", att.size())
|
194 |
+
if key_padding_mask is not None:
|
195 |
+
key_padding_mask = key_padding_mask[:, None, None, :] # (B, T) -> (B, 1, 1, T)
|
196 |
+
att = att.masked_fill(key_padding_mask == 0, float('-inf'))
|
197 |
+
att = F.softmax(att, dim = -1) # softmax
|
198 |
+
|
199 |
+
x = att @ v # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
|
200 |
+
# print("x deocder size: ", x.size())
|
201 |
+
x = x.transpose(1, 2).contiguous().view(B, T, -1) # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
|
202 |
+
# print("x deocder size transposed: ", x.size())
|
203 |
+
x = self.output(x)
|
204 |
+
|
205 |
+
if self.collect_attention_map:
|
206 |
+
self.attention_map = att
|
207 |
+
|
208 |
+
return x
|
209 |
+
|
210 |
+
class GeGLU(torch.nn.Module):
|
211 |
+
def __init__(self, config, layer_id, time_shift = False):
|
212 |
+
super().__init__()
|
213 |
+
self.layer_id = layer_id
|
214 |
+
|
215 |
+
if time_shift:
|
216 |
+
self.time_shift = nn.ZeroPad2d((0,0,1,0))
|
217 |
+
|
218 |
+
hidden_sz = 3 * config.n_ffn
|
219 |
+
self.key = nn.Linear(config.n_embd, hidden_sz)
|
220 |
+
self.value = nn.Linear(config.n_embd, hidden_sz)
|
221 |
+
self.weight = nn.Linear(hidden_sz, config.n_embd)
|
222 |
+
|
223 |
+
def forward(self, x):
|
224 |
+
B, T, C = x.size()
|
225 |
+
if hasattr(self, 'time_shift'):
|
226 |
+
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
|
227 |
+
|
228 |
+
k = self.key(x)
|
229 |
+
v = self.value(x)
|
230 |
+
y = self.weight(F.gelu(k) * v)
|
231 |
+
return y
|
tasks/audio.py
CHANGED
@@ -2,11 +2,19 @@ from fastapi import APIRouter
|
|
2 |
from datetime import datetime
|
3 |
from datasets import load_dataset
|
4 |
from sklearn.metrics import accuracy_score
|
|
|
5 |
import random
|
6 |
import os
|
|
|
|
|
7 |
|
8 |
from .utils.evaluation import AudioEvaluationRequest
|
9 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
from dotenv import load_dotenv
|
12 |
load_dotenv()
|
@@ -43,20 +51,43 @@ async def evaluate_audio(request: AudioEvaluationRequest):
|
|
43 |
# Split dataset
|
44 |
train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
|
45 |
test_dataset = train_test["test"]
|
46 |
-
|
47 |
# Start tracking emissions
|
48 |
tracker.start()
|
49 |
tracker.start_task("inference")
|
50 |
-
|
51 |
#--------------------------------------------------------------------------------------------
|
52 |
# YOUR MODEL INFERENCE CODE HERE
|
53 |
# Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
|
54 |
-
#--------------------------------------------------------------------------------------------
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
# Make random predictions (placeholder for actual model inference)
|
57 |
true_labels = test_dataset["label"]
|
58 |
-
|
59 |
-
|
60 |
#--------------------------------------------------------------------------------------------
|
61 |
# YOUR MODEL INFERENCE STOPS HERE
|
62 |
#--------------------------------------------------------------------------------------------
|
|
|
2 |
from datetime import datetime
|
3 |
from datasets import load_dataset
|
4 |
from sklearn.metrics import accuracy_score
|
5 |
+
import numpy as np
|
6 |
import random
|
7 |
import os
|
8 |
+
import torch
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
|
11 |
from .utils.evaluation import AudioEvaluationRequest
|
12 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
13 |
+
from data import FFTDataset
|
14 |
+
from models import DualEncoder
|
15 |
+
from train import Trainer
|
16 |
+
from data_utils import collate_fn, Container
|
17 |
+
import yaml
|
18 |
|
19 |
from dotenv import load_dotenv
|
20 |
load_dotenv()
|
|
|
51 |
# Split dataset
|
52 |
train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
|
53 |
test_dataset = train_test["test"]
|
54 |
+
|
55 |
# Start tracking emissions
|
56 |
tracker.start()
|
57 |
tracker.start_task("inference")
|
58 |
+
|
59 |
#--------------------------------------------------------------------------------------------
|
60 |
# YOUR MODEL INFERENCE CODE HERE
|
61 |
# Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
|
62 |
+
#--------------------------------------------------------------------------------------------
|
63 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
64 |
+
args_path = 'config.yaml'
|
65 |
+
data_args = Container(**yaml.safe_load(open(args_path, 'r'))['Data'])
|
66 |
+
model_args = Container(**yaml.safe_load(open(args_path, 'r'))['CNNEncoder'])
|
67 |
+
model_args_f = Container(**yaml.safe_load(open(args_path, 'r'))['CNNEncoder_f'])
|
68 |
+
conformer_args = Container(**yaml.safe_load(open(args_path, 'r'))['Conformer'])
|
69 |
+
|
70 |
+
test_dataset = FFTDataset(test_dataset)
|
71 |
+
test_dl = DataLoader(test_dataset, batch_size=data_args.batch_size, collate_fn=collate_fn)
|
72 |
+
|
73 |
+
model = DualEncoder(model_args, model_args_f, conformer_args)
|
74 |
+
model = model.to(device)
|
75 |
+
missing, unexpected = model.load_state_dict(torch.load(model_args.checkpoint_path))
|
76 |
+
|
77 |
+
loss_fn = torch.nn.BCEWithLogitsLoss()
|
78 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
|
79 |
+
trainer = Trainer(model=model, optimizer=optimizer,
|
80 |
+
criterion=loss_fn, output_dim=model_args.output_dim, scaler=None,
|
81 |
+
scheduler=None, train_dataloader=None,
|
82 |
+
val_dataloader=None, device=device,
|
83 |
+
exp_num='test', log_path=None,
|
84 |
+
range_update=None,
|
85 |
+
accumulation_step=1, max_iter=np.inf,
|
86 |
+
exp_name=f"frugal_cnnencoder_inference")
|
87 |
+
predictions, acc = trainer.predict(test_dl, device=device)
|
88 |
# Make random predictions (placeholder for actual model inference)
|
89 |
true_labels = test_dataset["label"]
|
90 |
+
|
|
|
91 |
#--------------------------------------------------------------------------------------------
|
92 |
# YOUR MODEL INFERENCE STOPS HERE
|
93 |
#--------------------------------------------------------------------------------------------
|
tasks/config.yaml
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Data:
|
2 |
+
# Basics
|
3 |
+
log_dir: '/data/frugal/logs'
|
4 |
+
# Data
|
5 |
+
dataset: "KeplerDataset"
|
6 |
+
data_dir: '/data/lightPred/data'
|
7 |
+
model_name: "CNNEncoder"
|
8 |
+
batch_size: 16
|
9 |
+
num_epochs: 1000
|
10 |
+
exp_num: 2
|
11 |
+
max_len_spectra: 4096
|
12 |
+
max_days_lc: 270
|
13 |
+
lc_freq: 0.0208
|
14 |
+
create_umap: True
|
15 |
+
|
16 |
+
CNNEncoder:
|
17 |
+
# Model
|
18 |
+
in_channels: 1
|
19 |
+
num_layers: 4
|
20 |
+
stride: 1
|
21 |
+
encoder_dims: [32,64,128,256]
|
22 |
+
kernel_size: 3
|
23 |
+
dropout_p: 0.3
|
24 |
+
output_dim: 2
|
25 |
+
beta: 1
|
26 |
+
load_checkpoint: True
|
27 |
+
checkpoint_num: 1
|
28 |
+
activation: "silu"
|
29 |
+
sine_w0: 1.0
|
30 |
+
avg_output: True
|
31 |
+
checkpoint_path: 'logs/frugal_2025-01-10/frugal_cnnencoder_2.pth'
|
32 |
+
|
33 |
+
CNNEncoder_f:
|
34 |
+
# Model
|
35 |
+
in_channels: 1
|
36 |
+
num_layers: 4
|
37 |
+
stride: 1
|
38 |
+
encoder_dims: [32,64,128]
|
39 |
+
kernel_size: 3
|
40 |
+
dropout_p: 0.3
|
41 |
+
output_dim: 2
|
42 |
+
beta: 1
|
43 |
+
load_checkpoint: True
|
44 |
+
checkpoint_num: 1
|
45 |
+
activation: "silu"
|
46 |
+
sine_w0: 1.0
|
47 |
+
avg_output: True
|
48 |
+
|
49 |
+
|
50 |
+
Conformer:
|
51 |
+
encoder: ["mhsa_pro", "conv"]
|
52 |
+
timeshift: false
|
53 |
+
num_layers: 8
|
54 |
+
encoder_dim: 128
|
55 |
+
num_heads: 8
|
56 |
+
kernel_size: 3
|
57 |
+
dropout_p: 0.2
|
58 |
+
norm: "postnorm"
|
59 |
+
|
60 |
+
|
61 |
+
Optimization:
|
62 |
+
# Optimization
|
63 |
+
max_lr: 1e-5
|
64 |
+
weight_decay: 5e-6
|
65 |
+
warmup_pct: 0.3
|
66 |
+
steps_per_epoch: 3500
|
tasks/data.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import IterableDataset
|
3 |
+
from torch.fft import fft
|
4 |
+
from itertools import tee
|
5 |
+
import random
|
6 |
+
import torchaudio.transforms as T
|
7 |
+
|
8 |
+
|
9 |
+
class SplitDataset(IterableDataset):
|
10 |
+
def __init__(self, dataset, is_train=True, train_ratio=0.8):
|
11 |
+
self.dataset = dataset
|
12 |
+
self.is_train = is_train
|
13 |
+
self.train_ratio = train_ratio
|
14 |
+
|
15 |
+
def __iter__(self):
|
16 |
+
count = 0
|
17 |
+
for item in self.dataset:
|
18 |
+
# For first train_ratio portion of items, yield to train
|
19 |
+
# For remaining items, yield to validation
|
20 |
+
is_train_item = count < int(self.train_ratio * 100)
|
21 |
+
if is_train_item == self.is_train:
|
22 |
+
yield item
|
23 |
+
count = (count + 1) % 100
|
24 |
+
|
25 |
+
|
26 |
+
class FFTDataset(IterableDataset):
|
27 |
+
def __init__(self, original_dataset, orig_sample_rate=12000, target_sample_rate=6000):
|
28 |
+
self.dataset = original_dataset
|
29 |
+
self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate)
|
30 |
+
|
31 |
+
def __iter__(self):
|
32 |
+
for item in self.dataset:
|
33 |
+
# Assuming your audio data is in item['audio']
|
34 |
+
# Modify this based on your actual data structure
|
35 |
+
audio_data = torch.tensor(item['audio']['array']).float()
|
36 |
+
if len(audio_data) == 0:
|
37 |
+
continue
|
38 |
+
resampled_audio = self.resampler(audio_data)
|
39 |
+
fft_data = fft(resampled_audio)
|
40 |
+
|
41 |
+
# Update the item with FFT data
|
42 |
+
item['audio']['fft'] = fft_data
|
43 |
+
yield item
|
tasks/data_utils.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.distributed as dist
|
4 |
+
from torch.nn.utils.rnn import pad_sequence
|
5 |
+
|
6 |
+
def collate_fn(batch):
|
7 |
+
# Extract audio arrays and FFT data from the batch of dictionaries
|
8 |
+
audio_arrays = [torch.tensor(item['audio']['array']) for item in batch]
|
9 |
+
fft_arrays = [torch.tensor(item['audio']['fft']) for item in batch]
|
10 |
+
labels = [torch.tensor(item['label']) for item in batch]
|
11 |
+
|
12 |
+
# Pad both sequences
|
13 |
+
padded_audio = pad_sequence(audio_arrays, batch_first=True, padding_value=0)
|
14 |
+
padded_fft = pad_sequence(fft_arrays, batch_first=True, padding_value=0)
|
15 |
+
|
16 |
+
# Return as dictionary with the same structure
|
17 |
+
return {
|
18 |
+
'audio': {
|
19 |
+
'array': padded_audio,
|
20 |
+
'fft': padded_fft
|
21 |
+
},
|
22 |
+
'label': torch.stack(labels)
|
23 |
+
|
24 |
+
}
|
25 |
+
|
26 |
+
class Container(object):
|
27 |
+
'''A container class that can be used to store any attributes.'''
|
28 |
+
def __init__(self, **kwargs):
|
29 |
+
self.__dict__.update(kwargs)
|
30 |
+
|
31 |
+
def load_dict(self, dict):
|
32 |
+
for key, value in dict.items():
|
33 |
+
if getattr(self, key, None) is None:
|
34 |
+
setattr(self, key, value)
|
35 |
+
|
36 |
+
def print_attributes(self):
|
37 |
+
for key, value in vars(self).items():
|
38 |
+
print(f"{key}: {value}")
|
39 |
+
|
40 |
+
def get_dict(self):
|
41 |
+
return self.__dict__
|
42 |
+
|
43 |
+
def setup():
|
44 |
+
"""
|
45 |
+
Setup the distributed training environment.
|
46 |
+
"""
|
47 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
48 |
+
rank = int(os.environ["SLURM_PROCID"])
|
49 |
+
jobid = int(os.environ["SLURM_JOBID"])
|
50 |
+
gpus_per_node = torch.cuda.device_count()
|
51 |
+
print('jobid ', jobid)
|
52 |
+
print('gpus per node ', gpus_per_node)
|
53 |
+
print(f"Hello from rank {rank} of {world_size} where there are" \
|
54 |
+
f" {gpus_per_node} allocated GPUs per node. ", flush=True)
|
55 |
+
|
56 |
+
# initialize the process group
|
57 |
+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
58 |
+
|
59 |
+
if rank == 0: print(f"Group initialized? {dist.is_initialized()}", flush=True)
|
60 |
+
local_rank = rank - gpus_per_node * (rank // gpus_per_node)
|
61 |
+
torch.cuda.set_device(local_rank)
|
62 |
+
print(f"rank: {rank}, local_rank: {local_rank}")
|
63 |
+
return local_rank, world_size, gpus_per_node
|
tasks/models.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from Modules.conformer import ConformerEncoder, ConformerDecoder
|
4 |
+
from Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding
|
5 |
+
|
6 |
+
class ConvBlock(nn.Module):
|
7 |
+
def __init__(self, args, num_layer) -> None:
|
8 |
+
super().__init__()
|
9 |
+
if args.activation == 'silu':
|
10 |
+
self.activation = nn.SiLU()
|
11 |
+
else:
|
12 |
+
self.activation = nn.ReLU()
|
13 |
+
in_channels = args.encoder_dims[num_layer-1] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
|
14 |
+
out_channels = args.encoder_dims[num_layer] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
|
15 |
+
self.layers = nn.Sequential(
|
16 |
+
nn.Conv1d(in_channels=in_channels,
|
17 |
+
out_channels=out_channels,
|
18 |
+
kernel_size=args.kernel_size,
|
19 |
+
stride=1, padding='same', bias=False),
|
20 |
+
nn.BatchNorm1d(num_features=out_channels),
|
21 |
+
self.activation,
|
22 |
+
)
|
23 |
+
|
24 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
25 |
+
return self.layers(x)
|
26 |
+
|
27 |
+
class CNNEncoder(nn.Module):
|
28 |
+
def __init__(self, args) -> None:
|
29 |
+
super().__init__()
|
30 |
+
print("Using CNN encoder wit activation: ", args.activation, 'args avg_output: ', args.avg_output)
|
31 |
+
if args.activation == 'silu':
|
32 |
+
self.activation = nn.SiLU()
|
33 |
+
else:
|
34 |
+
self.activation = nn.ReLU()
|
35 |
+
self.embedding = nn.Sequential(nn.Conv1d(in_channels = args.in_channels,
|
36 |
+
kernel_size=3, out_channels = args.encoder_dims[0], stride=1, padding = 'same', bias = False),
|
37 |
+
nn.BatchNorm1d(args.encoder_dims[0]),
|
38 |
+
self.activation,
|
39 |
+
)
|
40 |
+
|
41 |
+
self.layers = nn.ModuleList([ConvBlock(args, i+1)
|
42 |
+
for i in range(args.num_layers)])
|
43 |
+
self.pool = nn.MaxPool1d(2)
|
44 |
+
self.output_dim = args.encoder_dims[-1]
|
45 |
+
self.min_seq_len = 2
|
46 |
+
self.avg_output = args.avg_output
|
47 |
+
|
48 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
49 |
+
if len(x.shape)==2:
|
50 |
+
x = x.unsqueeze(1)
|
51 |
+
if len(x.shape)==3 and x.shape[-1]==1:
|
52 |
+
x = x.permute(0,2,1)
|
53 |
+
x = self.embedding(x)
|
54 |
+
for m in self.layers:
|
55 |
+
x = m(x)
|
56 |
+
if x.shape[-1] > self.min_seq_len:
|
57 |
+
x = self.pool(x)
|
58 |
+
if self.avg_output:
|
59 |
+
x = x.mean(dim=-1)
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
class MultiEncoder(nn.Module):
|
64 |
+
def __init__(self, args, conformer_args):
|
65 |
+
super().__init__()
|
66 |
+
self.backbone = CNNEncoder(args)
|
67 |
+
self.backbone.avg_output = False
|
68 |
+
self.head_size = conformer_args.encoder_dim // conformer_args.num_heads
|
69 |
+
self.rotary_ndims = int(self.head_size * 0.5)
|
70 |
+
self.pe = RotaryEmbedding(self.rotary_ndims)
|
71 |
+
self.encoder = ConformerEncoder(conformer_args)
|
72 |
+
self.output_dim = conformer_args.encoder_dim
|
73 |
+
self.avg_output = args.avg_output
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
# Store backbone output in a separate tensor
|
77 |
+
backbone_out = self.backbone(x)
|
78 |
+
|
79 |
+
# Create x_enc from backbone_out
|
80 |
+
if len(backbone_out.shape) == 2:
|
81 |
+
x_enc = backbone_out.unsqueeze(1).clone()
|
82 |
+
else:
|
83 |
+
x_enc = backbone_out.permute(0,2,1).clone()
|
84 |
+
|
85 |
+
RoPE = self.pe(x_enc, x_enc.shape[1])
|
86 |
+
x_enc = self.encoder(x_enc, RoPE)
|
87 |
+
|
88 |
+
if len(x_enc.shape) == 3:
|
89 |
+
if self.avg_output:
|
90 |
+
x_enc = x_enc.sum(dim=1)
|
91 |
+
else:
|
92 |
+
x_enc = x_enc.permute(0,2,1)
|
93 |
+
|
94 |
+
# Return x_enc and the original backbone output
|
95 |
+
return x_enc, backbone_out
|
96 |
+
|
97 |
+
class DualEncoder(nn.Module):
|
98 |
+
def __init__(self, args_x, args_f, conformer_args) -> None:
|
99 |
+
super().__init__()
|
100 |
+
self.encoder_x = CNNEncoder(args_x)
|
101 |
+
self.encoder_f = MultiEncoder(args_f, conformer_args)
|
102 |
+
total_output_dim = args_x.encoder_dims[-1] + args_f.encoder_dims[-1]
|
103 |
+
self.regressor = nn.Sequential(
|
104 |
+
nn.Linear(total_output_dim, total_output_dim//2),
|
105 |
+
nn.BatchNorm1d(total_output_dim//2),
|
106 |
+
nn.SiLU(),
|
107 |
+
nn.Linear(total_output_dim//2, 1)
|
108 |
+
)
|
109 |
+
|
110 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
111 |
+
x1 = self.encoder_x(x)
|
112 |
+
x2, _ = self.encoder_f(x)
|
113 |
+
logits = torch.cat([x1, x2], dim=-1)
|
114 |
+
return self.regressor(logits).squeeze()
|
tasks/train.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.cuda.amp import autocast
|
3 |
+
import numpy as np
|
4 |
+
import time
|
5 |
+
import os
|
6 |
+
import yaml
|
7 |
+
from matplotlib import pyplot as plt
|
8 |
+
import glob
|
9 |
+
from collections import OrderedDict
|
10 |
+
from tqdm import tqdm
|
11 |
+
import torch.distributed as dist
|
12 |
+
import umap
|
13 |
+
|
14 |
+
class Trainer(object):
|
15 |
+
"""
|
16 |
+
A class that encapsulates the training loop for a PyTorch model.
|
17 |
+
"""
|
18 |
+
def __init__(self, model, optimizer, criterion, train_dataloader, device, world_size=1, output_dim=2,
|
19 |
+
scheduler=None, val_dataloader=None, max_iter=np.inf, scaler=None,
|
20 |
+
grad_clip=False, exp_num=None, log_path=None, exp_name=None, plot_every=None,
|
21 |
+
cos_inc=False, range_update=None, accumulation_step=1, wandb_log=False, num_quantiles=1,
|
22 |
+
update_func=lambda x: x):
|
23 |
+
self.model = model
|
24 |
+
self.optimizer = optimizer
|
25 |
+
self.criterion = criterion
|
26 |
+
self.scaler = scaler
|
27 |
+
self.grad_clip = grad_clip
|
28 |
+
self.cos_inc = cos_inc
|
29 |
+
self.output_dim = output_dim
|
30 |
+
self.scheduler = scheduler
|
31 |
+
self.train_dl = train_dataloader
|
32 |
+
self.val_dl = val_dataloader
|
33 |
+
self.train_sampler = self.get_sampler_from_dataloader(train_dataloader)
|
34 |
+
self.val_sampler = self.get_sampler_from_dataloader(val_dataloader)
|
35 |
+
self.max_iter = max_iter
|
36 |
+
self.device = device
|
37 |
+
self.world_size = world_size
|
38 |
+
self.exp_num = exp_num
|
39 |
+
self.exp_name = exp_name
|
40 |
+
self.log_path = log_path
|
41 |
+
self.best_state_dict = None
|
42 |
+
self.plot_every = plot_every
|
43 |
+
self.logger = None
|
44 |
+
self.range_update = range_update
|
45 |
+
self.accumulation_step = accumulation_step
|
46 |
+
self.wandb = wandb_log
|
47 |
+
self.num_quantiles = num_quantiles
|
48 |
+
self.update_func = update_func
|
49 |
+
# if log_path is not None:
|
50 |
+
# self.logger =SummaryWriter(f'{self.log_path}/exp{self.exp_num}')
|
51 |
+
# # print(f"logger path: {self.log_path}/exp{self.exp_num}")
|
52 |
+
|
53 |
+
# print("logger is: ", self.logger)
|
54 |
+
|
55 |
+
def get_sampler_from_dataloader(self, dataloader):
|
56 |
+
if hasattr(dataloader, 'sampler'):
|
57 |
+
if isinstance(dataloader.sampler, torch.utils.data.DistributedSampler):
|
58 |
+
return dataloader.sampler
|
59 |
+
elif hasattr(dataloader.sampler, 'sampler'):
|
60 |
+
return dataloader.sampler.sampler
|
61 |
+
|
62 |
+
if hasattr(dataloader, 'batch_sampler') and hasattr(dataloader.batch_sampler, 'sampler'):
|
63 |
+
return dataloader.batch_sampler.sampler
|
64 |
+
|
65 |
+
return None
|
66 |
+
|
67 |
+
def fit(self, num_epochs, device, early_stopping=None, only_p=False, best='loss', conf=False):
|
68 |
+
"""
|
69 |
+
Fits the model for the given number of epochs.
|
70 |
+
"""
|
71 |
+
min_loss = np.inf
|
72 |
+
best_acc = 0
|
73 |
+
train_loss, val_loss, = [], []
|
74 |
+
train_acc, val_acc = [], []
|
75 |
+
lrs = []
|
76 |
+
# self.optim_params['lr_history'] = []
|
77 |
+
epochs_without_improvement = 0
|
78 |
+
main_proccess = (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0) or self.device == 'cpu'
|
79 |
+
|
80 |
+
print(f"Starting training for {num_epochs} epochs")
|
81 |
+
print("is main process: ", main_proccess, flush=True)
|
82 |
+
global_time = time.time()
|
83 |
+
self.epoch = 0
|
84 |
+
for epoch in range(num_epochs):
|
85 |
+
self.epoch = epoch
|
86 |
+
start_time = time.time()
|
87 |
+
plot = (self.plot_every is not None) and (epoch % self.plot_every == 0)
|
88 |
+
t_loss, t_acc = self.train_epoch(device, epoch=epoch)
|
89 |
+
t_loss_mean = np.nanmean(t_loss)
|
90 |
+
train_loss.extend(t_loss)
|
91 |
+
global_train_accuracy, global_train_loss = self.process_loss(t_acc, t_loss_mean)
|
92 |
+
if main_proccess: # Only perform this on the master GPU
|
93 |
+
train_acc.append(global_train_accuracy.mean().item())
|
94 |
+
|
95 |
+
v_loss, v_acc = self.eval_epoch(device, epoch=epoch)
|
96 |
+
v_loss_mean = np.nanmean(v_loss)
|
97 |
+
val_loss.extend(v_loss)
|
98 |
+
global_val_accuracy, global_val_loss = self.process_loss(v_acc, v_loss_mean)
|
99 |
+
if main_proccess: # Only perform this on the master GPU
|
100 |
+
val_acc.append(global_val_accuracy.mean().item())
|
101 |
+
|
102 |
+
current_objective = global_val_loss if best == 'loss' else global_val_accuracy.mean()
|
103 |
+
improved = False
|
104 |
+
|
105 |
+
if best == 'loss':
|
106 |
+
if current_objective < min_loss:
|
107 |
+
min_loss = current_objective
|
108 |
+
improved = True
|
109 |
+
else:
|
110 |
+
if current_objective > best_acc:
|
111 |
+
best_acc = current_objective
|
112 |
+
improved = True
|
113 |
+
|
114 |
+
if improved:
|
115 |
+
model_name = f'{self.log_path}/{self.exp_num}/{self.exp_name}.pth'
|
116 |
+
print(f"saving model at {model_name}...")
|
117 |
+
torch.save(self.model.state_dict(), model_name)
|
118 |
+
self.best_state_dict = self.model.state_dict()
|
119 |
+
epochs_without_improvement = 0
|
120 |
+
else:
|
121 |
+
epochs_without_improvement += 1
|
122 |
+
|
123 |
+
current_lr = self.optimizer.param_groups[0]['lr'] if self.scheduler is None \
|
124 |
+
else self.scheduler.get_last_lr()[0]
|
125 |
+
|
126 |
+
lrs.append(current_lr)
|
127 |
+
|
128 |
+
print(f'Epoch {epoch}, lr {current_lr}, Train Loss: {global_train_loss:.6f}, Val Loss:'\
|
129 |
+
f'{global_val_loss:.6f}, Train Acc: {global_train_accuracy.round(decimals=4).tolist()}, '\
|
130 |
+
f'Val Acc: {global_val_accuracy.round(decimals=4).tolist()},'\
|
131 |
+
f'Time: {time.time() - start_time:.2f}s, Total Time: {(time.time() - global_time)/3600} hr', flush=True)
|
132 |
+
if epoch % 10 == 0:
|
133 |
+
print(os.system('nvidia-smi'))
|
134 |
+
|
135 |
+
if epochs_without_improvement == early_stopping:
|
136 |
+
print('early stopping!', flush=True)
|
137 |
+
break
|
138 |
+
if time.time() - global_time > (23.83 * 3600):
|
139 |
+
print("time limit reached")
|
140 |
+
break
|
141 |
+
|
142 |
+
return {"num_epochs":num_epochs, "train_loss": train_loss,
|
143 |
+
"val_loss": val_loss, "train_acc": train_acc, "val_acc": val_acc, "lrs": lrs}
|
144 |
+
|
145 |
+
def process_loss(self, acc, loss_mean):
|
146 |
+
if torch.cuda.is_available() and torch.distributed.is_initialized():
|
147 |
+
global_accuracy = torch.tensor(acc).cuda() # Convert accuracy to a tensor on the GPU
|
148 |
+
torch.distributed.reduce(global_accuracy, dst=0, op=torch.distributed.ReduceOp.SUM)
|
149 |
+
global_loss = torch.tensor(loss_mean).cuda() # Convert loss to a tensor on the GPU
|
150 |
+
torch.distributed.reduce(global_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
|
151 |
+
|
152 |
+
# Divide both loss and accuracy by world size
|
153 |
+
world_size = torch.distributed.get_world_size()
|
154 |
+
global_loss /= world_size
|
155 |
+
global_accuracy /= world_size
|
156 |
+
else:
|
157 |
+
global_loss = torch.tensor(loss_mean)
|
158 |
+
global_accuracy = torch.tensor(acc)
|
159 |
+
return global_accuracy, global_loss
|
160 |
+
|
161 |
+
def load_best_model(self, to_ddp=True, from_ddp=True):
|
162 |
+
data_dir = f'{self.log_path}/exp{self.exp_num}'
|
163 |
+
# data_dir = f'{self.log_path}/exp29' # for debugging
|
164 |
+
|
165 |
+
state_dict_files = glob.glob(data_dir + '/*.pth')
|
166 |
+
print("loading model from ", state_dict_files[-1])
|
167 |
+
|
168 |
+
state_dict = torch.load(state_dict_files[-1]) if to_ddp else torch.load(state_dict_files[0],map_location=self.device)
|
169 |
+
|
170 |
+
if from_ddp:
|
171 |
+
print("loading distributed model")
|
172 |
+
# Remove "module." from keys
|
173 |
+
new_state_dict = OrderedDict()
|
174 |
+
for key, value in state_dict.items():
|
175 |
+
if key.startswith('module.'):
|
176 |
+
while key.startswith('module.'):
|
177 |
+
key = key[7:]
|
178 |
+
new_state_dict[key] = value
|
179 |
+
state_dict = new_state_dict
|
180 |
+
# print("state_dict: ", state_dict.keys())
|
181 |
+
# print("model: ", self.model.state_dict().keys())
|
182 |
+
|
183 |
+
self.model.load_state_dict(state_dict, strict=False)
|
184 |
+
|
185 |
+
def check_gradients(self):
|
186 |
+
for name, param in self.model.named_parameters():
|
187 |
+
if param.grad is not None:
|
188 |
+
grad_norm = param.grad.norm().item()
|
189 |
+
if grad_norm > 10:
|
190 |
+
print(f"Large gradient in {name}: {grad_norm}")
|
191 |
+
|
192 |
+
def train_epoch(self, device, epoch):
|
193 |
+
"""
|
194 |
+
Trains the model for one epoch.
|
195 |
+
"""
|
196 |
+
if self.train_sampler is not None:
|
197 |
+
try:
|
198 |
+
self.train_sampler.set_epoch(epoch)
|
199 |
+
except AttributeError:
|
200 |
+
pass
|
201 |
+
self.model.train()
|
202 |
+
train_loss = []
|
203 |
+
train_acc = 0
|
204 |
+
total = 0
|
205 |
+
all_accs = torch.zeros(self.output_dim, device=device)
|
206 |
+
pbar = tqdm(self.train_dl)
|
207 |
+
for i, batch in enumerate(pbar):
|
208 |
+
if self.optimizer is not None:
|
209 |
+
self.optimizer.zero_grad()
|
210 |
+
loss, acc , y = self.train_batch(batch, i, device)
|
211 |
+
train_loss.append(loss.item())
|
212 |
+
all_accs = all_accs + acc
|
213 |
+
total += len(y)
|
214 |
+
pbar.set_description(f"train_acc: {acc}, train_loss: {loss.item()}")
|
215 |
+
if i > self.max_iter:
|
216 |
+
break
|
217 |
+
print("number of train_accs: ", train_acc)
|
218 |
+
return train_loss, all_accs/total
|
219 |
+
|
220 |
+
def train_batch(self, batch, batch_idx, device):
|
221 |
+
x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
|
222 |
+
x = x.to(device).float()
|
223 |
+
fft = fft.to(device).float()
|
224 |
+
y = y.to(device).float()
|
225 |
+
y_pred = self.model(fft)
|
226 |
+
loss = self.criterion(y_pred, y)
|
227 |
+
loss.backward()
|
228 |
+
self.optimizer.step()
|
229 |
+
if self.scheduler is not None:
|
230 |
+
self.scheduler.step()
|
231 |
+
# get predicted classes
|
232 |
+
probs = torch.sigmoid(y_pred)
|
233 |
+
cls_pred = (probs > 0.5).float()
|
234 |
+
acc = (cls_pred == y).sum()
|
235 |
+
return loss, acc, y
|
236 |
+
|
237 |
+
def eval_epoch(self, device, epoch):
|
238 |
+
"""
|
239 |
+
Evaluates the model for one epoch.
|
240 |
+
"""
|
241 |
+
self.model.eval()
|
242 |
+
val_loss = []
|
243 |
+
val_acc = 0
|
244 |
+
total = 0
|
245 |
+
all_accs = torch.zeros(self.output_dim, device=device)
|
246 |
+
pbar = tqdm(self.val_dl)
|
247 |
+
for i,batch in enumerate(pbar):
|
248 |
+
loss, acc, y = self.eval_batch(batch, i, device)
|
249 |
+
val_loss.append(loss.item())
|
250 |
+
all_accs = all_accs + acc
|
251 |
+
total += len(y)
|
252 |
+
pbar.set_description(f"val_acc: {acc}, val_loss: {loss.item()}")
|
253 |
+
if i > self.max_iter:
|
254 |
+
break
|
255 |
+
return val_loss, all_accs/total
|
256 |
+
|
257 |
+
def eval_batch(self, batch, batch_idx, device):
|
258 |
+
x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
|
259 |
+
x = x.to(device).float()
|
260 |
+
fft = fft.to(device).float()
|
261 |
+
y = y.to(device).float()
|
262 |
+
with torch.no_grad():
|
263 |
+
y_pred = self.model(fft)
|
264 |
+
loss = self.criterion(y_pred, y)
|
265 |
+
probs = torch.sigmoid(y_pred)
|
266 |
+
cls_pred = (probs > 0.5).float()
|
267 |
+
acc = (cls_pred == y).sum()
|
268 |
+
return loss, acc, y
|
269 |
+
|
270 |
+
def predict(self, test_dataloader, device):
|
271 |
+
"""
|
272 |
+
Returns the predictions of the model on the given dataset.
|
273 |
+
"""
|
274 |
+
self.model.eval()
|
275 |
+
total = 0
|
276 |
+
all_accs = 0
|
277 |
+
predictions = []
|
278 |
+
pbar = tqdm(self.val_dl)
|
279 |
+
for i,batch in enumerate(pbar):
|
280 |
+
x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
|
281 |
+
x = x.to(device).float()
|
282 |
+
fft = fft.to(device).float()
|
283 |
+
y = y.to(device).float()
|
284 |
+
with torch.no_grad():
|
285 |
+
y_pred = self.model(fft)
|
286 |
+
loss = self.criterion(y_pred, y)
|
287 |
+
probs = torch.sigmoid(y_pred)
|
288 |
+
cls_pred = (probs > 0.5).float()
|
289 |
+
acc = (cls_pred == y).sum()
|
290 |
+
predictions.append(cls_pred)
|
291 |
+
all_accs += acc
|
292 |
+
total += len(y)
|
293 |
+
return predictions, all_accs/total
|