mranzinger commited on
Commit
28c5370
·
verified ·
1 Parent(s): abc42a0
config.json CHANGED
@@ -354,7 +354,7 @@
354
  432
355
  ],
356
  "torch_dtype": "bfloat16",
357
- "transformers_version": "4.37.2",
358
  "version": "radio_v2.1",
359
  "vitdet_window_size": null
360
  }
 
354
  432
355
  ],
356
  "torch_dtype": "bfloat16",
357
+ "transformers_version": "4.40.1",
358
  "version": "radio_v2.1",
359
  "vitdet_window_size": null
360
  }
enable_spectral_reparam.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import getLogger
2
+ import math
3
+ import os
4
+ from typing import Union, Tuple
5
+ from types import MethodType
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torch.nn.utils import parametrize
11
+ from torch.nn.utils.parametrizations import _SpectralNorm
12
+
13
+ from timm.models.vision_transformer import Attention, Mlp
14
+
15
+ _EPS = 1e-5
16
+
17
+
18
+ class _SNReweight(_SpectralNorm):
19
+ def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, alpha: float = 0.05, version: int = 2, **kwargs):
20
+ super().__init__(weight, *args, **kwargs)
21
+
22
+ self.alpha = alpha
23
+ self.version = version
24
+ self.register_buffer('_sn_version', torch.tensor(version))
25
+
26
+ if init_norm_to_current:
27
+ # This will set the numerator to match the denominator, which should preserve the original values
28
+ init_scale = self._get_sigma(weight).item()
29
+ else:
30
+ init_scale = 1.0
31
+
32
+ if version == 1:
33
+ init_value = init_scale
34
+ elif version == 2:
35
+ t = init_scale - alpha
36
+ if t < _EPS:
37
+ getLogger("spectral_reparam").warn(f'The initialized spectral norm {init_scale} is too small to be represented. Setting to {_EPS} instead.')
38
+ t = _EPS
39
+
40
+ init_value = math.log(math.exp(t) - 1)
41
+ else:
42
+ raise ValueError(f'Unsupported version: {version}')
43
+
44
+ # Make 2D so that weight decay gets applied
45
+ self.scale = nn.Parameter(torch.tensor([[init_value]], dtype=torch.float32, device=weight.device))
46
+
47
+ # Re-implementing this because we need to make division by sigma safe
48
+ def _get_sigma(self, weight: torch.Tensor) -> torch.Tensor:
49
+ if weight.ndim == 1:
50
+ # Faster and more exact path, no need to approximate anything
51
+ sigma = weight.norm()
52
+ else:
53
+ weight_mat = self._reshape_weight_to_matrix(weight)
54
+ if self.training:
55
+ self._power_method(weight_mat, self.n_power_iterations)
56
+ # See above on why we need to clone
57
+ u = self._u.clone(memory_format=torch.contiguous_format)
58
+ v = self._v.clone(memory_format=torch.contiguous_format)
59
+ # The proper way of computing this should be through F.bilinear, but
60
+ # it seems to have some efficiency issues:
61
+ # https://github.com/pytorch/pytorch/issues/58093
62
+ sigma = torch.dot(u, torch.mv(weight_mat, v))
63
+
64
+ return sigma + self.eps
65
+
66
+ def forward(self, weight: torch.Tensor, *args, **kwargs):
67
+ dtype = weight.dtype
68
+ sigma = self._get_sigma(weight, *args, **kwargs)
69
+
70
+ if self.version == 1:
71
+ scale = self.scale
72
+ elif self.version == 2:
73
+ scale = F.softplus(self.scale) + self.alpha
74
+ else:
75
+ raise ValueError(f'Unsupported version: {self.version}')
76
+
77
+ scale = scale.float() / sigma.float()
78
+
79
+ y = weight * scale
80
+
81
+ if dtype in (torch.float16, torch.bfloat16):
82
+ y = y.to(dtype)
83
+ return y
84
+
85
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
86
+ version_key = f'{prefix}_sn_version'
87
+ if version_key not in state_dict:
88
+ self.version = 1
89
+ state_dict[version_key] = torch.tensor(1)
90
+ return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
91
+
92
+
93
+ class _AttnSNReweight(nn.Module):
94
+ def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, renorm_values: bool = False, **kwargs):
95
+ super().__init__()
96
+
97
+ parts = weight.split(weight.shape[0] // 3, dim=0)
98
+
99
+ ct = 2 if not renorm_values else 3
100
+
101
+ self.parts = nn.ModuleList([
102
+ _SNReweight(p, *args, init_norm_to_current=init_norm_to_current, **kwargs) if i < ct else nn.Identity()
103
+ for i, p in enumerate(parts)
104
+ ])
105
+
106
+ def forward(self, weight: torch.Tensor, *args, **kwargs):
107
+ parts = weight.split(weight.shape[0] // 3, dim=0)
108
+
109
+ parts = [
110
+ fn(p)
111
+ for fn, p in zip(self.parts, parts)
112
+ ]
113
+
114
+ return torch.cat(parts, dim=0)
115
+
116
+
117
+ def enable_spectral_reparam(model: nn.Module,
118
+ n_power_iterations: int = 1,
119
+ eps: float = 1e-6,
120
+ init_norm_to_current: bool = False,
121
+ renorm_values: bool = True,
122
+ renorm_mlp: bool = True):
123
+ # print('Enabling spectral reparametrization')
124
+ for mod in model.modules():
125
+ if isinstance(mod, Attention):
126
+ parametrize.register_parametrization(
127
+ mod.qkv,
128
+ 'weight',
129
+ _AttnSNReweight(mod.qkv.weight, n_power_iterations, dim=0, eps=eps, init_norm_to_current=init_norm_to_current, renorm_values=renorm_values),
130
+ )
131
+ pass
132
+ elif isinstance(mod, Mlp) and renorm_mlp:
133
+ parametrize.register_parametrization(
134
+ mod.fc1,
135
+ 'weight',
136
+ _SNReweight(mod.fc1.weight, n_power_iterations, dim=0, eps=eps, init_norm_to_current=init_norm_to_current),
137
+ )
138
+ parametrize.register_parametrization(
139
+ mod.fc2,
140
+ 'weight',
141
+ _SNReweight(mod.fc2.weight, n_power_iterations, dim=0, eps=eps, init_norm_to_current=init_norm_to_current),
142
+ )
143
+ pass
144
+
145
+
146
+ def configure_spectral_reparam_from_args(model: nn.Module, args):
147
+ spectral_reparam = getattr(args, 'spectral_reparam', False)
148
+ if isinstance(spectral_reparam, bool) and spectral_reparam:
149
+ enable_spectral_reparam(model, init_norm_to_current=args.pretrained)
150
+ elif isinstance(spectral_reparam, dict):
151
+ enable_spectral_reparam(
152
+ model,
153
+ n_power_iterations=spectral_reparam.get('n_power_iterations', 1),
154
+ eps=spectral_reparam.get('eps', 1e-12),
155
+ init_norm_to_current=args.pretrained,
156
+ )
157
+
158
+
159
+ def disable_spectral_reparam(model: nn.Module):
160
+ for mod in model.modules():
161
+ if isinstance(mod, Attention):
162
+ parametrize.remove_parametrizations(mod.qkv, 'weight')
163
+ pass
164
+ elif isinstance(mod, Mlp):
165
+ parametrize.remove_parametrizations(mod.fc1, 'weight')
166
+ parametrize.remove_parametrizations(mod.fc2, 'weight')
167
+ pass
168
+
169
+
170
+ if __name__ == '__main__':
171
+ import argparse
172
+ from . import radio_model as create_model
173
+
174
+ parser = argparse.ArgumentParser(description='Remove parametrization from state dict')
175
+ parser.add_argument('--checkpoint', type=str, required=True, help='The checkpoint to load')
176
+ parser.add_argument('--output', type=str, default='', help='Where to store the checkpoint')
177
+ parser.add_argument('--release', default=False, action='store_true', help='Prune extraneous checkpoint fields')
178
+ parser.add_argument('--strict', default=False, action='store_true', help='Strictly load the state dict')
179
+
180
+ args = parser.parse_args()
181
+
182
+ if not args.output:
183
+ chk_dir, chk_name = os.path.split(args.checkpoint)
184
+ args.output = os.path.join(chk_dir, f'clean_{chk_name}')
185
+ print(f'Set output to "{args.output}"')
186
+
187
+ chk = torch.load(args.checkpoint, map_location='cpu', mmap=True)
188
+
189
+ model = create_model.create_model_from_args(chk['args'])
190
+
191
+ key = 'base_model.'
192
+ mod_state = dict()
193
+ extra_state = dict()
194
+ for k, v in chk['state_dict'].items():
195
+ if k.startswith(key):
196
+ mod_state[k[len(key):]] = v
197
+ else:
198
+ extra_state[k] = v
199
+
200
+ chk_load_info = model.load_state_dict(mod_state, strict=args.strict)
201
+ if chk_load_info.unexpected_keys or chk_load_info.missing_keys:
202
+ print(chk_load_info)
203
+
204
+ if chk['args'].spectral_reparam:
205
+ disable_spectral_reparam(model)
206
+
207
+ if hasattr(chk['args'], 'dtype'):
208
+ model.to(dtype=chk['args'].dtype)
209
+
210
+ mod_state = model.state_dict()
211
+ final_state = dict()
212
+ final_state.update({f'{key}{k}': v for k, v in mod_state.items()})
213
+ final_state.update(extra_state)
214
+
215
+ chk['state_dict'] = final_state
216
+ chk['args'].spectral_reparam = False
217
+
218
+ if args.release:
219
+ chk = {
220
+ 'arch': chk['arch'],
221
+ 'epoch': chk['epoch'],
222
+ 'state_dict': chk['state_dict'],
223
+ 'args': chk['args'],
224
+ }
225
+
226
+ torch.save(chk, args.output)
227
+ pass
eradio_model.py CHANGED
@@ -1162,6 +1162,9 @@ class FasterViT(nn.Module):
1162
  return {'rpb'}
1163
 
1164
  def forward_features(self, x):
 
 
 
1165
  x = self.patch_embed(x)
1166
  full_features = None
1167
  for il, level in enumerate(self.levels):
 
1162
  return {'rpb'}
1163
 
1164
  def forward_features(self, x):
1165
+ _, _, H, W = x.shape
1166
+ if H % 32 != 0 or W % 32 != 0:
1167
+ raise ValueError(f"E-RADIO requires input dimensions to be divisible by 32 but got H x W: {H} x {W}")
1168
  x = self.patch_embed(x)
1169
  full_features = None
1170
  for il, level in enumerate(self.levels):
hf_model.py CHANGED
@@ -12,7 +12,7 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  from collections import namedtuple
15
- from typing import Optional, List, Union
16
 
17
  from timm.models import VisionTransformer
18
  import torch
@@ -20,6 +20,7 @@ from transformers import PretrainedConfig, PreTrainedModel
20
 
21
 
22
  from .common import RESOURCE_MAP, DEFAULT_VERSION
 
23
  # Force import of eradio_model in order to register it.
24
  from .eradio_model import eradio
25
  from .radio_model import create_model_from_args
@@ -122,5 +123,14 @@ class RADIOModel(PreTrainedModel):
122
  def input_conditioner(self) -> InputConditioner:
123
  return self.radio_model.input_conditioner
124
 
 
 
 
 
 
 
 
 
 
125
  def forward(self, x: torch.Tensor):
126
  return self.radio_model.forward(x)
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  from collections import namedtuple
15
+ from typing import Callable, Optional, List, Union
16
 
17
  from timm.models import VisionTransformer
18
  import torch
 
20
 
21
 
22
  from .common import RESOURCE_MAP, DEFAULT_VERSION
23
+
24
  # Force import of eradio_model in order to register it.
25
  from .eradio_model import eradio
26
  from .radio_model import create_model_from_args
 
123
  def input_conditioner(self) -> InputConditioner:
124
  return self.radio_model.input_conditioner
125
 
126
+ @input_conditioner.setter
127
+ def input_conditioner(self, v: InputConditioner):
128
+ self.radio_model.input_conditioner = v
129
+
130
+ def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]:
131
+ ret = self.input_conditioner
132
+ self.input_conditioner = nn.Identity()
133
+ return ret
134
+
135
  def forward(self, x: torch.Tensor):
136
  return self.radio_model.forward(x)
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:df75c4351ef558af885acbf0d21ad53fd273e3720b5ae3d1e7d4a23df1ca9ed1
3
- size 1306581088
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03534ca8b7a26b0cbf69073b944fdd47f41aedad1b3b01c1e387c27191abc8de
3
+ size 1304018880
radio_model.py CHANGED
@@ -18,6 +18,7 @@ from .input_conditioner import InputConditioner
18
  from . import extra_timm_models
19
  from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
20
  from . import eradio_model
 
21
 
22
 
23
  class Resolution(NamedTuple):
@@ -106,6 +107,12 @@ class RADIOModel(nn.Module):
106
  fn()
107
 
108
  def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
 
 
 
 
 
 
109
  x = self.input_conditioner(x)
110
  y = self.model.forward_features(x)
111
 
@@ -180,6 +187,11 @@ def create_model_from_args(args) -> nn.Module:
180
  **args.model_kwargs,
181
  )
