RADIO / adaptor_registry.py
gheinrich's picture
Upload model
be257a4 verified
raw
history blame
1.37 kB
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from argparse import Namespace
from typing import Dict, Any
import torch
from .adaptor_generic import GenericAdaptor, AdaptorBase
dict_t = Dict[str, Any]
state_t = Dict[str, torch.Tensor]
class AdaptorRegistry:
def __init__(self):
self._registry = {}
def register_adaptor(self, name):
def decorator(factory_function):
if name in self._registry:
raise ValueError(f"Model '{name}' already registered")
self._registry[name] = factory_function
return factory_function
return decorator
def create_adaptor(self, name, main_config: Namespace, adaptor_config: dict_t, state: state_t) -> AdaptorBase:
if name not in self._registry:
return GenericAdaptor(main_config, adaptor_config, state)
return self._registry[name](main_config, adaptor_config, state)
# Creating an instance of the registry
adaptor_registry = AdaptorRegistry()