soumickmj's picture
Upload UNetFTSR
301db00 verified
raw
history blame contribute delete
620 Bytes
from transformers import PreTrainedModel
from .unet3D import UNet
from .UNetFTSRConfig import UNetFTSRConfig
class UNetFTSR(PreTrainedModel):
config_class = UNetFTSRConfig
def __init__(self, config):
super().__init__(config)
self.model = UNet(
in_channels=config.in_channels,
n_classes=config.n_classes,
depth=config.depth,
wf=config.wf,
padding=config.padding,
batch_norm=config.batch_norm,
up_mode=config.up_mode,
dropout=config.dropout)
def forward(self, x):
return self.model(x)