root
initial commit
5e0b9df
raw
history blame
2.02 kB
import math
import os
import sys
from typing import Iterable
import numpy as np
import copy
import itertools
import torch
import hotr.util.misc as utils
import hotr.util.logger as loggers
from hotr.data.evaluators.hico_eval import HICOEvaluator
@torch.no_grad()
def hico_evaluate(model, postprocessors, data_loader, device, thr, args=None):
model.eval()
metric_logger = loggers.MetricLogger(mode="test", delimiter=" ")
header = 'Evaluation Inference (HICO-DET)'
preds = []
gts = []
indices = []
hoi_recognition_time = []
for samples, targets in metric_logger.log_every(data_loader, 50, header):
samples = samples.to(device)
targets = [{k: (v.to(device) if k != 'id' else v) for k, v in t.items()} for t in targets]
outputs = model(samples)
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
results = postprocessors['hoi'](outputs, orig_target_sizes, threshold=thr, dataset='hico-det', args=args)
hoi_recognition_time.append(results[0]['hoi_recognition_time'] * 1000)
preds.extend(list(itertools.chain.from_iterable(utils.all_gather(results))))
# For avoiding a runtime error, the copy is used
gts.extend(list(itertools.chain.from_iterable(utils.all_gather(copy.deepcopy(targets)))))
print(f"[stats] HOI Recognition Time (avg) : {sum(hoi_recognition_time)/len(hoi_recognition_time):.4f} ms")
# gather the stats from all processes
metric_logger.synchronize_between_processes()
img_ids = [img_gts['id'] for img_gts in gts]
_, indices = np.unique(img_ids, return_index=True)
preds = [img_preds for i, img_preds in enumerate(preds) if i in indices]
gts = [img_gts for i, img_gts in enumerate(gts) if i in indices]
evaluator = HICOEvaluator(preds, gts, data_loader.dataset.rare_triplets,
data_loader.dataset.non_rare_triplets, data_loader.dataset.correct_mat)
stats = evaluator.evaluate()
return stats