182
 
 
 
 
 
 
183
  assert (
184
  not args.cls_token_per_teacher or args.cpe_max_size is not None
185
  ), "CPE must be enabled for multiple CLS tokens!"
@@ -192,4 +204,7 @@ def create_model_from_args(args) -> nn.Module:
192
  register_multiple=args.register_multiple,
193
  )
194
 
 
 
 
195
  return model
 
18
  from . import extra_timm_models
19
  from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
20
  from . import eradio_model
21
+ from .enable_spectral_reparam import configure_spectral_reparam_from_args
22
 
23
 
24
  class Resolution(NamedTuple):
 
107
  fn()
108
 
109
  def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
110
+ res_step = self.min_resolution_step
111
+ if res_step is not None and (x.shape[-2] % res_step != 0 or x.shape[-1] % res_step != 0):
112
+ raise ValueError('The input resolution must be a multiple of `self.min_resolution_step`. '
113
+ '`self.get_nearest_supported_resolution(<height>, <width>) is provided as a convenience API. '
114
+ f'Input: {x.shape[-2:]}, Nearest: {self.get_nearest_supported_resolution(*x.shape[-2:])}')
115
+
116
  x = self.input_conditioner(x)
117
  y = self.model.forward_features(x)
118
 
 
187
  **args.model_kwargs,
188
  )
189
 
190
+ if hasattr(model, 'norm') and not getattr(args, 'model_norm', False):
191
+ model.norm = nn.Identity()
192
+
193
+ model.head = nn.Identity()
194
+
195
  assert (
196
  not args.cls_token_per_teacher or args.cpe_max_size is not None
197
  ), "CPE must be enabled for multiple CLS tokens!"
 
204
  register_multiple=args.register_multiple,
205
  )
206
 
207
+ if args.spectral_reparam:
208
+ configure_spectral_reparam_from_args(model, args)
209
+
210
  return model