from transformers import PreTrainedModel from .unet3D import UNet from .UNetFTSRConfig import UNetFTSRConfig class UNetFTSR(PreTrainedModel): config_class = UNetFTSRConfig def __init__(self, config): super().__init__(config) self.model = UNet( in_channels=config.in_channels, n_classes=config.n_classes, depth=config.depth, wf=config.wf, padding=config.padding, batch_norm=config.batch_norm, up_mode=config.up_mode, dropout=config.dropout) def forward(self, x): return self.model(x)