Someshfengde commited on
Commit
31f23f1
·
1 Parent(s): 6b340f0

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. script.py +1905 -4
script.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import pandas as pd
2
  import numpy as np
3
  import os
@@ -6,9 +7,1894 @@ import timm
6
  import torchvision.transforms as T
7
  from PIL import Image
8
  import torch
9
- from create_model import HieraForImageClassification
10
  from transformers import AutoImageProcessor
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def is_gpu_available():
13
  """Check if the python package `onnxruntime-gpu` is installed."""
14
  return torch.cuda.is_available()
@@ -24,8 +1910,8 @@ class PytorchWorker:
24
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
  print(f"Using devide: {self.device}")
26
 
27
- image_processor = AutoImageProcessor.from_pretrained("./hiera_model")
28
- model = HieraForImageClassification.from_pretrained("./hiera_model", num_labels =1784 ).to(self.device).eval()
29
 
30
  return model, image_processor
31
 
@@ -62,7 +1948,7 @@ def make_submission(test_metadata, model_path, model_name, output_csv_path="./su
62
  user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
63
  user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)
64
 
65
-
66
  if __name__ == "__main__":
67
 
68
  import zipfile
@@ -82,3 +1968,18 @@ if __name__ == "__main__":
82
  model_name=MODEL_NAME
83
  )
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
  import pandas as pd
3
  import numpy as np
4
  import os
 
7
  import torchvision.transforms as T
8
  from PIL import Image
9
  import torch
 
10
  from transformers import AutoImageProcessor
11
 
