amaye15 commited on
Commit
c433e44
1 Parent(s): db6eb2b

Upload folder using huggingface_hub

Browse files
__init__.py ADDED
File without changes
__pycache__/__init__.cpython-312.pyc ADDED
Binary file (143 Bytes). View file
 
__pycache__/configuration_davit.cpython-312.pyc ADDED
Binary file (1.67 kB). View file
 
__pycache__/modeling_davit.cpython-312.pyc ADDED
Binary file (29.1 kB). View file
 
config.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DaViTModel"
4
+ ],
5
+ "conv_at_attn": true,
6
+ "conv_at_ffn": true,
7
+ "depths": [
8
+ 1,
9
+ 1,
10
+ 9,
11
+ 1
12
+ ],
13
+ "drop_path_rate": 0.1,
14
+ "embed_dims": [
15
+ 256,
16
+ 512,
17
+ 1024,
18
+ 2048
19
+ ],
20
+ "enable_checkpoint": false,
21
+ "in_chans": 3,
22
+ "mlp_ratio": 4.0,
23
+ "model_type": "davit",
24
+ "norm_layer": "layer_norm",
25
+ "num_groups": [
26
+ 8,
27
+ 16,
28
+ 32,
29
+ 64
30
+ ],
31
+ "num_heads": [
32
+ 8,
33
+ 16,
34
+ 32,
35
+ 64
36
+ ],
37
+ "patch_padding": [
38
+ 3,
39
+ 1,
40
+ 1,
41
+ 1
42
+ ],
43
+ "patch_prenorm": [
44
+ false,
45
+ true,
46
+ true,
47
+ true
48
+ ],
49
+ "patch_size": [
50
+ 7,
51
+ 3,
52
+ 3,
53
+ 3
54
+ ],
55
+ "patch_stride": [
56
+ 4,
57
+ 2,
58
+ 2,
59
+ 2
60
+ ],
61
+ "projection_dim": 1024,
62
+ "qkv_bias": true,
63
+ "torch_dtype": "float32",
64
+ "transformers_version": "4.43.3",
65
+ "window_size": 12
66
+ }
configuration_davit.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ # Define configuration class
5
+ class DaViTConfig(PretrainedConfig):
6
+ model_type = "davit"
7
+
8
+ def __init__(
9
+ self,
10
+ in_chans=3,
11
+ # num_classes=1000,
12
+ depths=(1, 1, 9, 1),
13
+ patch_size=(7, 3, 3, 3),
14
+ patch_stride=(4, 2, 2, 2),
15
+ patch_padding=(3, 1, 1, 1),
16
+ patch_prenorm=(False, True, True, True),
17
+ embed_dims=(256, 512, 1024, 2048),
18
+ num_heads=(8, 16, 32, 64),
19
+ num_groups=(8, 16, 32, 64),
20
+ window_size=12,
21
+ mlp_ratio=4.0,
22
+ qkv_bias=True,
23
+ drop_path_rate=0.1,
24
+ norm_layer="layer_norm",
25
+ enable_checkpoint=False,
26
+ conv_at_attn=True,
27
+ conv_at_ffn=True,
28
+ projection_dim=1024,
29
+ **kwargs
30
+ ):
31
+ super().__init__(**kwargs)
32
+ self.in_chans = in_chans
33
+ # self.num_classes = num_classes # Classes remove for AutoModel
34
+ self.depths = depths
35
+ self.patch_size = patch_size
36
+ self.patch_stride = patch_stride
37
+ self.patch_padding = patch_padding
38
+ self.patch_prenorm = patch_prenorm
39
+ self.embed_dims = embed_dims
40
+ self.num_heads = num_heads
41
+ self.num_groups = num_groups
42
+ self.window_size = window_size
43
+ self.mlp_ratio = mlp_ratio
44
+ self.qkv_bias = qkv_bias
45
+ self.drop_path_rate = drop_path_rate
46
+ self.norm_layer = norm_layer
47
+ self.enable_checkpoint = enable_checkpoint
48
+ self.conv_at_attn = conv_at_attn
49
+ self.conv_at_ffn = conv_at_ffn
50
+ self.projection_dim = projection_dim
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:315a18b687f56bdc95c40f81607a910b7ccbf24f01edb7f1dfca00f4ba5afaee
3
+ size 1442592416
modeling_davit.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ PyTorch DaViT model."""
17
+
18
+
19
+ import math
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+ import torch.nn.functional as F
24
+ import torch.utils.checkpoint as checkpoint
25
+ from collections import OrderedDict
26
+ from einops import rearrange
27
+ from timm.models.layers import DropPath, trunc_normal_
28
+
29
+ from transformers.modeling_utils import PreTrainedModel
30
+ from transformers.utils import logging
31
+
32
+ # Ensure ConvEmbed, SpatialBlock, ChannelBlock, MySequential, etc., are defined before using them
33
+ from .configuration_davit import DaViTConfig
34
+
35
+ from transformers import AutoModel, AutoConfig
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ class LearnedAbsolutePositionEmbedding2D(nn.Module):
41
+ """
42
+ This module learns positional embeddings up to a fixed maximum size.
43
+ """
44
+
45
+ def __init__(self, embedding_dim=256, num_pos=50):
46
+ super().__init__()
47
+ self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2)
48
+ self.column_embeddings = nn.Embedding(
49
+ num_pos, embedding_dim - (embedding_dim // 2)
50
+ )
51
+
52
+ def forward(self, pixel_values):
53
+ """
54
+ pixel_values: (batch_size, height, width, num_channels)
55
+ returns: (batch_size, height, width, embedding_dim * 2)
56
+ """
57
+ if len(pixel_values.shape) != 4:
58
+ raise ValueError("pixel_values must be a 4D tensor")
59
+ height, width = pixel_values.shape[1:3]
60
+ width_values = torch.arange(width, device=pixel_values.device)
61
+ height_values = torch.arange(height, device=pixel_values.device)
62
+ x_emb = self.column_embeddings(width_values)
63
+ y_emb = self.row_embeddings(height_values)
64
+ # (height, width, embedding_dim * 2)
65
+ pos = torch.cat(
66
+ [
67
+ x_emb.unsqueeze(0).repeat(height, 1, 1),
68
+ y_emb.unsqueeze(1).repeat(1, width, 1),
69
+ ],
70
+ dim=-1,
71
+ )
72
+ # (embedding_dim * 2, height, width)
73
+ pos = pos.permute(2, 0, 1)
74
+ pos = pos.unsqueeze(0)
75
+ # (batch_size, embedding_dim * 2, height, width)
76
+ pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
77
+ # (batch_size, height, width, embedding_dim * 2)
78
+ pos = pos.permute(0, 2, 3, 1)
79
+ return pos
80
+
81
+
82
+ class PositionalEmbeddingCosine1D(nn.Module):
83
+ """
84
+ This class implements a very simple positional encoding. It follows closely
85
+ the encoder from the link below:
86
+ https://pytorch.org/tutorials/beginner/translation_transformer.html
87
+
88
+ Args:
89
+ embed_dim: The dimension of the embeddings.
90
+ dropout_prob: The dropout probability.
91
+ max_seq_len: The maximum length to precompute the positional encodings.
92
+ """
93
+
94
+ def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None:
95
+ super(PositionalEmbeddingCosine1D, self).__init__()
96
+ self.embed_dim = embed_dim
97
+ self.max_seq_len = max_seq_len
98
+ # Generate the sinusoidal arrays.
99
+ factor = math.log(10000)
100
+ denominator = torch.exp(
101
+ -factor * torch.arange(0, self.embed_dim, 2) / self.embed_dim
102
+ )
103
+ # Matrix where rows correspond to a positional embedding as a function
104
+ # of the position index (i.e., the row index).
105
+ frequencies = (
106
+ torch.arange(0, self.max_seq_len).reshape(self.max_seq_len, 1) * denominator
107
+ )
108
+ pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim))
109
+ # Populate uneven entries.
110
+ pos_idx_to_embed[:, 0::2] = torch.sin(frequencies)
111
+ pos_idx_to_embed[:, 1::2] = torch.cos(frequencies)
112
+ # Save the positional embeddings in a constant buffer.
113
+ self.register_buffer("pos_idx_to_embed", pos_idx_to_embed)
114
+
115
+ def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor:
116
+ """
117
+ Args:
118
+ seq_embeds: The sequence embeddings in order. Allowed size:
119
+ 1. [T, D], where T is the length of the sequence, and D is the
120
+ frame embedding dimension.
121
+ 2. [B, T, D], where B is the batch size and T and D are the
122
+ same as above.
123
+
124
+ Returns a tensor of with the same dimensions as the input: i.e.,
125
+ [1, T, D] or [T, D].
126
+ """
127
+ shape_len = len(seq_embeds.shape)
128
+ assert 2 <= shape_len <= 3
129
+ len_seq = seq_embeds.size(-2)
130
+ assert len_seq <= self.max_seq_len
131
+ pos_embeds = self.pos_idx_to_embed[0 : seq_embeds.size(-2), :]
132
+ # Adapt pre-computed positional embeddings to the input.
133
+ if shape_len == 3:
134
+ pos_embeds = pos_embeds.view((1, pos_embeds.size(0), pos_embeds.size(1)))
135
+ return pos_embeds
136
+
137
+
138
+ class LearnedAbsolutePositionEmbedding1D(nn.Module):
139
+ """
140
+ Learnable absolute positional embeddings for 1D sequences.
141
+
142
+ Args:
143
+ embed_dim: The dimension of the embeddings.
144
+ max_seq_len: The maximum length to precompute the positional encodings.
145
+ """
146
+
147
+ def __init__(self, embedding_dim: int = 512, num_pos: int = 1024) -> None:
148
+ super(LearnedAbsolutePositionEmbedding1D, self).__init__()
149
+ self.embeddings = nn.Embedding(num_pos, embedding_dim)
150
+ self.num_pos = num_pos
151
+
152
+ def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor:
153
+ """
154
+ Args:
155
+ seq_embeds: The sequence embeddings in order. Allowed size:
156
+ 1. [T, D], where T is the length of the sequence, and D is the
157
+ frame embedding dimension.
158
+ 2. [B, T, D], where B is the batch size and T and D are the
159
+ same as above.
160
+
161
+ Returns a tensor of with the same dimensions as the input: i.e.,
162
+ [1, T, D] or [T, D].
163
+ """
164
+ shape_len = len(seq_embeds.shape)
165
+ assert 2 <= shape_len <= 3
166
+ len_seq = seq_embeds.size(-2)
167
+ assert len_seq <= self.num_pos
168
+ # [T, D]
169
+ pos_embeds = self.embeddings(torch.arange(len_seq).to(seq_embeds.device))
170
+ # Adapt pre-computed positional embeddings to the input.
171
+ if shape_len == 3:
172
+ pos_embeds = pos_embeds.view((1, pos_embeds.size(0), pos_embeds.size(1)))
173
+ return pos_embeds
174
+
175
+
176
+ class MySequential(nn.Sequential):
177
+ def forward(self, *inputs):
178
+ for module in self._modules.values():
179
+ if type(inputs) == tuple:
180
+ inputs = module(*inputs)
181
+ else:
182
+ inputs = module(inputs)
183
+ return inputs
184
+
185
+
186
+ class PreNorm(nn.Module):
187
+ def __init__(self, norm, fn, drop_path=None):
188
+ super().__init__()
189
+ self.norm = norm
190
+ self.fn = fn
191
+ self.drop_path = drop_path
192
+
193
+ def forward(self, x, *args, **kwargs):
194
+ shortcut = x
195
+ if self.norm != None:
196
+ x, size = self.fn(self.norm(x), *args, **kwargs)
197
+ else:
198
+ x, size = self.fn(x, *args, **kwargs)
199
+
200
+ if self.drop_path:
201
+ x = self.drop_path(x)
202
+
203
+ x = shortcut + x
204
+
205
+ return x, size
206
+
207
+
208
+ class Mlp(nn.Module):
209
+ def __init__(
210
+ self,
211
+ in_features,
212
+ hidden_features=None,
213
+ out_features=None,
214
+ act_layer=nn.GELU,
215
+ ):
216
+ super().__init__()
217
+ out_features = out_features or in_features
218
+ hidden_features = hidden_features or in_features
219
+ self.net = nn.Sequential(
220
+ OrderedDict(
221
+ [
222
+ ("fc1", nn.Linear(in_features, hidden_features)),
223
+ ("act", act_layer()),
224
+ ("fc2", nn.Linear(hidden_features, out_features)),
225
+ ]
226
+ )
227
+ )
228
+
229
+ def forward(self, x, size):
230
+ return self.net(x), size
231
+
232
+
233
+ class DepthWiseConv2d(nn.Module):
234
+ def __init__(
235
+ self,
236
+ dim_in,
237
+ kernel_size,
238
+ padding,
239
+ stride,
240
+ bias=True,
241
+ ):
242
+ super().__init__()
243
+ self.dw = nn.Conv2d(
244
+ dim_in,
245
+ dim_in,
246
+ kernel_size=kernel_size,
247
+ padding=padding,
248
+ groups=dim_in,
249
+ stride=stride,
250
+ bias=bias,
251
+ )
252
+
253
+ def forward(self, x, size):
254
+ B, N, C = x.shape
255
+ H, W = size
256
+ assert N == H * W
257
+
258
+ x = self.dw(x.transpose(1, 2).view(B, C, H, W))
259
+ size = (x.size(-2), x.size(-1))
260
+ x = x.flatten(2).transpose(1, 2)
261
+ return x, size
262
+
263
+
264
+ class ConvEmbed(nn.Module):
265
+ """Image to Patch Embedding"""
266
+
267
+ def __init__(
268
+ self,
269
+ patch_size=7,
270
+ in_chans=3,
271
+ embed_dim=64,
272
+ stride=4,
273
+ padding=2,
274
+ norm_layer=None,
275
+ pre_norm=True,
276
+ ):
277
+ super().__init__()
278
+ self.patch_size = patch_size
279
+
280
+ self.proj = nn.Conv2d(
281
+ in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding
282
+ )
283
+
284
+ dim_norm = in_chans if pre_norm else embed_dim
285
+ self.norm = norm_layer(dim_norm) if norm_layer else None
286
+
287
+ self.pre_norm = pre_norm
288
+
289
+ def forward(self, x, size):
290
+ H, W = size
291
+ if len(x.size()) == 3:
292
+ if self.norm and self.pre_norm:
293
+ x = self.norm(x)
294
+ x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
295
+
296
+ x = self.proj(x)
297
+
298
+ _, _, H, W = x.shape
299
+ x = rearrange(x, "b c h w -> b (h w) c")
300
+ if self.norm and not self.pre_norm:
301
+ x = self.norm(x)
302
+
303
+ return x, (H, W)
304
+
305
+
306
+ class ChannelAttention(nn.Module):
307
+
308
+ def __init__(self, dim, groups=8, qkv_bias=True):
309
+ super().__init__()
310
+
311
+ self.groups = groups
312
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
313
+ self.proj = nn.Linear(dim, dim)
314
+
315
+ def forward(self, x, size):
316
+ B, N, C = x.shape
317
+
318
+ qkv = (
319
+ self.qkv(x)
320
+ .reshape(B, N, 3, self.groups, C // self.groups)
321
+ .permute(2, 0, 3, 1, 4)
322
+ )
323
+ q, k, v = qkv[0], qkv[1], qkv[2]
324
+
325
+ q = q * (float(N) ** -0.5)
326
+ attention = q.transpose(-1, -2) @ k
327
+ attention = attention.softmax(dim=-1)
328
+ x = (attention @ v.transpose(-1, -2)).transpose(-1, -2)
329
+ x = x.transpose(1, 2).reshape(B, N, C)
330
+ x = self.proj(x)
331
+ return x, size
332
+
333
+
334
+ class ChannelBlock(nn.Module):
335
+
336
+ def __init__(
337
+ self,
338
+ dim,
339
+ groups,
340
+ mlp_ratio=4.0,
341
+ qkv_bias=True,
342
+ drop_path_rate=0.0,
343
+ act_layer=nn.GELU,
344
+ norm_layer=nn.LayerNorm,
345
+ conv_at_attn=True,
346
+ conv_at_ffn=True,
347
+ ):
348
+ super().__init__()
349
+
350
+ drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
351
+
352
+ self.conv1 = (
353
+ PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
354
+ )
355
+ self.channel_attn = PreNorm(
356
+ norm_layer(dim),
357
+ ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias),
358
+ drop_path,
359
+ )
360
+ self.conv2 = (
361
+ PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
362
+ )
363
+ self.ffn = PreNorm(
364
+ norm_layer(dim),
365
+ Mlp(
366
+ in_features=dim,
367
+ hidden_features=int(dim * mlp_ratio),
368
+ act_layer=act_layer,
369
+ ),
370
+ drop_path,
371
+ )
372
+
373
+ def forward(self, x, size):
374
+ if self.conv1:
375
+ x, size = self.conv1(x, size)
376
+ x, size = self.channel_attn(x, size)
377
+
378
+ if self.conv2:
379
+ x, size = self.conv2(x, size)
380
+ x, size = self.ffn(x, size)
381
+
382
+ return x, size
383
+
384
+
385
+ def window_partition(x, window_size: int):
386
+ B, H, W, C = x.shape
387
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
388
+ windows = (
389
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
390
+ )
391
+ return windows
392
+
393
+
394
+ def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int):
395
+ B = batch_size
396
+ # this will cause onnx conversion failed for dynamic axis, because treated as constant
397
+ # int(windows.shape[0] / (H * W / window_size / window_size))
398
+ x = windows.view(
399
+ B, H // window_size, W // window_size, window_size, window_size, -1
400
+ )
401
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
402
+ return x
403
+
404
+
405
+ class WindowAttention(nn.Module):
406
+ def __init__(self, dim, num_heads, window_size, qkv_bias=True):
407
+
408
+ super().__init__()
409
+ self.dim = dim
410
+ self.window_size = window_size
411
+ self.num_heads = num_heads
412
+ head_dim = dim // num_heads
413
+ self.scale = float(head_dim) ** -0.5
414
+
415
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
416
+ self.proj = nn.Linear(dim, dim)
417
+
418
+ self.softmax = nn.Softmax(dim=-1)
419
+
420
+ def forward(self, x, size):
421
+
422
+ H, W = size
423
+ B, L, C = x.shape
424
+ assert L == H * W, "input feature has wrong size"
425
+
426
+ x = x.view(B, H, W, C)
427
+
428
+ pad_l = pad_t = 0
429
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
430
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
431
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
432
+ _, Hp, Wp, _ = x.shape
433
+
434
+ x = window_partition(x, self.window_size)
435
+ x = x.view(-1, self.window_size * self.window_size, C)
436
+
437
+ # W-MSA/SW-MSA
438
+ # attn_windows = self.attn(x_windows)
439
+
440
+ B_, N, C = x.shape
441
+ qkv = (
442
+ self.qkv(x)
443
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
444
+ .permute(2, 0, 3, 1, 4)
445
+ )
446
+ q, k, v = qkv[0], qkv[1], qkv[2]
447
+
448
+ q = q * self.scale
449
+ attn = q @ k.transpose(-2, -1)
450
+ attn = self.softmax(attn)
451
+
452
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
453
+ x = self.proj(x)
454
+
455
+ # merge windows
456
+ x = x.view(-1, self.window_size, self.window_size, C)
457
+ x = window_reverse(x, B, self.window_size, Hp, Wp)
458
+
459
+ if pad_r > 0 or pad_b > 0:
460
+ x = x[:, :H, :W, :].contiguous()
461
+
462
+ x = x.view(B, H * W, C)
463
+
464
+ return x, size
465
+
466
+
467
+ class SpatialBlock(nn.Module):
468
+
469
+ def __init__(
470
+ self,
471
+ dim,
472
+ num_heads,
473
+ window_size,
474
+ mlp_ratio=4.0,
475
+ qkv_bias=True,
476
+ drop_path_rate=0.0,
477
+ act_layer=nn.GELU,
478
+ norm_layer=nn.LayerNorm,
479
+ conv_at_attn=True,
480
+ conv_at_ffn=True,
481
+ ):
482
+ super().__init__()
483
+
484
+ drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
485
+
486
+ self.conv1 = (
487
+ PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
488
+ )
489
+ self.window_attn = PreNorm(
490
+ norm_layer(dim),
491
+ WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias),
492
+ drop_path,
493
+ )
494
+ self.conv2 = (
495
+ PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
496
+ )
497
+ self.ffn = PreNorm(
498
+ norm_layer(dim),
499
+ Mlp(
500
+ in_features=dim,
501
+ hidden_features=int(dim * mlp_ratio),
502
+ act_layer=act_layer,
503
+ ),
504
+ drop_path,
505
+ )
506
+
507
+ def forward(self, x, size):
508
+ if self.conv1:
509
+ x, size = self.conv1(x, size)
510
+ x, size = self.window_attn(x, size)
511
+
512
+ if self.conv2:
513
+ x, size = self.conv2(x, size)
514
+ x, size = self.ffn(x, size)
515
+ return x, size
516
+
517
+
518
+ # Define DaViT model class
519
+ class DaViTModel(PreTrainedModel):
520
+ config_class = DaViTConfig
521
+
522
+ def __init__(self, config: DaViTConfig):
523
+ super().__init__(config)
524
+
525
+ # self.num_classes = config.num_classes
526
+ self.embed_dims = config.embed_dims
527
+ self.num_heads = config.num_heads
528
+ self.num_groups = config.num_groups
529
+ self.num_stages = len(self.embed_dims)
530
+ self.enable_checkpoint = config.enable_checkpoint
531
+ assert self.num_stages == len(self.num_heads) == len(self.num_groups)
532
+
533
+ num_stages = len(config.embed_dims)
534
+ dpr = [
535
+ x.item()
536
+ for x in torch.linspace(0, config.drop_path_rate, sum(config.depths) * 2)
537
+ ]
538
+
539
+ depth_offset = 0
540
+ convs = []
541
+ blocks = []
542
+ for i in range(num_stages):
543
+ conv_embed = ConvEmbed(
544
+ patch_size=config.patch_size[i],
545
+ stride=config.patch_stride[i],
546
+ padding=config.patch_padding[i],
547
+ in_chans=config.in_chans if i == 0 else self.embed_dims[i - 1],
548
+ embed_dim=self.embed_dims[i],
549
+ norm_layer=(
550
+ nn.LayerNorm
551
+ if config.norm_layer == "layer_norm"
552
+ else nn.BatchNorm2d
553
+ ),
554
+ pre_norm=config.patch_prenorm[i],
555
+ )
556
+ convs.append(conv_embed)
557
+
558
+ block = MySequential(
559
+ *[
560
+ MySequential(
561
+ OrderedDict(
562
+ [
563
+ (
564
+ "spatial_block",
565
+ SpatialBlock(
566
+ self.embed_dims[i],
567
+ self.num_heads[i],
568
+ config.window_size,
569
+ drop_path_rate=dpr[depth_offset + j * 2],
570
+ qkv_bias=config.qkv_bias,
571
+ mlp_ratio=config.mlp_ratio,
572
+ conv_at_attn=config.conv_at_attn,
573
+ conv_at_ffn=config.conv_at_ffn,
574
+ ),
575
+ ),
576
+ (
577
+ "channel_block",
578
+ ChannelBlock(
579
+ self.embed_dims[i],
580
+ self.num_groups[i],
581
+ drop_path_rate=dpr[depth_offset + j * 2 + 1],
582
+ qkv_bias=config.qkv_bias,
583
+ mlp_ratio=config.mlp_ratio,
584
+ conv_at_attn=config.conv_at_attn,
585
+ conv_at_ffn=config.conv_at_ffn,
586
+ ),
587
+ ),
588
+ ]
589
+ )
590
+ )
591
+ for j in range(config.depths[i])
592
+ ]
593
+ )
594
+ blocks.append(block)
595
+ depth_offset += config.depths[i] * 2
596
+
597
+ self.convs = nn.ModuleList(convs)
598
+ self.blocks = nn.ModuleList(blocks)
599
+
600
+ self.norms = (
601
+ nn.LayerNorm(self.embed_dims[-1])
602
+ if config.norm_layer == "layer_norm"
603
+ else nn.BatchNorm2d(self.embed_dims[-1])
604
+ )
605
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
606
+ # self.head = (
607
+ # nn.Linear(self.embed_dims[-1], self.num_classes)
608
+ # if self.num_classes > 0
609
+ # else nn.Identity()
610
+ # )
611
+
612
+ self.apply(self._init_weights)
613
+
614
+ def _init_weights(self, m):
615
+ if isinstance(m, nn.Linear):
616
+ trunc_normal_(m.weight, std=0.02)
617
+ if m.bias is not None:
618
+ nn.init.constant_(m.bias, 0)
619
+ elif isinstance(m, nn.Conv2d):
620
+ nn.init.normal_(m.weight, std=0.02)
621
+ for name, _ in m.named_parameters():
622
+ if name in ["bias"]:
623
+ nn.init.constant_(m.bias, 0)
624
+ elif isinstance(m, nn.LayerNorm):
625
+ nn.init.constant_(m.weight, 1.0)
626
+ nn.init.constant_(m.bias, 0)
627
+ elif isinstance(m, nn.BatchNorm2d):
628
+ nn.init.constant_(m.weight, 1.0)
629
+ nn.init.constant_(m.bias, 0)
630
+
631
+ def forward_features_unpool(self, x):
632
+ """
633
+ forward until avg pooling
634
+ Args:
635
+ x (_type_): input image tensor
636
+ """
637
+ input_size = (x.size(2), x.size(3))
638
+ for conv, block in zip(self.convs, self.blocks):
639
+ x, input_size = conv(x, input_size)
640
+ if self.enable_checkpoint:
641
+ x, input_size = checkpoint.checkpoint(block, x, input_size)
642
+ else:
643
+ x, input_size = block(x, input_size)
644
+ return x
645
+
646
+ def forward_features(self, x):
647
+ x = self.forward_features_unpool(x)
648
+
649
+ # (batch_size, num_tokens, token_dim)
650
+ x = self.avgpool(x.transpose(1, 2))
651
+ # (batch_size, 1, num_tokens)
652
+ x = torch.flatten(x, 1)
653
+ x = self.norms(x)
654
+
655
+ return x
656
+
657
+ def forward(self, x):
658
+ x = self.forward_features(x)
659
+ # x = self.head(x)
660
+ return x
661
+
662
+
663
+ # Register the configuration and model
664
+ AutoConfig.register("davit", DaViTConfig)
665
+ AutoModel.register(DaViTConfig, DaViTModel)
test_davit_model.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoConfig, AutoModel
3
+ from .configuration_davit import DaViTConfig
4
+ from .modeling_davit import DaViTModel
5
+
6
+ # Register the configuration and model
7
+ AutoConfig.register("davit", DaViTConfig)
8
+ AutoModel.register(DaViTConfig, DaViTModel)
9
+
10
+ # Step 1: Create a configuration object
11
+ config = DaViTConfig()
12
+
13
+ # Step 2: Create a model object
14
+ model = AutoModel.from_config(config)
15
+
16
+ # Step 3: Run a forward pass
17
+ # Generate a random sample input tensor with shape (batch_size, channels, height, width)
18
+ batch_size = 2
19
+ channels = 3
20
+ height = 224
21
+ width = 224
22
+ sample_input = torch.randn(batch_size, channels, height, width)
23
+
24
+ # Pass the sample input through the model
25
+ output = model(sample_input)
26
+
27
+ # Print the output shape
28
+ print(f"Output shape: {output.shape}")
29
+
30
+ # Expected output shape: (batch_size, projection_dim)