|
from networks.encoders.mobilenetv2 import MobileNetV2 |
|
from networks.encoders.mobilenetv3 import MobileNetV3Large |
|
from networks.encoders.resnet import ResNet101, ResNet50 |
|
from networks.encoders.resnest import resnest |
|
from networks.encoders.swin import build_swin_model |
|
from networks.layers.normalization import FrozenBatchNorm2d |
|
from torch import nn |
|
|
|
|
|
def build_encoder(name, frozen_bn=True, freeze_at=-1): |
|
if frozen_bn: |
|
BatchNorm = FrozenBatchNorm2d |
|
else: |
|
BatchNorm = nn.BatchNorm2d |
|
|
|
if name == 'mobilenetv2': |
|
return MobileNetV2(16, BatchNorm, freeze_at=freeze_at) |
|
elif name == 'mobilenetv3': |
|
return MobileNetV3Large(16, BatchNorm, freeze_at=freeze_at) |
|
elif name == 'resnet50': |
|
return ResNet50(16, BatchNorm, freeze_at=freeze_at) |
|
elif name == 'resnet101': |
|
return ResNet101(16, BatchNorm, freeze_at=freeze_at) |
|
elif name == 'resnest50': |
|
return resnest.resnest50(norm_layer=BatchNorm, |
|
dilation=2, |
|
freeze_at=freeze_at) |
|
elif name == 'resnest101': |
|
return resnest.resnest101(norm_layer=BatchNorm, |
|
dilation=2, |
|
freeze_at=freeze_at) |
|
elif 'swin' in name: |
|
return build_swin_model(name, freeze_at=freeze_at) |
|
else: |
|
raise NotImplementedError |
|
|