gheinrich commited on
Commit
988c610
·
1 Parent(s): f1d53f2

Upload model

Browse files
Files changed (3) hide show
  1. hf_model.py +19 -27
  2. model.py +54 -6
  3. pytorch_model.bin +2 -2
hf_model.py CHANGED
@@ -20,6 +20,7 @@ from transformers import PretrainedConfig, PreTrainedModel
20
 
21
 
22
  from .model import create_model_from_args
 
23
  from .input_conditioner import get_default_conditioner, InputConditioner
24
 
25
 
@@ -42,7 +43,11 @@ class RADIOConfig(PretrainedConfig):
42
 
43
 
44
  class RADIOModel(PreTrainedModel):
45
- """Pretrained Hugging Face model for RADIO."""
 
 
 
 
46
 
47
  config_class = RADIOConfig
48
 
@@ -52,32 +57,19 @@ class RADIOModel(PreTrainedModel):
52
  RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
53
  args = RADIOArgs(**config.args)
54
  self.config = config
55
- self.model = create_model_from_args(args)
56
- self.input_conditioner: InputConditioner = get_default_conditioner()
57
-
58
- def forward(self, x: torch.Tensor):
59
- x = self.input_conditioner(x)
60
 
61
- y = self.model.forward_features(x)
 
 
 
 
 
62
 
63
- if isinstance(y, (list, tuple)):
64
- summary, all_feat = y
65
- elif isinstance(self.model, VisionTransformer):
66
- patch_gen = getattr(self.model, "patch_generator", None)
67
- if patch_gen is not None:
68
- summary = y[:, : patch_gen.num_cls_tokens].flatten(1)
69
- all_feat = y[:, patch_gen.num_skip :]
70
- elif self.model.global_pool == "avg":
71
- summary = y[:, self.model.num_prefix_tokens :].mean(dim=1)
72
- all_feat = y
73
- else:
74
- summary = y[:, 0]
75
- all_feat = y[:, 1:]
76
- else:
77
- raise ValueError("Unsupported model type")
78
 
79
- if self.config.return_summary and self.config.return_spatial_features:
80
- return summary, all_feat
81
- elif self.config.return_summary:
82
- return summary
83
- return all_feat
 
20
 
21
 
22
  from .model import create_model_from_args
23
+ from .model import RADIOModel as RADIOModelBase
24
  from .input_conditioner import get_default_conditioner, InputConditioner
25
 
26
 
 
43
 
44
 
45
  class RADIOModel(PreTrainedModel):
46
+ """Pretrained Hugging Face model for RADIO.
47
+
48
+ This classes inherits from both PreTrainedModel, which provides
49
+ HuggingFace's functionality for loading and saving models.
50
+ """
51
 
52
  config_class = RADIOConfig
53
 
 
57
  RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
58
  args = RADIOArgs(**config.args)
59
  self.config = config
60
+ model = create_model_from_args(args)
61
+ input_conditioner: InputConditioner = get_default_conditioner()
 
 
 
62
 
63
+ self.radio_model = RADIOModelBase(
64
+ model,
65
+ input_conditioner,
66
+ config.return_summary,
67
+ config.return_spatial_features,
68
+ )
69
 
70
+ @property
71
+ def model(self) -> VisionTransformer:
72
+ return self.radio_model.model
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ def forward(self, x: torch.Tensor):
75
+ return self.radio_model.forward(x)
 
 
 
model.py CHANGED
@@ -6,11 +6,56 @@
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
 
9
  from torch import nn
10
 
11
- from timm.models import create_model
12
 
13
  from .enable_cpe_support import enable_cpe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  def create_model_from_args(args) -> nn.Module:
@@ -36,13 +81,16 @@ def create_model_from_args(args) -> nn.Module:
36
  **args.model_kwargs,
37
  )
38
 
39
- assert not args.cls_token_per_teacher or args.cpe_max_size is not None, "CPE must be enabled for multiple CLS tokens!"
 
 
40
 
41
  if args.cpe_max_size is not None:
42
- enable_cpe(model,
43
- args.cpe_max_size,
44
- num_cls_tokens=len(args.teachers) if args.cls_token_per_teacher else 1,
45
- register_multiple=args.register_multiple,
 
46
  )
47
 
48
  return model
 
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
9
+ import torch
10
  from torch import nn
11
 
12
+ from timm.models import create_model, VisionTransformer
13
 
14
  from .enable_cpe_support import enable_cpe
15
+ from .input_conditioner import InputConditioner
16
+
17
+
18
+ class RADIOModel(nn.Module):
19
+ def __init__(
20
+ self,
21
+ model: nn.Module,
22
+ input_conditioner: InputConditioner,
23
+ return_summary: bool,
24
+ return_spatial_features: bool,
25
+ ):
26
+ super().__init__()
27
+
28
+ self.model = model
29
+ self.input_conditioner = input_conditioner
30
+ self.return_summary = return_summary
31
+ self.return_spatial_features = return_spatial_features
32
+
33
+ def forward(self, x: torch.Tensor):
34
+ x = self.input_conditioner(x)
35
+
36
+ y = self.model.forward_features(x)
37
+
38
+ if isinstance(y, (list, tuple)):
39
+ summary, all_feat = y
40
+ elif isinstance(self.model, VisionTransformer):
41
+ patch_gen = getattr(self.model, "patch_generator", None)
42
+ if patch_gen is not None:
43
+ summary = y[:, : patch_gen.num_cls_tokens].flatten(1)
44
+ all_feat = y[:, patch_gen.num_skip :]
45
+ elif self.model.global_pool == "avg":
46
+ summary = y[:, self.model.num_prefix_tokens :].mean(dim=1)
47
+ all_feat = y
48
+ else:
49
+ summary = y[:, 0]
50
+ all_feat = y[:, 1:]
51
+ else:
52
+ raise ValueError("Unsupported model type")
53
+
54
+ if self.return_summary and self.return_spatial_features:
55
+ return summary, all_feat
56
+ elif self.return_summary:
57
+ return summary
58
+ return all_feat
59
 
60
 
61
  def create_model_from_args(args) -> nn.Module:
 
81
  **args.model_kwargs,
82
  )
83
 
84
+ assert (
85
+ not args.cls_token_per_teacher or args.cpe_max_size is not None
86
+ ), "CPE must be enabled for multiple CLS tokens!"
87
 
88
  if args.cpe_max_size is not None:
89
+ enable_cpe(
90
+ model,
91
+ args.cpe_max_size,
92
+ num_cls_tokens=len(args.teachers) if args.cls_token_per_teacher else 1,
93
+ register_multiple=args.register_multiple,
94
  )
95
 
96
  return model
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:242360b04b7f78204b535ce8a96e28ef3316520d55be43e6873fd45696fb9d61
3
- size 2662619441
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad369b92359d9a42f93f6bbb9be2191f79b4b6fc923fdd31d992ca32336f608d
3
+ size 2662624177