crowdcontrol / code /test.py
promptsai's picture
Upload 30 files
023485e verified
raw
history blame
2.99 kB
import argparse
import torch
import os
import numpy as np
import datasets.crowd as crowd
from models import vgg19
parser = argparse.ArgumentParser(description='Test ')
parser.add_argument('--device', default='0', help='assign device')
parser.add_argument('--crop-size', type=int, default=512,
help='the crop size of the train image')
parser.add_argument('--model-path', type=str, default='pretrained_models/model_qnrf.pth',
help='saved model path')
parser.add_argument('--data-path', type=str,
default='data/QNRF-Train-Val-Test',
help='saved model path')
parser.add_argument('--dataset', type=str, default='qnrf',
help='dataset name: qnrf, nwpu, sha, shb')
parser.add_argument('--pred-density-map-path', type=str, default='',
help='save predicted density maps when pred-density-map-path is not empty.')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.device # set vis gpu
device = torch.device('cuda')
model_path = args.model_path
crop_size = args.crop_size
data_path = args.data_path
if args.dataset.lower() == 'qnrf':
dataset = crowd.Crowd_qnrf(os.path.join(data_path, 'test'), crop_size, 8, method='val')
elif args.dataset.lower() == 'nwpu':
dataset = crowd.Crowd_nwpu(os.path.join(data_path, 'val'), crop_size, 8, method='val')
elif args.dataset.lower() == 'sha' or args.dataset.lower() == 'shb':
dataset = crowd.Crowd_sh(os.path.join(data_path, 'test_data'), crop_size, 8, method='val')
else:
raise NotImplementedError
dataloader = torch.utils.data.DataLoader(dataset, 1, shuffle=False,
num_workers=1, pin_memory=True)
if args.pred_density_map_path:
import cv2
if not os.path.exists(args.pred_density_map_path):
os.makedirs(args.pred_density_map_path)
model = vgg19()
model.to(device)
model.load_state_dict(torch.load(model_path, device))
model.eval()
image_errs = []
for inputs, count, name in dataloader:
inputs = inputs.to(device)
assert inputs.size(0) == 1, 'the batch size should equal to 1'
with torch.set_grad_enabled(False):
outputs, _ = model(inputs)
img_err = count[0].item() - torch.sum(outputs).item()
print(name, img_err, count[0].item(), torch.sum(outputs).item())
image_errs.append(img_err)
if args.pred_density_map_path:
vis_img = outputs[0, 0].cpu().numpy()
# normalize density map values from 0 to 1, then map it to 0-255.
vis_img = (vis_img - vis_img.min()) / (vis_img.max() - vis_img.min() + 1e-5)
vis_img = (vis_img * 255).astype(np.uint8)
vis_img = cv2.applyColorMap(vis_img, cv2.COLORMAP_JET)
cv2.imwrite(os.path.join(args.pred_density_map_path, str(name[0]) + '.png'), vis_img)
image_errs = np.array(image_errs)
mse = np.sqrt(np.mean(np.square(image_errs)))
mae = np.mean(np.abs(image_errs))
print('{}: mae {}, mse {}\n'.format(model_path, mae, mse))