Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Apache License, Version 2.0 | |
# found in the LICENSE file in the root directory of this source tree. | |
# References: | |
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py | |
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py | |
from functools import partial | |
import math | |
import logging | |
from typing import Sequence, Tuple, Union, Callable | |
from collections import OrderedDict | |
import torch | |
import torch.nn as nn | |
import torch.utils.checkpoint | |
from torch.nn.init import trunc_normal_ | |
from .dinov2.hub.backbones import dinov2_vitb14 | |
class FrozenDinoV2ImageEmbedder(nn.Module): | |
""" | |
Uses the dinov2 image encoder with camera modulation. | |
Not actually frozen... If you want that set cond_stage_trainable=False in cfg | |
""" | |
def __init__( | |
self, | |
version='dinov2_vitb14', | |
ckpt_path=None, | |
lrm_mode='plain_lrm', | |
): | |
super().__init__() | |
self.lrm_mode = lrm_mode | |
assert version in ['dinov2_vitb14', 'dinov2_vits14', 'dinov2_vitl14', 'dinov2_vitg14'] | |
self.model = dinov2_vitb14(pretrained=False) | |
if ckpt_path is not None: | |
self.load_pretrained(ckpt_path) | |
else: | |
print('None pretrained model for dinov2 encoder ...') | |
def load_pretrained(self, ckpt_path): | |
print('Loading dinov2 encoder ...') | |
orig_state_dict = torch.load(ckpt_path, map_location='cpu') | |
try: | |
ret = self.model.load_state_dict(orig_state_dict, strict=False) | |
print(ret) | |
print('Successfully loaded orig state dict') | |
except: | |
new_state_dict = OrderedDict() | |
for k, v in orig_state_dict['state_dict'].items(): | |
if 'img_encoder' in k: | |
new_state_dict[k.replace('img_encoder.model.', '')] = v | |
ret = self.model.load_state_dict(new_state_dict, strict=False) | |
print(ret) | |
print('Successfully loaded new state dict') | |
def forward(self, x, *args, **kwargs): | |
ret = self.model.forward_features_with_camera(x, *args, **kwargs) | |
output = torch.cat([ret['x_norm_clstoken'].unsqueeze(1), ret['x_norm_patchtokens']], dim=1) | |
return output | |