AiOS / util /misc.py
ttxskk
update
d7e58f0
raw
history blame
22 kB
"""Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
from mmcv.runner import get_dist_info, init_dist
import os
import random
import subprocess
import time
from collections import OrderedDict, defaultdict, deque
import datetime
import pickle
from typing import Optional, List
import socket
import json, time
import numpy as np
import torch
import torch.distributed as dist
from torch import Tensor
import logging
import colorsys
import torchvision
__torchvision_need_compat_flag = float(
torchvision.__version__.split('.')[1]) < 7
if __torchvision_need_compat_flag:
from torchvision.ops import _new_empty_tensor
from torchvision.ops.misc import _output_size
def is_free_port(port: int) -> bool:
ips = socket.gethostbyname_ex(socket.gethostname())[-1]
ips.append('localhost')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return all(s.connect_ex((ip, port)) != 0 for ip in ips)
def find_free_port() -> str:
# Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Binding to port 0 will cause the OS to find an available port for us
sock.bind(('', 0))
port = sock.getsockname()[1]
sock.close()
# NOTE: there is still a chance the port could be taken by other processes.
return port
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average."""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = '{median:.4f} ({global_avg:.4f})'
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total],
dtype=torch.float64,
device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
if d.shape[0] == 0:
return 0
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to('cuda')
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device='cuda')
size_list = [torch.tensor([0], device='cuda') for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
tensor_list = []
for _ in size_list:
tensor_list.append(
torch.empty((max_size, ), dtype=torch.uint8, device='cuda'))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size, ),
dtype=torch.uint8,
device='cuda')
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that all processes
have the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
# import pdb; pdb.set_trace()
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
# pdb.set_trace()
values = torch.stack(values, dim=0)
try:
dist.all_reduce(values)
rank = dist.get_rank()
# logging.info(f'Rank {rank} after all_reduce')
except Exception as e:
rank = dist.get_rank()
print(f'Exception in rank {rank}: {e}')
# print(f'values: {values}')
# print(f'names: {names}')
logging.info(f'Rank {rank} after all_reduce')
if average:
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict
def setup_logging():
logging.basicConfig(level=logging.INFO)
rank = dist.get_rank()
logging.info(f'Rank {rank} before all_reduce')
class MetricLogger(object):
def __init__(self, delimiter='\t'):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
if meter.count > 0:
loss_str.append('{}: {}'.format(name, str(meter)))
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None, logger=None):
if logger is None:
print_func = print
else:
print_func = logger.info
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
if torch.cuda.is_available():
log_msg = self.delimiter.join([
header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}',
'time: {time}', 'data: {data}', 'max mem: {memory:.0f}'
])
else:
log_msg = self.delimiter.join([
header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}',
'time: {time}', 'data: {data}'
])
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
# import pdb; pdb.set_trace()
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print_func(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print_func(
log_msg.format(i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print_func('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
class LogBuffer:
def __init__(self):
self.val_history = OrderedDict()
self.n_history = OrderedDict()
self.output = OrderedDict()
self.ready = False
def clear(self) -> None:
self.val_history.clear()
self.n_history.clear()
self.clear_output()
def clear_output(self) -> None:
self.output.clear()
self.ready = False
def update(self, vars: dict, count: int = 1) -> None:
assert isinstance(vars, dict)
for key, var in vars.items():
if key not in self.val_history:
self.val_history[key] = []
self.n_history[key] = []
self.val_history[key].append(var)
self.n_history[key].append(count)
def average(self, n: int = 0) -> None:
"""Average latest n values or all values."""
assert n >= 0
for key in self.val_history:
values = np.array(self.val_history[key][-n:])
nums = np.array(self.n_history[key][-n:])
avg = np.sum(values * nums) / np.sum(nums)
self.output[key] = avg
self.ready = True
def get_sha():
cwd = os.path.dirname(os.path.abspath(__file__))
def _run(command):
return subprocess.check_output(command,
cwd=cwd).decode('ascii').strip()
sha = 'N/A'
diff = 'clean'
branch = 'N/A'
try:
sha = _run(['git', 'rev-parse', 'HEAD'])
subprocess.check_output(['git', 'diff'], cwd=cwd)
diff = _run(['git', 'diff-index', 'HEAD'])
diff = 'has uncommited changes' if diff else 'clean'
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
except Exception:
pass
message = f'sha: {sha}, status: {diff}, branch: {branch}'
return message
def collate_fn(batch):
# import pdb; pdb.set_trace()
batch = list(zip(*batch))
batch[0] = nested_tensor_from_tensor_list(batch[0])
return tuple(batch)
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
class NestedTensor(object):
def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
self.mask = mask
if mask == 'auto':
self.mask = torch.zeros_like(tensors).to(tensors.device)
if self.mask.dim() == 3:
self.mask = self.mask.sum(0).to(bool)
elif self.mask.dim() == 4:
self.mask = self.mask.sum(1).to(bool)
else:
raise ValueError(
'tensors dim must be 3 or 4 but {}({})'.format(
self.tensors.dim(), self.tensors.shape))
def imgsize(self):
res = []
for i in range(self.tensors.shape[0]):
mask = self.mask[i]
maxH = (~mask).sum(0).max()
maxW = (~mask).sum(1).max()
res.append(torch.Tensor([maxH, maxW]))
return res
def to(self, device):
# type: (Device) -> NestedTensor # noqa
cast_tensor = self.tensors.to(device)
mask = self.mask
if mask is not None:
assert mask is not None
cast_mask = mask.to(device)
else:
cast_mask = None
return NestedTensor(cast_tensor, cast_mask)
def to_img_list_single(self, tensor, mask):
assert tensor.dim() == 3, 'dim of tensor should be 3 but {}'.format(
tensor.dim())
maxH = (~mask).sum(0).max()
maxW = (~mask).sum(1).max()
img = tensor[:, :maxH, :maxW]
return img
def to_img_list(self):
"""remove the padding and convert to img list
Returns:
[type]: [description]
"""
if self.tensors.dim() == 3:
return self.to_img_list_single(self.tensors, self.mask)
else:
res = []
for i in range(self.tensors.shape[0]):
tensor_i = self.tensors[i]
mask_i = self.mask[i]
res.append(self.to_img_list_single(tensor_i, mask_i))
return res
@property
def device(self):
return self.tensors.device
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
@property
def shape(self):
return {
'tensors.shape': self.tensors.shape,
'mask.shape': self.mask.shape
}
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
if torchvision._is_tracing():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return _onnx_nested_tensor_from_tensor_list(tensor_list)
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[:img.shape[0], :img.shape[1], :img.shape[2]].copy_(img)
m[:img.shape[1], :img.shape[2]] = False
else:
raise ValueError('not supported')
return NestedTensor(tensor, mask)
@torch.jit.unused
def _onnx_nested_tensor_from_tensor_list(
tensor_list: List[Tensor]) -> NestedTensor:
max_size = []
for i in range(tensor_list[0].dim()):
max_size_i = torch.max(
torch.stack([img.shape[i] for img in tensor_list
]).to(torch.float32)).to(torch.int64)
max_size.append(max_size_i)
max_size = tuple(max_size)
padded_imgs = []
padded_masks = []
for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
padded_img = torch.nn.functional.pad(
img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img)
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
padded_mask = torch.nn.functional.pad(m,
(0, padding[2], 0, padding[1]),
'constant', 1)
padded_masks.append(padded_mask.to(torch.bool))
tensor = torch.stack(padded_imgs)
mask = torch.stack(padded_masks)
return NestedTensor(tensor, mask=mask)
def setup_for_distributed(is_master):
"""This function disables printing when not in master process."""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != '':
local_world_size = int(os.environ['WORLD_SIZE'])
args.world_size = args.world_size * local_world_size
args.gpu = args.local_rank = int(os.environ['LOCAL_RANK'])
args.rank = args.rank * local_world_size + args.local_rank
print('world size: {}, rank: {}, local rank: {}'.format(args.world_size, args.rank, args.local_rank))
print(json.dumps(dict(os.environ), indent=2))
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.local_rank = int(os.environ['SLURM_LOCALID'])
args.world_size = int(os.environ['SLURM_NPROCS'])
print('world size: {}, world rank: {}, local rank: {}, device_count: {}'.format(args.world_size, args.rank, args.local_rank, torch.cuda.device_count()))
print("os.environ['SLURM_JOB_NODELIST']:", os.environ['SLURM_JOB_NODELIST'])
print(json.dumps(dict(os.environ), indent=2))
print('args:')
print(json.dumps(vars(args), indent=2))
else:
print('Not using distributed mode')
args.distributed = False
args.world_size = 1
args.rank = 0
args.local_rank = 0
return
print("world_size:{} rank:{} local_rank:{}".format(args.world_size, args.rank, args.local_rank))
args.distributed = True
torch.cuda.set_device(args.local_rank)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
print("Before torch.distributed.barrier()")
torch.distributed.barrier()
print("End torch.distributed.barrier()")
setup_for_distributed(args.rank == 0)
@torch.no_grad()
def accuracy(output, target, topk=(1, )):
"""Computes the precision@k for the specified values of k."""
if target.numel() == 0:
return [torch.zeros([], device=output.device)]
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def interpolate(input,
size=None,
scale_factor=None,
mode='nearest',
align_corners=None):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""Equivalent to nn.functional.interpolate, but with support for empty
batch sizes.
This will eventually be supported natively by PyTorch, and this class can
go away.
"""
if __torchvision_need_compat_flag < 0.7:
if input.numel() > 0:
return torch.nn.functional.interpolate(input, size, scale_factor,
mode, align_corners)
output_shape = _output_size(2, input, size, scale_factor)
output_shape = list(input.shape[:-2]) + list(output_shape)
return _new_empty_tensor(input, output_shape)
else:
return torchvision.ops.misc.interpolate(input, size, scale_factor,
mode, align_corners)
class color_sys():
def __init__(self, num_colors) -> None:
self.num_colors = num_colors
colors = []
for i in np.arange(0., 360., 360. / num_colors):
hue = i / 360.
lightness = (50 + np.random.rand() * 10) / 100.
saturation = (90 + np.random.rand() * 10) / 100.
colors.append(
tuple([
int(j * 255)
for j in colorsys.hls_to_rgb(hue, lightness, saturation)
]))
self.colors = colors
def __call__(self, idx):
return self.colors[idx]
def inverse_sigmoid(x, eps=1e-3):
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
def clean_state_dict(state_dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k[:7] == 'module.':
k = k[7:] # remove `module.`
new_state_dict[k] = v
return new_state_dict