root
initial commit
5e0b9df
raw
history blame
10.1 kB
# ------------------------------------------------------------------------
# HOTR official code : hotr/data/datasets/hico.py
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
# Modified from QPIC (https://github.com/hitachi-rd-cv/qpic)
# Copyright (c) Hitachi, Ltd. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
from pathlib import Path
from PIL import Image
import json
from collections import defaultdict
import numpy as np
import torch
import torch.utils.data
import torchvision
from hotr.data.datasets import builtin_meta
import hotr.data.transforms.transforms as T
class HICODetection(torch.utils.data.Dataset):
def __init__(self, img_set, img_folder, anno_file, action_list_file, transforms, num_queries):
self.img_set = img_set
self.img_folder = img_folder
with open(anno_file, 'r') as f:
self.annotations = json.load(f)
with open(action_list_file, 'r') as f:
self.action_lines = f.readlines()
self._transforms = transforms
self.num_queries = num_queries
self.get_metadata()
if img_set == 'train':
self.ids = []
for idx, img_anno in enumerate(self.annotations):
for hoi in img_anno['hoi_annotation']:
if hoi['subject_id'] >= len(img_anno['annotations']) or hoi['object_id'] >= len(img_anno['annotations']):
break
else:
self.ids.append(idx)
else:
self.ids = list(range(len(self.annotations)))
############################################################################
# Number Method
############################################################################
def get_metadata(self):
meta = builtin_meta._get_coco_instances_meta()
self.COCO_CLASSES = meta['coco_classes']
self._valid_obj_ids = [id for id in meta['thing_dataset_id_to_contiguous_id'].keys()]
self._valid_verb_ids, self._valid_verb_names = [], []
for action_line in self.action_lines[2:]:
act_id, act_name = action_line.split()
self._valid_verb_ids.append(int(act_id))
self._valid_verb_names.append(act_name)
def get_valid_obj_ids(self):
return self._valid_obj_ids
def get_actions(self):
return self._valid_verb_names
def num_category(self):
return len(self.COCO_CLASSES)
def num_action(self):
return len(self._valid_verb_ids)
############################################################################
def __len__(self):
return len(self.ids)
def __getitem__(self, idx):
img_anno = self.annotations[self.ids[idx]]
img = Image.open(self.img_folder / img_anno['file_name']).convert('RGB')
w, h = img.size
# cut out the GTs that exceed the number of object queries
if self.img_set == 'train' and len(img_anno['annotations']) > self.num_queries:
img_anno['annotations'] = img_anno['annotations'][:self.num_queries]
boxes = [obj['bbox'] for obj in img_anno['annotations']]
# guard against no boxes via resizing
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
if self.img_set == 'train':
# Add index for confirming which boxes are kept after image transformation
classes = [(i, self._valid_obj_ids.index(obj['category_id'])) for i, obj in enumerate(img_anno['annotations'])]
else:
classes = [self._valid_obj_ids.index(obj['category_id']) for obj in img_anno['annotations']]
classes = torch.tensor(classes, dtype=torch.int64)
target = {}
target['orig_size'] = torch.as_tensor([int(h), int(w)])
target['size'] = torch.as_tensor([int(h), int(w)])
if self.img_set == 'train':
boxes[:, 0::2].clamp_(min=0, max=w)
boxes[:, 1::2].clamp_(min=0, max=h)
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
boxes = boxes[keep]
classes = classes[keep]
target['boxes'] = boxes
target['labels'] = classes
target['iscrowd'] = torch.tensor([0 for _ in range(boxes.shape[0])])
target['area'] = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
if self._transforms is not None:
img, target = self._transforms(img, target)
kept_box_indices = [label[0] for label in target['labels']]
target['labels'] = target['labels'][:, 1]
obj_labels, verb_labels, sub_boxes, obj_boxes = [], [], [], []
sub_obj_pairs = []
for hoi in img_anno['hoi_annotation']:
if hoi['subject_id'] not in kept_box_indices or hoi['object_id'] not in kept_box_indices:
continue
sub_obj_pair = (hoi['subject_id'], hoi['object_id'])
if sub_obj_pair in sub_obj_pairs:
verb_labels[sub_obj_pairs.index(sub_obj_pair)][self._valid_verb_ids.index(hoi['category_id'])] = 1
else:
sub_obj_pairs.append(sub_obj_pair)
obj_labels.append(target['labels'][kept_box_indices.index(hoi['object_id'])])
verb_label = [0 for _ in range(len(self._valid_verb_ids))]
verb_label[self._valid_verb_ids.index(hoi['category_id'])] = 1
sub_box = target['boxes'][kept_box_indices.index(hoi['subject_id'])]
obj_box = target['boxes'][kept_box_indices.index(hoi['object_id'])]
verb_labels.append(verb_label)
sub_boxes.append(sub_box)
obj_boxes.append(obj_box)
if len(sub_obj_pairs) == 0:
target['pair_targets'] = torch.zeros((0,), dtype=torch.int64)
target['pair_actions'] = torch.zeros((0, len(self._valid_verb_ids)), dtype=torch.float32)
target['sub_boxes'] = torch.zeros((0, 4), dtype=torch.float32)
target['obj_boxes'] = torch.zeros((0, 4), dtype=torch.float32)
else:
target['pair_targets'] = torch.stack(obj_labels)
target['pair_actions'] = torch.as_tensor(verb_labels, dtype=torch.float32)
target['sub_boxes'] = torch.stack(sub_boxes)
target['obj_boxes'] = torch.stack(obj_boxes)
else:
target['boxes'] = boxes
target['labels'] = classes
target['id'] = idx
if self._transforms is not None:
img, _ = self._transforms(img, None)
hois = []
for hoi in img_anno['hoi_annotation']:
hois.append((hoi['subject_id'], hoi['object_id'], self._valid_verb_ids.index(hoi['category_id'])))
target['hois'] = torch.as_tensor(hois, dtype=torch.int64)
return img, target
def set_rare_hois(self, anno_file):
with open(anno_file, 'r') as f:
annotations = json.load(f)
counts = defaultdict(lambda: 0)
for img_anno in annotations:
hois = img_anno['hoi_annotation']
bboxes = img_anno['annotations']
for hoi in hois:
triplet = (self._valid_obj_ids.index(bboxes[hoi['subject_id']]['category_id']),
self._valid_obj_ids.index(bboxes[hoi['object_id']]['category_id']),
self._valid_verb_ids.index(hoi['category_id']))
counts[triplet] += 1
self.rare_triplets = []
self.non_rare_triplets = []
for triplet, count in counts.items():
if count < 10:
self.rare_triplets.append(triplet)
else:
self.non_rare_triplets.append(triplet)
def load_correct_mat(self, path):
self.correct_mat = np.load(path)
# Add color jitter to coco transforms
def make_hico_transforms(image_set):
normalize = T.Compose([
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
if image_set == 'train':
return T.Compose([
T.RandomHorizontalFlip(),
T.ColorJitter(.4, .4, .4),
T.RandomSelect(
T.RandomResize(scales, max_size=1333),
T.Compose([
T.RandomResize([400, 500, 600]),
T.RandomSizeCrop(384, 600),
T.RandomResize(scales, max_size=1333),
])
),
normalize,
])
if image_set == 'val':
return T.Compose([
T.RandomResize([800], max_size=1333),
normalize,
])
if image_set == 'test':
return T.Compose([
T.RandomResize([800], max_size=1333),
normalize,
])
raise ValueError(f'unknown {image_set}')
def build(image_set, args):
root = Path(args.data_path)
assert root.exists(), f'provided HOI path {root} does not exist'
PATHS = {
'train': (root / 'images' / 'train2015', root / 'annotations' / 'trainval_hico.json'),
'val': (root / 'images' / 'test2015', root / 'annotations' / 'test_hico.json'),
'test': (root / 'images' / 'test2015', root / 'annotations' / 'test_hico.json')
}
CORRECT_MAT_PATH = root / 'annotations' / 'corre_hico.npy'
action_list_file = root / 'list_action.txt'
img_folder, anno_file = PATHS[image_set]
dataset = HICODetection(image_set, img_folder, anno_file, action_list_file, transforms=make_hico_transforms(image_set),
num_queries=args.num_queries)
if image_set == 'val' or image_set == 'test':
dataset.set_rare_hois(PATHS['train'][1])
dataset.load_correct_mat(CORRECT_MAT_PATH)
return dataset