Spaces:
Runtime error
Runtime error
File size: 5,559 Bytes
5e0b9df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
# ------------------------------------------------------------------------
# HOTR official code : hotr/util/logger.py
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
import torch
import time
import datetime
import sys
from time import sleep
from collections import defaultdict
from hotr.util.misc import SmoothedValue
def print_params(model):
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('\n[Logger] Number of params: ', n_parameters)
return n_parameters
def print_args(args):
print('\n[Logger] DETR Arguments:')
for k, v in vars(args).items():
if k in [
'lr', 'lr_backbone', 'lr_drop',
'frozen_weights',
'backbone', 'dilation',
'position_embedding', 'enc_layers', 'dec_layers', 'num_queries',
'dataset_file']:
print(f'\t{k}: {v}')
if args.HOIDet:
print('\n[Logger] DETR_HOI Arguments:')
for k, v in vars(args).items():
if k in [
'freeze_enc',
'query_flag',
'hoi_nheads',
'hoi_dim_feedforward',
'hoi_dec_layers',
'hoi_idx_loss_coef',
'hoi_act_loss_coef',
'hoi_eos_coef',
'object_threshold']:
print(f'\t{k}: {v}')
class MetricLogger(object):
def __init__(self, mode="test", delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
self.mode = mode
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
if torch.cuda.is_available():
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}',
'max mem: {memory:.0f}'
])
else:
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
])
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if (i % print_freq == 0 and i !=0) or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i+1, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB),
flush=(self.mode=='test'), end=("\r" if self.mode=='test' else "\n"))
else:
print(log_msg.format(
i+1, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)),
flush=(self.mode=='test'), end=("\r" if self.mode=='test' else "\n"))
else:
log_interval = self.delimiter.join([header, '[{0' + space_fmt + '}/{1}]'])
if torch.cuda.is_available(): print(log_interval.format(i+1, len(iterable)), flush=True, end="\r")
else: print(log_interval.format(i+1, len(iterable)), flush=True, end="\r")
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
if self.mode=='test': print("")
print('[stats] Total Time ({}) : {} ({:.4f} s / it)'.format(
self.mode, total_time_str, total_time / len(iterable)))
|