|
from typing import Optional |
|
from dataclasses import dataclass |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers.modeling_outputs import ModelOutput |
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
from .configuration_fcn4flare import FCN4FlareConfig |
|
|
|
|
|
class MaskDiceLoss(nn.Module): |
|
r""" |
|
Computes the Mask Dice Loss between the predicted and target tensors. |
|
$$ |
|
\text{loss} = 1 - \frac{2 \times \text{intersection} + \epsilon}{\text{predicted} + \text{target} + \epsilon} |
|
$$ |
|
|
|
Args: |
|
maskdice_threshold (float): Threshold value for the predicted tensor. |
|
|
|
Returns: |
|
loss (float): Computed Mask Dice Loss. |
|
""" |
|
def __init__(self, maskdice_threshold): |
|
super().__init__() |
|
self.maskdice_threshold = maskdice_threshold |
|
|
|
def forward(self, inputs, targets): |
|
""" |
|
Computes the forward pass of the Mask Dice Loss. |
|
|
|
Args: |
|
inputs (torch.Tensor): Predicted tensor. |
|
targets (torch.Tensor): Target tensor. |
|
|
|
Returns: |
|
loss (float): Computed Mask Dice Loss. |
|
""" |
|
n = targets.size(0) |
|
smooth = 1e-8 |
|
|
|
|
|
inputs_act = torch.gt(inputs, self.maskdice_threshold) |
|
inputs_act = inputs_act.long() |
|
inputs = inputs * inputs_act |
|
|
|
intersection = inputs * targets |
|
dice_diff = (2 * intersection.sum(1) + smooth) / (inputs.sum(1) + targets.sum(1) + smooth * n) |
|
loss = 1 - dice_diff.mean() |
|
return loss |
|
|
|
|
|
class NaNMask(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, inputs): |
|
|
|
nan_mask = torch.isnan(inputs).float() |
|
|
|
inputs = torch.nan_to_num(inputs, nan=0.0) |
|
|
|
return torch.cat([inputs, nan_mask], dim=-1) |
|
|
|
|
|
class SamePadConv(nn.Module): |
|
def __init__(self, input_dim, output_dim, kernel_size, dilation=1): |
|
super().__init__() |
|
self.receptive_field = (kernel_size - 1) * dilation + 1 |
|
padding = self.receptive_field // 2 |
|
self.conv = nn.Conv1d( |
|
input_dim, output_dim, kernel_size, |
|
padding=padding, |
|
dilation=dilation |
|
) |
|
self.batchnorm = nn.BatchNorm1d(output_dim) |
|
self.remove = 1 if self.receptive_field % 2 == 0 else 0 |
|
|
|
def forward(self, x): |
|
x = self.conv(x) |
|
x = self.batchnorm(x) |
|
x = F.gelu(x) |
|
if self.remove > 0: |
|
x = x[:, :, : -self.remove] |
|
return x |
|
|
|
|
|
class ConvBlock(nn.Module): |
|
def __init__(self, input_dim, output_dim, kernel_size, dilation): |
|
super().__init__() |
|
self.conv1 = SamePadConv(input_dim, output_dim, kernel_size, dilation=dilation) |
|
self.conv2 = SamePadConv(output_dim, output_dim, kernel_size, dilation=dilation) |
|
|
|
def forward(self, x): |
|
residual = x |
|
x = self.conv1(x) |
|
x = self.conv2(x) |
|
return x + residual |
|
|
|
|
|
class Backbone(nn.Module): |
|
def __init__(self, input_dim, dim_list, dilation, kernel_size): |
|
super().__init__() |
|
self.net = nn.Sequential(*[ |
|
ConvBlock( |
|
dim_list[i-1] if i > 0 else input_dim, |
|
dim_list[i], |
|
kernel_size=kernel_size, |
|
dilation=dilation[i] |
|
) |
|
for i in range(len(dim_list)) |
|
]) |
|
|
|
def forward(self, x): |
|
return self.net(x) |
|
|
|
|
|
class LightCurveEncoder(nn.Module): |
|
def __init__(self, input_dim, output_dim, depth, dilation): |
|
super().__init__() |
|
self.mapping = nn.Conv1d(input_dim + 1, output_dim, 1) |
|
self.backbone = Backbone( |
|
output_dim, |
|
[output_dim] * depth, |
|
dilation, |
|
kernel_size=3 |
|
) |
|
self.repr_dropout = nn.Dropout(p=0.1) |
|
|
|
def forward(self, x): |
|
x = x.transpose(1, 2) |
|
x = self.mapping(x) |
|
x = self.backbone(x) |
|
x = self.repr_dropout(x) |
|
return x |
|
|
|
|
|
class SegHead(nn.Module): |
|
def __init__(self, input_dim, output_dim): |
|
super().__init__() |
|
self.conv = SamePadConv(input_dim, input_dim, 3) |
|
self.projector = nn.Conv1d(input_dim, output_dim, 1) |
|
|
|
def forward(self, x): |
|
|
|
x = self.conv(x) |
|
x = self.projector(x) |
|
x = x.transpose(1, 2) |
|
return x |
|
|
|
|
|
class FCN4FlarePreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. |
|
""" |
|
config_class = FCN4FlareConfig |
|
base_model_prefix = "fcn4flare" |
|
supports_gradient_checkpointing = True |
|
|
|
def _init_weights(self, module): |
|
if isinstance(module, nn.Conv1d): |
|
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') |
|
elif isinstance(module, nn.BatchNorm1d): |
|
nn.init.constant_(module.weight, 1) |
|
nn.init.constant_(module.bias, 0) |
|
|
|
|
|
@dataclass |
|
class FCN4FlareOutput(ModelOutput): |
|
""" |
|
Output type of FCN4Flare. |
|
|
|
Args: |
|
loss (`Optional[torch.FloatTensor]` of shape `(1,)`, *optional*): |
|
Mask Dice loss if labels provided, None otherwise. |
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, output_dim)`): |
|
Prediction scores of the model. |
|
hidden_states (`torch.FloatTensor` of shape `(batch_size, hidden_dim, sequence_length)`): |
|
Hidden states from the encoder. |
|
""" |
|
loss: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
hidden_states: torch.FloatTensor = None |
|
|
|
|
|
class FCN4FlareModel(FCN4FlarePreTrainedModel): |
|
def __init__(self, config: FCN4FlareConfig): |
|
super().__init__(config) |
|
|
|
self.nan_mask = NaNMask() |
|
self.encoder = LightCurveEncoder( |
|
config.input_dim, |
|
config.hidden_dim, |
|
config.depth, |
|
config.dilation |
|
) |
|
self.seghead = SegHead(config.hidden_dim, config.output_dim) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_features, |
|
sequence_mask=None, |
|
labels=None, |
|
return_dict=True, |
|
): |
|
|
|
inputs_with_mask = self.nan_mask(input_features) |
|
|
|
|
|
outputs = self.encoder(inputs_with_mask) |
|
logits = self.seghead(outputs) |
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = MaskDiceLoss(self.config.maskdice_threshold) |
|
logits_sigmoid = torch.sigmoid(logits).squeeze(-1) |
|
|
|
if sequence_mask is not None: |
|
|
|
labels_for_loss = labels.clone() |
|
labels_for_loss = torch.nan_to_num(labels_for_loss, nan=0.0) |
|
labels_for_loss = labels_for_loss * sequence_mask |
|
logits_sigmoid = logits_sigmoid * sequence_mask |
|
loss = loss_fct(logits_sigmoid, labels_for_loss) |
|
else: |
|
loss = loss_fct(logits_sigmoid, labels) |
|
|
|
if not return_dict: |
|
output = (logits,) |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return FCN4FlareOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs |
|
) |
|
|