"""Misc functions, including distributed helpers. Mostly copy-paste from torchvision references. """ from mmcv.runner import get_dist_info, init_dist import os import random import subprocess import time from collections import OrderedDict, defaultdict, deque import datetime import pickle from typing import Optional, List import socket import json, time import numpy as np import torch import torch.distributed as dist from torch import Tensor import logging import colorsys import torchvision __torchvision_need_compat_flag = float( torchvision.__version__.split('.')[1]) < 7 if __torchvision_need_compat_flag: from torchvision.ops import _new_empty_tensor from torchvision.ops.misc import _output_size def is_free_port(port: int) -> bool: ips = socket.gethostbyname_ex(socket.gethostname())[-1] ips.append('localhost') with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: return all(s.connect_ex((ip, port)) != 0 for ip in ips) def find_free_port() -> str: # Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # Binding to port 0 will cause the OS to find an available port for us sock.bind(('', 0)) port = sock.getsockname()[1] sock.close() # NOTE: there is still a chance the port could be taken by other processes. return port class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a window or the global series average.""" def __init__(self, window_size=20, fmt=None): if fmt is None: fmt = '{median:.4f} ({global_avg:.4f})' self.deque = deque(maxlen=window_size) self.total = 0.0 self.count = 0 self.fmt = fmt def update(self, value, n=1): self.deque.append(value) self.count += n self.total += value * n def synchronize_between_processes(self): """ Warning: does not synchronize the deque! """ if not is_dist_avail_and_initialized(): return t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') dist.barrier() dist.all_reduce(t) t = t.tolist() self.count = int(t[0]) self.total = t[1] @property def median(self): d = torch.tensor(list(self.deque)) if d.shape[0] == 0: return 0 return d.median().item() @property def avg(self): d = torch.tensor(list(self.deque), dtype=torch.float32) return d.mean().item() @property def global_avg(self): return self.total / self.count @property def max(self): return max(self.deque) @property def value(self): return self.deque[-1] def __str__(self): return self.fmt.format(median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value) def all_gather(data): """ Run all_gather on arbitrary picklable data (not necessarily tensors) Args: data: any picklable object Returns: list[data]: list of data gathered from each rank """ world_size = get_world_size() if world_size == 1: return [data] # serialized to a Tensor buffer = pickle.dumps(data) storage = torch.ByteStorage.from_buffer(buffer) tensor = torch.ByteTensor(storage).to('cuda') # obtain Tensor size of each rank local_size = torch.tensor([tensor.numel()], device='cuda') size_list = [torch.tensor([0], device='cuda') for _ in range(world_size)] dist.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) tensor_list = [] for _ in size_list: tensor_list.append( torch.empty((max_size, ), dtype=torch.uint8, device='cuda')) if local_size != max_size: padding = torch.empty(size=(max_size - local_size, ), dtype=torch.uint8, device='cuda') tensor = torch.cat((tensor, padding), dim=0) dist.all_gather(tensor_list, tensor) data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list def reduce_dict(input_dict, average=True): """ Args: input_dict (dict): all the values will be reduced average (bool): whether to do average or sum Reduce the values in the dictionary from all processes so that all processes have the averaged results. Returns a dict with the same fields as input_dict, after reduction. """ world_size = get_world_size() if world_size < 2: return input_dict with torch.no_grad(): names = [] values = [] # sort the keys so that they are consistent across processes # import pdb; pdb.set_trace() for k in sorted(input_dict.keys()): names.append(k) values.append(input_dict[k]) # pdb.set_trace() values = torch.stack(values, dim=0) try: dist.all_reduce(values) rank = dist.get_rank() # logging.info(f'Rank {rank} after all_reduce') except Exception as e: rank = dist.get_rank() print(f'Exception in rank {rank}: {e}') # print(f'values: {values}') # print(f'names: {names}') logging.info(f'Rank {rank} after all_reduce') if average: values /= world_size reduced_dict = {k: v for k, v in zip(names, values)} return reduced_dict def setup_logging(): logging.basicConfig(level=logging.INFO) rank = dist.get_rank() logging.info(f'Rank {rank} before all_reduce') class MetricLogger(object): def __init__(self, delimiter='\t'): self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter def update(self, **kwargs): for k, v in kwargs.items(): if isinstance(v, torch.Tensor): v = v.item() assert isinstance(v, (float, int)) self.meters[k].update(v) def __getattr__(self, attr): if attr in self.meters: return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] raise AttributeError("'{}' object has no attribute '{}'".format( type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): if meter.count > 0: loss_str.append('{}: {}'.format(name, str(meter))) return self.delimiter.join(loss_str) def synchronize_between_processes(self): for meter in self.meters.values(): meter.synchronize_between_processes() def add_meter(self, name, meter): self.meters[name] = meter def log_every(self, iterable, print_freq, header=None, logger=None): if logger is None: print_func = print else: print_func = logger.info i = 0 if not header: header = '' start_time = time.time() end = time.time() iter_time = SmoothedValue(fmt='{avg:.4f}') data_time = SmoothedValue(fmt='{avg:.4f}') space_fmt = ':' + str(len(str(len(iterable)))) + 'd' if torch.cuda.is_available(): log_msg = self.delimiter.join([ header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 'time: {time}', 'data: {data}', 'max mem: {memory:.0f}' ]) else: log_msg = self.delimiter.join([ header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 'time: {time}', 'data: {data}' ]) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) yield obj # import pdb; pdb.set_trace() iter_time.update(time.time() - end) if i % print_freq == 0 or i == len(iterable) - 1: eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): print_func( log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time), memory=torch.cuda.max_memory_allocated() / MB)) else: print_func( log_msg.format(i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time))) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print_func('{} Total time: {} ({:.4f} s / it)'.format( header, total_time_str, total_time / len(iterable))) class LogBuffer: def __init__(self): self.val_history = OrderedDict() self.n_history = OrderedDict() self.output = OrderedDict() self.ready = False def clear(self) -> None: self.val_history.clear() self.n_history.clear() self.clear_output() def clear_output(self) -> None: self.output.clear() self.ready = False def update(self, vars: dict, count: int = 1) -> None: assert isinstance(vars, dict) for key, var in vars.items(): if key not in self.val_history: self.val_history[key] = [] self.n_history[key] = [] self.val_history[key].append(var) self.n_history[key].append(count) def average(self, n: int = 0) -> None: """Average latest n values or all values.""" assert n >= 0 for key in self.val_history: values = np.array(self.val_history[key][-n:]) nums = np.array(self.n_history[key][-n:]) avg = np.sum(values * nums) / np.sum(nums) self.output[key] = avg self.ready = True def get_sha(): cwd = os.path.dirname(os.path.abspath(__file__)) def _run(command): return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() sha = 'N/A' diff = 'clean' branch = 'N/A' try: sha = _run(['git', 'rev-parse', 'HEAD']) subprocess.check_output(['git', 'diff'], cwd=cwd) diff = _run(['git', 'diff-index', 'HEAD']) diff = 'has uncommited changes' if diff else 'clean' branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) except Exception: pass message = f'sha: {sha}, status: {diff}, branch: {branch}' return message def collate_fn(batch): # import pdb; pdb.set_trace() batch = list(zip(*batch)) batch[0] = nested_tensor_from_tensor_list(batch[0]) return tuple(batch) def _max_by_axis(the_list): # type: (List[List[int]]) -> List[int] maxes = the_list[0] for sublist in the_list[1:]: for index, item in enumerate(sublist): maxes[index] = max(maxes[index], item) return maxes class NestedTensor(object): def __init__(self, tensors, mask: Optional[Tensor]): self.tensors = tensors self.mask = mask if mask == 'auto': self.mask = torch.zeros_like(tensors).to(tensors.device) if self.mask.dim() == 3: self.mask = self.mask.sum(0).to(bool) elif self.mask.dim() == 4: self.mask = self.mask.sum(1).to(bool) else: raise ValueError( 'tensors dim must be 3 or 4 but {}({})'.format( self.tensors.dim(), self.tensors.shape)) def imgsize(self): res = [] for i in range(self.tensors.shape[0]): mask = self.mask[i] maxH = (~mask).sum(0).max() maxW = (~mask).sum(1).max() res.append(torch.Tensor([maxH, maxW])) return res def to(self, device): # type: (Device) -> NestedTensor # noqa cast_tensor = self.tensors.to(device) mask = self.mask if mask is not None: assert mask is not None cast_mask = mask.to(device) else: cast_mask = None return NestedTensor(cast_tensor, cast_mask) def to_img_list_single(self, tensor, mask): assert tensor.dim() == 3, 'dim of tensor should be 3 but {}'.format( tensor.dim()) maxH = (~mask).sum(0).max() maxW = (~mask).sum(1).max() img = tensor[:, :maxH, :maxW] return img def to_img_list(self): """remove the padding and convert to img list Returns: [type]: [description] """ if self.tensors.dim() == 3: return self.to_img_list_single(self.tensors, self.mask) else: res = [] for i in range(self.tensors.shape[0]): tensor_i = self.tensors[i] mask_i = self.mask[i] res.append(self.to_img_list_single(tensor_i, mask_i)) return res @property def device(self): return self.tensors.device def decompose(self): return self.tensors, self.mask def __repr__(self): return str(self.tensors) @property def shape(self): return { 'tensors.shape': self.tensors.shape, 'mask.shape': self.mask.shape } def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): # TODO make this more general if tensor_list[0].ndim == 3: if torchvision._is_tracing(): # nested_tensor_from_tensor_list() does not export well to ONNX # call _onnx_nested_tensor_from_tensor_list() instead return _onnx_nested_tensor_from_tensor_list(tensor_list) # TODO make it support different-sized images max_size = _max_by_axis([list(img.shape) for img in tensor_list]) # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) batch_shape = [len(tensor_list)] + max_size b, c, h, w = batch_shape dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) mask = torch.ones((b, h, w), dtype=torch.bool, device=device) for img, pad_img, m in zip(tensor_list, tensor, mask): pad_img[:img.shape[0], :img.shape[1], :img.shape[2]].copy_(img) m[:img.shape[1], :img.shape[2]] = False else: raise ValueError('not supported') return NestedTensor(tensor, mask) @torch.jit.unused def _onnx_nested_tensor_from_tensor_list( tensor_list: List[Tensor]) -> NestedTensor: max_size = [] for i in range(tensor_list[0].dim()): max_size_i = torch.max( torch.stack([img.shape[i] for img in tensor_list ]).to(torch.float32)).to(torch.int64) max_size.append(max_size_i) max_size = tuple(max_size) padded_imgs = [] padded_masks = [] for img in tensor_list: padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] padded_img = torch.nn.functional.pad( img, (0, padding[2], 0, padding[1], 0, padding[0])) padded_imgs.append(padded_img) m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), 'constant', 1) padded_masks.append(padded_mask.to(torch.bool)) tensor = torch.stack(padded_imgs) mask = torch.stack(padded_masks) return NestedTensor(tensor, mask=mask) def setup_for_distributed(is_master): """This function disables printing when not in master process.""" import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): force = kwargs.pop('force', False) if is_master or force: builtin_print(*args, **kwargs) __builtin__.print = print def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def is_main_process(): return get_rank() == 0 def save_on_master(*args, **kwargs): if is_main_process(): torch.save(*args, **kwargs) def init_distributed_mode(args): if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != '': local_world_size = int(os.environ['WORLD_SIZE']) args.world_size = args.world_size * local_world_size args.gpu = args.local_rank = int(os.environ['LOCAL_RANK']) args.rank = args.rank * local_world_size + args.local_rank print('world size: {}, rank: {}, local rank: {}'.format(args.world_size, args.rank, args.local_rank)) print(json.dumps(dict(os.environ), indent=2)) elif 'SLURM_PROCID' in os.environ: args.rank = int(os.environ['SLURM_PROCID']) args.gpu = args.local_rank = int(os.environ['SLURM_LOCALID']) args.world_size = int(os.environ['SLURM_NPROCS']) print('world size: {}, world rank: {}, local rank: {}, device_count: {}'.format(args.world_size, args.rank, args.local_rank, torch.cuda.device_count())) print("os.environ['SLURM_JOB_NODELIST']:", os.environ['SLURM_JOB_NODELIST']) print(json.dumps(dict(os.environ), indent=2)) print('args:') print(json.dumps(vars(args), indent=2)) else: print('Not using distributed mode') args.distributed = False args.world_size = 1 args.rank = 0 args.local_rank = 0 return print("world_size:{} rank:{} local_rank:{}".format(args.world_size, args.rank, args.local_rank)) args.distributed = True torch.cuda.set_device(args.local_rank) args.dist_backend = 'nccl' print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True) torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) print("Before torch.distributed.barrier()") torch.distributed.barrier() print("End torch.distributed.barrier()") setup_for_distributed(args.rank == 0) @torch.no_grad() def accuracy(output, target, topk=(1, )): """Computes the precision@k for the specified values of k.""" if target.numel() == 0: return [torch.zeros([], device=output.device)] maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None): # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor """Equivalent to nn.functional.interpolate, but with support for empty batch sizes. This will eventually be supported natively by PyTorch, and this class can go away. """ if __torchvision_need_compat_flag < 0.7: if input.numel() > 0: return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners) output_shape = _output_size(2, input, size, scale_factor) output_shape = list(input.shape[:-2]) + list(output_shape) return _new_empty_tensor(input, output_shape) else: return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) class color_sys(): def __init__(self, num_colors) -> None: self.num_colors = num_colors colors = [] for i in np.arange(0., 360., 360. / num_colors): hue = i / 360. lightness = (50 + np.random.rand() * 10) / 100. saturation = (90 + np.random.rand() * 10) / 100. colors.append( tuple([ int(j * 255) for j in colorsys.hls_to_rgb(hue, lightness, saturation) ])) self.colors = colors def __call__(self, idx): return self.colors[idx] def inverse_sigmoid(x, eps=1e-3): x = x.clamp(min=0, max=1) x1 = x.clamp(min=eps) x2 = (1 - x).clamp(min=eps) return torch.log(x1 / x2) def clean_state_dict(state_dict): new_state_dict = OrderedDict() for k, v in state_dict.items(): if k[:7] == 'module.': k = k[7:] # remove `module.` new_state_dict[k] = v return new_state_dict