Pytorch weights for Kornia ViT converted from the original google JAX vision-transformer repo.

Using it with kornia:

from kornia.contrib import VisionTransformer

vit_model = VisionTransformer.from_config('vit_l/16', pretrained=True)
...

Original weights from AugReg as recommended by google research vision transformer repo: This weight is based on the AugReg l ViT_L/16 pretrained on imagenet21k

Weights converted to PyTorch for Kornia ViT implementation (by @gau-nernst in kornia/kornia#2786)

Convert jax checkpoint function
def convert_jax_checkpoint(np_state_dict: dict[str, np.ndarray]):
    
    def get_weight(key: str) -> torch.Tensor:
        return torch.from_numpy(np_state_dict[key])
    
    state_dict = dict()
    state_dict["patch_embedding.cls_token"] = get_weight("cls")
    state_dict["patch_embedding.backbone.weight"] = get_weight("embedding/kernel").permute(3, 2, 0, 1)  # conv »
    state_dict["patch_embedding.backbone.bias"] = get_weight("embedding/bias")
    state_dict["patch_embedding.positions"] = get_weight("Transformer/posembed_input/pos_embedding").squeeze(0)
    
    # for i, block in enumerate(self.encoder.blocks):
    for i in range(100):
        prefix1 = f"encoder.blocks.{i}"
        prefix2 = f"Transformer/encoderblock_{i}"

        if f"{prefix2}/LayerNorm_0/scale" not in np_state_dict:
            break

        state_dict[f"{prefix1}.0.fn.0.weight"] = get_weight(f"{prefix2}/LayerNorm_0/scale")
        state_dict[f"{prefix1}.0.fn.0.bias"] = get_weight(f"{prefix2}/LayerNorm_0/bias")

        mha_prefix = f"{prefix2}/MultiHeadDotProductAttention_1"
        qkv_weight = [get_weight(f"{mha_prefix}/{x}/kernel") for x in ["query", "key", "value"]]
        qkv_bias = [get_weight(f"{mha_prefix}/{x}/bias") for x in ["query", "key", "value"]]
        state_dict[f"{prefix1}.0.fn.1.qkv.weight"] = torch.cat(qkv_weight, 1).flatten(1).T
        state_dict[f"{prefix1}.0.fn.1.qkv.bias"] = torch.cat(qkv_bias, 0).flatten()
        state_dict[f"{prefix1}.0.fn.1.projection.weight"] = get_weight(f"{mha_prefix}/out/kernel").flatten(0, 1»
        state_dict[f"{prefix1}.0.fn.1.projection.bias"] = get_weight(f"{mha_prefix}/out/bias")

        state_dict[f"{prefix1}.1.fn.0.weight"] = get_weight(f"{prefix2}/LayerNorm_2/scale")
        state_dict[f"{prefix1}.1.fn.0.bias"] = get_weight(f"{prefix2}/LayerNorm_2/bias")
        state_dict[f"{prefix1}.1.fn.1.0.weight"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_0/kernel").T
        state_dict[f"{prefix1}.1.fn.1.0.bias"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_0/bias")
        state_dict[f"{prefix1}.1.fn.1.3.weight"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_1/kernel").T
        state_dict[f"{prefix1}.1.fn.1.3.bias"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_1/bias")

    state_dict["norm.weight"] = get_weight("Transformer/encoder_norm/scale")
    state_dict["norm.bias"] = get_weight("Transformer/encoder_norm/bias")
    return state_dict
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.