File size: 5,559 Bytes
5e0b9df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# ------------------------------------------------------------------------
# HOTR official code : hotr/util/logger.py
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
import torch
import time
import datetime
import sys
from time import sleep
from collections import defaultdict

from hotr.util.misc import SmoothedValue

def print_params(model):
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('\n[Logger] Number of params: ', n_parameters)
    return n_parameters

def print_args(args):
    print('\n[Logger] DETR Arguments:')
    for k, v in vars(args).items():
        if k in [
            'lr', 'lr_backbone', 'lr_drop',
            'frozen_weights',
            'backbone', 'dilation',
            'position_embedding', 'enc_layers', 'dec_layers', 'num_queries',
            'dataset_file']:
            print(f'\t{k}: {v}')

    if args.HOIDet:
        print('\n[Logger] DETR_HOI Arguments:')
        for k, v in vars(args).items():
            if k in [
                'freeze_enc',
                'query_flag',
                'hoi_nheads',
                'hoi_dim_feedforward',
                'hoi_dec_layers',
                'hoi_idx_loss_coef',
                'hoi_act_loss_coef',
                'hoi_eos_coef',
                'object_threshold']:
                print(f'\t{k}: {v}')

class MetricLogger(object):
    def __init__(self, mode="test", delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter
        self.mode = mode

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, attr))

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(
                "{}: {}".format(name, str(meter))
            )
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if not header:
            header = ''
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt='{avg:.4f}')
        data_time = SmoothedValue(fmt='{avg:.4f}')
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
        if torch.cuda.is_available():
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}',
                'max mem: {memory:.0f}'
            ])
        else:
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}'
            ])
        MB = 1024.0 * 1024.0
        for obj in iterable:
            data_time.update(time.time() - end)
            yield obj
            iter_time.update(time.time() - end)

            if (i % print_freq == 0 and i !=0) or i == len(iterable) - 1:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    print(log_msg.format(
                        i+1, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time),
                        memory=torch.cuda.max_memory_allocated() / MB),
                        flush=(self.mode=='test'), end=("\r" if self.mode=='test' else "\n"))
                else:
                    print(log_msg.format(
                        i+1, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time)),
                        flush=(self.mode=='test'), end=("\r" if self.mode=='test' else "\n"))
            else:
                log_interval = self.delimiter.join([header, '[{0' + space_fmt + '}/{1}]'])
                if torch.cuda.is_available(): print(log_interval.format(i+1, len(iterable)), flush=True, end="\r")
                else: print(log_interval.format(i+1, len(iterable)), flush=True, end="\r")

            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        if self.mode=='test': print("")
        print('[stats] Total Time ({}) : {} ({:.4f} s / it)'.format(
            self.mode, total_time_str, total_time / len(iterable)))