Spaces:
Runtime error
Runtime error
# ------------------------------------------------------------------------ | |
# HOTR official code : hotr/data/datasets/hico.py | |
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved | |
# ------------------------------------------------------------------------ | |
# Modified from QPIC (https://github.com/hitachi-rd-cv/qpic) | |
# Copyright (c) Hitachi, Ltd. All Rights Reserved. | |
# Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |
# ------------------------------------------------------------------------ | |
from pathlib import Path | |
from PIL import Image | |
import json | |
from collections import defaultdict | |
import numpy as np | |
import torch | |
import torch.utils.data | |
import torchvision | |
from hotr.data.datasets import builtin_meta | |
import hotr.data.transforms.transforms as T | |
class HICODetection(torch.utils.data.Dataset): | |
def __init__(self, img_set, img_folder, anno_file, action_list_file, transforms, num_queries): | |
self.img_set = img_set | |
self.img_folder = img_folder | |
with open(anno_file, 'r') as f: | |
self.annotations = json.load(f) | |
with open(action_list_file, 'r') as f: | |
self.action_lines = f.readlines() | |
self._transforms = transforms | |
self.num_queries = num_queries | |
self.get_metadata() | |
if img_set == 'train': | |
self.ids = [] | |
for idx, img_anno in enumerate(self.annotations): | |
for hoi in img_anno['hoi_annotation']: | |
if hoi['subject_id'] >= len(img_anno['annotations']) or hoi['object_id'] >= len(img_anno['annotations']): | |
break | |
else: | |
self.ids.append(idx) | |
else: | |
self.ids = list(range(len(self.annotations))) | |
############################################################################ | |
# Number Method | |
############################################################################ | |
def get_metadata(self): | |
meta = builtin_meta._get_coco_instances_meta() | |
self.COCO_CLASSES = meta['coco_classes'] | |
self._valid_obj_ids = [id for id in meta['thing_dataset_id_to_contiguous_id'].keys()] | |
self._valid_verb_ids, self._valid_verb_names = [], [] | |
for action_line in self.action_lines[2:]: | |
act_id, act_name = action_line.split() | |
self._valid_verb_ids.append(int(act_id)) | |
self._valid_verb_names.append(act_name) | |
def get_valid_obj_ids(self): | |
return self._valid_obj_ids | |
def get_actions(self): | |
return self._valid_verb_names | |
def num_category(self): | |
return len(self.COCO_CLASSES) | |
def num_action(self): | |
return len(self._valid_verb_ids) | |
############################################################################ | |
def __len__(self): | |
return len(self.ids) | |
def __getitem__(self, idx): | |
img_anno = self.annotations[self.ids[idx]] | |
img = Image.open(self.img_folder / img_anno['file_name']).convert('RGB') | |
w, h = img.size | |
# cut out the GTs that exceed the number of object queries | |
if self.img_set == 'train' and len(img_anno['annotations']) > self.num_queries: | |
img_anno['annotations'] = img_anno['annotations'][:self.num_queries] | |
boxes = [obj['bbox'] for obj in img_anno['annotations']] | |
# guard against no boxes via resizing | |
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) | |
if self.img_set == 'train': | |
# Add index for confirming which boxes are kept after image transformation | |
classes = [(i, self._valid_obj_ids.index(obj['category_id'])) for i, obj in enumerate(img_anno['annotations'])] | |
else: | |
classes = [self._valid_obj_ids.index(obj['category_id']) for obj in img_anno['annotations']] | |
classes = torch.tensor(classes, dtype=torch.int64) | |
target = {} | |
target['orig_size'] = torch.as_tensor([int(h), int(w)]) | |
target['size'] = torch.as_tensor([int(h), int(w)]) | |
if self.img_set == 'train': | |
boxes[:, 0::2].clamp_(min=0, max=w) | |
boxes[:, 1::2].clamp_(min=0, max=h) | |
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) | |
boxes = boxes[keep] | |
classes = classes[keep] | |
target['boxes'] = boxes | |
target['labels'] = classes | |
target['iscrowd'] = torch.tensor([0 for _ in range(boxes.shape[0])]) | |
target['area'] = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) | |
if self._transforms is not None: | |
img, target = self._transforms(img, target) | |
kept_box_indices = [label[0] for label in target['labels']] | |
target['labels'] = target['labels'][:, 1] | |
obj_labels, verb_labels, sub_boxes, obj_boxes = [], [], [], [] | |
sub_obj_pairs = [] | |
for hoi in img_anno['hoi_annotation']: | |
if hoi['subject_id'] not in kept_box_indices or hoi['object_id'] not in kept_box_indices: | |
continue | |
sub_obj_pair = (hoi['subject_id'], hoi['object_id']) | |
if sub_obj_pair in sub_obj_pairs: | |
verb_labels[sub_obj_pairs.index(sub_obj_pair)][self._valid_verb_ids.index(hoi['category_id'])] = 1 | |
else: | |
sub_obj_pairs.append(sub_obj_pair) | |
obj_labels.append(target['labels'][kept_box_indices.index(hoi['object_id'])]) | |
verb_label = [0 for _ in range(len(self._valid_verb_ids))] | |
verb_label[self._valid_verb_ids.index(hoi['category_id'])] = 1 | |
sub_box = target['boxes'][kept_box_indices.index(hoi['subject_id'])] | |
obj_box = target['boxes'][kept_box_indices.index(hoi['object_id'])] | |
verb_labels.append(verb_label) | |
sub_boxes.append(sub_box) | |
obj_boxes.append(obj_box) | |
if len(sub_obj_pairs) == 0: | |
target['pair_targets'] = torch.zeros((0,), dtype=torch.int64) | |
target['pair_actions'] = torch.zeros((0, len(self._valid_verb_ids)), dtype=torch.float32) | |
target['sub_boxes'] = torch.zeros((0, 4), dtype=torch.float32) | |
target['obj_boxes'] = torch.zeros((0, 4), dtype=torch.float32) | |
else: | |
target['pair_targets'] = torch.stack(obj_labels) | |
target['pair_actions'] = torch.as_tensor(verb_labels, dtype=torch.float32) | |
target['sub_boxes'] = torch.stack(sub_boxes) | |
target['obj_boxes'] = torch.stack(obj_boxes) | |
else: | |
target['boxes'] = boxes | |
target['labels'] = classes | |
target['id'] = idx | |
if self._transforms is not None: | |
img, _ = self._transforms(img, None) | |
hois = [] | |
for hoi in img_anno['hoi_annotation']: | |
hois.append((hoi['subject_id'], hoi['object_id'], self._valid_verb_ids.index(hoi['category_id']))) | |
target['hois'] = torch.as_tensor(hois, dtype=torch.int64) | |
return img, target | |
def set_rare_hois(self, anno_file): | |
with open(anno_file, 'r') as f: | |
annotations = json.load(f) | |
counts = defaultdict(lambda: 0) | |
for img_anno in annotations: | |
hois = img_anno['hoi_annotation'] | |
bboxes = img_anno['annotations'] | |
for hoi in hois: | |
triplet = (self._valid_obj_ids.index(bboxes[hoi['subject_id']]['category_id']), | |
self._valid_obj_ids.index(bboxes[hoi['object_id']]['category_id']), | |
self._valid_verb_ids.index(hoi['category_id'])) | |
counts[triplet] += 1 | |
self.rare_triplets = [] | |
self.non_rare_triplets = [] | |
for triplet, count in counts.items(): | |
if count < 10: | |
self.rare_triplets.append(triplet) | |
else: | |
self.non_rare_triplets.append(triplet) | |
def load_correct_mat(self, path): | |
self.correct_mat = np.load(path) | |
# Add color jitter to coco transforms | |
def make_hico_transforms(image_set): | |
normalize = T.Compose([ | |
T.ToTensor(), | |
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] | |
if image_set == 'train': | |
return T.Compose([ | |
T.RandomHorizontalFlip(), | |
T.ColorJitter(.4, .4, .4), | |
T.RandomSelect( | |
T.RandomResize(scales, max_size=1333), | |
T.Compose([ | |
T.RandomResize([400, 500, 600]), | |
T.RandomSizeCrop(384, 600), | |
T.RandomResize(scales, max_size=1333), | |
]) | |
), | |
normalize, | |
]) | |
if image_set == 'val': | |
return T.Compose([ | |
T.RandomResize([800], max_size=1333), | |
normalize, | |
]) | |
if image_set == 'test': | |
return T.Compose([ | |
T.RandomResize([800], max_size=1333), | |
normalize, | |
]) | |
raise ValueError(f'unknown {image_set}') | |
def build(image_set, args): | |
root = Path(args.data_path) | |
assert root.exists(), f'provided HOI path {root} does not exist' | |
PATHS = { | |
'train': (root / 'images' / 'train2015', root / 'annotations' / 'trainval_hico.json'), | |
'val': (root / 'images' / 'test2015', root / 'annotations' / 'test_hico.json'), | |
'test': (root / 'images' / 'test2015', root / 'annotations' / 'test_hico.json') | |
} | |
CORRECT_MAT_PATH = root / 'annotations' / 'corre_hico.npy' | |
action_list_file = root / 'list_action.txt' | |
img_folder, anno_file = PATHS[image_set] | |
dataset = HICODetection(image_set, img_folder, anno_file, action_list_file, transforms=make_hico_transforms(image_set), | |
num_queries=args.num_queries) | |
if image_set == 'val' or image_set == 'test': | |
dataset.set_rare_hois(PATHS['train'][1]) | |
dataset.load_correct_mat(CORRECT_MAT_PATH) | |
return dataset |