|
from transformers import PretrainedConfig |
|
|
|
|
|
|
|
class DaViTConfig(PretrainedConfig): |
|
model_type = "davit" |
|
|
|
def __init__( |
|
self, |
|
in_chans=3, |
|
|
|
depths=(1, 1, 9, 1), |
|
patch_size=(7, 3, 3, 3), |
|
patch_stride=(4, 2, 2, 2), |
|
patch_padding=(3, 1, 1, 1), |
|
patch_prenorm=(False, True, True, True), |
|
embed_dims=(256, 512, 1024, 2048), |
|
num_heads=(8, 16, 32, 64), |
|
num_groups=(8, 16, 32, 64), |
|
window_size=12, |
|
mlp_ratio=4.0, |
|
qkv_bias=True, |
|
drop_path_rate=0.1, |
|
norm_layer="layer_norm", |
|
enable_checkpoint=False, |
|
conv_at_attn=True, |
|
conv_at_ffn=True, |
|
projection_dim=1024, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.in_chans = in_chans |
|
|
|
self.depths = depths |
|
self.patch_size = patch_size |
|
self.patch_stride = patch_stride |
|
self.patch_padding = patch_padding |
|
self.patch_prenorm = patch_prenorm |
|
self.embed_dims = embed_dims |
|
self.num_heads = num_heads |
|
self.num_groups = num_groups |
|
self.window_size = window_size |
|
self.mlp_ratio = mlp_ratio |
|
self.qkv_bias = qkv_bias |
|
self.drop_path_rate = drop_path_rate |
|
self.norm_layer = norm_layer |
|
self.enable_checkpoint = enable_checkpoint |
|
self.conv_at_attn = conv_at_attn |
|
self.conv_at_ffn = conv_at_ffn |
|
self.projection_dim = projection_dim |
|
|