Spaces:
Runtime error
Runtime error
import math | |
import os | |
import sys | |
from typing import Iterable | |
import numpy as np | |
import copy | |
import itertools | |
import torch | |
import hotr.util.misc as utils | |
import hotr.util.logger as loggers | |
from hotr.data.evaluators.hico_eval import HICOEvaluator | |
def hico_evaluate(model, postprocessors, data_loader, device, thr, args=None): | |
model.eval() | |
metric_logger = loggers.MetricLogger(mode="test", delimiter=" ") | |
header = 'Evaluation Inference (HICO-DET)' | |
preds = [] | |
gts = [] | |
indices = [] | |
hoi_recognition_time = [] | |
for samples, targets in metric_logger.log_every(data_loader, 50, header): | |
samples = samples.to(device) | |
targets = [{k: (v.to(device) if k != 'id' else v) for k, v in t.items()} for t in targets] | |
outputs = model(samples) | |
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) | |
results = postprocessors['hoi'](outputs, orig_target_sizes, threshold=thr, dataset='hico-det', args=args) | |
hoi_recognition_time.append(results[0]['hoi_recognition_time'] * 1000) | |
preds.extend(list(itertools.chain.from_iterable(utils.all_gather(results)))) | |
# For avoiding a runtime error, the copy is used | |
gts.extend(list(itertools.chain.from_iterable(utils.all_gather(copy.deepcopy(targets))))) | |
print(f"[stats] HOI Recognition Time (avg) : {sum(hoi_recognition_time)/len(hoi_recognition_time):.4f} ms") | |
# gather the stats from all processes | |
metric_logger.synchronize_between_processes() | |
img_ids = [img_gts['id'] for img_gts in gts] | |
_, indices = np.unique(img_ids, return_index=True) | |
preds = [img_preds for i, img_preds in enumerate(preds) if i in indices] | |
gts = [img_gts for i, img_gts in enumerate(gts) if i in indices] | |
evaluator = HICOEvaluator(preds, gts, data_loader.dataset.rare_triplets, | |
data_loader.dataset.non_rare_triplets, data_loader.dataset.correct_mat) | |
stats = evaluator.evaluate() | |
return stats |