Spaces:
Sleeping
Sleeping
File size: 2,183 Bytes
749745d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
from collections import OrderedDict
import torch
from torch import nn
from maskrcnn_benchmark.modeling import registry
from . import bert_model
from . import rnn_model
from . import clip_model
from . import word_utils
from . import roberta_fused_model
from . import roberta_fused_model_v2
from . import roberta_fused_model_tiny
@registry.LANGUAGE_BACKBONES.register("bert-base-uncased")
def build_bert_backbone(cfg):
body = bert_model.BertEncoder(cfg)
model = nn.Sequential(OrderedDict([("body", body)]))
return model
@registry.LANGUAGE_BACKBONES.register("roberta-base")
def build_bert_backbone(cfg):
body = bert_model.BertEncoder(cfg)
model = nn.Sequential(OrderedDict([("body", body)]))
return model
@registry.LANGUAGE_BACKBONES.register("rnn")
def build_rnn_backbone(cfg):
body = rnn_model.RNNEnoder(cfg)
model = nn.Sequential(OrderedDict([("body", body)]))
return model
@registry.LANGUAGE_BACKBONES.register("clip")
def build_clip_backbone(cfg):
body = clip_model.CLIPTransformer(cfg)
model = nn.Sequential(OrderedDict([("body", body)]))
return model
@registry.LANGUAGE_BACKBONES.register("roberta-fused")
def build_clip_backbone(cfg):
body = roberta_fused_model.RobertaFusedEncoder(cfg)
model = nn.Sequential(OrderedDict([("body", body)]))
return model
@registry.LANGUAGE_BACKBONES.register("roberta-fused-v2")
def build_clip_backbone(cfg):
body = roberta_fused_model_v2.RobertaFusedEncoder(cfg)
model = nn.Sequential(OrderedDict([("body", body)]))
return model
@registry.LANGUAGE_BACKBONES.register("roberta-fused-tiny")
def build_clip_backbone(cfg):
body = roberta_fused_model_tiny.RobertaFusedEncoder(cfg)
model = nn.Sequential(OrderedDict([("body", body)]))
return model
def build_backbone(cfg):
assert (
cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE in registry.LANGUAGE_BACKBONES
), "cfg.MODEL.LANGUAGE_BACKBONE.TYPE: {} is not registered in registry".format(
cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE
)
return registry.LANGUAGE_BACKBONES[cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE](cfg)
|