Spaces:
Runtime error
Runtime error
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is licensed under a Creative Commons | |
# Attribution-NonCommercial-ShareAlike 4.0 International License. | |
# You should have received a copy of the license along with this | |
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ | |
import re | |
import contextlib | |
import functools | |
import numpy as np | |
import torch | |
import warnings | |
import dnnlib | |
#---------------------------------------------------------------------------- | |
# Re-seed torch & numpy random generators based on the given arguments. | |
def set_random_seed(*args): | |
seed = hash(args) % (1 << 31) | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
#---------------------------------------------------------------------------- | |
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the | |
# same constant is used multiple times. | |
_constant_cache = dict() | |
def constant(value, shape=None, dtype=None, device=None, memory_format=None): | |
value = np.asarray(value) | |
if shape is not None: | |
shape = tuple(shape) | |
if dtype is None: | |
dtype = torch.get_default_dtype() | |
if device is None: | |
device = torch.device('cpu') | |
if memory_format is None: | |
memory_format = torch.contiguous_format | |
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) | |
tensor = _constant_cache.get(key, None) | |
if tensor is None: | |
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) | |
if shape is not None: | |
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) | |
tensor = tensor.contiguous(memory_format=memory_format) | |
_constant_cache[key] = tensor | |
return tensor | |
#---------------------------------------------------------------------------- | |
# Variant of constant() that inherits dtype and device from the given | |
# reference tensor by default. | |
def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None): | |
if dtype is None: | |
dtype = ref.dtype | |
if device is None: | |
device = ref.device | |
return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format) | |
#---------------------------------------------------------------------------- | |
# Cached construction of temporary tensors in pinned CPU memory. | |
def pinned_buf(shape, dtype): | |
return torch.empty(shape, dtype=dtype).pin_memory() | |
#---------------------------------------------------------------------------- | |
# Symbolic assert. | |
try: | |
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access | |
except AttributeError: | |
symbolic_assert = torch.Assert # 1.7.0 | |
#---------------------------------------------------------------------------- | |
# Context manager to temporarily suppress known warnings in torch.jit.trace(). | |
# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 | |
def suppress_tracer_warnings(): | |
flt = ('ignore', None, torch.jit.TracerWarning, None, 0) | |
warnings.filters.insert(0, flt) | |
yield | |
warnings.filters.remove(flt) | |
#---------------------------------------------------------------------------- | |
# Assert that the shape of a tensor matches the given list of integers. | |
# None indicates that the size of a dimension is allowed to vary. | |
# Performs symbolic assertion when used in torch.jit.trace(). | |
def assert_shape(tensor, ref_shape): | |
if tensor.ndim != len(ref_shape): | |
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') | |
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): | |
if ref_size is None: | |
pass | |
elif isinstance(ref_size, torch.Tensor): | |
with suppress_tracer_warnings(): # as_tensor results are registered as constants | |
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') | |
elif isinstance(size, torch.Tensor): | |
with suppress_tracer_warnings(): # as_tensor results are registered as constants | |
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') | |
elif size != ref_size: | |
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') | |
#---------------------------------------------------------------------------- | |
# Function decorator that calls torch.autograd.profiler.record_function(). | |
def profiled_function(fn): | |
def decorator(*args, **kwargs): | |
with torch.autograd.profiler.record_function(fn.__name__): | |
return fn(*args, **kwargs) | |
decorator.__name__ = fn.__name__ | |
return decorator | |
#---------------------------------------------------------------------------- | |
# Sampler for torch.utils.data.DataLoader that loops over the dataset | |
# indefinitely, shuffling items as it goes. | |
class InfiniteSampler(torch.utils.data.Sampler): | |
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, start_idx=0): | |
assert len(dataset) > 0 | |
assert num_replicas > 0 | |
assert 0 <= rank < num_replicas | |
warnings.filterwarnings('ignore', '`data_source` argument is not used and will be removed') | |
super().__init__(dataset) | |
self.dataset_size = len(dataset) | |
self.start_idx = start_idx + rank | |
self.stride = num_replicas | |
self.shuffle = shuffle | |
self.seed = seed | |
def __iter__(self): | |
idx = self.start_idx | |
epoch = None | |
while True: | |
if epoch != idx // self.dataset_size: | |
epoch = idx // self.dataset_size | |
order = np.arange(self.dataset_size) | |
if self.shuffle: | |
np.random.RandomState(hash((self.seed, epoch)) % (1 << 31)).shuffle(order) | |
yield int(order[idx % self.dataset_size]) | |
idx += self.stride | |
#---------------------------------------------------------------------------- | |
# Utilities for operating with torch.nn.Module parameters and buffers. | |
def params_and_buffers(module): | |
assert isinstance(module, torch.nn.Module) | |
return list(module.parameters()) + list(module.buffers()) | |
def named_params_and_buffers(module): | |
assert isinstance(module, torch.nn.Module) | |
return list(module.named_parameters()) + list(module.named_buffers()) | |
def copy_params_and_buffers(src_module, dst_module, require_all=False): | |
assert isinstance(src_module, torch.nn.Module) | |
assert isinstance(dst_module, torch.nn.Module) | |
src_tensors = dict(named_params_and_buffers(src_module)) | |
for name, tensor in named_params_and_buffers(dst_module): | |
assert (name in src_tensors) or (not require_all) | |
if name in src_tensors: | |
tensor.copy_(src_tensors[name]) | |
#---------------------------------------------------------------------------- | |
# Context manager for easily enabling/disabling DistributedDataParallel | |
# synchronization. | |
def ddp_sync(module, sync): | |
assert isinstance(module, torch.nn.Module) | |
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): | |
yield | |
else: | |
with module.no_sync(): | |
yield | |
#---------------------------------------------------------------------------- | |
# Check DistributedDataParallel consistency across processes. | |
def check_ddp_consistency(module, ignore_regex=None): | |
assert isinstance(module, torch.nn.Module) | |
for name, tensor in named_params_and_buffers(module): | |
fullname = type(module).__name__ + '.' + name | |
if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): | |
continue | |
tensor = tensor.detach() | |
if tensor.is_floating_point(): | |
tensor = torch.nan_to_num(tensor) | |
other = tensor.clone() | |
torch.distributed.broadcast(tensor=other, src=0) | |
assert (tensor == other).all(), fullname | |
#---------------------------------------------------------------------------- | |
# Print summary table of module hierarchy. | |
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): | |
assert isinstance(module, torch.nn.Module) | |
assert not isinstance(module, torch.jit.ScriptModule) | |
assert isinstance(inputs, (tuple, list)) | |
# Register hooks. | |
entries = [] | |
nesting = [0] | |
def pre_hook(_mod, _inputs): | |
nesting[0] += 1 | |
def post_hook(mod, _inputs, outputs): | |
nesting[0] -= 1 | |
if nesting[0] <= max_nesting: | |
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] | |
outputs = [t for t in outputs if isinstance(t, torch.Tensor)] | |
entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) | |
hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] | |
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] | |
# Run module. | |
outputs = module(*inputs) | |
for hook in hooks: | |
hook.remove() | |
# Identify unique outputs, parameters, and buffers. | |
tensors_seen = set() | |
for e in entries: | |
e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] | |
e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] | |
e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] | |
tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} | |
# Filter out redundant entries. | |
if skip_redundant: | |
entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] | |
# Construct table. | |
rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] | |
rows += [['---'] * len(rows[0])] | |
param_total = 0 | |
buffer_total = 0 | |
submodule_names = {mod: name for name, mod in module.named_modules()} | |
for e in entries: | |
name = '<top-level>' if e.mod is module else submodule_names[e.mod] | |
param_size = sum(t.numel() for t in e.unique_params) | |
buffer_size = sum(t.numel() for t in e.unique_buffers) | |
output_shapes = [str(list(t.shape)) for t in e.outputs] | |
output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] | |
rows += [[ | |
name + (':0' if len(e.outputs) >= 2 else ''), | |
str(param_size) if param_size else '-', | |
str(buffer_size) if buffer_size else '-', | |
(output_shapes + ['-'])[0], | |
(output_dtypes + ['-'])[0], | |
]] | |
for idx in range(1, len(e.outputs)): | |
rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] | |
param_total += param_size | |
buffer_total += buffer_size | |
rows += [['---'] * len(rows[0])] | |
rows += [['Total', str(param_total), str(buffer_total), '-', '-']] | |
# Print table. | |
widths = [max(len(cell) for cell in column) for column in zip(*rows)] | |
print() | |
for row in rows: | |
print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) | |
print() | |
#---------------------------------------------------------------------------- | |
# Tile a batch of images into a 2D grid. | |
def tile_images(x, w, h): | |
assert x.ndim == 4 # NCHW => CHW | |
return x.reshape(h, w, *x.shape[1:]).permute(2, 0, 3, 1, 4).reshape(x.shape[1], h * x.shape[2], w * x.shape[3]) | |
#---------------------------------------------------------------------------- | |