from PIL import Image import torch.utils.data as data import os from glob import glob import torch import torchvision.transforms.functional as F from torchvision import transforms import random import numpy as np import scipy.io as sio def random_crop(im_h, im_w, crop_h, crop_w): res_h = im_h - crop_h res_w = im_w - crop_w i = random.randint(0, res_h) j = random.randint(0, res_w) return i, j, crop_h, crop_w def gen_discrete_map(im_height, im_width, points): """ func: generate the discrete map. points: [num_gt, 2], for each row: [width, height] """ discrete_map = np.zeros([im_height, im_width], dtype=np.float32) h, w = discrete_map.shape[:2] num_gt = points.shape[0] if num_gt == 0: return discrete_map # fast create discrete map points_np = np.array(points).round().astype(int) p_h = np.minimum(points_np[:, 1], np.array([h-1]*num_gt).astype(int)) p_w = np.minimum(points_np[:, 0], np.array([w-1]*num_gt).astype(int)) p_index = torch.from_numpy(p_h* im_width + p_w) discrete_map = torch.zeros(im_width * im_height).scatter_add_(0, index=p_index, src=torch.ones(im_width*im_height)).view(im_height, im_width).numpy() ''' slow method for p in points: p = np.round(p).astype(int) p[0], p[1] = min(h - 1, p[1]), min(w - 1, p[0]) discrete_map[p[0], p[1]] += 1 ''' assert np.sum(discrete_map) == num_gt return discrete_map class Base(data.Dataset): def __init__(self, root_path, crop_size, downsample_ratio=8): self.root_path = root_path self.c_size = crop_size self.d_ratio = downsample_ratio assert self.c_size % self.d_ratio == 0 self.dc_size = self.c_size // self.d_ratio self.trans = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def __len__(self): pass def __getitem__(self, item): pass def train_transform(self, img, keypoints): wd, ht = img.size st_size = 1.0 * min(wd, ht) assert st_size >= self.c_size assert len(keypoints) >= 0 i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size) img = F.crop(img, i, j, h, w) if len(keypoints) > 0: keypoints = keypoints - [j, i] idx_mask = (keypoints[:, 0] >= 0) * (keypoints[:, 0] <= w) * \ (keypoints[:, 1] >= 0) * (keypoints[:, 1] <= h) keypoints = keypoints[idx_mask] else: keypoints = np.empty([0, 2]) gt_discrete = gen_discrete_map(h, w, keypoints) down_w = w // self.d_ratio down_h = h // self.d_ratio gt_discrete = gt_discrete.reshape([down_h, self.d_ratio, down_w, self.d_ratio]).sum(axis=(1, 3)) assert np.sum(gt_discrete) == len(keypoints) if len(keypoints) > 0: if random.random() > 0.5: img = F.hflip(img) gt_discrete = np.fliplr(gt_discrete) keypoints[:, 0] = w - keypoints[:, 0] else: if random.random() > 0.5: img = F.hflip(img) gt_discrete = np.fliplr(gt_discrete) gt_discrete = np.expand_dims(gt_discrete, 0) return self.trans(img), torch.from_numpy(keypoints.copy()).float(), torch.from_numpy( gt_discrete.copy()).float() class Crowd_qnrf(Base): def __init__(self, root_path, crop_size, downsample_ratio=8, method='train'): super().__init__(root_path, crop_size, downsample_ratio) self.method = method self.im_list = sorted(glob(os.path.join(self.root_path, '*.jpg'))) print('number of img: {}'.format(len(self.im_list))) if method not in ['train', 'val']: raise Exception("not implement") def __len__(self): return len(self.im_list) def __getitem__(self, item): img_path = self.im_list[item] gd_path = img_path.replace('jpg', 'npy') img = Image.open(img_path).convert('RGB') if self.method == 'train': keypoints = np.load(gd_path) return self.train_transform(img, keypoints) elif self.method == 'val': keypoints = np.load(gd_path) img = self.trans(img) name = os.path.basename(img_path).split('.')[0] return img, len(keypoints), name class Crowd_nwpu(Base): def __init__(self, root_path, crop_size, downsample_ratio=8, method='train'): super().__init__(root_path, crop_size, downsample_ratio) self.method = method self.im_list = sorted(glob(os.path.join(self.root_path, '*.jpg'))) print('number of img: {}'.format(len(self.im_list))) if method not in ['train', 'val', 'test']: raise Exception("not implement") def __len__(self): return len(self.im_list) def __getitem__(self, item): img_path = self.im_list[item] gd_path = img_path.replace('jpg', 'npy') img = Image.open(img_path).convert('RGB') if self.method == 'train': keypoints = np.load(gd_path) return self.train_transform(img, keypoints) elif self.method == 'val': keypoints = np.load(gd_path) img = self.trans(img) name = os.path.basename(img_path).split('.')[0] return img, len(keypoints), name elif self.method == 'test': img = self.trans(img) name = os.path.basename(img_path).split('.')[0] return img, name class Crowd_sh(Base): def __init__(self, root_path, crop_size, downsample_ratio=8, method='train'): super().__init__(root_path, crop_size, downsample_ratio) self.method = method if method not in ['train', 'val']: raise Exception("not implement") self.im_list = sorted(glob(os.path.join(self.root_path, 'images', '*.jpg'))) print('number of img: {}'.format(len(self.im_list))) def __len__(self): return len(self.im_list) def __getitem__(self, item): img_path = self.im_list[item] name = os.path.basename(img_path).split('.')[0] gd_path = os.path.join(self.root_path, 'ground-truth', 'GT_{}.mat'.format(name)) img = Image.open(img_path).convert('RGB') keypoints = sio.loadmat(gd_path)['image_info'][0][0][0][0][0] if self.method == 'train': return self.train_transform(img, keypoints) elif self.method == 'val': img = self.trans(img) return img, len(keypoints), name def train_transform(self, img, keypoints): wd, ht = img.size st_size = 1.0 * min(wd, ht) # resize the image to fit the crop size if st_size < self.c_size: rr = 1.0 * self.c_size / st_size wd = round(wd * rr) ht = round(ht * rr) st_size = 1.0 * min(wd, ht) img = img.resize((wd, ht), Image.BICUBIC) keypoints = keypoints * rr assert st_size >= self.c_size, print(wd, ht) assert len(keypoints) >= 0 i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size) img = F.crop(img, i, j, h, w) if len(keypoints) > 0: keypoints = keypoints - [j, i] idx_mask = (keypoints[:, 0] >= 0) * (keypoints[:, 0] <= w) * \ (keypoints[:, 1] >= 0) * (keypoints[:, 1] <= h) keypoints = keypoints[idx_mask] else: keypoints = np.empty([0, 2]) gt_discrete = gen_discrete_map(h, w, keypoints) down_w = w // self.d_ratio down_h = h // self.d_ratio gt_discrete = gt_discrete.reshape([down_h, self.d_ratio, down_w, self.d_ratio]).sum(axis=(1, 3)) assert np.sum(gt_discrete) == len(keypoints) if len(keypoints) > 0: if random.random() > 0.5: img = F.hflip(img) gt_discrete = np.fliplr(gt_discrete) keypoints[:, 0] = w - keypoints[:, 0] - 1 else: if random.random() > 0.5: img = F.hflip(img) gt_discrete = np.fliplr(gt_discrete) gt_discrete = np.expand_dims(gt_discrete, 0) return self.trans(img), torch.from_numpy(keypoints.copy()).float(), torch.from_numpy( gt_discrete.copy()).float()