# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # This work is licensed under a Creative Commons # Attribution-NonCommercial-ShareAlike 4.0 International License. # You should have received a copy of the license along with this # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ """Facilities for pickling Python code alongside other data. The pickled code is automatically imported into a separate Python module during unpickling. This way, any previously exported pickles will remain usable even if the original code is no longer available, or if the current version of the code is not consistent with what was originally pickled.""" import sys import pickle import io import inspect import copy import uuid import types import functools import dnnlib #---------------------------------------------------------------------------- _version = 6 # internal version number _decorators = set() # {decorator_class, ...} _import_hooks = [] # [hook_function, ...] _module_to_src_dict = dict() # {module: src, ...} _src_to_module_dict = dict() # {src: module, ...} #---------------------------------------------------------------------------- def persistent_class(orig_class): r"""Class decorator that extends a given class to save its source code when pickled. Example: from torch_utils import persistence @persistence.persistent_class class MyNetwork(torch.nn.Module): def __init__(self, num_inputs, num_outputs): super().__init__() self.fc = MyLayer(num_inputs, num_outputs) ... @persistence.persistent_class class MyLayer(torch.nn.Module): ... When pickled, any instance of `MyNetwork` and `MyLayer` will save its source code alongside other internal state (e.g., parameters, buffers, and submodules). This way, any previously exported pickle will remain usable even if the class definitions have been modified or are no longer available. The decorator saves the source code of the entire Python module containing the decorated class. It does *not* save the source code of any imported modules. Thus, the imported modules must be available during unpickling, also including `torch_utils.persistence` itself. It is ok to call functions defined in the same module from the decorated class. However, if the decorated class depends on other classes defined in the same module, they must be decorated as well. This is illustrated in the above example in the case of `MyLayer`. It is also possible to employ the decorator just-in-time before calling the constructor. For example: cls = MyLayer if want_to_make_it_persistent: cls = persistence.persistent_class(cls) layer = cls(num_inputs, num_outputs) As an additional feature, the decorator also keeps track of the arguments that were used to construct each instance of the decorated class. The arguments can be queried via `obj.init_args` and `obj.init_kwargs`, and they are automatically pickled alongside other object state. This feature can be disabled on a per-instance basis by setting `self._record_init_args = False` in the constructor. A typical use case is to first unpickle a previous instance of a persistent class, and then upgrade it to use the latest version of the source code: with open('old_pickle.pkl', 'rb') as f: old_net = pickle.load(f) new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) misc.copy_params_and_buffers(old_net, new_net, require_all=True) """ assert isinstance(orig_class, type) if is_persistent(orig_class): return orig_class assert orig_class.__module__ in sys.modules orig_module = sys.modules[orig_class.__module__] orig_module_src = _module_to_src(orig_module) @functools.wraps(orig_class, updated=()) class Decorator(orig_class): _orig_module_src = orig_module_src _orig_class_name = orig_class.__name__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) record_init_args = getattr(self, '_record_init_args', True) self._init_args = copy.deepcopy(args) if record_init_args else None self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None assert orig_class.__name__ in orig_module.__dict__ _check_pickleable(self.__reduce__()) @property def init_args(self): assert self._init_args is not None return copy.deepcopy(self._init_args) @property def init_kwargs(self): assert self._init_kwargs is not None return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) def __reduce__(self): fields = list(super().__reduce__()) fields += [None] * max(3 - len(fields), 0) if fields[0] is not _reconstruct_persistent_obj: meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) fields[0] = _reconstruct_persistent_obj # reconstruct func fields[1] = (meta,) # reconstruct args fields[2] = None # state dict return tuple(fields) _decorators.add(Decorator) return Decorator #---------------------------------------------------------------------------- def is_persistent(obj): r"""Test whether the given object or class is persistent, i.e., whether it will save its source code when pickled. """ try: if obj in _decorators: return True except TypeError: pass return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck #---------------------------------------------------------------------------- def import_hook(hook): r"""Register an import hook that is called whenever a persistent object is being unpickled. A typical use case is to patch the pickled source code to avoid errors and inconsistencies when the API of some imported module has changed. The hook should have the following signature: hook(meta) -> modified meta `meta` is an instance of `dnnlib.EasyDict` with the following fields: type: Type of the persistent object, e.g. `'class'`. version: Internal version number of `torch_utils.persistence`. module_src Original source code of the Python module. class_name: Class name in the original Python module. state: Internal state of the object. Example: @persistence.import_hook def wreck_my_network(meta): if meta.class_name == 'MyNetwork': print('MyNetwork is being imported. I will wreck it!') meta.module_src = meta.module_src.replace("True", "False") return meta """ assert callable(hook) _import_hooks.append(hook) #---------------------------------------------------------------------------- def _reconstruct_persistent_obj(meta): r"""Hook that is called internally by the `pickle` module to unpickle a persistent object. """ meta = dnnlib.EasyDict(meta) meta.state = dnnlib.EasyDict(meta.state) for hook in _import_hooks: meta = hook(meta) assert meta is not None assert meta.version == _version module = _src_to_module(meta.module_src) assert meta.type == 'class' orig_class = module.__dict__[meta.class_name] decorator_class = persistent_class(orig_class) obj = decorator_class.__new__(decorator_class) setstate = getattr(obj, '__setstate__', None) if callable(setstate): setstate(meta.state) # pylint: disable=not-callable else: obj.__dict__.update(meta.state) return obj #---------------------------------------------------------------------------- def _module_to_src(module): r"""Query the source code of a given Python module. """ src = _module_to_src_dict.get(module, None) if src is None: src = inspect.getsource(module) _module_to_src_dict[module] = src _src_to_module_dict[src] = module return src def _src_to_module(src): r"""Get or create a Python module for the given source code. """ module = _src_to_module_dict.get(src, None) if module is None: module_name = "_imported_module_" + uuid.uuid4().hex module = types.ModuleType(module_name) sys.modules[module_name] = module _module_to_src_dict[module] = src _src_to_module_dict[src] = module exec(src, module.__dict__) # pylint: disable=exec-used return module #---------------------------------------------------------------------------- def _check_pickleable(obj): r"""Check that the given object is pickleable, raising an exception if it is not. This function is expected to be considerably more efficient than actually pickling the object. """ def recurse(obj): if isinstance(obj, (list, tuple, set)): return [recurse(x) for x in obj] if isinstance(obj, dict): return [[recurse(x), recurse(y)] for x, y in obj.items()] if isinstance(obj, (str, int, float, bool, bytes, bytearray)): return None # Python primitive types are pickleable. if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: return None # NumPy arrays and PyTorch tensors are pickleable. if is_persistent(obj): return None # Persistent objects are pickleable, by virtue of the constructor check. return obj with io.BytesIO() as f: pickle.dump(recurse(obj), f) #----------------------------------------------------------------------------