File size: 1,192 Bytes
db40549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from torch import nn

from timm.models import create_model

from .enable_cpe_support import enable_cpe


def create_model_from_args(args) -> nn.Module:
    in_chans = 3
    if args.in_chans is not None:
        in_chans = args.in_chans
    elif args.input_size is not None:
        in_chans = args.input_size[0]

    model = create_model(
        args.model,
        pretrained=args.pretrained,
        in_chans=in_chans,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        scriptable=args.torchscript,
        checkpoint_path=args.initial_checkpoint,
        **args.model_kwargs,
    )

    assert not args.cls_token_per_teacher or args.cpe_max_size is not None, "CPE must be enabled for multiple CLS tokens!"

    if args.cpe_max_size is not None:
        enable_cpe(model,
                   args.cpe_max_size,
                   num_cls_tokens=len(args.teachers) if args.cls_token_per_teacher else 1,
                   register_multiple=args.register_multiple,
        )

    return model