Spaces:
Runtime error
Runtime error
File size: 2,016 Bytes
5e0b9df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import math
import os
import sys
from typing import Iterable
import numpy as np
import copy
import itertools
import torch
import hotr.util.misc as utils
import hotr.util.logger as loggers
from hotr.data.evaluators.hico_eval import HICOEvaluator
@torch.no_grad()
def hico_evaluate(model, postprocessors, data_loader, device, thr, args=None):
model.eval()
metric_logger = loggers.MetricLogger(mode="test", delimiter=" ")
header = 'Evaluation Inference (HICO-DET)'
preds = []
gts = []
indices = []
hoi_recognition_time = []
for samples, targets in metric_logger.log_every(data_loader, 50, header):
samples = samples.to(device)
targets = [{k: (v.to(device) if k != 'id' else v) for k, v in t.items()} for t in targets]
outputs = model(samples)
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
results = postprocessors['hoi'](outputs, orig_target_sizes, threshold=thr, dataset='hico-det', args=args)
hoi_recognition_time.append(results[0]['hoi_recognition_time'] * 1000)
preds.extend(list(itertools.chain.from_iterable(utils.all_gather(results))))
# For avoiding a runtime error, the copy is used
gts.extend(list(itertools.chain.from_iterable(utils.all_gather(copy.deepcopy(targets)))))
print(f"[stats] HOI Recognition Time (avg) : {sum(hoi_recognition_time)/len(hoi_recognition_time):.4f} ms")
# gather the stats from all processes
metric_logger.synchronize_between_processes()
img_ids = [img_gts['id'] for img_gts in gts]
_, indices = np.unique(img_ids, return_index=True)
preds = [img_preds for i, img_preds in enumerate(preds) if i in indices]
gts = [img_gts for i, img_gts in enumerate(gts) if i in indices]
evaluator = HICOEvaluator(preds, gts, data_loader.dataset.rare_triplets,
data_loader.dataset.non_rare_triplets, data_loader.dataset.correct_mat)
stats = evaluator.evaluate()
return stats |