Maxwell-Jia commited on
Commit
0d2aee9
·
verified ·
1 Parent(s): 37777d0

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_fcn4flare.py +32 -0
  2. modeling_fcn4flare.py +242 -0
configuration_fcn4flare.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+
4
+ class FCN4FlareConfig(PretrainedConfig):
5
+ """
6
+ Configuration class for FCN4Flare model.
7
+ """
8
+ model_type = "fcn4flare"
9
+
10
+ def __init__(
11
+ self,
12
+ input_dim=3,
13
+ hidden_dim=64,
14
+ output_dim=1,
15
+ depth=4,
16
+ dilation=[1, 2, 4, 8],
17
+ maskdice_threshold=0.5,
18
+ dropout_rate=0.1,
19
+ kernel_size=3,
20
+ **kwargs
21
+ ):
22
+ """Initialize FCN4FlareConfig."""
23
+ super().__init__(**kwargs)
24
+
25
+ self.input_dim = input_dim
26
+ self.hidden_dim = hidden_dim
27
+ self.output_dim = output_dim
28
+ self.depth = depth
29
+ self.dilation = dilation
30
+ self.maskdice_threshold = maskdice_threshold
31
+ self.dropout_rate = dropout_rate
32
+ self.kernel_size = kernel_size
modeling_fcn4flare.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from transformers.modeling_outputs import ModelOutput
8
+ from transformers.modeling_utils import PreTrainedModel
9
+
10
+ from .configuration_fcn4flare import FCN4FlareConfig
11
+
12
+
13
+ class MaskDiceLoss(nn.Module):
14
+ r"""
15
+ Computes the Mask Dice Loss between the predicted and target tensors.
16
+ $$
17
+ \text{loss} = 1 - \frac{2 \times \text{intersection} + \epsilon}{\text{predicted} + \text{target} + \epsilon}
18
+ $$
19
+
20
+ Args:
21
+ maskdice_threshold (float): Threshold value for the predicted tensor.
22
+
23
+ Returns:
24
+ loss (float): Computed Mask Dice Loss.
25
+ """
26
+ def __init__(self, maskdice_threshold):
27
+ super().__init__()
28
+ self.maskdice_threshold = maskdice_threshold
29
+
30
+ def forward(self, inputs, targets):
31
+ """
32
+ Computes the forward pass of the Mask Dice Loss.
33
+
34
+ Args:
35
+ inputs (torch.Tensor): Predicted tensor.
36
+ targets (torch.Tensor): Target tensor.
37
+
38
+ Returns:
39
+ loss (float): Computed Mask Dice Loss.
40
+ """
41
+ n = targets.size(0)
42
+ smooth = 1e-8
43
+
44
+ # Apply thresholding to inputs
45
+ inputs_act = torch.gt(inputs, self.maskdice_threshold)
46
+ inputs_act = inputs_act.long()
47
+ inputs = inputs * inputs_act
48
+
49
+ intersection = inputs * targets
50
+ dice_diff = (2 * intersection.sum(1) + smooth) / (inputs.sum(1) + targets.sum(1) + smooth * n)
51
+ loss = 1 - dice_diff.mean()
52
+ return loss
53
+
54
+
55
+ class NaNMask(nn.Module):
56
+ def __init__(self):
57
+ super().__init__()
58
+
59
+ def forward(self, inputs):
60
+ # Create a mask where NaNs are marked as 1
61
+ nan_mask = torch.isnan(inputs).float()
62
+ # Replace NaNs with 0 in the input tensor
63
+ inputs = torch.nan_to_num(inputs, nan=0.0)
64
+ # Concatenate the input tensor with the NaN mask
65
+ return torch.cat([inputs, nan_mask], dim=-1)
66
+
67
+
68
+ class SamePadConv(nn.Module):
69
+ def __init__(self, input_dim, output_dim, kernel_size, dilation=1):
70
+ super().__init__()
71
+ self.receptive_field = (kernel_size - 1) * dilation + 1
72
+ padding = self.receptive_field // 2
73
+ self.conv = nn.Conv1d(
74
+ input_dim, output_dim, kernel_size,
75
+ padding=padding,
76
+ dilation=dilation
77
+ )
78
+ self.batchnorm = nn.BatchNorm1d(output_dim)
79
+ self.remove = 1 if self.receptive_field % 2 == 0 else 0
80
+
81
+ def forward(self, x):
82
+ x = self.conv(x)
83
+ x = self.batchnorm(x)
84
+ x = F.gelu(x)
85
+ if self.remove > 0:
86
+ x = x[:, :, : -self.remove]
87
+ return x
88
+
89
+
90
+ class ConvBlock(nn.Module):
91
+ def __init__(self, input_dim, output_dim, kernel_size, dilation):
92
+ super().__init__()
93
+ self.conv1 = SamePadConv(input_dim, output_dim, kernel_size, dilation=dilation)
94
+ self.conv2 = SamePadConv(output_dim, output_dim, kernel_size, dilation=dilation)
95
+
96
+ def forward(self, x):
97
+ residual = x
98
+ x = self.conv1(x)
99
+ x = self.conv2(x)
100
+ return x + residual
101
+
102
+
103
+ class Backbone(nn.Module):
104
+ def __init__(self, input_dim, dim_list, dilation, kernel_size):
105
+ super().__init__()
106
+ self.net = nn.Sequential(*[
107
+ ConvBlock(
108
+ dim_list[i-1] if i > 0 else input_dim,
109
+ dim_list[i],
110
+ kernel_size=kernel_size,
111
+ dilation=dilation[i]
112
+ )
113
+ for i in range(len(dim_list))
114
+ ])
115
+
116
+ def forward(self, x):
117
+ return self.net(x)
118
+
119
+
120
+ class LightCurveEncoder(nn.Module):
121
+ def __init__(self, input_dim, output_dim, depth, dilation):
122
+ super().__init__()
123
+ self.mapping = nn.Conv1d(input_dim + 1, output_dim, 1) # +1 for NaN mask
124
+ self.backbone = Backbone(
125
+ output_dim,
126
+ [output_dim] * depth,
127
+ dilation,
128
+ kernel_size=3
129
+ )
130
+ self.repr_dropout = nn.Dropout(p=0.1)
131
+
132
+ def forward(self, x):
133
+ x = x.transpose(1, 2) # B x Ci x T
134
+ x = self.mapping(x) # B x Ch x T
135
+ x = self.backbone(x) # B x Co x T
136
+ x = self.repr_dropout(x)
137
+ return x
138
+
139
+
140
+ class SegHead(nn.Module):
141
+ def __init__(self, input_dim, output_dim):
142
+ super().__init__()
143
+ self.conv = SamePadConv(input_dim, input_dim, 3)
144
+ self.projector = nn.Conv1d(input_dim, output_dim, 1)
145
+
146
+ def forward(self, x):
147
+ # x: B x Ci x T
148
+ x = self.conv(x) # B x Ci x T
149
+ x = self.projector(x) # B x Co x T
150
+ x = x.transpose(1, 2) # B x T x Co
151
+ return x
152
+
153
+
154
+ class FCN4FlarePreTrainedModel(PreTrainedModel):
155
+ """
156
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
157
+ """
158
+ config_class = FCN4FlareConfig
159
+ base_model_prefix = "fcn4flare"
160
+ supports_gradient_checkpointing = True
161
+
162
+ def _init_weights(self, module):
163
+ if isinstance(module, nn.Conv1d):
164
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
165
+ elif isinstance(module, nn.BatchNorm1d):
166
+ nn.init.constant_(module.weight, 1)
167
+ nn.init.constant_(module.bias, 0)
168
+
169
+
170
+ @dataclass
171
+ class FCN4FlareOutput(ModelOutput):
172
+ """
173
+ Output type of FCN4Flare.
174
+
175
+ Args:
176
+ loss (`Optional[torch.FloatTensor]` of shape `(1,)`, *optional*):
177
+ Mask Dice loss if labels provided, None otherwise.
178
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, output_dim)`):
179
+ Prediction scores of the model.
180
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, hidden_dim, sequence_length)`):
181
+ Hidden states from the encoder.
182
+ """
183
+ loss: Optional[torch.FloatTensor] = None
184
+ logits: torch.FloatTensor = None
185
+ hidden_states: torch.FloatTensor = None
186
+
187
+
188
+ class FCN4FlareModel(FCN4FlarePreTrainedModel):
189
+ def __init__(self, config: FCN4FlareConfig):
190
+ super().__init__(config)
191
+
192
+ self.nan_mask = NaNMask()
193
+ self.encoder = LightCurveEncoder(
194
+ config.input_dim,
195
+ config.hidden_dim,
196
+ config.depth,
197
+ config.dilation
198
+ )
199
+ self.seghead = SegHead(config.hidden_dim, config.output_dim)
200
+
201
+ # Initialize weights and apply final processing
202
+ self.post_init()
203
+
204
+ def forward(
205
+ self,
206
+ input_features,
207
+ sequence_mask=None,
208
+ labels=None,
209
+ return_dict=True,
210
+ ):
211
+ # Apply NaN masking
212
+ inputs_with_mask = self.nan_mask(input_features)
213
+
214
+ # Encoder and segmentation head
215
+ outputs = self.encoder(inputs_with_mask)
216
+ logits = self.seghead(outputs)
217
+
218
+ # Loss calculation
219
+ loss = None
220
+ if labels is not None:
221
+ loss_fct = MaskDiceLoss(self.config.maskdice_threshold)
222
+ logits_sigmoid = torch.sigmoid(logits).squeeze(-1)
223
+
224
+ if sequence_mask is not None:
225
+ # Copy labels and replace padding positions with zeros
226
+ labels_for_loss = labels.clone()
227
+ labels_for_loss = torch.nan_to_num(labels_for_loss, nan=0.0)
228
+ labels_for_loss = labels_for_loss * sequence_mask
229
+ logits_sigmoid = logits_sigmoid * sequence_mask
230
+ loss = loss_fct(logits_sigmoid, labels_for_loss)
231
+ else:
232
+ loss = loss_fct(logits_sigmoid, labels)
233
+
234
+ if not return_dict:
235
+ output = (logits,)
236
+ return ((loss,) + output) if loss is not None else output
237
+
238
+ return FCN4FlareOutput(
239
+ loss=loss,
240
+ logits=logits,
241
+ hidden_states=outputs
242
+ )