# ------------------------------------------------------------------------ # HOTR official code : hotr/data/datasets/hico.py # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved # ------------------------------------------------------------------------ # Modified from QPIC (https://github.com/hitachi-rd-cv/qpic) # Copyright (c) Hitachi, Ltd. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ from pathlib import Path from PIL import Image import json from collections import defaultdict import numpy as np import torch import torch.utils.data import torchvision from hotr.data.datasets import builtin_meta import hotr.data.transforms.transforms as T class HICODetection(torch.utils.data.Dataset): def __init__(self, img_set, img_folder, anno_file, action_list_file, transforms, num_queries): self.img_set = img_set self.img_folder = img_folder with open(anno_file, 'r') as f: self.annotations = json.load(f) with open(action_list_file, 'r') as f: self.action_lines = f.readlines() self._transforms = transforms self.num_queries = num_queries self.get_metadata() if img_set == 'train': self.ids = [] for idx, img_anno in enumerate(self.annotations): for hoi in img_anno['hoi_annotation']: if hoi['subject_id'] >= len(img_anno['annotations']) or hoi['object_id'] >= len(img_anno['annotations']): break else: self.ids.append(idx) else: self.ids = list(range(len(self.annotations))) ############################################################################ # Number Method ############################################################################ def get_metadata(self): meta = builtin_meta._get_coco_instances_meta() self.COCO_CLASSES = meta['coco_classes'] self._valid_obj_ids = [id for id in meta['thing_dataset_id_to_contiguous_id'].keys()] self._valid_verb_ids, self._valid_verb_names = [], [] for action_line in self.action_lines[2:]: act_id, act_name = action_line.split() self._valid_verb_ids.append(int(act_id)) self._valid_verb_names.append(act_name) def get_valid_obj_ids(self): return self._valid_obj_ids def get_actions(self): return self._valid_verb_names def num_category(self): return len(self.COCO_CLASSES) def num_action(self): return len(self._valid_verb_ids) ############################################################################ def __len__(self): return len(self.ids) def __getitem__(self, idx): img_anno = self.annotations[self.ids[idx]] img = Image.open(self.img_folder / img_anno['file_name']).convert('RGB') w, h = img.size # cut out the GTs that exceed the number of object queries if self.img_set == 'train' and len(img_anno['annotations']) > self.num_queries: img_anno['annotations'] = img_anno['annotations'][:self.num_queries] boxes = [obj['bbox'] for obj in img_anno['annotations']] # guard against no boxes via resizing boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) if self.img_set == 'train': # Add index for confirming which boxes are kept after image transformation classes = [(i, self._valid_obj_ids.index(obj['category_id'])) for i, obj in enumerate(img_anno['annotations'])] else: classes = [self._valid_obj_ids.index(obj['category_id']) for obj in img_anno['annotations']] classes = torch.tensor(classes, dtype=torch.int64) target = {} target['orig_size'] = torch.as_tensor([int(h), int(w)]) target['size'] = torch.as_tensor([int(h), int(w)]) if self.img_set == 'train': boxes[:, 0::2].clamp_(min=0, max=w) boxes[:, 1::2].clamp_(min=0, max=h) keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) boxes = boxes[keep] classes = classes[keep] target['boxes'] = boxes target['labels'] = classes target['iscrowd'] = torch.tensor([0 for _ in range(boxes.shape[0])]) target['area'] = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) if self._transforms is not None: img, target = self._transforms(img, target) kept_box_indices = [label[0] for label in target['labels']] target['labels'] = target['labels'][:, 1] obj_labels, verb_labels, sub_boxes, obj_boxes = [], [], [], [] sub_obj_pairs = [] for hoi in img_anno['hoi_annotation']: if hoi['subject_id'] not in kept_box_indices or hoi['object_id'] not in kept_box_indices: continue sub_obj_pair = (hoi['subject_id'], hoi['object_id']) if sub_obj_pair in sub_obj_pairs: verb_labels[sub_obj_pairs.index(sub_obj_pair)][self._valid_verb_ids.index(hoi['category_id'])] = 1 else: sub_obj_pairs.append(sub_obj_pair) obj_labels.append(target['labels'][kept_box_indices.index(hoi['object_id'])]) verb_label = [0 for _ in range(len(self._valid_verb_ids))] verb_label[self._valid_verb_ids.index(hoi['category_id'])] = 1 sub_box = target['boxes'][kept_box_indices.index(hoi['subject_id'])] obj_box = target['boxes'][kept_box_indices.index(hoi['object_id'])] verb_labels.append(verb_label) sub_boxes.append(sub_box) obj_boxes.append(obj_box) if len(sub_obj_pairs) == 0: target['pair_targets'] = torch.zeros((0,), dtype=torch.int64) target['pair_actions'] = torch.zeros((0, len(self._valid_verb_ids)), dtype=torch.float32) target['sub_boxes'] = torch.zeros((0, 4), dtype=torch.float32) target['obj_boxes'] = torch.zeros((0, 4), dtype=torch.float32) else: target['pair_targets'] = torch.stack(obj_labels) target['pair_actions'] = torch.as_tensor(verb_labels, dtype=torch.float32) target['sub_boxes'] = torch.stack(sub_boxes) target['obj_boxes'] = torch.stack(obj_boxes) else: target['boxes'] = boxes target['labels'] = classes target['id'] = idx if self._transforms is not None: img, _ = self._transforms(img, None) hois = [] for hoi in img_anno['hoi_annotation']: hois.append((hoi['subject_id'], hoi['object_id'], self._valid_verb_ids.index(hoi['category_id']))) target['hois'] = torch.as_tensor(hois, dtype=torch.int64) return img, target def set_rare_hois(self, anno_file): with open(anno_file, 'r') as f: annotations = json.load(f) counts = defaultdict(lambda: 0) for img_anno in annotations: hois = img_anno['hoi_annotation'] bboxes = img_anno['annotations'] for hoi in hois: triplet = (self._valid_obj_ids.index(bboxes[hoi['subject_id']]['category_id']), self._valid_obj_ids.index(bboxes[hoi['object_id']]['category_id']), self._valid_verb_ids.index(hoi['category_id'])) counts[triplet] += 1 self.rare_triplets = [] self.non_rare_triplets = [] for triplet, count in counts.items(): if count < 10: self.rare_triplets.append(triplet) else: self.non_rare_triplets.append(triplet) def load_correct_mat(self, path): self.correct_mat = np.load(path) # Add color jitter to coco transforms def make_hico_transforms(image_set): normalize = T.Compose([ T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] if image_set == 'train': return T.Compose([ T.RandomHorizontalFlip(), T.ColorJitter(.4, .4, .4), T.RandomSelect( T.RandomResize(scales, max_size=1333), T.Compose([ T.RandomResize([400, 500, 600]), T.RandomSizeCrop(384, 600), T.RandomResize(scales, max_size=1333), ]) ), normalize, ]) if image_set == 'val': return T.Compose([ T.RandomResize([800], max_size=1333), normalize, ]) if image_set == 'test': return T.Compose([ T.RandomResize([800], max_size=1333), normalize, ]) raise ValueError(f'unknown {image_set}') def build(image_set, args): root = Path(args.data_path) assert root.exists(), f'provided HOI path {root} does not exist' PATHS = { 'train': (root / 'images' / 'train2015', root / 'annotations' / 'trainval_hico.json'), 'val': (root / 'images' / 'test2015', root / 'annotations' / 'test_hico.json'), 'test': (root / 'images' / 'test2015', root / 'annotations' / 'test_hico.json') } CORRECT_MAT_PATH = root / 'annotations' / 'corre_hico.npy' action_list_file = root / 'list_action.txt' img_folder, anno_file = PATHS[image_set] dataset = HICODetection(image_set, img_folder, anno_file, action_list_file, transforms=make_hico_transforms(image_set), num_queries=args.num_queries) if image_set == 'val' or image_set == 'test': dataset.set_rare_hois(PATHS['train'][1]) dataset.load_correct_mat(CORRECT_MAT_PATH) return dataset