import numpy as np import torch from torch import nn from torchvision import transforms import torch.nn.functional as F from torch.distributions.categorical import Categorical from kornia.filters.kernels import (get_spatial_gradient_kernel2d, normalize_kernel2d) def l2_normalize(x): return F.normalize(x, p=2.0, dim=-1, eps=1e-6) def reduce_max(x, dim, keepdim=True): return torch.max(x, dim=dim, keepdim=keepdim)[0] def coordinate_ims(batch_size, seq_length, imsize): static = False if seq_length == 0: static = True seq_length = 1 B = batch_size T = seq_length H,W = imsize ones = torch.ones([B,H,W,1], dtype=torch.float32) h = torch.divide(torch.arange(H).to(ones), torch.tensor(H-1, dtype=torch.float32)) h = 2.0 * ((h.view(1, H, 1, 1) * ones) - 0.5) w = torch.divide(torch.arange(W).to(ones), torch.tensor(W-1, dtype=torch.float32)) w = 2.0 * ((w.view(1, 1, W, 1) * ones) - 0.5) h = torch.stack([h]*T, 1) w = torch.stack([w]*T, 1) hw_ims = torch.cat([h,w], -1) if static: hw_ims = hw_ims[:,0] return hw_ims def dot_product_attention(queries, keys, normalize=True, eps=1e-8): """ Compute the normalized dot product between two PyTorch tensors """ B,N,D_q = queries.size() _B,N_k,D_k = keys.size() assert D_q == D_k, (queries.shape, keys.shape) if normalize: queries = F.normalize(queries, p=2.0, dim=-1, eps=eps) keys = F.normalize(keys, p=2.0, dim=-1, eps=eps) outputs = torch.matmul(queries, torch.transpose(keys, 1, 2)) # [B, N, N_k] attention = torch.transpose(outputs, 1, 2) # [B, N_k, N] return outputs def sample_image_inds_from_probs(probs, num_points, eps=1e-9): B,H,W = probs.shape P = num_points N = H*W probs = probs.reshape(B,N) probs = torch.maximum(probs + eps, torch.tensor(0., device=probs.device)) / (probs.sum(dim=-1, keepdim=True) + eps) dist = Categorical(probs=probs, validate_args=False) indices = dist.sample([P]).permute(1,0).to(torch.int32) # [B,P] indices_h = torch.minimum(torch.maximum(torch.div(indices, W, rounding_mode='floor'), torch.tensor(0)), torch.tensor(H-1)) indices_w = torch.minimum(torch.maximum(torch.fmod(indices, W), torch.tensor(0)), torch.tensor(W-1)) indices = torch.stack([indices_h, indices_w], dim=-1) # [B,P,2] return indices def get_gradient_image(image, mode='sobel', order=1, normalize_kernel=True): B,C,H,W = list(image.size()) # prepare kernel kernel = get_spatial_gradient_kernel2d(mode, order) if normalize_kernel: kernel = normalize_kernel2d(kernel) tmp_kernel = kernel.to(image).detach() tmp_kernel = tmp_kernel.unsqueeze(1).unsqueeze(1) kernel_flip = tmp_kernel.flip(-3) # pad spatial dims of image padding = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2] out_channels = 3 if (order == 2) else 2 padded_image = F.pad(image.reshape(B*C, 1, H, W), padding, 'replicate')[:, :, None] # [B*C,1,1,H+p,W+p] gradient_image = F.conv3d(padded_image, kernel_flip, padding=0).view(B, C, out_channels, H, W) return gradient_image def sample_coordinates_at_borders(image, num_points=16, mask=None, sum_edges=True, normalized_coordinates=True): """ Sample num_points in normalized (h,w) coordinates from the borders of the input image """ B,C,H,W = list(image.size()) if mask is not None: assert mask.shape[2:] == image.shape[2:], (mask.size(), image.size()) else: mask = torch.ones(size=(B,1,H,W)).to(image) gradient_image = get_gradient_image(image * mask, mode='sobel', order=1) # [B,C,2,H,W] gradient_magnitude = torch.sqrt(torch.square(gradient_image).sum(dim=2)) if sum_edges: edges = gradient_magnitude.sum(1) # [B,H,W] else: edges = gradient_magnitude.max(1)[0] if mask is not None: edges = edges * mask[:,0] coordinates = sample_image_inds_from_probs(edges, num_points=num_points) if normalized_coordinates: coordinates = coordinates.to(torch.float32) coordinates /= torch.tensor([H-1,W-1], dtype=torch.float32)[None,None].to(coordinates.device) coordinates = 2.0 * coordinates - 1.0 return coordinates def index_into_images(images, indices, channels_last=False): """ index into an image at P points to get its values images: [B,C,H,W] indices: [B,P,2] """ assert indices.size(-1) == 2, indices.size() if channels_last: images = images.permute(0,3,1,2) # [B,C,H,W] B,C,H,W = images.shape _,P,_ = indices.shape inds_h, inds_w = list(indices.to(torch.long).permute(2,0,1)) # [B,P] each inds_b = torch.arange(B, dtype=torch.long).unsqueeze(-1).expand(-1,P).to(indices) inds = torch.stack([inds_b, inds_h, inds_w], 0).to(torch.long) values = images.permute(0,2,3,1)[list(inds)] # [B,P,C] return values def soft_index(images, indices, scale_by_imsize=True): assert indices.shape[-1] == 2, indices.shape B,C,H,W = images.shape _,P,_ = indices.shape # h_inds, w_inds = indices.split([1,1], dim=-1) h_inds, w_inds = list(indices.permute(2,0,1)) if scale_by_imsize: h_inds = (h_inds + 1.0) * torch.tensor(H).to(h_inds) * 0.5 w_inds = (w_inds + 1.0) * torch.tensor(W).to(w_inds) * 0.5 h_inds = torch.maximum(torch.minimum(h_inds, torch.tensor(H-1).to(h_inds)), torch.tensor(0.).to(h_inds)) w_inds = torch.maximum(torch.minimum(w_inds, torch.tensor(W-1).to(w_inds)), torch.tensor(0.).to(w_inds)) h_floor = torch.floor(h_inds) w_floor = torch.floor(w_inds) h_ceil = torch.ceil(h_inds) w_ceil = torch.ceil(w_inds) bot_right_weight = (h_inds - h_floor) * (w_inds - w_floor) bot_left_weight = (h_inds - h_floor) * (w_ceil - w_inds) top_right_weight = (h_ceil - h_inds) * (w_inds - w_floor) top_left_weight = (h_ceil - h_inds) * (w_ceil - w_inds) in_bounds = (bot_right_weight + bot_left_weight + top_right_weight + top_left_weight) > 0.95 in_bounds = in_bounds.to(torch.float32) top_left_vals = index_into_images(images, torch.stack([h_floor, w_floor], -1)) top_right_vals = index_into_images(images, torch.stack([h_floor, w_ceil], -1)) bot_left_vals = index_into_images(images, torch.stack([h_ceil, w_floor], -1)) bot_right_vals = index_into_images(images, torch.stack([h_ceil, w_ceil], -1)) im_vals = top_left_vals * top_left_weight[...,None] im_vals += top_right_vals * top_right_weight[...,None] im_vals += bot_left_vals * bot_left_weight[...,None] im_vals += bot_right_vals * bot_right_weight[...,None] im_vals = im_vals.view(B,P,C) return im_vals def compute_compatibility(positions, plateau, phenotypes=None, availability=None, noise=0.1): """ Compute how well "fit" each agent is for the position it's at on the plateau, according to its "phenotype" positions: [B,P,2] plateau: [B,H,W,Q] phenotypes: [B,P,D] or None availability: [B,H,W,A] """ B,H,W,Q = plateau.shape P = positions.shape[1] if phenotypes is None: phenotypes = soft_index(plateau, positions) if availability is not None: assert list(availability.shape)[:-1] == list(plateau.shape)[:-1], (availability.shape, plateau.shape) A = availability.size(-1) assert P % A == 0, (P, A) S = P // A # population size print("computing availability -- needlessly?", [B,H,W,A,Q]) plateau = availability[...,None] * plateau[...,None,:] # [B,H,W,A,Q] plateau = plateau.view(B,H,W,A*Q) plateau_values = soft_index(plateau.permute(0,3,1,2), positions, scale_by_imsize=True) if noise > 0: plateau_values += noise * torch.rand(size=plateau_values.size(), dtype=torch.float32).to(plateau_values.device) if availability is not None: plateau_values = l2_normalize(plateau_values.view(B, P, A, Q)) inds = torch.tile(torch.eye(A)[None].expand(B,-1,-1), (1,S,1))[...,None] # [B,P,A,1] plateau_values = torch.sum(plateau_values * inds.to(plateau_values), dim=-2) # [B,P,Q] else: plateau_values = l2_normalize(plateau_values) compatibility = torch.sum( l2_normalize(phenotypes) * plateau_values, dim=-1, keepdim=True) # [B,P,1] return compatibility def compute_pairwise_overlaps(masks, masks_target=None, mask_thresh=None, eps=1e-6): """Find overlaps between masks""" B,N,P = masks.shape if masks_target is None: masks_target = masks if mask_thresh is not None: masks = (masks > mask_thresh).to(torch.float32) masks_target = (masks_target > mask_thresh).to(torch.float32) ## union and intersection overlaps = masks[...,None] * masks_target[...,None,:] # [B,N,P,P] I = overlaps.sum(dim=1) U = torch.maximum(masks[...,None], masks_target[...,None,:]).sum(dim=1) iou = I / torch.maximum(U, torch.tensor(eps, dtype=torch.float32)) # [B,P,P] return iou def compete_agents(masks, fitnesses, alive, mask_thresh=0.5, compete_thresh=0.2, sticky_winners=True): """ Kill off agents (which mask dimensions are "alive") based on mask overlap and fitnesses of each args: masks: [B,N,P] fitnesses: [B,P,1] alive: [B,P,1] returns: still_alive: [B,P,1] """ B,N,P = masks.shape assert list(alive.shape) == [B,P,1], alive.shape assert list(fitnesses.shape) == [B,P,1], fitnesses.shape ## find territorial disputes overlaps = compute_pairwise_overlaps(masks, masks_target=None, mask_thresh=mask_thresh) disputes = overlaps > compete_thresh # [B,P,P] ## agents don't fight themselves disputes = torch.logical_and( disputes, torch.logical_not( torch.eye(P, dtype=torch.bool, device=disputes.device).unsqueeze(0).expand(B,-1,-1))) ## kill off the agents with lower fitness in each dispute killed = torch.logical_and(disputes, fitnesses < torch.transpose(fitnesses, 1, 2)) ## once an agent wins, it always wins again if sticky_winners: winners = (alive > 0.5) losers = torch.logical_not(winners) ## winners can't lose to last round's losers winners_vs_losers = torch.logical_and(winners, torch.transpose(losers, 1, 2)) # [B,P,P] killed = torch.logical_and(killed, torch.logical_not(winners_vs_losers)) ## losers can't overtake last round's winners losers_vs_winners = torch.logical_and(losers, torch.transpose(winners, 1, 2)) losers_vs_winners_disputes = torch.logical_and(losers_vs_winners, disputes) killed = torch.logical_or(killed, losers_vs_winners_disputes) ## if an agent was killed by *any* competitor, it's dead killed = torch.any(killed, dim=2, keepdim=True) alive = torch.logical_not(killed).to(torch.float32) return alive def compute_distance_weighted_vectors(vector_map, positions, mask=None, beta=1.0, eps=1e-8): """ compute vectors whose values are a weighted mean of vector_map, where weights are given by distance. """ B,H,W,D = vector_map.shape assert positions.size(-1) == 2, positions.size() B,P,_ = positions.shape N = H*W if mask is None: mask = torch.ones_like(vector_map[...,0:1]).to(vector_map.device) else: assert list(mask.shape) == [B,H,W,1] hw_grid = coordinate_ims(B, 0, [H,W]).view(B, N, 2).to(vector_map.device) delta_positions = hw_grid[:,None] - positions[:,:,None] # [B,P,N,2] distances = torch.sqrt(delta_positions[...,0]**2 + delta_positions[...,1]**2 + eps) # [B,P,N] ## max distance is 2*sqrt(2) inv_distances = (2.0 * np.sqrt(2.0)) / (distances + eps) inv_distances = F.softmax(beta * inv_distances * mask.view(B, 1, N), dim=-1) # [B,P,N] distance_weighted_vectors = torch.sum( vector_map.view(B, 1, N, D) * inv_distances[...,None], dim=2, keepdim=False) # [B,P,D] return distance_weighted_vectors def masks_from_phenotypes(plateau, phenotypes, normalize=True): B,H,W,Q = plateau.shape N = H*W masks = dot_product_attention( queries=plateau.view(B,N,Q), keys=phenotypes, normalize=normalize) masks = F.relu(masks) return masks class Competition(nn.Module): def __init__( self, size=None, num_masks=16, num_competition_rounds=5, mask_beta=10.0, reduce_func=reduce_max, stop_gradient=True, stop_gradient_phenotypes=True, normalization_func=l2_normalize, sum_edges=True, mask_thresh=0.5, compete_thresh=0.2, sticky_winners=True, selection_strength=100.0, homing_strength=10.0, mask_dead_segments=True ): super().__init__() self.num_masks = self.M = num_masks self.num_competition_rounds = num_competition_rounds self.mask_beta = mask_beta self.reduce_func = reduce_func self.normalization_func = normalization_func ## stop gradients self.sg_func = lambda x: (x.detach() if stop_gradient else x) self.sg_phenotypes_func = lambda x: (x.detach() if stop_gradient_phenotypes else x) ## agent sampling kwargs self.sum_edges = sum_edges ## competition kwargs self.mask_thresh = mask_thresh self.compete_thresh = compete_thresh self.sticky_winners = sticky_winners self.selection_strength = selection_strength self.homing_strength = homing_strength self.mask_dead_segments = mask_dead_segments ## shapes self.B = self.T = self.BT = self.N = self.Q = None self.size = size # [H,W] if self.size: assert len(self.size) == 2, self.size def reshape_batch_time(self, x, merge=True): if merge: self.is_temporal = True B, T = x.size()[0:2] if self.B: assert (B == self.B), (B, self.B) else: self.B = B if self.T: assert (T == self.T), (T, self.T) else: self.T = T assert B*T == (self.B * self.T), (B*T, self.B*self.T) if self.BT is None: self.BT = self.B * self.T return torch.reshape(x, [self.BT] + list(x.size())[2:]) else: # split BT = x.size()[0] assert self.B and self.T, (self.B, self.T) if self.BT is not None: assert BT == self.BT, (BT, self.BT) else: self.BT = BT return torch.reshape(x, [self.B, self.T] + list(x.size())[1:]) def process_plateau_input(self, plateau): shape = plateau.size() if len(shape) == 5: self.is_temporal = True self.B, self.T, self.H, self.W, self.Q = shape self.N = self.H * self.W self.BT = self.B * self.T plateau = self.reshape_batch_time(plateau) elif (len(shape) == 4) and (self.size is None): self.is_temporal = False self.B, self.H, self.W, self.Q = shape self.N = self.H * self.W self.T = 1 self.BT = self.B*self.T elif (len(shape) == 4) and (self.size is not None): self.is_temporal = True self.B, self.T, self.N, self.Q = shape self.BT = self.B * self.T self.H, self.W = self.size plateau = self.reshape_batch_time(plateau) plateau = torch.reshape(plateau, [self.BT, self.H, self.W, self.Q]) elif len(shape) == 3: assert self.size is not None, \ "You need to specify an image size to reshape the plateau of shape %s" % shape self.is_temporal = False self.B, self.N, self.Q = shape self.T = 1 self.BT = self.B self.H, self.W = self.size plateau = torch.reshape(plateau, [self.BT, self.H, self.W, self.Q]) else: raise ValueError("input plateau map with shape %s cannot be reshaped to [BT, H, W, Q]" % shape) return plateau def forward(self, plateau, agents=None, alive=None, phenotypes=None, compete=True, update_pointers=True, yoke_phenotypes_to_agents=True, noise=0.1 ): """ Find the uniform regions within the plateau map by competition between visual "indices." args: plateau: [B,[T],H,W,Q] feature map with smooth "plateaus" returns: masks: [B, [T], H, W, M] one mask in each of M channels agents: [B, [T], M, 2] positions of agents in normalized coordinates alive: [B, [T], M] binary vector indicating which masks are valid phenotypes: [B, [T], M, Q] unharvested: [B, [T], H, W] map of regions that weren't covered """ ## preprocess plateau = self.process_plateau_input(plateau) # [BT,H,W,Q] plateau = self.normalization_func(plateau) ## sample initial indices ("agents") from borders of the plateau map if agents is None: agents = sample_coordinates_at_borders( plateau.permute(0,3,1,2), num_points=self.M, mask=None, sum_edges=self.sum_edges) else: if self.is_temporal: agents = agents.view(self.BT, *agents.shape[2:]) ## the agents have "phenotypes" depending on where they're situated on the plateau map if phenotypes is None: phenotypes = self.sg_phenotypes_func( self.normalization_func( soft_index(plateau.permute(0,3,1,2), agents, scale_by_imsize=True))) elif self.is_temporal: phenotypes = phenotypes.view(self.BT, *phenotypes.shape[2:]) ## the "fitness" of an agent -- how likely it is to survive competition -- ## is how well its phenotype matches the plateau vector at its current position ## initially all of these agents are "alive" if alive is None: alive = torch.ones_like(agents[...,-1:]) # [BT,M,1] fitnesses = compute_compatibility(agents, plateau, phenotypes, availability=None, noise=noise) alive_mask = None else: if self.is_temporal: alive = alive.view(self.BT, *alive.shape[2:]) alive_mask = (alive > 0.5).float() fitnesses = alive_mask + compute_compatibility(agents, plateau, phenotypes, availability=None, noise=noise) * (1 - alive_mask) alive_t = torch.transpose(alive, 1, 2) # [BT, 1, M] ## compute the masks at initialization masks_pred = masks_from_phenotypes(plateau, phenotypes, normalize=True) ## find the "unharvested" regions of the plateau map not covered by agents unharvested = torch.minimum(self.reduce_func(masks_pred, dim=-1, keepdim=True), torch.tensor(1.0)) unharvested = 1.0 - unharvested.view(self.BT, self.H, self.W, 1) if alive_mask is not None: new_agents = sample_coordinates_at_borders( plateau.permute(0,3,1,2), num_points=self.M, mask=unharvested.permute(0,3,1,2), sum_edges=self.sum_edges) agents = agents * alive_mask + new_agents * (1.0 - alive_mask) new_phenotypes = self.sg_phenotypes_func( self.normalization_func( soft_index(plateau.permute(0,3,1,2), new_agents, scale_by_imsize=True))) phenotypes = phenotypes * alive_mask + new_phenotypes * (1.0 - alive_mask) for r in range(self.num_competition_rounds): # print("Evolution round {}".format(r+1)) ## compute the "availability" of the plateau map for each agent (i.e. where it can harvest from) alive_t = torch.transpose(alive, 1, 2) # [BT, 1, M] # availability = alive_t * masks_pred + (1.0 - alive_t) * unharvested.view(self.BT, self.N, 1) # availability = availability.view(self.BT, self.H, self.W, self.M) ## update the fitnesses if update_pointers and compete: fitnesses = compute_compatibility( positions=agents, plateau=plateau, phenotypes=phenotypes, # availability=availability) availability=None, noise=noise ) ## kill agents that have wandered off the map in_bounds = torch.all( torch.logical_and(agents < 1.0, agents > -1.0), dim=-1, keepdim=True) # [BT,M,1] fitnesses *= in_bounds.to(fitnesses) ## break ties in fitness fitnesses -= 0.001 * torch.arange(self.M, dtype=torch.float32)[None,:,None].expand(self.BT,-1,-1).to(fitnesses.device) ## recompute the masks (why?) if yoke_phenotypes_to_agents: occupied_regions = self.sg_phenotypes_func( soft_index(plateau.permute(0,3,1,2), agents, scale_by_imsize=True)) masks_pred = masks_from_phenotypes(plateau, occupied_regions, normalize=True) # [BT,N,M] ## have each pair of agents compete. ## If their masks overlap, the winner is the one with higher fitness if compete: alive = compete_agents(masks_pred, fitnesses, alive, mask_thresh=self.mask_thresh, compete_thresh=self.compete_thresh, sticky_winners=self.sticky_winners) alive *= in_bounds.to(alive) alive_t = torch.transpose(alive, 1, 2) # print("Num alive masks", alive.sum(), "which ones --> ", np.where(alive[0,:,0].detach().cpu().numpy())) if not yoke_phenotypes_to_agents: masks_pred = masks_from_phenotypes(plateau, phenotypes, normalize=True) ## update which parts of the plateau are "unharvested" unharvested = torch.minimum(self.reduce_func(masks_pred * alive_t, dim=-1, keepdim=True), torch.tensor(1.0, dtype=torch.float32)) unharvested = 1.0 - unharvested.view(self.BT, self.H, self.W, 1) ## update phenotypes of the winners if update_pointers: if self.mask_thresh is not None: winner_phenotypes = (masks_pred[...,None] > self.mask_thresh).to(plateau) if self.selection_strength > 0: winner_phenotypes = winner_phenotypes * plateau.view(self.BT, self.N, 1, self.Q) winner_phenotypes = self.normalization_func(winner_phenotypes.mean(dim=1)) # [BT,M,Q] phenotypes += (alive * winner_phenotypes) * self.selection_strength ## reinitialize losing agent positions alive_mask = (alive > 0.5).to(torch.float32) loser_agents = sample_coordinates_at_borders( plateau.permute(0,3,1,2), num_points=self.M, mask=unharvested.permute(0,3,1,2), sum_edges=self.sum_edges) agents = agents * alive_mask + loser_agents * (1.0 - alive_mask) ## reinitialize loser agent phenotypes loser_phenotypes = self.normalization_func( compute_distance_weighted_vectors(plateau, agents, mask=unharvested, beta=self.homing_strength)) phenotypes = alive_mask * phenotypes + (1.0 - alive_mask) * loser_phenotypes phenotypes = self.normalization_func(phenotypes) ## that's it for this round! # print("round %d" % r, alive.shape, torch.where(alive[0,:,0])) ## run a final competition between the surviving masks if self.mask_beta is not None: masks_pred = F.softmax( self.mask_beta * masks_pred * alive_t - \ self.mask_beta * (1.0 - alive_t), dim=-1) if self.mask_dead_segments: masks_pred *= alive_t masks_pred = masks_pred.view(self.BT,self.H,self.W,self.M) if self.is_temporal: masks_pred = self.reshape_batch_time(masks_pred, merge=False) agents = self.reshape_batch_time(agents, merge=False) alive = self.reshape_batch_time(alive, merge=False) phenotypes = self.reshape_batch_time(phenotypes, merge=False) unharvested = self.reshape_batch_time(unharvested, merge=False) return (masks_pred, agents, alive, phenotypes, unharvested) @staticmethod def masks_to_segments(masks): return masks.argmax(-1) @staticmethod def flatten_plateau_with_masks(plateau, masks, alive, flatten_masks=True): B,M,_ = alive.shape Q = plateau.shape[-1] if flatten_masks: masks = F.one_hot((alive[...,None,None,:,0] * masks).argmax(-1), num_classes=M).float() flat_plateau = torch.zeros_like(plateau) phenotypes = torch.zeros((B,M,Q), device=plateau.device).float() for b in range(B): m_inds = torch.where(alive[b,:,0])[0] masks_b = masks[b,...,m_inds] num_px = masks_b.sum((0,1)).clamp(min=1)[:,None] # [K,1] phenos_b = torch.einsum('hwk,hwq->kq', masks_b, plateau[b]) / num_px # [K,Q] flat_plateau_b = (masks_b[...,None] * phenos_b[None,None]).sum(-2) # [H,W,Q] phenotypes[b,m_inds,:] = phenos_b flat_plateau[b] = flat_plateau_b _norm = lambda x: F.normalize(x, p=2, dim=-1) return (_norm(flat_plateau), _norm(phenotypes)) @staticmethod def plot_agents(agents, alive, size=[128,128]): B,M,_ = alive.shape agent_map = -1 * torch.ones((B,*size), device=alive.device, dtype=torch.long) for b in range(B): inds = torch.where(alive[b,:,0]) for i in inds[0]: pos = agents[b,i]*0.5 + 0.5 pos = pos * torch.tensor(size, device=pos.device) hmin, wmin = list(torch.floor(pos).long()) hmax, wmax = list(torch.ceil(pos).long()) agent_map[b,[hmin,hmin,hmax,hmax],[wmin,wmax,wmin,wmax]] = i return agent_map if __name__ == '__main__': Comp = Competition(num_masks=32, num_competition_rounds=5) left = torch.ones(size=(32,8)).unsqueeze(-1) * torch.tensor([1.,0.2,0.]) middle = torch.ones(size=(32,16)).unsqueeze(-1) * torch.tensor([0.,1.,0.2]) right = torch.ones(size=(32,8)).unsqueeze(-1) * torch.tensor([0.1,0.,1.]) plateau = torch.cat([left, middle, right], dim=-2).unsqueeze(0) masks, agents, alive, phenotypes, unharvested = Comp(plateau) mask_inds = np.where(alive[0,:,0].numpy())[0] print(np.argmax(masks[0,...], axis=-1)) for ind in mask_inds: print("num pixels in mask %d ---> %d" % (ind, (np.argmax(masks[0], -1) == ind).sum()))