# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import namedtuple from typing import Optional from timm.models import VisionTransformer import torch from transformers import PretrainedConfig, PreTrainedModel from .model import create_model_from_args from .input_conditioner import get_default_conditioner, InputConditioner resource_map = { 'radio_v1': 'https://huggingface.co./nvidia/RADIO/raw/main/radio_v1.pth.tar' } class RADIOConfig(PretrainedConfig): """Pretrained Hugging Face configuration for RADIO models.""" def __init__( self, args: Optional[dict] = None, version: Optional[str]="v1", **kwargs, ): self.args = args self.version = version super().__init__(**kwargs) class RADIOModel(PreTrainedModel): """Pretrained Hugging Face model for RADIO.""" def __init__(self, config): super().__init__(config) RADIOArgs = namedtuple("RADIOArgs", config.args.keys()) args = RADIOArgs(**config.args) self.model = create_model_from_args(args) self.input_conditioner: InputConditioner = get_default_conditioner() #return RADIOModel(mod, conditioner, return_summary=return_summary, return_spatial_features=return_spatial_features) def forward(self, x: torch.Tensor): x = self.input_conditioner(x) y = self.model.forward_features(x) if isinstance(y, (list, tuple)): summary, all_feat = y elif isinstance(self.model, VisionTransformer): patch_gen = getattr(self.model, 'patch_generator', None) if patch_gen is not None: summary = y[:, :patch_gen.num_cls_tokens].flatten(1) all_feat = y[:, patch_gen.num_skip:] elif self.model.global_pool == 'avg': summary = y[:, self.model.num_prefix_tokens:].mean(dim=1) all_feat = y else: summary = y[:, 0] all_feat = y[:, 1:] else: raise ValueError("Unsupported model type") if self.return_summary and self.return_spatial_features: return summary, all_feat elif self.return_summary: return summary return all_feat