File size: 1,381 Bytes
c985ba4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
from networks.encoders.mobilenetv2 import MobileNetV2
from networks.encoders.mobilenetv3 import MobileNetV3Large
from networks.encoders.resnet import ResNet101, ResNet50
from networks.encoders.resnest import resnest
from networks.encoders.swin import build_swin_model
from networks.layers.normalization import FrozenBatchNorm2d
from torch import nn
def build_encoder(name, frozen_bn=True, freeze_at=-1):
if frozen_bn:
BatchNorm = FrozenBatchNorm2d
else:
BatchNorm = nn.BatchNorm2d
if name == 'mobilenetv2':
return MobileNetV2(16, BatchNorm, freeze_at=freeze_at)
elif name == 'mobilenetv3':
return MobileNetV3Large(16, BatchNorm, freeze_at=freeze_at)
elif name == 'resnet50':
return ResNet50(16, BatchNorm, freeze_at=freeze_at)
elif name == 'resnet101':
return ResNet101(16, BatchNorm, freeze_at=freeze_at)
elif name == 'resnest50':
return resnest.resnest50(norm_layer=BatchNorm,
dilation=2,
freeze_at=freeze_at)
elif name == 'resnest101':
return resnest.resnest101(norm_layer=BatchNorm,
dilation=2,
freeze_at=freeze_at)
elif 'swin' in name:
return build_swin_model(name, freeze_at=freeze_at)
else:
raise NotImplementedError
|