crowdcontrol / code /train_helper.py
promptsai's picture
Upload 30 files
023485e verified
import os
import time
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
import numpy as np
from datetime import datetime
from datasets.crowd import Crowd_qnrf, Crowd_nwpu, Crowd_sh
from models import vgg19
from losses.ot_loss import OT_Loss
from utils.pytorch_utils import Save_Handle, AverageMeter
import utils.log_utils as log_utils
def train_collate(batch):
transposed_batch = list(zip(*batch))
images = torch.stack(transposed_batch[0], 0)
points = transposed_batch[1] # the number of points is not fixed, keep it as a list of tensor
gt_discretes = torch.stack(transposed_batch[2], 0)
return images, points, gt_discretes
class Trainer(object):
def __init__(self, args):
self.args = args
def setup(self):
args = self.args
sub_dir = 'input-{}_wot-{}_wtv-{}_reg-{}_nIter-{}_normCood-{}'.format(
args.crop_size, args.wot, args.wtv, args.reg, args.num_of_iter_in_ot, args.norm_cood)
self.save_dir = os.path.join('ckpts', sub_dir)
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
time_str = datetime.strftime(datetime.now(), '%m%d-%H%M%S')
self.logger = log_utils.get_logger(os.path.join(self.save_dir, 'train-{:s}.log'.format(time_str)))
log_utils.print_config(vars(args), self.logger)
if torch.cuda.is_available():
self.device = torch.device("cuda")
self.device_count = torch.cuda.device_count()
assert self.device_count == 1
self.logger.info('using {} gpus'.format(self.device_count))
else:
raise Exception("gpu is not available")
downsample_ratio = 8
if args.dataset.lower() == 'qnrf':
self.datasets = {x: Crowd_qnrf(os.path.join(args.data_dir, x),
args.crop_size, downsample_ratio, x) for x in ['train', 'val']}
elif args.dataset.lower() == 'nwpu':
self.datasets = {x: Crowd_nwpu(os.path.join(args.data_dir, x),
args.crop_size, downsample_ratio, x) for x in ['train', 'val']}
elif args.dataset.lower() == 'sha' or args.dataset.lower() == 'shb':
self.datasets = {'train': Crowd_sh(os.path.join(args.data_dir, 'train_data'),
args.crop_size, downsample_ratio, 'train'),
'val': Crowd_sh(os.path.join(args.data_dir, 'test_data'),
args.crop_size, downsample_ratio, 'val'),
}
else:
raise NotImplementedError
self.dataloaders = {x: DataLoader(self.datasets[x],
collate_fn=(train_collate
if x == 'train' else default_collate),
batch_size=(args.batch_size
if x == 'train' else 1),
shuffle=(True if x == 'train' else False),
num_workers=args.num_workers * self.device_count,
pin_memory=(True if x == 'train' else False))
for x in ['train', 'val']}
self.model = vgg19()
self.model.to(self.device)
self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
self.start_epoch = 0
if args.resume:
self.logger.info('loading pretrained model from ' + args.resume)
suf = args.resume.rsplit('.', 1)[-1]
if suf == 'tar':
checkpoint = torch.load(args.resume, self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.start_epoch = checkpoint['epoch'] + 1
elif suf == 'pth':
self.model.load_state_dict(torch.load(args.resume, self.device))
else:
self.logger.info('random initialization')
self.ot_loss = OT_Loss(args.crop_size, downsample_ratio, args.norm_cood, self.device, args.num_of_iter_in_ot,
args.reg)
self.tv_loss = nn.L1Loss(reduction='none').to(self.device)
self.mse = nn.MSELoss().to(self.device)
self.mae = nn.L1Loss().to(self.device)
self.save_list = Save_Handle(max_num=1)
self.best_mae = np.inf
self.best_mse = np.inf
self.best_count = 0
def train(self):
"""training process"""
args = self.args
for epoch in range(self.start_epoch, args.max_epoch + 1):
self.logger.info('-' * 5 + 'Epoch {}/{}'.format(epoch, args.max_epoch) + '-' * 5)
self.epoch = epoch
self.train_eopch()
if epoch % args.val_epoch == 0 and epoch >= args.val_start:
self.val_epoch()
def train_eopch(self):
epoch_ot_loss = AverageMeter()
epoch_ot_obj_value = AverageMeter()
epoch_wd = AverageMeter()
epoch_count_loss = AverageMeter()
epoch_tv_loss = AverageMeter()
epoch_loss = AverageMeter()
epoch_mae = AverageMeter()
epoch_mse = AverageMeter()
epoch_start = time.time()
self.model.train() # Set model to training mode
for step, (inputs, points, gt_discrete) in enumerate(self.dataloaders['train']):
inputs = inputs.to(self.device)
gd_count = np.array([len(p) for p in points], dtype=np.float32)
points = [p.to(self.device) for p in points]
gt_discrete = gt_discrete.to(self.device)
N = inputs.size(0)
with torch.set_grad_enabled(True):
outputs, outputs_normed = self.model(inputs)
# Compute OT loss.
ot_loss, wd, ot_obj_value = self.ot_loss(outputs_normed, outputs, points)
ot_loss = ot_loss * self.args.wot
ot_obj_value = ot_obj_value * self.args.wot
epoch_ot_loss.update(ot_loss.item(), N)
epoch_ot_obj_value.update(ot_obj_value.item(), N)
epoch_wd.update(wd, N)
# Compute counting loss.
count_loss = self.mae(outputs.sum(1).sum(1).sum(1),
torch.from_numpy(gd_count).float().to(self.device))
epoch_count_loss.update(count_loss.item(), N)
# Compute TV loss.
gd_count_tensor = torch.from_numpy(gd_count).float().to(self.device).unsqueeze(1).unsqueeze(
2).unsqueeze(3)
gt_discrete_normed = gt_discrete / (gd_count_tensor + 1e-6)
tv_loss = (self.tv_loss(outputs_normed, gt_discrete_normed).sum(1).sum(1).sum(
1) * torch.from_numpy(gd_count).float().to(self.device)).mean(0) * self.args.wtv
epoch_tv_loss.update(tv_loss.item(), N)
loss = ot_loss + count_loss + tv_loss
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
pred_count = torch.sum(outputs.view(N, -1), dim=1).detach().cpu().numpy()
pred_err = pred_count - gd_count
epoch_loss.update(loss.item(), N)
epoch_mse.update(np.mean(pred_err * pred_err), N)
epoch_mae.update(np.mean(abs(pred_err)), N)
self.logger.info(
'Epoch {} Train, Loss: {:.2f}, OT Loss: {:.2e}, Wass Distance: {:.2f}, OT obj value: {:.2f}, '
'Count Loss: {:.2f}, TV Loss: {:.2f}, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'
.format(self.epoch, epoch_loss.get_avg(), epoch_ot_loss.get_avg(), epoch_wd.get_avg(),
epoch_ot_obj_value.get_avg(), epoch_count_loss.get_avg(), epoch_tv_loss.get_avg(),
np.sqrt(epoch_mse.get_avg()), epoch_mae.get_avg(),
time.time() - epoch_start))
model_state_dic = self.model.state_dict()
save_path = os.path.join(self.save_dir, '{}_ckpt.tar'.format(self.epoch))
torch.save({
'epoch': self.epoch,
'optimizer_state_dict': self.optimizer.state_dict(),
'model_state_dict': model_state_dic
}, save_path)
self.save_list.append(save_path)
def val_epoch(self):
args = self.args
epoch_start = time.time()
self.model.eval() # Set model to evaluate mode
epoch_res = []
for inputs, count, name in self.dataloaders['val']:
inputs = inputs.to(self.device)
assert inputs.size(0) == 1, 'the batch size should equal to 1 in validation mode'
with torch.set_grad_enabled(False):
outputs, _ = self.model(inputs)
res = count[0].item() - torch.sum(outputs).item()
epoch_res.append(res)
epoch_res = np.array(epoch_res)
mse = np.sqrt(np.mean(np.square(epoch_res)))
mae = np.mean(np.abs(epoch_res))
self.logger.info('Epoch {} Val, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'
.format(self.epoch, mse, mae, time.time() - epoch_start))
model_state_dic = self.model.state_dict()
if (2.0 * mse + mae) < (2.0 * self.best_mse + self.best_mae):
self.best_mse = mse
self.best_mae = mae
self.logger.info("save best mse {:.2f} mae {:.2f} model epoch {}".format(self.best_mse,
self.best_mae,
self.epoch))
torch.save(model_state_dic, os.path.join(self.save_dir, 'best_model_{}.pth'.format(self.best_count)))
self.best_count += 1