Kevin L
Initial commit
5fd3fc6
import math
import torch
import numpy as np
from fbrs.inference.clicker import Click
from .base import BaseTransform
class Crops(BaseTransform):
def __init__(self, crop_size=(320, 480), min_overlap=0.2):
super().__init__()
self.crop_height, self.crop_width = crop_size
self.min_overlap = min_overlap
self.x_offsets = None
self.y_offsets = None
self._counts = None
def transform(self, image_nd, clicks_lists):
assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
image_height, image_width = image_nd.shape[2:4]
self._counts = None
if image_height < self.crop_height or image_width < self.crop_width:
return image_nd, clicks_lists
self.x_offsets = get_offsets(image_width, self.crop_width, self.min_overlap)
self.y_offsets = get_offsets(image_height, self.crop_height, self.min_overlap)
self._counts = np.zeros((image_height, image_width))
image_crops = []
for dy in self.y_offsets:
for dx in self.x_offsets:
self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1
image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx + self.crop_width]
image_crops.append(image_crop)
image_crops = torch.cat(image_crops, dim=0)
self._counts = torch.tensor(self._counts, device=image_nd.device, dtype=torch.float32)
clicks_list = clicks_lists[0]
clicks_lists = []
for dy in self.y_offsets:
for dx in self.x_offsets:
crop_clicks = [Click(is_positive=x.is_positive, coords=(x.coords[0] - dy, x.coords[1] - dx))
for x in clicks_list]
clicks_lists.append(crop_clicks)
return image_crops, clicks_lists
def inv_transform(self, prob_map):
if self._counts is None:
return prob_map
new_prob_map = torch.zeros((1, 1, *self._counts.shape),
dtype=prob_map.dtype, device=prob_map.device)
crop_indx = 0
for dy in self.y_offsets:
for dx in self.x_offsets:
new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx + self.crop_width] += prob_map[crop_indx, 0]
crop_indx += 1
new_prob_map = torch.div(new_prob_map, self._counts)
return new_prob_map
def get_state(self):
return self.x_offsets, self.y_offsets, self._counts
def set_state(self, state):
self.x_offsets, self.y_offsets, self._counts = state
def reset(self):
self.x_offsets = None
self.y_offsets = None
self._counts = None
def get_offsets(length, crop_size, min_overlap_ratio=0.2):
if length == crop_size:
return [0]
N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio)
N = math.ceil(N)
overlap_ratio = (N - length / crop_size) / (N - 1)
overlap_width = int(crop_size * overlap_ratio)
offsets = [0]
for i in range(1, N):
new_offset = offsets[-1] + crop_size - overlap_width
if new_offset + crop_size > length:
new_offset = length - crop_size
offsets.append(new_offset)
return offsets