12
+ #%%
13
+ # coding=utf-8
14
+ # Copyright 2024 Meta and The HuggingFace Inc. team. All rights reserved.
15
+ #
16
+ # Licensed under the Apache License, Version 2.0 (the "License");
17
+ # you may not use this file except in compliance with the License.
18
+ # You may obtain a copy of the License at
19
+ #
20
+ # http://www.apache.org/licenses/LICENSE-2.0
21
+ #
22
+ # Unless required by applicable law or agreed to in writing, software
23
+ # distributed under the License is distributed on an "AS IS" BASIS,
24
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25
+ # See the License for the specific language governing permissions and
26
+ # limitations under the License.
27
+ """ PyTorch Hiera model."""
28
+
29
+
30
+ import math
31
+ from dataclasses import dataclass
32
+ from typing import Dict, List, Optional, Tuple, Union
33
+
34
+ import torch
35
+ import torch.utils.checkpoint
36
+ from torch import nn
37
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
38
+
39
+ import transformers
40
+
41
+ from transformers.activations import ACT2FN
42
+ from transformers.modeling_outputs import (
43
+ BackboneOutput,
44
+ BaseModelOutput,
45
+ BaseModelOutputWithPooling,
46
+ ImageClassifierOutput,
47
+ ModelOutput,
48
+ )
49
+ from transformers.modeling_utils import PreTrainedModel
50
+ from transformers.utils import (
51
+ add_code_sample_docstrings,
52
+ add_start_docstrings,
53
+ add_start_docstrings_to_model_forward,
54
+ logging,
55
+ replace_return_docstrings,
56
+ )
57
+ from transformers.utils.backbone_utils import BackboneMixin
58
+ # coding=utf-8
59
+ # Copyright 2024 Meta and The HuggingFace Inc. team. All rights reserved.
60
+ #
61
+ # Licensed under the Apache License, Version 2.0 (the "License");
62
+ # you may not use this file except in compliance with the License.
63
+ # You may obtain a copy of the License at
64
+ #
65
+ # http://www.apache.org/licenses/LICENSE-2.0
66
+ #
67
+ # Unless required by applicable law or agreed to in writing, software
68
+ # distributed under the License is distributed on an "AS IS" BASIS,
69
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70
+ # See the License for the specific language governing permissions and
71
+ # limitations under the License.
72
+ """ Hiera model configuration"""
73
+
74
+ from collections import OrderedDict
75
+ from typing import Mapping
76
+
77
+ from packaging import version
78
+
79
+ from transformers.configuration_utils import PretrainedConfig
80
+ from transformers.onnx import OnnxConfig
81
+ from transformers.utils import logging
82
+ from transformers.utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
83
+
84
+
85
+ logger = logging.get_logger(__name__)
86
+
87
+ HIERA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
88
+ "EduardoPacheco/hiera-tiny-224": "https://huggingface.co/EduardoPacheco/hiera-tiny-224/resolve/main/config.json",
89
+ }
90
+
91
+
92
+ class HieraConfig(BackboneConfigMixin, PretrainedConfig):
93
+ r"""
94
+ This is the configuration class to store the configuration of a [`HieraModel`]. It is used to instantiate an Hiera
95
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
96
+ defaults will yield a similar configuration to that of the Hiera
97
+ [EduardoPacheco/hiera-base-224](https://huggingface.co/EduardoPacheco/hiera-base-224) architecture.
98
+
99
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
100
+ documentation from [`PretrainedConfig`] for more information.
101
+
102
+
103
+ Args:
104
+ embed_dim (`int`, *optional*, defaults to 96):
105
+ Dimensionality of patch embedding.
106
+ input_size (`list(int)`, *optional*, defaults to `[224, 224]`):
107
+ The size (resolution) of input in the format (height, width) for images
108
+ and (frames, height, width) for videos.
109
+ patch_kernel (`list(int)`, *optional*, defaults to `[7, 7]`):
110
+ The size (resolution) of each patch.
111
+ patch_stride (`list(int)`, *optional*, defaults to `[4, 4]`):
112
+ The stride of the patch.
113
+ patch_padding (`list(int)`, *optional*, defaults to `[3, 3]`):
114
+ The padding of the patch.
115
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
116
+ The ratio of mlp hidden dim to embedding dim.
117
+ depths (`list(int)`, *optional*, defaults to `[2, 3, 16, 3]`):
118
+ Depth of each layer in the Transformer encoder.
119
+ initial_num_heads (`int`, *optional*, defaults to 1):
120
+ Initial number of attention heads in the first layer of the Transformer encoder.
121
+ num_head_multiplier (`float`, *optional*, defaults to 2.0):
122
+ The multiplier to the number of attention heads in each layer of the Transformer encoder.
123
+ embed_dim_multiplier (`float`, *optional*, defaults to 2.0):
124
+ The multiplier to the dimensionality of patch embedding in each layer of the Transformer encoder.
125
+ num_query_pool (`int`, *optional*, defaults to 3):
126
+ The number of query pool stages.
127
+ query_stride (`list(int)`, *optional*, defaults to `[2, 2]`):
128
+ The stride of the query pool.
129
+ masked_unit_size (`list(int)`, *optional*, defaults to `[8, 8]`):
130
+ The size of the masked unit.
131
+ masked_unit_attention (`list(bool)`, *optional*, defaults to `[True, True, False, False]`):
132
+ Whether to use masked unit attention in each layer of the Transformer encoder.
133
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
134
+ The drop path rate.
135
+ sep_pos_embed (`bool`, *optional*, defaults to `False`):
136
+ Whether to use separate position embedding for temporal and spatial dimensions. Must be `True` for videos.
137
+ and `False` for images.
138
+ num_channels (`int`, *optional*, defaults to 3):
139
+ The number of input channels.
140
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
141
+ The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
142
+ `"selu"` and `"gelu_new"` are supported.
143
+ initializer_range (`float`, *optional*, defaults to 0.02):
144
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices and
145
+ the zero_initializer for initializing all bias vectors.
146
+ layer_norm_init (`float`, *optional*, defaults to 1.0):
147
+ The initial weight value for layer normalization layers.
148
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
149
+ The epsilon used by the layer normalization layers.
150
+ decoder_embed_dim (`int`, *optional*):
151
+ Dimensionality of decoder embeddings for MAE pretraining.
152
+ decoder_depth (`int`, *optional*):
153
+ Depth of the decoder for MAE pretraining.
154
+ decoder_num_heads (`int`, *optional*):
155
+ Number of attention heads in each layer of the decoder for MAE pretraining.
156
+ norm_pix_loss (`bool`, *optional*, defaults to `True`):
157
+ Whether to normalize the pixel loss by the number of pixels.
158
+ mask_ratio (`float`, *optional*, defaults to 0.6):
159
+ The ratio of masked tokens in the input.
160
+ out_features (`List[str]`, *optional*):
161
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
162
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
163
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
164
+ same order as defined in the `stage_names` attribute.
165
+ out_indices (`List[int]`, *optional*):
166
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
167
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
168
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
169
+ same order as defined in the `stage_names` attribute.
170
+
171
+
172
+ Example:
173
+
174
+ ```python
175
+ >>> from transformers import HieraConfig, HieraModel
176
+
177
+ >>> # Initializing a Hiera hiera-base-patch16-224 style configuration
178
+ >>> configuration = HieraConfig()
179
+
180
+ >>> # Initializing a model (with random weights) from the hiera-base-patch16-224 style configuration
181
+ >>> model = HieraModel(configuration)
182
+
183
+ >>> # Accessing the model configuration
184
+ >>> configuration = model.config
185
+ ```"""
186
+
187
+ model_type = "hiera"
188
+
189
+ attribute_map = {"num_hidden_layers": "num_layers"}
190
+
191
+ def __init__(
192
+ self,
193
+ embed_dim=96,
194
+ input_size=[224, 224],
195
+ patch_kernel=[7, 7],
196
+ patch_stride=[4, 4],
197
+ patch_padding=[3, 3],
198
+ mlp_ratio=4.0,
199
+ depths=[2, 3, 16, 3],
200
+ initial_num_heads=1,
201
+ num_head_multiplier=2.0,
202
+ embed_dim_multiplier=2.0,
203
+ num_query_pool=3,
204
+ query_stride=[2, 2],
205
+ masked_unit_size=[8, 8],
206
+ masked_unit_attention=[True, True, False, False],
207
+ drop_path_rate=0.0,
208
+ sep_pos_embed=False,
209
+ num_channels=3,
210
+ hidden_act="gelu",
211
+ initializer_range=0.02,
212
+ layer_norm_init=1.0,
213
+ layer_norm_eps=1e-6,
214
+ decoder_embed_dim=None,
215
+ decoder_depth=None,
216
+ decoder_num_heads=None,
217
+ norm_pix_loss=True,
218
+ mask_ratio=0.6,
219
+ out_features=None,
220
+ out_indices=None,
221
+ **kwargs,
222
+ ):
223
+ super().__init__(**kwargs)
224
+ if masked_unit_size[0] % query_stride[0] ** (len(depths) - 1) != 0:
225
+ raise ValueError(
226
+ f"masked_unit_size[0] ({masked_unit_size[0]}) must be divisible by query_stride[0] ({query_stride[0]}) "
227
+ f"raised to the power of the number of layers ({len(depths) - 1})"
228
+ )
229
+
230
+ if num_query_pool >= len(depths):
231
+ raise ValueError(
232
+ f"num_query_pool ({num_query_pool}) must be less than the number of layers ({len(depths)})"
233
+ )
234
+
235
+ self.embed_dim = embed_dim
236
+ self.input_size = input_size
237
+ self.patch_kernel = patch_kernel
238
+ self.patch_stride = patch_stride
239
+ self.patch_padding = patch_padding
240
+ self.mlp_ratio = mlp_ratio
241
+ self.depths = depths
242
+ self.num_layers = len(depths)
243
+ self.initial_num_heads = initial_num_heads
244
+ self.num_head_multiplier = num_head_multiplier
245
+ self.embed_dim_multiplier = embed_dim_multiplier
246
+ self.num_query_pool = num_query_pool
247
+ self.query_stride = query_stride
248
+ self.masked_unit_size = masked_unit_size
249
+ self.masked_unit_attention = masked_unit_attention
250
+ self.drop_path_rate = drop_path_rate
251
+ self.sep_pos_embed = sep_pos_embed
252
+ self.num_channels = num_channels
253
+ self.hidden_act = hidden_act
254
+ self.initializer_range = initializer_range
255
+ self.layer_norm_init = layer_norm_init
256
+ self.layer_norm_eps = layer_norm_eps
257
+ self.decoder_embed_dim = decoder_embed_dim
258
+ self.decoder_depth = decoder_depth
259
+ self.decoder_num_heads = decoder_num_heads
260
+ self.norm_pix_loss = norm_pix_loss
261
+ self.mask_ratio = mask_ratio
262
+ # we set the hidden_size attribute in order to make Hiera work with VisionEncoderDecoderModel
263
+ # this indicates the channel dimension after the last stage of the model
264
+ self.hidden_size = int(embed_dim * embed_dim_multiplier ** (len(depths) - 1))
265
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
266
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
267
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
268
+ )
269
+
270
+
271
+ class HieraOnnxConfig(OnnxConfig):
272
+ torch_onnx_minimum_version = version.parse("1.11")
273
+
274
+ @property
275
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
276
+ return OrderedDict(
277
+ [
278
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
279
+ ]
280
+ )
281
+
282
+ @property
283
+ def atol_for_validation(self) -> float:
284
+ return 1e-4
285
+
286
+ logger = logging.get_logger(__name__)
287
+
288
+ # General docstring
289
+ _CONFIG_FOR_DOC = "HieraConfig"
290
+
291
+ # Base docstring
292
+ _CHECKPOINT_FOR_DOC = "EduardoPacheco/hiera-tiny-224"
293
+ _EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
294
+
295
+ # Image classification docstring
296
+ _IMAGE_CLASS_CHECKPOINT = "EduardoPacheco/hiera-tiny-224-in1k"
297
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
298
+
299
+
300
+ HIERA_PRETRAINED_MODEL_ARCHIVE_LIST = [
301
+ "EduardoPacheco/hiera-tiny-224",
302
+ # See all Hiera models at https://huggingface.co/models?filter=hiera
303
+ ]
304
+
305
+
306
+ @dataclass
307
+ class HieraEncoderOutput(ModelOutput):
308
+ """
309
+ Hiera encoder's outputs, with potential hidden states and attentions.
310
+
311
+ Args:
312
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
313
+ Sequence of hidden-states at the output of the last layer of the model.
314
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
315
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
316
+ shape `(batch_size, sequence_length, hidden_size)`. Thesre are the unrolled hidden states of the model.
317
+
318
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
319
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
320
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
321
+ sequence_length)`.
322
+
323
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
324
+ heads.
325
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
326
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
327
+ shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.
328
+
329
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
330
+ include the spatial dimensions.
331
+ """
332
+
333
+ last_hidden_state: torch.FloatTensor = None
334
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
335
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
336
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
337
+
338
+
339
+ @dataclass
340
+ class HieraModelOutput(ModelOutput):
341
+ """
342
+ Hiera model's outputs that also contains a pooling of the last hidden states.
343
+
344
+ Args:
345
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
346
+ Sequence of hidden-states at the output of the last layer of the model.
347
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
348
+ Average pooling of the last layer hidden-state.
349
+ mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
350
+ Tensor indicating which patches are masked (0) and which are not (1).
351
+ ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
352
+ Tensor containing the original index of the (shuffled) masked patches.
353
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
354
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
355
+ shape `(batch_size, sequence_length, hidden_size)`. These are the unrolled hidden states of the model.
356
+
357
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
358
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
359
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
360
+ sequence_length)`.
361
+
362
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
363
+ heads.
364
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
365
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
366
+ shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.
367
+
368
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
369
+ include the spatial dimensions.
370
+ """
371
+
372
+ last_hidden_state: torch.FloatTensor = None
373
+ pooler_output: Optional[torch.FloatTensor] = None
374
+ mask: torch.LongTensor = None
375
+ ids_restore: torch.LongTensor = None
376
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
377
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
378
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
379
+
380
+
381
+ @dataclass
382
+ class HieraForImageClassificationOutput(ImageClassifierOutput):
383
+ """
384
+ Hiera image classification outputs.
385
+
386
+ Args:
387
+ loss (`torch.FloatTensor` of shape `(1,)`, `optional`):
388
+ Classification loss.
389
+ logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
390
+ Prediction scores of the classification head (logits of the output layer).
391
+ hidden_states (`tuple(torch.FloatTensor)`, `optional`):
392
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
393
+ shape `(batch_size, sequence_length, hidden_size)`. These are the unrolled hidden states of the model.
394
+
395
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
396
+ attentions (`tuple(torch.FloatTensor)`, `optional`):
397
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
398
+ sequence_length)`.
399
+
400
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
401
+ heads.
402
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, `optional`):
403
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
404
+ shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.
405
+
406
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
407
+ include the spatial dimensions.
408
+ """
409
+
410
+ loss: Optional[torch.FloatTensor] = None
411
+ logits: torch.FloatTensor = None
412
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
413
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
414
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
415
+
416
+
417
+ @dataclass
418
+ class HieraForPreTrainingOutput(ModelOutput):
419
+ """
420
+ Class for ViTMAEForPreTraining's outputs, with potential hidden states and attentions.
421
+
422
+ Args:
423
+ loss (`torch.FloatTensor` of shape `(1,)`):
424
+ Pixel reconstruction loss.
425
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
426
+ Pixel reconstruction logits.
427
+ mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
428
+ Tensor indicating which patches are masked (0) and which are not (1).
429
+ ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
430
+ Tensor containing the original index of the (shuffled) masked patches.
431
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
432
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
433
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
434
+ plus the initial embedding outputs.
435
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
436
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
437
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
438
+ the self-attention heads.
439
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
440
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
441
+ shape `(batch_size, height, width, hidden_size)`. Hidden-states of the model at the output of each layer
442
+ plus the initial embedding outputs reshaped to include the spatial dimensions.
443
+ """
444
+
445
+ loss: Optional[torch.FloatTensor] = None
446
+ logits: torch.FloatTensor = None
447
+ mask: torch.LongTensor = None
448
+ ids_restore: torch.LongTensor = None
449
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
450
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
451
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
452
+
453
+
454
+ # Taken from https://github.com/facebookresearch/hiera/blob/main/hiera/hiera_utils.py#L73
455
+ def conv_nd(n: int) -> nn.Module:
456
+ """
457
+ Returns a conv with nd (e.g., Conv2d for n=2). Work up to n=3.
458
+ If you wanted a 4d Hiera, you could probably just implement this for n=4. (no promises)
459
+ """
460
+ return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n]
461
+
462
+
463
+ # Taken from https://github.com/facebookresearch/hiera/blob/main/hiera/hiera_utils.py#L81
464
+ def do_pool(x: torch.Tensor, stride: int) -> torch.Tensor:
465
+ # Refer to `Unroll` to see how this performs a maxpool-Nd
466
+ return x.view(x.shape[0], stride, -1, x.shape[-1]).max(dim=1).values
467
+
468
+
469
+ class HieraPatchEmbeddings(nn.Module):
470
+ """
471
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
472
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
473
+ Transformer.
474
+ """
475
+
476
+ def __init__(self, config, is_mae: bool = False):
477
+ super().__init__()
478
+
479
+ # Support any number of spatial dimensions
480
+ self.spatial_dims = len(config.patch_kernel)
481
+ if self.spatial_dims not in (2, 3):
482
+ raise ValueError(
483
+ f"The number of dimensions of the input image should be 2 or 3, but got {self.spatial_dims}."
484
+ )
485
+ self.num_channels = config.num_channels
486
+ self.image_size = config.input_size[-2:]
487
+ self.tokens_spatial_shape = [i // s for i, s in zip(config.input_size, config.patch_stride)]
488
+ self.mask_spatial_shape = [i // s for i, s in zip(self.tokens_spatial_shape, config.masked_unit_size)]
489
+ self.mask_ratio = config.mask_ratio
490
+ self.is_mae = is_mae
491
+
492
+ self.projection = conv_nd(self.spatial_dims)(
493
+ self.num_channels,
494
+ config.embed_dim,
495
+ kernel_size=config.patch_kernel,
496
+ stride=config.patch_stride,
497
+ padding=config.patch_padding,
498
+ )
499
+
500
+ def masked_conv(self, pixel_values: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
501
+ """Zero-out the masked regions of the input before conv.
502
+ Prevents leakage of masked regions when using overlapping kernels.
503
+ """
504
+ if mask is None:
505
+ return self.projection(pixel_values)
506
+
507
+ target_size = pixel_values.shape[2:]
508
+ # Reshape mask to (batch_size, 1, mask_unit_height, mask_unit_width)
509
+ mask = mask.view(pixel_values.shape[0], 1, *self.mask_spatial_shape)
510
+
511
+ if len(mask.shape[2:]) != len(target_size):
512
+ raise ValueError(
513
+ f"The length of the spatial dimensions of the mask should match the one from input image, but got {len(mask.shape[2:])} and {len(target_size)}."
514
+ )
515
+
516
+ if mask.shape[2:] != target_size:
517
+ mask = nn.functional.interpolate(mask, size=target_size)
518
+
519
+ return self.projection(pixel_values * mask.bool())
520
+
521
+ def random_masking(self, pixel_values, noise=None):
522
+ """
523
+ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
524
+ noise.
525
+
526
+ Args:
527
+ pixel_values (`torch.LongTensor` of shape `(batch_size, num_channels, height, width)`)
528
+ noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*) which is
529
+ mainly used for testing purposes to control randomness and maintain the reproducibility
530
+ """
531
+ batch_size = pixel_values.shape[0]
532
+ # Tokens selected for masking at mask unit level
533
+ num_windows = math.prod(self.mask_spatial_shape)
534
+ len_keep = int(num_windows * (1 - self.mask_ratio))
535
+
536
+ if noise is None:
537
+ noise = torch.rand(batch_size, num_windows, device=pixel_values.device)
538
+
539
+ # Sort noise for each sample
540
+ ids_shuffle = torch.argsort(noise, dim=1)
541
+ # ascend: small is keep, large is remove
542
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
543
+
544
+ # Generate the binary mask: 1 is *keep*, 0 is *remove*
545
+ # Note this is opposite to original MAE
546
+ mask = torch.zeros([batch_size, num_windows], device=pixel_values.device)
547
+ mask[:, :len_keep] = 1
548
+ # Unshuffle to get the binary mask
549
+ mask = torch.gather(mask, dim=1, index=ids_restore)
550
+
551
+ return mask, ids_restore
552
+
553
+ def forward(
554
+ self,
555
+ pixel_values: torch.Tensor,
556
+ noise: Optional[torch.FloatTensor] = None,
557
+ interpolate_pos_encoding: bool = False,
558
+ ) -> torch.Tensor:
559
+ num_channels = pixel_values.shape[1]
560
+ height, width = pixel_values.shape[-2:]
561
+
562
+ if num_channels != self.num_channels:
563
+ raise ValueError(
564
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
565
+ f" Expected {self.num_channels} but got {num_channels}."
566
+ )
567
+
568
+ if not interpolate_pos_encoding:
569
+ if height != self.image_size[0] or width != self.image_size[1]:
570
+ raise ValueError(
571
+ f"Input image size ({height}*{width}) doesn't match model"
572
+ f" ({self.image_size[0]}*{self.image_size[1]})."
573
+ )
574
+
575
+ (mask, ids_restore) = self.random_masking(pixel_values, noise=noise) if self.is_mae else (None, None)
576
+
577
+ embeddings = self.masked_conv(pixel_values, mask)
578
+ embeddings = embeddings.flatten(2).transpose(2, 1)
579
+
580
+ return embeddings, mask, ids_restore
581
+
582
+
583
+ class HieraEmbeddings(nn.Module):
584
+ """
585
+ Construct position and patch embeddings.
586
+ """
587
+
588
+ def __init__(self, config: HieraConfig, is_mae: bool = False) -> None:
589
+ super().__init__()
590
+ self.patch_stride = config.patch_stride
591
+ self.tokens_spatial_shape = [i // s for i, s in zip(config.input_size, config.patch_stride)]
592
+ self.mask_spatial_shape = [i // s for i, s in zip(self.tokens_spatial_shape, config.masked_unit_size)]
593
+ self.num_tokens = math.prod(self.tokens_spatial_shape)
594
+ self.sep_pos_embed = config.sep_pos_embed
595
+ self.is_mae = is_mae
596
+
597
+ self.patch_embeddings = HieraPatchEmbeddings(config, is_mae=is_mae)
598
+
599
+ if self.sep_pos_embed:
600
+ self.position_embeddings_spatial = nn.Parameter(
601
+ torch.zeros(
602
+ 1,
603
+ self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2],
604
+ config.embed_dim,
605
+ )
606
+ )
607
+ self.position_embeddings_temporal = nn.Parameter(
608
+ torch.zeros(1, self.tokens_spatial_shape[0], config.embed_dim)
609
+ )
610
+ else:
611
+ self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_tokens, config.embed_dim))
612
+
613
+ def interpolate_pos_encoding(
614
+ self, embeddings: torch.Tensor, pos_embeds: torch.Tensor, height: int, width: int
615
+ ) -> torch.Tensor:
616
+ """
617
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
618
+ resolution images.
619
+
620
+ Adapted from:
621
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
622
+ """
623
+
624
+ num_patches = embeddings.shape[1]
625
+ num_positions = pos_embeds.shape[1]
626
+ if num_patches == num_positions and height == width:
627
+ return pos_embeds
628
+ dim = embeddings.shape[-1]
629
+ h0 = height // self.patch_stride[0] if not self.sep_pos_embed else height // self.patch_stride[1]
630
+ w0 = width // self.patch_stride[1] if not self.sep_pos_embed else width // self.patch_stride[2]
631
+ # we add a small number to avoid floating point error in the interpolation
632
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
633
+ h0, w0 = h0 + 0.1, w0 + 0.1
634
+ pos_embeds = pos_embeds.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
635
+ pos_embeds = pos_embeds.permute(0, 3, 1, 2)
636
+ pos_embeds = nn.functional.interpolate(
637
+ pos_embeds,
638
+ scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
639
+ mode="bicubic",
640
+ align_corners=False,
641
+ )
642
+ if int(h0) != pos_embeds.shape[-2] or int(w0) != pos_embeds.shape[-1]:
643
+ raise ValueError("The interpolated position encoding does not have the right size")
644
+ pos_embeds = pos_embeds.permute(0, 2, 3, 1).view(1, -1, dim)
645
+ return pos_embeds
646
+
647
+ def get_position_embedding(
648
+ self, embeddings: torch.Tensor, height: int, width: int, interpolate_pos_encoding: bool
649
+ ) -> torch.Tensor:
650
+ if self.sep_pos_embed:
651
+ spatial = self.position_embeddings_spatial
652
+ spatial = (
653
+ self.interpolate_pos_encoding(embeddings, spatial, height, width)
654
+ if interpolate_pos_encoding
655
+ else spatial
656
+ )
657
+ spatial = spatial.repeat(1, self.tokens_spatial_shape[0], 1)
658
+
659
+ temporal = torch.repeat_interleave(
660
+ self.position_embeddings_temporal,
661
+ self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2],
662
+ dim=1,
663
+ )
664
+
665
+ return spatial + temporal
666
+ else:
667
+ position_embeddings = self.position_embeddings
668
+ position_embeddings = (
669
+ self.interpolate_pos_encoding(embeddings, position_embeddings, height, width)
670
+ if interpolate_pos_encoding
671
+ else position_embeddings
672
+ )
673
+ return position_embeddings
674
+
675
+ def forward(
676
+ self,
677
+ pixel_values: torch.Tensor,
678
+ noise: Optional[torch.FloatTensor] = None,
679
+ interpolate_pos_encoding: bool = False,
680
+ ) -> torch.Tensor:
681
+ if len(self.tokens_spatial_shape) == 2:
682
+ batch_size, num_channels, height, width = pixel_values.shape
683
+ else:
684
+ batch_size, num_channels, depth, height, width = pixel_values.shape
685
+
686
+ embeddings, mask, ids_restore = self.patch_embeddings(
687
+ pixel_values, noise=noise, interpolate_pos_encoding=interpolate_pos_encoding
688
+ )
689
+
690
+ embeddings = embeddings + self.get_position_embedding(embeddings, height, width, interpolate_pos_encoding)
691
+
692
+ return embeddings, mask, ids_restore
693
+
694
+
695
+ class HieraMaskUnitAttention(nn.Module):
696
+ """
697
+ Computes either Mask Unit or Global Attention. Also is able to perform q pooling.
698
+
699
+ Note: this assumes the tokens have already been flattened and unrolled into mask units.
700
+ """
701
+
702
+ def __init__(
703
+ self,
704
+ dim: int,
705
+ dim_out: int,
706
+ num_heads: int,
707
+ query_stride: int = 1,
708
+ window_size: int = 0,
709
+ use_mask_unit_attn: bool = False,
710
+ ):
711
+ super().__init__()
712
+
713
+ self.dim = dim
714
+ self.dim_out = dim_out
715
+ self.num_heads = num_heads
716
+ self.query_stride = query_stride
717
+
718
+ self.head_dim = dim_out // num_heads
719
+ self.scale = (self.head_dim) ** -0.5
720
+
721
+ self.qkv = nn.Linear(dim, 3 * dim_out)
722
+ self.proj = nn.Linear(dim_out, dim_out)
723
+
724
+ self.window_size = window_size
725
+ self.use_mask_unit_attn = use_mask_unit_attn
726
+
727
+ def forward(
728
+ self,
729
+ hidden_states: torch.Tensor,
730
+ head_mask: Optional[torch.FloatTensor] = None,
731
+ output_attentions: bool = False,
732
+ ) -> torch.Tensor:
733
+ """Input should be of shape [batch, tokens, channels]."""
734
+ batch_size, seq_len, _ = hidden_states.shape
735
+
736
+ num_windows = 1
737
+ if self.use_mask_unit_attn:
738
+ num_windows = seq_len // (self.query_stride * self.window_size)
739
+
740
+ qkv = self.qkv(hidden_states)
741
+ qkv = qkv.reshape(batch_size, -1, num_windows, 3, self.num_heads, self.head_dim)
742
+ qkv = qkv.permute(3, 0, 4, 2, 1, 5)
743
+
744
+ query, key, value = qkv.unbind(0)
745
+
746
+ if self.query_stride > 1:
747
+ # Refer to Unroll to see how this performs a maxpool-Nd
748
+ query = query.view(batch_size, self.num_heads, num_windows, self.query_stride, -1, self.head_dim)
749
+ query = query.max(dim=3).values
750
+
751
+ attn_weights = (query * self.scale) @ key.transpose(-1, -2)
752
+ attn_weights = attn_weights.softmax(dim=-1)
753
+
754
+ # Mask heads if we want to
755
+ if head_mask is not None:
756
+ attn_weights = attn_weights * head_mask
757
+
758
+ attn_output = attn_weights @ value
759
+ attn_output = attn_output.transpose(1, 3).reshape(batch_size, -1, self.dim_out)
760
+ attn_output = self.proj(attn_output)
761
+
762
+ return (attn_output, attn_weights) if output_attentions else (attn_output, None)
763
+
764
+
765
+ # Copied from transformers.models.beit.modeling_beit.drop_path
766
+ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
767
+ """
768
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
769
+
770
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
771
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
772
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
773
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
774
+ argument.
775
+ """
776
+ if drop_prob == 0.0 or not training:
777
+ return input
778
+ keep_prob = 1 - drop_prob
779
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
780
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
781
+ random_tensor.floor_() # binarize
782
+ output = input.div(keep_prob) * random_tensor
783
+ return output
784
+
785
+
786
+ # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Hiera
787
+ class HieraDropPath(nn.Module):
788
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
789
+
790
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
791
+ super().__init__()
792
+ self.drop_prob = drop_prob
793
+
794
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
795
+ return drop_path(hidden_states, self.drop_prob, self.training)
796
+
797
+ def extra_repr(self) -> str:
798
+ return "p={}".format(self.drop_prob)
799
+
800
+
801
+ class HieraMlp(nn.Module):
802
+ def __init__(self, config, dim: int):
803
+ super().__init__()
804
+ self.config = config
805
+ self.activation_fn = ACT2FN[config.hidden_act]
806
+ self.fc1 = nn.Linear(dim, int(dim * config.mlp_ratio))
807
+ self.fc2 = nn.Linear(int(dim * config.mlp_ratio), dim)
808
+
809
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
810
+ hidden_states = self.fc1(hidden_states)
811
+ hidden_states = self.activation_fn(hidden_states)
812
+ hidden_states = self.fc2(hidden_states)
813
+ return hidden_states
814
+
815
+
816
+ class HieraLayer(nn.Module):
817
+ def __init__(
818
+ self,
819
+ config,
820
+ dim: int,
821
+ dim_out: int,
822
+ num_heads: int,
823
+ drop_path: float = 0.0,
824
+ query_stride: int = 1,
825
+ window_size: int = 0,
826
+ use_mask_unit_attn: bool = False,
827
+ ):
828
+ super().__init__()
829
+
830
+ self.dim = dim
831
+ self.dim_out = dim_out
832
+ self.query_stride = query_stride
833
+
834
+ self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
835
+ self.attn = HieraMaskUnitAttention(dim, dim_out, num_heads, query_stride, window_size, use_mask_unit_attn)
836
+
837
+ self.layernorm_after = nn.LayerNorm(dim_out, eps=config.layer_norm_eps)
838
+ self.mlp = HieraMlp(config, dim_out)
839
+
840
+ self.drop_path = HieraDropPath(drop_path) if drop_path > 0 else nn.Identity()
841
+ if dim != dim_out:
842
+ self.proj = nn.Linear(dim, dim_out)
843
+
844
+ def forward(
845
+ self,
846
+ hidden_states: torch.Tensor,
847
+ head_mask: Optional[torch.FloatTensor] = None,
848
+ output_attentions: bool = False,
849
+ ) -> torch.Tensor:
850
+ batch_size, seq_len, _ = hidden_states.shape
851
+ # Attention + Q Pooling
852
+ hidden_states_norm = self.layernorm_before(hidden_states)
853
+ if self.dim != self.dim_out:
854
+ hidden_states = self.proj(hidden_states_norm)
855
+ # Refer to `HieraUnroll` to see how this performs a maxpool-Nd
856
+ hidden_states = hidden_states.view(batch_size, self.query_stride, -1, self.dim_out).max(dim=1).values
857
+
858
+ (hidden_states_norm, attn_weights) = self.attn(
859
+ hidden_states_norm, head_mask, output_attentions=output_attentions
860
+ )
861
+ hidden_states = hidden_states + self.drop_path(hidden_states_norm)
862
+
863
+ residual = hidden_states
864
+ hidden_states = self.layernorm_after(hidden_states)
865
+ hidden_states = self.mlp(hidden_states)
866
+ hidden_states = residual + self.drop_path(hidden_states)
867
+
868
+ return (hidden_states, attn_weights)
869
+
870
+
871
+ class HieraStage(nn.Module):
872
+ def __init__(
873
+ self,
874
+ config,
875
+ depth: int,
876
+ dim: int,
877
+ dim_out: int,
878
+ num_heads: int,
879
+ drop_path: List[float],
880
+ query_stride: List[int],
881
+ window_size: int,
882
+ use_mask_unit_attn: bool,
883
+ stage_num: Optional[int] = None,
884
+ ) -> None:
885
+ super().__init__()
886
+ # we need to know if the previous stage used masked attention
887
+ # mask unit or global attention.
888
+ # lag by 1 layer, so that global attention,
889
+ # applied post pooling on lower resolution
890
+ previous_stage_used_masked_attention = False
891
+ if stage_num is not None:
892
+ previous_stage_used_masked_attention = config.masked_unit_attention[stage_num - 1 if stage_num > 0 else 0]
893
+ self.layers = nn.ModuleList(
894
+ [
895
+ HieraLayer(
896
+ config=config,
897
+ dim=dim if i == 0 else dim_out,
898
+ dim_out=dim_out,
899
+ num_heads=num_heads,
900
+ drop_path=drop_path[i],
901
+ query_stride=query_stride[i],
902
+ window_size=window_size,
903
+ use_mask_unit_attn=use_mask_unit_attn or (previous_stage_used_masked_attention and i == 0),
904
+ )
905
+ for i in range(depth)
906
+ ]
907
+ )
908
+
909
+ def forward(
910
+ self, hidden_states: torch.Tensor, head_mask: Optional[torch.FloatTensor], output_attentions: bool = False
911
+ ) -> torch.Tensor:
912
+ for i, layer_module in enumerate(self.layers):
913
+ layer_head_mask = head_mask[i] if head_mask is not None else None
914
+ (hidden_states, attn_weights) = layer_module(
915
+ hidden_states, layer_head_mask, output_attentions=output_attentions
916
+ )
917
+
918
+ return hidden_states, attn_weights
919
+
920
+
921
+ def undo_windowing(hidden_states: torch.Tensor, shape: List[int], mask_unit_shape: List[int]) -> torch.Tensor:
922
+ """
923
+ Restore spatial organization by undoing windowed organization of mask units.
924
+ """
925
+ num_dims = len(shape)
926
+ batch_size, hidden_size = hidden_states.shape[0], hidden_states.shape[-1]
927
+ # From: [batch_size, num_mask_unit_height*num_#mask_unit_wdith, mask_unit_height, mask_unit_width, hidden_size]
928
+ # To: [batch_size, num_mask_unit_height, num_mask_unit_width, mask_unit_height, mask_unit_width, hidden_size]
929
+ num_mask_units = [s // mu for s, mu in zip(shape, mask_unit_shape)]
930
+ hidden_states = hidden_states.view(batch_size, *num_mask_units, *mask_unit_shape, hidden_size)
931
+
932
+ # From: [batch_size, num_mask_unit_height, num_mask_unit_width, mask_unit_height, mask_unit_width, hidden_size]
933
+ # To: [batch_size, num_mask_unit_height*mask_unit_height, num_mask_unit_width*mask_unit_width, hidden_size]
934
+ permute = (
935
+ [0]
936
+ + sum(
937
+ [list(p) for p in zip(range(1, 1 + num_dims), range(1 + num_dims, 1 + 2 * num_dims))],
938
+ [],
939
+ )
940
+ + [len(hidden_states.shape) - 1]
941
+ )
942
+ hidden_states = hidden_states.permute(permute).reshape(batch_size, *shape, hidden_size)
943
+
944
+ return hidden_states
945
+
946
+
947
+ class HieraEncoder(nn.Module):
948
+ def __init__(self, config: HieraConfig) -> None:
949
+ super().__init__()
950
+ self.config = config
951
+
952
+ # stochastic depth decay rule
953
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
954
+ # query strides rule
955
+ stage_ends = [sum(config.depths[:i]) - 1 for i in range(1, len(config.depths) + 1)]
956
+ query_pool_layer = [stage_end + 1 for stage_end in stage_ends[: config.num_query_pool]]
957
+ query_strides = [
958
+ math.prod(config.query_stride) if i in query_pool_layer else 1 for i in range(sum(config.depths))
959
+ ]
960
+
961
+ # Transformer blocks
962
+ self.stages = nn.ModuleList()
963
+ embed_dim = config.embed_dim
964
+
965
+ for idx_stage, depth in enumerate(config.depths):
966
+ dim_out = int(config.embed_dim * config.embed_dim_multiplier**idx_stage)
967
+
968
+ stage = HieraStage(
969
+ config=config,
970
+ depth=depth,
971
+ dim=embed_dim,
972
+ dim_out=dim_out,
973
+ num_heads=int(config.initial_num_heads * config.num_head_multiplier**idx_stage),
974
+ drop_path=dpr[sum(config.depths[:idx_stage]) : sum(config.depths[: idx_stage + 1])],
975
+ query_stride=query_strides[sum(config.depths[:idx_stage]) : sum(config.depths[: idx_stage + 1])],
976
+ window_size=int(math.prod(config.masked_unit_size) * math.prod(config.query_stride) ** -idx_stage),
977
+ use_mask_unit_attn=config.masked_unit_attention[idx_stage],
978
+ stage_num=idx_stage,
979
+ )
980
+
981
+ embed_dim = dim_out
982
+ self.stages.append(stage)
983
+
984
+ # Setting reroll schedule
985
+ # The first stage has to reverse everything
986
+ # The next stage has to reverse all but the first unroll, etc.
987
+ stage_size = [i // s for i, s in zip(config.input_size, config.patch_stride)]
988
+ unroll_schedule = [config.query_stride] * len(config.depths[:-1])
989
+
990
+ self.schedule = {}
991
+ for idx_stage in range(len(config.depths)):
992
+ self.schedule[idx_stage] = unroll_schedule, stage_size
993
+ if idx_stage < config.num_query_pool:
994
+ stage_size = [i // s for i, s in zip(stage_size, config.query_stride)]
995
+ unroll_schedule = unroll_schedule[1:]
996
+
997
+ self.gradient_checkpointing = False
998
+
999
+ def reroll(
1000
+ self, hidden_states: torch.Tensor, stage_idx: int, mask: Optional[torch.BoolTensor] = None
1001
+ ) -> torch.Tensor:
1002
+ """
1003
+ Roll the given tensor back up to spatial order assuming it's from the given block.
1004
+
1005
+ If no mask is provided returns:
1006
+ - [batch_size, height, width, hidden_size] for 2d
1007
+ - [batch_size, frames, height, width, hidden_size] for 3d
1008
+ If a mask is provided returns:
1009
+ - [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size] for 2d
1010
+ """
1011
+ schedule, size = self.schedule[stage_idx]
1012
+ batch_size, seq_len, hidden_size = hidden_states.shape
1013
+
1014
+ num_dim = len(size)
1015
+ mask_unit_shape = [1] * num_dim
1016
+
1017
+ for strides in schedule:
1018
+ # Extract the current patch from seq_len
1019
+ hidden_states = hidden_states.view(
1020
+ batch_size, *strides, seq_len // math.prod(strides), *mask_unit_shape, hidden_size
1021
+ )
1022
+
1023
+ # Move that patch into the current MU
1024
+ # Example in 2d:
1025
+ # Input: [batch_size, stride, stride, seq_len//(stride*stride), mask_unit_height, mask_unit_width, hidden_size]
1026
+ # Output: [batch_size, seq_len//(stride*stride), stride, mask_unit_height, stride, mask_unit_width, hidden_size]
1027
+ L = len(hidden_states.shape)
1028
+ permute = (
1029
+ [0, 1 + num_dim]
1030
+ + sum(
1031
+ [list(p) for p in zip(range(1, 1 + num_dim), range(1 + num_dim + 1, L - 1))],
1032
+ [],
1033
+ )
1034
+ + [L - 1]
1035
+ )
1036
+ hidden_states = hidden_states.permute(permute)
1037
+
1038
+ # Reshape to [batch_size, seq_len//(stride*stride), *mask_units, hidden_size]
1039
+ for i in range(num_dim):
1040
+ mask_unit_shape[i] *= strides[i]
1041
+ hidden_states = hidden_states.reshape(batch_size, -1, *mask_unit_shape, hidden_size)
1042
+ seq_len = hidden_states.shape[1]
1043
+
1044
+ # Current shape (e.g., 2d: [batch_size, #num_mask_units_height*#num_mask_units_width, mask_unit_height, mask_unit_width, hidden_size])
1045
+ hidden_states = hidden_states.view(batch_size, seq_len, *mask_unit_shape, hidden_size)
1046
+
1047
+ # If masked, return [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size]
1048
+ if mask is not None:
1049
+ return hidden_states
1050
+
1051
+ # If not masked, we can return [batch_size, height, width, hidden_size]
1052
+ hidden_states = undo_windowing(hidden_states, size, mask_unit_shape)
1053
+
1054
+ return hidden_states
1055
+
1056
+ def forward(
1057
+ self,
1058
+ hidden_states: torch.Tensor,
1059
+ mask: Optional[torch.BoolTensor] = None,
1060
+ head_mask: Optional[torch.FloatTensor] = None,
1061
+ output_attentions: bool = False,
1062
+ output_hidden_states: bool = False,
1063
+ return_dict: bool = True,
1064
+ ) -> Union[tuple, BaseModelOutput]:
1065
+ all_hidden_states = () if output_hidden_states else None
1066
+ all_reshaped_hidden_states = () if output_hidden_states else None
1067
+ all_self_attentions = () if output_attentions else None
1068
+
1069
+ if output_hidden_states:
1070
+ all_hidden_states = all_hidden_states + (hidden_states,)
1071
+ reshaped_hidden_states = self.reroll(hidden_states, stage_idx=0, mask=mask)
1072
+ all_reshaped_hidden_states = all_reshaped_hidden_states + (reshaped_hidden_states,)
1073
+
1074
+ for i, stage_module in enumerate(self.stages):
1075
+ layer_head_mask = head_mask[i] if head_mask is not None else None
1076
+
1077
+ if self.gradient_checkpointing and self.training:
1078
+ layer_outputs = self._gradient_checkpointing_func(
1079
+ stage_module.__call__, hidden_states, layer_head_mask, output_attentions
1080
+ )
1081
+ else:
1082
+ layer_outputs = stage_module(hidden_states, layer_head_mask, output_attentions)
1083
+
1084
+ hidden_states = layer_outputs[0]
1085
+
1086
+ if output_attentions:
1087
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
1088
+
1089
+ if output_hidden_states:
1090
+ all_hidden_states = all_hidden_states + (hidden_states,)
1091
+ reshaped_hidden_states = self.reroll(hidden_states, stage_idx=i, mask=mask)
1092
+ all_reshaped_hidden_states = all_reshaped_hidden_states + (reshaped_hidden_states,)
1093
+
1094
+ if not return_dict:
1095
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
1096
+ return HieraEncoderOutput(
1097
+ last_hidden_state=hidden_states,
1098
+ hidden_states=all_hidden_states,
1099
+ attentions=all_self_attentions,
1100
+ reshaped_hidden_states=all_reshaped_hidden_states,
1101
+ )
1102
+
1103
+
1104
+ def unroll(hidden_states: torch.Tensor, size: List[int], schedule: List[List[int]]) -> torch.Tensor:
1105
+ """
1106
+ Reorders the tokens such that patches are contiguous in memory.
1107
+ E.g., given [batch_size, (height, width), hidden_size] and stride of (stride, stride), this will re-order the tokens as
1108
+ [batch_size, (stride, stride, height // stride, width // stride), hidden_size]
1109
+
1110
+ This allows operations like Max2d to be computed as x.view(batch_size, stride*stride, -1, hidden_size).max(dim=1).
1111
+ Not only is this faster, but it also makes it easy to support inputs of arbitrary
1112
+ dimensions in addition to patch-wise sparsity.
1113
+
1114
+ Performing this operation multiple times in sequence puts entire windows as contiguous
1115
+ in memory. For instance, if you applied the stride (2, 2) 3 times, entire windows of
1116
+ size 8x8 would be contiguous in memory, allowing operations like mask unit attention
1117
+ computed easily and efficiently, while also allowing max to be applied sequentially.
1118
+
1119
+ Note: This means that intermediate values of the model are not in height x width order, so they
1120
+ need to be re-rolled if you want to use the intermediate values as a height x width feature map.
1121
+ The last block of the network is fine though, since by then the strides are all consumed.
1122
+ """
1123
+ batch_size, _, hidden_size = hidden_states.shape
1124
+
1125
+ current_size = size
1126
+ hidden_states = hidden_states.view(*([batch_size] + current_size + [hidden_size]))
1127
+
1128
+ for strides in schedule:
1129
+ # Move patches with the given strides to the batch dimension
1130
+
1131
+ # Create a view of the tensor with the patch stride as separate dims
1132
+ # For example in 2d: [batch_size, height // stride, stride, width // stride, stride, C]
1133
+ current_size = [i // s for i, s in zip(current_size, strides)]
1134
+ # initialize new_shape with [height // stride, stride, width // stride, stride]
1135
+ new_shape = [item for pair in zip(current_size, strides) for item in pair]
1136
+ # add batch_size and hidden_size to new_shape
1137
+ new_shape = [batch_size] + new_shape + [hidden_size]
1138
+ hidden_states = hidden_states.view(new_shape)
1139
+
1140
+ # Move the patch stride into the batch dimension
1141
+ # For example in 2d: [batch_size, stride, stride, height // stride, width // stride, hidden_size]
1142
+ num_dims = len(new_shape)
1143
+ permute = [0] + list(range(2, num_dims - 1, 2)) + list(range(1, num_dims - 1, 2)) + [num_dims - 1]
1144
+ hidden_states = hidden_states.permute(permute)
1145
+
1146
+ # Now finally flatten the relevant dims into the batch dimension
1147
+ hidden_states = hidden_states.flatten(0, len(strides))
1148
+ batch_size *= math.prod(strides)
1149
+
1150
+ hidden_states = hidden_states.reshape(-1, math.prod(size), hidden_size)
1151
+ return hidden_states
1152
+
1153
+
1154
+ class HieraPreTrainedModel(PreTrainedModel):
1155
+ """
1156
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1157
+ models.
1158
+ """
1159
+
1160
+ config_class = HieraConfig
1161
+ base_model_prefix = "hiera"
1162
+ main_input_name = "pixel_values"
1163
+ supports_gradient_checkpointing = True
1164
+
1165
+ def _init_weights(self, module) -> None:
1166
+ """Initialize the weights"""
1167
+ std = self.config.initializer_range
1168
+
1169
+ if isinstance(module, HieraEmbeddings):
1170
+ if self.config.sep_pos_embed:
1171
+ nn.init.trunc_normal_(module.position_embeddings_spatial, std=std)
1172
+ nn.init.trunc_normal_(module.position_embeddings_temporal, std=std)
1173
+ else:
1174
+ nn.init.trunc_normal_(module.position_embeddings, std=std)
1175
+
1176
+ elif isinstance(module, HieraDecoder):
1177
+ nn.init.trunc_normal_(module.mask_token, std=std)
1178
+ nn.init.trunc_normal_(module.decoder_position_embeddings, std=std)
1179
+
1180
+ elif isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
1181
+ nn.init.trunc_normal_(module.weight, std=std)
1182
+ if module.bias is not None:
1183
+ nn.init.constant_(module.bias, std)
1184
+
1185
+ elif isinstance(module, nn.LayerNorm):
1186
+ nn.init.constant_(module.bias, std)
1187
+ nn.init.constant_(module.weight, self.config.layer_norm_init)
1188
+
1189
+
1190
+ HIERA_START_DOCSTRING = r"""
1191
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
1192
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
1193
+ behavior.
1194
+
1195
+ Parameters:
1196
+ config ([`HieraConfig`]): Model configuration class with all the parameters of the model.
1197
+ Initializing with a config file does not load the weights associated with the model, only the
1198
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1199
+ """
1200
+
1201
+ HIERA_INPUTS_DOCSTRING = r"""
1202
+ Args:
1203
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1204
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`BitImageProcessor.__call__`]
1205
+ for details.
1206
+
1207
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1208
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
1209
+
1210
+ - 1 indicates the head is **not masked**,
1211
+ - 0 indicates the head is **masked**.
1212
+
1213
+ output_attentions (`bool`, *optional*):
1214
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1215
+ tensors for more detail.
1216
+ output_hidden_states (`bool`, *optional*):
1217
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1218
+ more detail.
1219
+ interpolate_pos_encoding (`bool`, *optional*):
1220
+ Whether to interpolate the pre-trained position encodings.
1221
+ return_dict (`bool`, *optional*):
1222
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1223
+ """
1224
+
1225
+
1226
+ class HieraPooler(nn.Module):
1227
+ def __init__(self, config: HieraConfig):
1228
+ super().__init__()
1229
+ num_features = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
1230
+ self.layernorm = nn.LayerNorm(num_features, eps=config.layer_norm_eps)
1231
+ self.pooler = nn.AdaptiveAvgPool1d(1)
1232
+
1233
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1234
+ hidden_states = hidden_states.transpose(1, 2)
1235
+ pooled_output = self.pooler(hidden_states)
1236
+ pooled_output = torch.flatten(pooled_output, 1)
1237
+ pooled_output = self.layernorm(pooled_output)
1238
+ return pooled_output
1239
+
1240
+
1241
+ @add_start_docstrings(
1242
+ "The bare Hiera Model transformer outputting raw hidden-states without any specific head on top.",
1243
+ HIERA_START_DOCSTRING,
1244
+ """
1245
+ add_pooling_layer (`bool`, *optional*, defaults to `True`):
1246
+ Whether or not to apply pooling layer.
1247
+ is_mae (`bool`, *optional*, defaults to `False`):
1248
+ Whether or not to run the model on MAE mode.
1249
+ """,
1250
+ )
1251
+ class HieraModel(HieraPreTrainedModel):
1252
+ def __init__(self, config: HieraConfig, add_pooling_layer: bool = True, is_mae: bool = False):
1253
+ super().__init__(config)
1254
+ self.num_features = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
1255
+
1256
+ self.embeddings = HieraEmbeddings(config, is_mae=is_mae)
1257
+ self.encoder = HieraEncoder(config)
1258
+
1259
+ self.unroll_size = [i // s for i, s in zip(config.input_size, config.patch_stride)]
1260
+ self.unroll_schedule = [config.query_stride] * len(config.depths[:-1])
1261
+
1262
+ self.pooler = HieraPooler(config) if add_pooling_layer else None
1263
+
1264
+ # Initialize weights and apply final processing
1265
+ self.post_init()
1266
+
1267
+ def get_input_embeddings(self) -> HieraPatchEmbeddings:
1268
+ return self.embeddings.patch_embeddings
1269
+
1270
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
1271
+ """
1272
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1273
+ class PreTrainedModel
1274
+ """
1275
+ for layer, heads in heads_to_prune.items():
1276
+ self.encoder.layer[layer].attention.prune_heads(heads)
1277
+
1278
+ @add_start_docstrings_to_model_forward(HIERA_INPUTS_DOCSTRING)
1279
+ @add_code_sample_docstrings(
1280
+ checkpoint=_CHECKPOINT_FOR_DOC,
1281
+ output_type=HieraModelOutput,
1282
+ config_class=_CONFIG_FOR_DOC,
1283
+ modality="vision",
1284
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
1285
+ )
1286
+ def forward(
1287
+ self,
1288
+ pixel_values: Optional[torch.Tensor] = None,
1289
+ noise: Optional[torch.FloatTensor] = None,
1290
+ head_mask: Optional[torch.Tensor] = None,
1291
+ output_attentions: Optional[bool] = None,
1292
+ output_hidden_states: Optional[bool] = None,
1293
+ interpolate_pos_encoding: Optional[bool] = None,
1294
+ return_dict: Optional[bool] = None,
1295
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1296
+ r"""
1297
+ noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*) which is
1298
+ mainly used for testing purposes to control randomness and maintain the reproducibility
1299
+ when is_mae is set to True.
1300
+ """
1301
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1302
+ output_hidden_states = (
1303
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1304
+ )
1305
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1306
+
1307
+ if pixel_values is None:
1308
+ raise ValueError("You have to specify pixel_values")
1309
+
1310
+ # Prepare head mask if needed
1311
+ # 1.0 in head_mask indicate we keep the head
1312
+ # attention_probs has shape bsz x n_heads x N x N
1313
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1314
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1315
+ head_mask = self.get_head_mask(head_mask, len(self.config.depths))
1316
+
1317
+ # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
1318
+ expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
1319
+ if pixel_values.dtype != expected_dtype:
1320
+ pixel_values = pixel_values.to(expected_dtype)
1321
+
1322
+ embedding_output, mask, ids_restore = self.embeddings(
1323
+ pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, noise=noise
1324
+ )
1325
+
1326
+ hidden_states = unroll(embedding_output, self.unroll_size, self.unroll_schedule)
1327
+
1328
+ # Discard masked tokens if mask is provided
1329
+ if mask is not None:
1330
+ mask_unit_area = math.prod(self.config.masked_unit_size)
1331
+ batch_size, _, hidden_size = hidden_states.shape
1332
+ positions = mask.unsqueeze(-1).tile(1, mask_unit_area, hidden_size)
1333
+ positions = positions.bool()
1334
+ hidden_states = hidden_states[positions]
1335
+ hidden_states = hidden_states.view(batch_size, -1, hidden_size)
1336
+
1337
+ encoder_outputs = self.encoder(
1338
+ hidden_states,
1339
+ mask=mask,
1340
+ head_mask=head_mask,
1341
+ output_attentions=output_attentions,
1342
+ output_hidden_states=output_hidden_states,
1343
+ return_dict=return_dict,
1344
+ )
1345
+ sequence_output = encoder_outputs[0]
1346
+ pooled_output = None
1347
+ if self.pooler is not None:
1348
+ pooled_output = self.pooler(sequence_output)
1349
+
1350
+ if not return_dict:
1351
+ head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
1352
+ head_outputs = head_outputs + (mask, ids_restore) if mask is not None else head_outputs
1353
+ return head_outputs + encoder_outputs[1:]
1354
+
1355
+ return HieraModelOutput(
1356
+ last_hidden_state=sequence_output,
1357
+ pooler_output=pooled_output,
1358
+ mask=mask,
1359
+ ids_restore=ids_restore,
1360
+ hidden_states=encoder_outputs.hidden_states,
1361
+ attentions=encoder_outputs.attentions,
1362
+ reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
1363
+ )
1364
+
1365
+
1366
+ class HieraDecoder(nn.Module):
1367
+ def __init__(self, config: HieraConfig):
1368
+ super().__init__()
1369
+ num_features = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
1370
+ self.tokens_spatial_shape = [i // s for i, s in zip(config.input_size, config.patch_stride)]
1371
+ self.tokens_spatial_shape_final = [
1372
+ i // s ** (config.num_query_pool) for i, s in zip(self.tokens_spatial_shape, config.query_stride)
1373
+ ]
1374
+ self.mask_unit_spatial_shape_final = [
1375
+ i // s ** (config.num_query_pool) for i, s in zip(config.masked_unit_size, config.query_stride)
1376
+ ]
1377
+
1378
+ self.decoder_embeddings = nn.Linear(num_features, config.decoder_embed_dim)
1379
+
1380
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_embed_dim))
1381
+
1382
+ self.decoder_position_embeddings = nn.Parameter(
1383
+ torch.zeros(1, math.prod(self.tokens_spatial_shape_final), config.decoder_embed_dim)
1384
+ )
1385
+
1386
+ self.decoder_block = HieraStage(
1387
+ config=config,
1388
+ dim=config.decoder_embed_dim,
1389
+ dim_out=config.decoder_embed_dim,
1390
+ num_heads=config.decoder_num_heads,
1391
+ depth=config.decoder_depth,
1392
+ use_mask_unit_attn=False,
1393
+ drop_path=[0.0] * config.decoder_depth,
1394
+ query_stride=[1] * config.decoder_depth,
1395
+ window_size=0,
1396
+ )
1397
+
1398
+ self.decoder_norm = nn.LayerNorm(config.decoder_embed_dim, eps=config.layer_norm_eps)
1399
+
1400
+ # patch stride of prediction
1401
+ self.pred_stride = config.patch_stride[-1] * (config.query_stride[-1] ** config.num_query_pool)
1402
+ pred_dim = (self.pred_stride ** len(config.query_stride)) * config.num_channels
1403
+
1404
+ self.decoder_pred = nn.Linear(config.decoder_embed_dim, pred_dim)
1405
+
1406
+ def forward(
1407
+ self,
1408
+ encoder_hidden_states: torch.Tensor,
1409
+ mask: torch.BoolTensor,
1410
+ head_mask: Optional[torch.Tensor] = None,
1411
+ output_attentions: bool = False,
1412
+ ) -> torch.Tensor:
1413
+ # Embed tokens
1414
+ hidden_states = self.decoder_embeddings(encoder_hidden_states)
1415
+
1416
+ # Combine visible and mask tokens
1417
+
1418
+ # hidden_states : [batch_size, num_mask_units_visible, *mask_unit_spatial_shape_final, decoder_embed_dim]
1419
+ # mask: [batch_size, num_mask_units]
1420
+ decoder_hidden_states = torch.zeros(
1421
+ *mask.shape, *hidden_states.shape[2:], device=hidden_states.device, dtype=hidden_states.dtype
1422
+ )
1423
+ mask_tokens = self.mask_token.view((1,) * (len(mask.shape) + len(hidden_states.shape[2:-1])) + (-1,))
1424
+ new_mask_shape = mask.shape + (1,) * len(hidden_states.shape[2:])
1425
+ mask = mask.reshape(new_mask_shape)
1426
+ expand_shape = (-1,) * 2 + hidden_states.shape[2:]
1427
+ mask = mask.expand(expand_shape)
1428
+ decoder_hidden_states[mask.bool()] = hidden_states.flatten()
1429
+ decoder_hidden_states = (1 - mask) * mask_tokens + mask * decoder_hidden_states
1430
+
1431
+ # Get back spatial order
1432
+ hidden_states = undo_windowing(
1433
+ decoder_hidden_states,
1434
+ self.tokens_spatial_shape_final,
1435
+ self.mask_unit_spatial_shape_final,
1436
+ )
1437
+ mask = undo_windowing(
1438
+ mask[..., 0:1],
1439
+ self.tokens_spatial_shape_final,
1440
+ self.mask_unit_spatial_shape_final,
1441
+ )
1442
+
1443
+ # Flatten
1444
+ hidden_states = hidden_states.reshape(hidden_states.shape[0], -1, hidden_states.shape[-1])
1445
+ mask = mask.view(hidden_states.shape[0], -1)
1446
+
1447
+ # Add pos embed
1448
+ hidden_states = hidden_states + self.decoder_position_embeddings
1449
+
1450
+ # Apply decoder blocks
1451
+ hidden_states, attn_weights = self.decoder_block(
1452
+ hidden_states, head_mask=head_mask, output_attentions=output_attentions
1453
+ )
1454
+ hidden_states = self.decoder_norm(hidden_states)
1455
+
1456
+ # Predictor projection
1457
+ hidden_states = self.decoder_pred(hidden_states)
1458
+
1459
+ return hidden_states, mask
1460
+
1461
+
1462
+ class HieraMultiScaleHead(nn.Module):
1463
+ def __init__(self, config: HieraConfig):
1464
+ super().__init__()
1465
+ self.mask_unit_spatial_shape_final = [
1466
+ i // s ** (config.num_query_pool) for i, s in zip(config.masked_unit_size, config.query_stride)
1467
+ ]
1468
+ self.stage_dimensions = [
1469
+ int(config.embed_dim * config.embed_dim_multiplier**i) for i in range(len(config.depths))
1470
+ ]
1471
+ current_masked_unit_size = config.masked_unit_size
1472
+ self.multi_scale_fusion_heads = nn.ModuleList()
1473
+
1474
+ for idx in range(config.num_query_pool):
1475
+ kernel = [i // s for i, s in zip(current_masked_unit_size, self.mask_unit_spatial_shape_final)]
1476
+ current_masked_unit_size = [i // s for i, s in zip(current_masked_unit_size, config.query_stride)]
1477
+ self.multi_scale_fusion_heads.append(
1478
+ conv_nd(len(config.query_stride))(
1479
+ self.stage_dimensions[idx],
1480
+ self.stage_dimensions[-1],
1481
+ kernel_size=kernel,
1482
+ stride=kernel,
1483
+ )
1484
+ )
1485
+ self.multi_scale_fusion_heads.append(nn.Identity())
1486
+
1487
+ def apply_fusion_head(self, head: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
1488
+ if isinstance(head, nn.Identity):
1489
+ return hidden_states
1490
+
1491
+ batch_size, num_mask_units = hidden_states.shape[0:2]
1492
+ # From: [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size]
1493
+ # To: head([batch_size * num_mask_units, hidden_size, mask_unit_height, mask_unit_width])
1494
+ permute = [0] + [len(hidden_states.shape) - 2] + list(range(1, len(hidden_states.shape) - 2))
1495
+ hidden_states = hidden_states.reshape(batch_size * num_mask_units, *hidden_states.shape[2:])
1496
+ hidden_states = hidden_states.permute(permute)
1497
+ hidden_states = head(hidden_states)
1498
+
1499
+ # Restore original layout
1500
+ permute = [0] + list(range(2, len(hidden_states.shape))) + [1]
1501
+ hidden_states = hidden_states.permute(permute)
1502
+ hidden_states = hidden_states.reshape(
1503
+ batch_size, num_mask_units, *hidden_states.shape[1:-1], hidden_states.shape[-1]
1504
+ )
1505
+ return hidden_states
1506
+
1507
+ def forward(self, feature_maps: List[torch.Tensor]) -> torch.Tensor:
1508
+ # Multi-scale fusion
1509
+ hidden_states = 0.0
1510
+ for head, feature_map in zip(self.multi_scale_fusion_heads, feature_maps):
1511
+ hidden_states = hidden_states + self.apply_fusion_head(head, feature_map)
1512
+
1513
+ return hidden_states
1514
+
1515
+
1516
+ @add_start_docstrings(
1517
+ """The Hiera Model transformer with the decoder on top for self-supervised pre-training.
1518
+
1519
+ <Tip>
1520
+
1521
+ Note that we provide a script to pre-train this model on custom data in our [examples
1522
+ directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
1523
+
1524
+ </Tip>
1525
+ """,
1526
+ HIERA_START_DOCSTRING,
1527
+ )
1528
+ class HieraForPreTraining(HieraPreTrainedModel):
1529
+ def __init__(self, config: HieraConfig) -> None:
1530
+ super().__init__(config)
1531
+ # Encoder
1532
+ self.hiera = HieraModel(config, add_pooling_layer=False, is_mae=True)
1533
+ self.encoder_norm = nn.LayerNorm(self.hiera.num_features, eps=config.layer_norm_eps)
1534
+ # Multi-scale fusion heads
1535
+ self.multiscale_fusion = HieraMultiScaleHead(config)
1536
+ # Decoder
1537
+ self.decoder = HieraDecoder(config)
1538
+ self.pred_stride = self.decoder.pred_stride
1539
+
1540
+ # Initialize weights and apply final processing
1541
+ self.post_init()
1542
+
1543
+ def get_pixel_label_2d(self, pixel_values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
1544
+ # mask (boolean tensor): True means *masked*
1545
+ pixel_values = pixel_values.permute(0, 2, 3, 1)
1546
+
1547
+ size = self.pred_stride
1548
+ label = pixel_values.unfold(1, size, size).unfold(2, size, size)
1549
+ label = label.flatten(1, 2).flatten(2)
1550
+ label = label[mask.bool()]
1551
+ if self.config.norm_pix_loss:
1552
+ mean = label.mean(dim=-1, keepdim=True)
1553
+ var = label.var(dim=-1, keepdim=True)
1554
+ label = (label - mean) / (var + 1.0e-6) ** 0.5
1555
+
1556
+ return label
1557
+
1558
+ def get_pixel_label_3d(self, pixel_values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
1559
+ # mask (boolean tensor): True means *masked*
1560
+ pixel_values = pixel_values[:, :, :: self.patch_stride[0], :, :]
1561
+
1562
+ size = self.pred_stride
1563
+ label = pixel_values.unfold(3, size, size).unfold(4, size, size)
1564
+ # Different from 2D
1565
+ label = label.permute(0, 2, 3, 4, 5, 6, 1)
1566
+ label = label.flatten(1, 3).flatten(2)
1567
+ label = label[mask.bool()]
1568
+ if self.config.norm_pix_loss:
1569
+ mean = label.mean(dim=-1, keepdim=True)
1570
+ var = label.var(dim=-1, keepdim=True)
1571
+ label = (label - mean) / (var + 1.0e-6) ** 0.5
1572
+
1573
+ return label
1574
+
1575
+ def forward_loss(self, pixel_values: torch.Tensor, logits: torch.Tensor, mask: torch.BoolTensor):
1576
+ # We invert the mask such that 1.0 is *masked*
1577
+ mask = 1 - mask
1578
+ if len(self.config.query_stride) == 2:
1579
+ label = self.get_pixel_label_2d(pixel_values, mask)
1580
+ elif len(self.config.query_stride) == 3:
1581
+ label = self.get_pixel_label_3d(pixel_values, mask)
1582
+ else:
1583
+ raise NotImplementedError("Only images and videos are supported")
1584
+
1585
+ logits = logits[mask.bool()]
1586
+ loss = (logits - label) ** 2
1587
+ loss = loss.mean()
1588
+
1589
+ return loss
1590
+
1591
+ @add_start_docstrings_to_model_forward(HIERA_INPUTS_DOCSTRING)
1592
+ @replace_return_docstrings(output_type=HieraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1593
+ def forward(
1594
+ self,
1595
+ pixel_values: Optional[torch.Tensor] = None,
1596
+ noise: Optional[torch.FloatTensor] = None,
1597
+ head_mask: Optional[torch.Tensor] = None,
1598
+ output_attentions: Optional[bool] = None,
1599
+ output_hidden_states: Optional[bool] = None,
1600
+ interpolate_pos_encoding: Optional[bool] = None,
1601
+ return_dict: Optional[bool] = None,
1602
+ ) -> Union[tuple, HieraForPreTrainingOutput]:
1603
+ r"""
1604
+ noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*) which is
1605
+ mainly used for testing purposes to control randomness and maintain the reproducibility
1606
+ when is_mae is set to True.
1607
+
1608
+ Returns:
1609
+
1610
+ Examples:
1611
+ ```python
1612
+ >>> from transformers import AutoImageProcessor, HieraForPreTraining
1613
+ >>> import torch
1614
+ >>> from PIL import Image
1615
+ >>> import requests
1616
+
1617
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1618
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1619
+
1620
+ >>> image_processor = AutoImageProcessor.from_pretrained("EduardoPacheco/hiera-tiny-224-mae")
1621
+ >>> model = HieraForPreTraining.from_pretrained("EduardoPacheco/hiera-tiny-224-mae")
1622
+
1623
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1624
+
1625
+ >>> outputs = model(**inputs)
1626
+ >>> logits = outputs.logits
1627
+ >>> list(logits.shape)
1628
+ [1, 196, 768]
1629
+ ```"""
1630
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1631
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1632
+ output_hidden_states = (
1633
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1634
+ )
1635
+
1636
+ outputs = self.hiera(
1637
+ pixel_values,
1638
+ noise=noise,
1639
+ head_mask=head_mask,
1640
+ output_attentions=output_attentions,
1641
+ output_hidden_states=True,
1642
+ interpolate_pos_encoding=interpolate_pos_encoding,
1643
+ return_dict=True,
1644
+ )
1645
+
1646
+ feature_maps = outputs.reshaped_hidden_states
1647
+ mask = outputs.mask
1648
+ ids_to_restore = outputs.ids_restore
1649
+ # Take only the query pooled and last hidden states
1650
+ feature_maps = feature_maps[1 : self.hiera.config.num_query_pool + 1] + (feature_maps[-1],)
1651
+ fused_hidden_states = self.multiscale_fusion(feature_maps)
1652
+ fused_hidden_states = self.encoder_norm(fused_hidden_states)
1653
+
1654
+ # Reconstruct pixel values
1655
+ logits, mask = self.decoder(
1656
+ fused_hidden_states,
1657
+ mask=mask,
1658
+ head_mask=head_mask,
1659
+ output_attentions=output_attentions,
1660
+ )
1661
+
1662
+ loss = self.forward_loss(pixel_values, logits, mask)
1663
+
1664
+ if not return_dict:
1665
+ output = (logits, mask, ids_to_restore)
1666
+ if output_hidden_states:
1667
+ output = output + (outputs.hidden_states,)
1668
+ if output_attentions:
1669
+ output = output + (outputs.attentions,)
1670
+ if output_hidden_states:
1671
+ output = output + (outputs.reshaped_hidden_states,)
1672
+ return ((loss,) + output) if loss is not None else output
1673
+
1674
+ return HieraForPreTrainingOutput(
1675
+ loss=loss,
1676
+ logits=logits,
1677
+ mask=mask,
1678
+ ids_restore=ids_to_restore,
1679
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
1680
+ attentions=outputs.attentions,
1681
+ reshaped_hidden_states=outputs.reshaped_hidden_states if output_hidden_states else None,
1682
+ )
1683
+
1684
+
1685
+ @add_start_docstrings(
1686
+ """
1687
+ Hiera Model transformer with an image classification head on top (a linear layer on top of the final hidden state with
1688
+ average pooling) e.g. for ImageNet.
1689
+
1690
+ <Tip>
1691
+
1692
+ Note that it's possible to fine-tune Hiera on higher resolution images than the ones it has been trained on, by
1693
+ setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
1694
+ position embeddings to the higher resolution.
1695
+
1696
+ </Tip>
1697
+ """,
1698
+ HIERA_START_DOCSTRING,
1699
+ )
1700
+ class HieraForImageClassification(HieraPreTrainedModel):
1701
+ def __init__(self, config: HieraConfig) -> None:
1702
+ super().__init__(config)
1703
+
1704
+ self.num_labels = config.num_labels
1705
+ self.hiera = HieraModel(config, add_pooling_layer=True, is_mae=False)
1706
+
1707
+ # Classifier head
1708
+ self.classifier = (
1709
+ nn.Linear(self.hiera.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
1710
+ )
1711
+
1712
+ # Initialize weights and apply final processing
1713
+ self.post_init()
1714
+
1715
+ @add_start_docstrings_to_model_forward(HIERA_INPUTS_DOCSTRING)
1716
+ @add_code_sample_docstrings(
1717
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
1718
+ output_type=HieraForImageClassificationOutput,
1719
+ config_class=_CONFIG_FOR_DOC,
1720
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
1721
+ )
1722
+ def forward(
1723
+ self,
1724
+ pixel_values: Optional[torch.Tensor] = None,
1725
+ head_mask: Optional[torch.Tensor] = None,
1726
+ labels: Optional[torch.Tensor] = None,
1727
+ output_attentions: Optional[bool] = None,
1728
+ output_hidden_states: Optional[bool] = None,
1729
+ interpolate_pos_encoding: Optional[bool] = None,
1730
+ return_dict: Optional[bool] = None,
1731
+ ) -> Union[tuple, HieraForImageClassificationOutput]:
1732
+ r"""
1733
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1734
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
1735
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1736
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1737
+ """
1738
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1739
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1740
+ output_hidden_states = (
1741
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1742
+ )
1743
+
1744
+ outputs = self.hiera(
1745
+ pixel_values,
1746
+ head_mask=head_mask,
1747
+ output_attentions=output_attentions,
1748
+ output_hidden_states=output_hidden_states,
1749
+ interpolate_pos_encoding=interpolate_pos_encoding,
1750
+ return_dict=return_dict,
1751
+ )
1752
+
1753
+ pooled_output = outputs[1]
1754
+
1755
+ logits = self.classifier(pooled_output)
1756
+
1757
+ loss = None
1758
+ if labels is not None:
1759
+ # move labels to correct device to enable model parallelism
1760
+ labels = labels.to(logits.device)
1761
+ if self.config.problem_type is None:
1762
+ if self.num_labels == 1:
1763
+ self.config.problem_type = "regression"
1764
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1765
+ self.config.problem_type = "single_label_classification"
1766
+ else:
1767
+ self.config.problem_type = "multi_label_classification"
1768
+
1769
+ if self.config.problem_type == "regression":
1770
+ loss_fct = MSELoss()
1771
+ if self.num_labels == 1:
1772
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1773
+ else:
1774
+ loss = loss_fct(logits, labels)
1775
+ elif self.config.problem_type == "single_label_classification":
1776
+ loss_fct = CrossEntropyLoss()
1777
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1778
+ elif self.config.problem_type == "multi_label_classification":
1779
+ loss_fct = BCEWithLogitsLoss()
1780
+ loss = loss_fct(logits, labels)
1781
+
1782
+ if not return_dict:
1783
+ output = (logits,) + outputs[4:]
1784
+ return ((loss,) + output) if loss is not None else output
1785
+
1786
+ return HieraForImageClassificationOutput(
1787
+ loss=loss,
1788
+ logits=logits,
1789
+ hidden_states=outputs.hidden_states,
1790
+ attentions=outputs.attentions,
1791
+ reshaped_hidden_states=outputs.reshaped_hidden_states,
1792
+ )
1793
+
1794
+
1795
+ @add_start_docstrings(
1796
+ """
1797
+ Hiera backbone, to be used with frameworks like DETR and MaskFormer.
1798
+ """,
1799
+ HIERA_START_DOCSTRING,
1800
+ )
1801
+ class HieraBackbone(HieraPreTrainedModel, BackboneMixin):
1802
+ def __init__(self, config: HieraConfig):
1803
+ super().__init__(config)
1804
+ super()._init_backbone(config)
1805
+
1806
+ self.num_features = [config.embed_dim] + [
1807
+ int(config.embed_dim * config.embed_dim_multiplier**i) for i in range(len(config.depths))
1808
+ ]
1809
+ self.embeddings = HieraEmbeddings(config, is_mae=False)
1810
+ self.encoder = HieraEncoder(config)
1811
+
1812
+ # Add layer norms to hidden states of out_features
1813
+ hidden_states_norms = {}
1814
+ for stage, num_channels in zip(self._out_features, self.channels):
1815
+ hidden_states_norms[stage] = nn.LayerNorm(num_channels)
1816
+ self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
1817
+
1818
+ # Initialize weights and apply final processing
1819
+ self.post_init()
1820
+
1821
+ def get_input_embeddings(self):
1822
+ return self.embeddings.patch_embeddings
1823
+
1824
+ def forward(
1825
+ self,
1826
+ pixel_values: torch.Tensor,
1827
+ output_hidden_states: Optional[bool] = None,
1828
+ output_attentions: Optional[bool] = None,
1829
+ return_dict: Optional[bool] = None,
1830
+ ) -> BackboneOutput:
1831
+ """
1832
+ Returns:
1833
+
1834
+ Examples:
1835
+
1836
+ ```python
1837
+ >>> from transformers import AutoImageProcessor, AutoBackbone
1838
+ >>> import torch
1839
+ >>> from PIL import Image
1840
+ >>> import requests
1841
+
1842
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1843
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1844
+
1845
+ >>> processor = AutoImageProcessor.from_pretrained("EduardoPacheco/hiera-tiny-224")
1846
+ >>> model = AutoBackbone.from_pretrained(
1847
+ ... "EduardoPacheco/hiera-tiny-224", out_features=["stage1", "stage2", "stage3", "stage4"]
1848
+ ... )
1849
+
1850
+ >>> inputs = processor(image, return_tensors="pt")
1851
+ >>> outputs = model(**inputs)
1852
+ >>> feature_maps = outputs.feature_maps
1853
+ >>> list(feature_maps[-1].shape)
1854
+ [1, 768, 7, 7]
1855
+ ```"""
1856
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1857
+ output_hidden_states = (
1858
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1859
+ )
1860
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1861
+
1862
+ embedding_output, _, _ = self.embeddings(pixel_values)
1863
+
1864
+ outputs = self.encoder(
1865
+ embedding_output,
1866
+ head_mask=None,
1867
+ output_attentions=output_attentions,
1868
+ output_hidden_states=True,
1869
+ return_dict=True,
1870
+ )
1871
+
1872
+ hidden_states = outputs.reshaped_hidden_states
1873
+
1874
+ feature_maps = ()
1875
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
1876
+ if stage in self.out_features:
1877
+ batch_size, height, width, num_channels = hidden_state.shape
1878
+ hidden_state = hidden_state.view(batch_size, height * width, num_channels)
1879
+ hidden_state = self.hidden_states_norms[stage](hidden_state)
1880
+ hidden_state = hidden_state.view(batch_size, height, width, num_channels)
1881
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
1882
+ feature_maps += (hidden_state,)
1883
+
1884
+ if not return_dict:
1885
+ output = (feature_maps,)
1886
+ if output_hidden_states:
1887
+ output += (outputs.hidden_states,)
1888
+ return output
1889
+
1890
+ return BackboneOutput(
1891
+ feature_maps=feature_maps,
1892
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
1893
+ attentions=outputs.attentions,
1894
+ )
1895
+ # %%
1896
+
1897
+
1898
  def is_gpu_available():
1899
  """Check if the python package `onnxruntime-gpu` is installed."""
1900
  return torch.cuda.is_available()
 
1910
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1911
  print(f"Using devide: {self.device}")
1912
 
1913
+ image_processor = AutoImageProcessor.from_pretrained("./hiera_model/")
1914
+ model = HieraForImageClassification.from_pretrained("./hiera_model/", num_labels =1784 ).to(self.device).eval()
1915
 
1916
  return model, image_processor
1917
 
 
1948
  user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
1949
  user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)
1950
 
1951
+ #%%
1952
  if __name__ == "__main__":
1953
 
1954
  import zipfile
 
1968
  model_name=MODEL_NAME
1969
  )
1970
 
1971
+
1972
+ # import requests
1973
+ # image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
1974
+ # # %%
1975
+ # image
1976
+ # # %%
1977
+ # model= PytorchWorker()
1978
+ # # %%
1979
+ # output = model.predict_image(image)
1980
+ # # %%
1981
+ # output
1982
+ # # %%
1983
+ # import numpy as np
1984
+ # np.argmax(output)
1985
+ # # %%