Spaces:
Runtime error
Runtime error
# ------------------------------------------------------------------------ | |
# HOTR official code : main.py | |
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved | |
# ------------------------------------------------------------------------ | |
# Modified from DETR (https://github.com/facebookresearch/detr) | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
# ------------------------------------------------------------------------ | |
import argparse | |
import datetime | |
import json | |
import random | |
import time | |
import multiprocessing | |
from pathlib import Path | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader, DistributedSampler | |
import hotr.data.datasets as datasets | |
import hotr.util.misc as utils | |
from hotr.engine.arg_parser import get_args_parser | |
from hotr.data.datasets import build_dataset, get_coco_api_from_dataset | |
from hotr.engine.trainer import train_one_epoch | |
from hotr.engine import hoi_evaluator, hoi_accumulator | |
from hotr.models import build_model | |
import wandb | |
from hotr.util.logger import print_params, print_args | |
def save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename): | |
# save_ckpt: function for saving checkpoints | |
output_dir = Path(args.output_dir) | |
if args.output_dir: | |
checkpoint_path = output_dir / f'{filename}.pth' | |
utils.save_on_master({ | |
'model': model_without_ddp.state_dict(), | |
'optimizer': optimizer.state_dict(), | |
'lr_scheduler': lr_scheduler.state_dict(), | |
'epoch': epoch, | |
'args': args, | |
}, checkpoint_path) | |
def main(args): | |
utils.init_distributed_mode(args) | |
if args.frozen_weights is not None: | |
print("Freeze weights for detector") | |
if not torch.cuda.is_available(): | |
args.device = 'cpu' | |
device = torch.device(args.device) | |
# fix the seed for reproducibility | |
seed = args.seed + utils.get_rank() | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
# Data Setup | |
dataset_train = build_dataset(image_set='train', args=args) | |
dataset_val = build_dataset(image_set='val' if not args.eval else 'test', args=args) | |
assert dataset_train.num_action() == dataset_val.num_action(), "Number of actions should be the same between splits" | |
args.num_classes = dataset_train.num_category() | |
args.num_actions = dataset_train.num_action() | |
args.action_names = dataset_train.get_actions() | |
if args.share_enc: args.hoi_enc_layers = args.enc_layers | |
if args.pretrained_dec: args.hoi_dec_layers = args.dec_layers | |
if args.dataset_file == 'vcoco': | |
# Save V-COCO dataset statistics | |
args.valid_ids = np.array(dataset_train.get_object_label_idx()).nonzero()[0] | |
args.invalid_ids = np.argwhere(np.array(dataset_train.get_object_label_idx()) == 0).squeeze(1) | |
args.human_actions = dataset_train.get_human_action() | |
args.object_actions = dataset_train.get_object_action() | |
args.num_human_act = dataset_train.num_human_act() | |
elif args.dataset_file == 'hico-det': | |
args.valid_obj_ids = dataset_train.get_valid_obj_ids() | |
print_args(args) | |
if args.distributed: | |
sampler_train = DistributedSampler(dataset_train, shuffle=True) | |
sampler_val = DistributedSampler(dataset_val, shuffle=False) | |
else: | |
sampler_train = torch.utils.data.RandomSampler(dataset_train) | |
sampler_val = torch.utils.data.SequentialSampler(dataset_val) | |
batch_sampler_train = torch.utils.data.BatchSampler( | |
sampler_train, args.batch_size, drop_last=True) | |
data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, | |
collate_fn=utils.collate_fn, num_workers=args.num_workers) | |
data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, | |
drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) | |
# Model Setup | |
model, criterion, postprocessors = build_model(args) | |
# import pdb;pdb.set_trace() | |
model.to(device) | |
model_without_ddp = model | |
if args.distributed: | |
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) | |
model_without_ddp = model.module | |
n_parameters = print_params(model) | |
param_dicts = [ | |
{"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]}, | |
{ | |
"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad], | |
"lr": args.lr_backbone, | |
}, | |
] | |
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) | |
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) | |
# lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [1,100]) | |
# Weight Setup | |
if args.frozen_weights is not None: | |
if args.frozen_weights.startswith('https'): | |
checkpoint = torch.hub.load_state_dict_from_url( | |
args.frozen_weights, map_location='cpu', check_hash=True) | |
else: | |
checkpoint = torch.load(args.frozen_weights, map_location='cpu') | |
model_without_ddp.detr.load_state_dict(checkpoint['model']) | |
if args.resume: | |
if args.resume.startswith('https'): | |
checkpoint = torch.hub.load_state_dict_from_url( | |
args.resume, map_location='cpu', check_hash=True) | |
else: | |
checkpoint = torch.load(args.resume, map_location='cpu') | |
model_without_ddp.load_state_dict(checkpoint['model']) | |
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: | |
optimizer.load_state_dict(checkpoint['optimizer']) | |
# lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) | |
args.start_epoch = checkpoint['epoch'] + 1 | |
# import pdb;pdb.set_trace() | |
if args.eval: | |
# test only mode | |
if args.HOIDet: | |
if args.dataset_file == 'vcoco': | |
total_res = hoi_evaluator(args, model, criterion, postprocessors, data_loader_val, device) | |
sc1, sc2 = hoi_accumulator(args, total_res, True, False) | |
elif args.dataset_file == 'hico-det': | |
test_stats = hoi_evaluator(args, model, None, postprocessors, data_loader_val, device) | |
print(f'| mAP (full)\t\t: {test_stats["mAP"]:.2f}') | |
print(f'| mAP (rare)\t\t: {test_stats["mAP rare"]:.2f}') | |
print(f'| mAP (non-rare)\t: {test_stats["mAP non-rare"]:.2f}') | |
else: raise ValueError(f'dataset {args.dataset_file} is not supported.') | |
return | |
else: | |
test_stats, coco_evaluator = evaluate_coco(model, criterion, postprocessors, | |
data_loader_val, base_ds, device, args.output_dir) | |
if args.output_dir: | |
utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth") | |
return | |
# stats | |
scenario1, scenario2 = 0, 0 | |
best_mAP, best_rare, best_non_rare = 0, 0, 0 | |
# add argparse | |
if args.wandb and utils.get_rank() == 0: | |
wandb.init( | |
project=args.project_name, | |
group=args.group_name, | |
name=args.run_name, | |
config=args | |
) | |
wandb.watch(model) | |
# Training starts here! | |
# lr_scheduler.step() | |
start_time = time.time() | |
for epoch in range(args.start_epoch, args.epochs): | |
if args.distributed: | |
sampler_train.set_epoch(epoch) | |
train_stats = train_one_epoch( | |
model, criterion, data_loader_train, optimizer, device, epoch, args.epochs, args.ramp_up_epoch,args.ramp_down_epoch,args.hoi_consistency_loss_coef, | |
args.clip_max_norm, dataset_file=args.dataset_file, log=args.wandb) | |
lr_scheduler.step() | |
# Validation | |
if args.validate: | |
print('-'*100) | |
if args.dataset_file == 'vcoco': | |
total_res = hoi_evaluator(args, model, criterion, postprocessors, data_loader_val, device) | |
if utils.get_rank() == 0: | |
sc1, sc2 = hoi_accumulator(args, total_res, False, args.wandb) | |
if sc1 > scenario1: | |
scenario1 = sc1 | |
scenario2 = sc2 | |
save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='best') | |
print(f'| Scenario #1 mAP : {sc1:.2f} ({scenario1:.2f})') | |
print(f'| Scenario #2 mAP : {sc2:.2f} ({scenario2:.2f})') | |
elif args.dataset_file == 'hico-det': | |
test_stats = hoi_evaluator(args, model, None, postprocessors, data_loader_val, device) | |
if utils.get_rank() == 0: | |
if test_stats['mAP'] > best_mAP: | |
best_mAP = test_stats['mAP'] | |
best_rare = test_stats['mAP rare'] | |
best_non_rare = test_stats['mAP non-rare'] | |
save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='best') | |
print(f'| mAP (full)\t\t: {test_stats["mAP"]:.2f} ({best_mAP:.2f})') | |
print(f'| mAP (rare)\t\t: {test_stats["mAP rare"]:.2f} ({best_rare:.2f})') | |
print(f'| mAP (non-rare)\t: {test_stats["mAP non-rare"]:.2f} ({best_non_rare:.2f})') | |
if args.wandb and utils.get_rank() == 0: | |
wandb.log({ | |
'mAP': test_stats['mAP'], | |
'mAP rare': test_stats['mAP rare'], | |
'mAP non-rare': test_stats['mAP non-rare'] | |
}) | |
print('-'*100) | |
save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='checkpoint') | |
if (epoch + 1) % args.lr_drop == 0 : | |
save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='checkpoint_'+str(epoch)) | |
# if (epoch + 1) % args.pseudo_epoch == 0 : | |
# save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='checkpoint_pseudo_'+str(epoch)) | |
total_time = time.time() - start_time | |
total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
print('Training time {}'.format(total_time_str)) | |
if args.dataset_file == 'vcoco': | |
print(f'| Scenario #1 mAP : {scenario1:.2f}') | |
print(f'| Scenario #2 mAP : {scenario2:.2f}') | |
elif args.dataset_file == 'hico-det': | |
print(f'| mAP (full)\t\t: {best_mAP:.2f}') | |
print(f'| mAP (rare)\t\t: {best_rare:.2f}') | |
print(f'| mAP (non-rare)\t: {best_non_rare:.2f}') | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser( | |
'End-to-End Human Object Interaction training and evaluation script', | |
parents=[get_args_parser()] | |
) | |
args = parser.parse_args() | |
if args.output_dir: | |
args.output_dir += f"/{args.group_name}/{args.run_name}/" | |
Path(args.output_dir).mkdir(parents=True, exist_ok=True) | |
main(args) | |