rahulvenkk
app.py updated
6dfcb0f
raw
history blame
27.5 kB
import numpy as np
import torch
from torch import nn
from torchvision import transforms
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from kornia.filters.kernels import (get_spatial_gradient_kernel2d,
normalize_kernel2d)
def l2_normalize(x):
return F.normalize(x, p=2.0, dim=-1, eps=1e-6)
def reduce_max(x, dim, keepdim=True):
return torch.max(x, dim=dim, keepdim=keepdim)[0]
def coordinate_ims(batch_size, seq_length, imsize):
static = False
if seq_length == 0:
static = True
seq_length = 1
B = batch_size
T = seq_length
H,W = imsize
ones = torch.ones([B,H,W,1], dtype=torch.float32)
h = torch.divide(torch.arange(H).to(ones), torch.tensor(H-1, dtype=torch.float32))
h = 2.0 * ((h.view(1, H, 1, 1) * ones) - 0.5)
w = torch.divide(torch.arange(W).to(ones), torch.tensor(W-1, dtype=torch.float32))
w = 2.0 * ((w.view(1, 1, W, 1) * ones) - 0.5)
h = torch.stack([h]*T, 1)
w = torch.stack([w]*T, 1)
hw_ims = torch.cat([h,w], -1)
if static:
hw_ims = hw_ims[:,0]
return hw_ims
def dot_product_attention(queries, keys, normalize=True, eps=1e-8):
"""
Compute the normalized dot product between two PyTorch tensors
"""
B,N,D_q = queries.size()
_B,N_k,D_k = keys.size()
assert D_q == D_k, (queries.shape, keys.shape)
if normalize:
queries = F.normalize(queries, p=2.0, dim=-1, eps=eps)
keys = F.normalize(keys, p=2.0, dim=-1, eps=eps)
outputs = torch.matmul(queries, torch.transpose(keys, 1, 2)) # [B, N, N_k]
attention = torch.transpose(outputs, 1, 2) # [B, N_k, N]
return outputs
def sample_image_inds_from_probs(probs, num_points, eps=1e-9):
B,H,W = probs.shape
P = num_points
N = H*W
probs = probs.reshape(B,N)
probs = torch.maximum(probs + eps, torch.tensor(0., device=probs.device)) / (probs.sum(dim=-1, keepdim=True) + eps)
dist = Categorical(probs=probs, validate_args=False)
indices = dist.sample([P]).permute(1,0).to(torch.int32) # [B,P]
indices_h = torch.minimum(torch.maximum(torch.div(indices, W, rounding_mode='floor'), torch.tensor(0)), torch.tensor(H-1))
indices_w = torch.minimum(torch.maximum(torch.fmod(indices, W), torch.tensor(0)), torch.tensor(W-1))
indices = torch.stack([indices_h, indices_w], dim=-1) # [B,P,2]
return indices
def get_gradient_image(image, mode='sobel', order=1, normalize_kernel=True):
B,C,H,W = list(image.size())
# prepare kernel
kernel = get_spatial_gradient_kernel2d(mode, order)
if normalize_kernel:
kernel = normalize_kernel2d(kernel)
tmp_kernel = kernel.to(image).detach()
tmp_kernel = tmp_kernel.unsqueeze(1).unsqueeze(1)
kernel_flip = tmp_kernel.flip(-3)
# pad spatial dims of image
padding = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2]
out_channels = 3 if (order == 2) else 2
padded_image = F.pad(image.reshape(B*C, 1, H, W), padding, 'replicate')[:, :, None] # [B*C,1,1,H+p,W+p]
gradient_image = F.conv3d(padded_image, kernel_flip, padding=0).view(B, C, out_channels, H, W)
return gradient_image
def sample_coordinates_at_borders(image, num_points=16, mask=None, sum_edges=True, normalized_coordinates=True):
"""
Sample num_points in normalized (h,w) coordinates from the borders of the input image
"""
B,C,H,W = list(image.size())
if mask is not None:
assert mask.shape[2:] == image.shape[2:], (mask.size(), image.size())
else:
mask = torch.ones(size=(B,1,H,W)).to(image)
gradient_image = get_gradient_image(image * mask, mode='sobel', order=1) # [B,C,2,H,W]
gradient_magnitude = torch.sqrt(torch.square(gradient_image).sum(dim=2))
if sum_edges:
edges = gradient_magnitude.sum(1) # [B,H,W]
else:
edges = gradient_magnitude.max(1)[0]
if mask is not None:
edges = edges * mask[:,0]
coordinates = sample_image_inds_from_probs(edges, num_points=num_points)
if normalized_coordinates:
coordinates = coordinates.to(torch.float32)
coordinates /= torch.tensor([H-1,W-1], dtype=torch.float32)[None,None].to(coordinates.device)
coordinates = 2.0 * coordinates - 1.0
return coordinates
def index_into_images(images, indices, channels_last=False):
"""
index into an image at P points to get its values
images: [B,C,H,W]
indices: [B,P,2]
"""
assert indices.size(-1) == 2, indices.size()
if channels_last:
images = images.permute(0,3,1,2) # [B,C,H,W]
B,C,H,W = images.shape
_,P,_ = indices.shape
inds_h, inds_w = list(indices.to(torch.long).permute(2,0,1)) # [B,P] each
inds_b = torch.arange(B, dtype=torch.long).unsqueeze(-1).expand(-1,P).to(indices)
inds = torch.stack([inds_b, inds_h, inds_w], 0).to(torch.long)
values = images.permute(0,2,3,1)[list(inds)] # [B,P,C]
return values
def soft_index(images, indices, scale_by_imsize=True):
assert indices.shape[-1] == 2, indices.shape
B,C,H,W = images.shape
_,P,_ = indices.shape
# h_inds, w_inds = indices.split([1,1], dim=-1)
h_inds, w_inds = list(indices.permute(2,0,1))
if scale_by_imsize:
h_inds = (h_inds + 1.0) * torch.tensor(H).to(h_inds) * 0.5
w_inds = (w_inds + 1.0) * torch.tensor(W).to(w_inds) * 0.5
h_inds = torch.maximum(torch.minimum(h_inds, torch.tensor(H-1).to(h_inds)), torch.tensor(0.).to(h_inds))
w_inds = torch.maximum(torch.minimum(w_inds, torch.tensor(W-1).to(w_inds)), torch.tensor(0.).to(w_inds))
h_floor = torch.floor(h_inds)
w_floor = torch.floor(w_inds)
h_ceil = torch.ceil(h_inds)
w_ceil = torch.ceil(w_inds)
bot_right_weight = (h_inds - h_floor) * (w_inds - w_floor)
bot_left_weight = (h_inds - h_floor) * (w_ceil - w_inds)
top_right_weight = (h_ceil - h_inds) * (w_inds - w_floor)
top_left_weight = (h_ceil - h_inds) * (w_ceil - w_inds)
in_bounds = (bot_right_weight + bot_left_weight + top_right_weight + top_left_weight) > 0.95
in_bounds = in_bounds.to(torch.float32)
top_left_vals = index_into_images(images, torch.stack([h_floor, w_floor], -1))
top_right_vals = index_into_images(images, torch.stack([h_floor, w_ceil], -1))
bot_left_vals = index_into_images(images, torch.stack([h_ceil, w_floor], -1))
bot_right_vals = index_into_images(images, torch.stack([h_ceil, w_ceil], -1))
im_vals = top_left_vals * top_left_weight[...,None]
im_vals += top_right_vals * top_right_weight[...,None]
im_vals += bot_left_vals * bot_left_weight[...,None]
im_vals += bot_right_vals * bot_right_weight[...,None]
im_vals = im_vals.view(B,P,C)
return im_vals
def compute_compatibility(positions, plateau, phenotypes=None, availability=None, noise=0.1):
"""
Compute how well "fit" each agent is for the position it's at on the plateau,
according to its "phenotype"
positions: [B,P,2]
plateau: [B,H,W,Q]
phenotypes: [B,P,D] or None
availability: [B,H,W,A]
"""
B,H,W,Q = plateau.shape
P = positions.shape[1]
if phenotypes is None:
phenotypes = soft_index(plateau, positions)
if availability is not None:
assert list(availability.shape)[:-1] == list(plateau.shape)[:-1], (availability.shape, plateau.shape)
A = availability.size(-1)
assert P % A == 0, (P, A)
S = P // A # population size
print("computing availability -- needlessly?", [B,H,W,A,Q])
plateau = availability[...,None] * plateau[...,None,:] # [B,H,W,A,Q]
plateau = plateau.view(B,H,W,A*Q)
plateau_values = soft_index(plateau.permute(0,3,1,2), positions, scale_by_imsize=True)
if noise > 0:
plateau_values += noise * torch.rand(size=plateau_values.size(), dtype=torch.float32).to(plateau_values.device)
if availability is not None:
plateau_values = l2_normalize(plateau_values.view(B, P, A, Q))
inds = torch.tile(torch.eye(A)[None].expand(B,-1,-1), (1,S,1))[...,None] # [B,P,A,1]
plateau_values = torch.sum(plateau_values * inds.to(plateau_values), dim=-2) # [B,P,Q]
else:
plateau_values = l2_normalize(plateau_values)
compatibility = torch.sum(
l2_normalize(phenotypes) * plateau_values, dim=-1, keepdim=True) # [B,P,1]
return compatibility
def compute_pairwise_overlaps(masks, masks_target=None, mask_thresh=None, eps=1e-6):
"""Find overlaps between masks"""
B,N,P = masks.shape
if masks_target is None:
masks_target = masks
if mask_thresh is not None:
masks = (masks > mask_thresh).to(torch.float32)
masks_target = (masks_target > mask_thresh).to(torch.float32)
## union and intersection
overlaps = masks[...,None] * masks_target[...,None,:] # [B,N,P,P]
I = overlaps.sum(dim=1)
U = torch.maximum(masks[...,None], masks_target[...,None,:]).sum(dim=1)
iou = I / torch.maximum(U, torch.tensor(eps, dtype=torch.float32)) # [B,P,P]
return iou
def compete_agents(masks, fitnesses, alive,
mask_thresh=0.5, compete_thresh=0.2,
sticky_winners=True):
"""
Kill off agents (which mask dimensions are "alive") based on mask overlap and fitnesses of each
args:
masks: [B,N,P]
fitnesses: [B,P,1]
alive: [B,P,1]
returns:
still_alive: [B,P,1]
"""
B,N,P = masks.shape
assert list(alive.shape) == [B,P,1], alive.shape
assert list(fitnesses.shape) == [B,P,1], fitnesses.shape
## find territorial disputes
overlaps = compute_pairwise_overlaps(masks, masks_target=None, mask_thresh=mask_thresh)
disputes = overlaps > compete_thresh # [B,P,P] <bool>
## agents don't fight themselves
disputes = torch.logical_and(
disputes, torch.logical_not(
torch.eye(P, dtype=torch.bool, device=disputes.device).unsqueeze(0).expand(B,-1,-1)))
## kill off the agents with lower fitness in each dispute
killed = torch.logical_and(disputes, fitnesses < torch.transpose(fitnesses, 1, 2))
## once an agent wins, it always wins again
if sticky_winners:
winners = (alive > 0.5)
losers = torch.logical_not(winners)
## winners can't lose to last round's losers
winners_vs_losers = torch.logical_and(winners, torch.transpose(losers, 1, 2)) # [B,P,P]
killed = torch.logical_and(killed, torch.logical_not(winners_vs_losers))
## losers can't overtake last round's winners
losers_vs_winners = torch.logical_and(losers, torch.transpose(winners, 1, 2))
losers_vs_winners_disputes = torch.logical_and(losers_vs_winners, disputes)
killed = torch.logical_or(killed, losers_vs_winners_disputes)
## if an agent was killed by *any* competitor, it's dead
killed = torch.any(killed, dim=2, keepdim=True)
alive = torch.logical_not(killed).to(torch.float32)
return alive
def compute_distance_weighted_vectors(vector_map, positions, mask=None, beta=1.0, eps=1e-8):
"""
compute vectors whose values are a weighted mean of vector_map, where weights are given by distance.
"""
B,H,W,D = vector_map.shape
assert positions.size(-1) == 2, positions.size()
B,P,_ = positions.shape
N = H*W
if mask is None:
mask = torch.ones_like(vector_map[...,0:1]).to(vector_map.device)
else:
assert list(mask.shape) == [B,H,W,1]
hw_grid = coordinate_ims(B, 0, [H,W]).view(B, N, 2).to(vector_map.device)
delta_positions = hw_grid[:,None] - positions[:,:,None] # [B,P,N,2]
distances = torch.sqrt(delta_positions[...,0]**2 + delta_positions[...,1]**2 + eps) # [B,P,N]
## max distance is 2*sqrt(2)
inv_distances = (2.0 * np.sqrt(2.0)) / (distances + eps)
inv_distances = F.softmax(beta * inv_distances * mask.view(B, 1, N), dim=-1) # [B,P,N]
distance_weighted_vectors = torch.sum(
vector_map.view(B, 1, N, D) * inv_distances[...,None], dim=2, keepdim=False) # [B,P,D]
return distance_weighted_vectors
def masks_from_phenotypes(plateau, phenotypes, normalize=True):
B,H,W,Q = plateau.shape
N = H*W
masks = dot_product_attention(
queries=plateau.view(B,N,Q),
keys=phenotypes,
normalize=normalize)
masks = F.relu(masks)
return masks
class Competition(nn.Module):
def __init__(
self,
size=None,
num_masks=16,
num_competition_rounds=5,
mask_beta=10.0,
reduce_func=reduce_max,
stop_gradient=True,
stop_gradient_phenotypes=True,
normalization_func=l2_normalize,
sum_edges=True,
mask_thresh=0.5,
compete_thresh=0.2,
sticky_winners=True,
selection_strength=100.0,
homing_strength=10.0,
mask_dead_segments=True
):
super().__init__()
self.num_masks = self.M = num_masks
self.num_competition_rounds = num_competition_rounds
self.mask_beta = mask_beta
self.reduce_func = reduce_func
self.normalization_func = normalization_func
## stop gradients
self.sg_func = lambda x: (x.detach() if stop_gradient else x)
self.sg_phenotypes_func = lambda x: (x.detach() if stop_gradient_phenotypes else x)
## agent sampling kwargs
self.sum_edges = sum_edges
## competition kwargs
self.mask_thresh = mask_thresh
self.compete_thresh = compete_thresh
self.sticky_winners = sticky_winners
self.selection_strength = selection_strength
self.homing_strength = homing_strength
self.mask_dead_segments = mask_dead_segments
## shapes
self.B = self.T = self.BT = self.N = self.Q = None
self.size = size # [H,W]
if self.size:
assert len(self.size) == 2, self.size
def reshape_batch_time(self, x, merge=True):
if merge:
self.is_temporal = True
B, T = x.size()[0:2]
if self.B:
assert (B == self.B), (B, self.B)
else:
self.B = B
if self.T:
assert (T == self.T), (T, self.T)
else:
self.T = T
assert B*T == (self.B * self.T), (B*T, self.B*self.T)
if self.BT is None:
self.BT = self.B * self.T
return torch.reshape(x, [self.BT] + list(x.size())[2:])
else: # split
BT = x.size()[0]
assert self.B and self.T, (self.B, self.T)
if self.BT is not None:
assert BT == self.BT, (BT, self.BT)
else:
self.BT = BT
return torch.reshape(x, [self.B, self.T] + list(x.size())[1:])
def process_plateau_input(self, plateau):
shape = plateau.size()
if len(shape) == 5:
self.is_temporal = True
self.B, self.T, self.H, self.W, self.Q = shape
self.N = self.H * self.W
self.BT = self.B * self.T
plateau = self.reshape_batch_time(plateau)
elif (len(shape) == 4) and (self.size is None):
self.is_temporal = False
self.B, self.H, self.W, self.Q = shape
self.N = self.H * self.W
self.T = 1
self.BT = self.B*self.T
elif (len(shape) == 4) and (self.size is not None):
self.is_temporal = True
self.B, self.T, self.N, self.Q = shape
self.BT = self.B * self.T
self.H, self.W = self.size
plateau = self.reshape_batch_time(plateau)
plateau = torch.reshape(plateau, [self.BT, self.H, self.W, self.Q])
elif len(shape) == 3:
assert self.size is not None, \
"You need to specify an image size to reshape the plateau of shape %s" % shape
self.is_temporal = False
self.B, self.N, self.Q = shape
self.T = 1
self.BT = self.B
self.H, self.W = self.size
plateau = torch.reshape(plateau, [self.BT, self.H, self.W, self.Q])
else:
raise ValueError("input plateau map with shape %s cannot be reshaped to [BT, H, W, Q]" % shape)
return plateau
def forward(self,
plateau,
agents=None,
alive=None,
phenotypes=None,
compete=True,
update_pointers=True,
yoke_phenotypes_to_agents=True,
noise=0.1
):
"""
Find the uniform regions within the plateau map
by competition between visual "indices."
args:
plateau: [B,[T],H,W,Q] feature map with smooth "plateaus"
returns:
masks: [B, [T], H, W, M] <float> one mask in each of M channels
agents: [B, [T], M, 2] <float> positions of agents in normalized coordinates
alive: [B, [T], M] <float> binary vector indicating which masks are valid
phenotypes: [B, [T], M, Q]
unharvested: [B, [T], H, W] <float> map of regions that weren't covered
"""
## preprocess
plateau = self.process_plateau_input(plateau) # [BT,H,W,Q]
plateau = self.normalization_func(plateau)
## sample initial indices ("agents") from borders of the plateau map
if agents is None:
agents = sample_coordinates_at_borders(
plateau.permute(0,3,1,2),
num_points=self.M,
mask=None,
sum_edges=self.sum_edges)
else:
if self.is_temporal:
agents = agents.view(self.BT, *agents.shape[2:])
## the agents have "phenotypes" depending on where they're situated on the plateau map
if phenotypes is None:
phenotypes = self.sg_phenotypes_func(
self.normalization_func(
soft_index(plateau.permute(0,3,1,2),
agents, scale_by_imsize=True)))
elif self.is_temporal:
phenotypes = phenotypes.view(self.BT, *phenotypes.shape[2:])
## the "fitness" of an agent -- how likely it is to survive competition --
## is how well its phenotype matches the plateau vector at its current position
## initially all of these agents are "alive"
if alive is None:
alive = torch.ones_like(agents[...,-1:]) # [BT,M,1]
fitnesses = compute_compatibility(agents, plateau, phenotypes, availability=None, noise=noise)
alive_mask = None
else:
if self.is_temporal:
alive = alive.view(self.BT, *alive.shape[2:])
alive_mask = (alive > 0.5).float()
fitnesses = alive_mask + compute_compatibility(agents, plateau, phenotypes, availability=None, noise=noise) * (1 - alive_mask)
alive_t = torch.transpose(alive, 1, 2) # [BT, 1, M]
## compute the masks at initialization
masks_pred = masks_from_phenotypes(plateau, phenotypes, normalize=True)
## find the "unharvested" regions of the plateau map not covered by agents
unharvested = torch.minimum(self.reduce_func(masks_pred, dim=-1, keepdim=True), torch.tensor(1.0))
unharvested = 1.0 - unharvested.view(self.BT, self.H, self.W, 1)
if alive_mask is not None:
new_agents = sample_coordinates_at_borders(
plateau.permute(0,3,1,2), num_points=self.M,
mask=unharvested.permute(0,3,1,2),
sum_edges=self.sum_edges)
agents = agents * alive_mask + new_agents * (1.0 - alive_mask)
new_phenotypes = self.sg_phenotypes_func(
self.normalization_func(
soft_index(plateau.permute(0,3,1,2),
new_agents, scale_by_imsize=True)))
phenotypes = phenotypes * alive_mask + new_phenotypes * (1.0 - alive_mask)
for r in range(self.num_competition_rounds):
# print("Evolution round {}".format(r+1))
## compute the "availability" of the plateau map for each agent (i.e. where it can harvest from)
alive_t = torch.transpose(alive, 1, 2) # [BT, 1, M]
# availability = alive_t * masks_pred + (1.0 - alive_t) * unharvested.view(self.BT, self.N, 1)
# availability = availability.view(self.BT, self.H, self.W, self.M)
## update the fitnesses
if update_pointers and compete:
fitnesses = compute_compatibility(
positions=agents,
plateau=plateau,
phenotypes=phenotypes,
# availability=availability)
availability=None,
noise=noise
)
## kill agents that have wandered off the map
in_bounds = torch.all(
torch.logical_and(agents < 1.0, agents > -1.0),
dim=-1, keepdim=True) # [BT,M,1]
fitnesses *= in_bounds.to(fitnesses)
## break ties in fitness
fitnesses -= 0.001 * torch.arange(self.M, dtype=torch.float32)[None,:,None].expand(self.BT,-1,-1).to(fitnesses.device)
## recompute the masks (why?)
if yoke_phenotypes_to_agents:
occupied_regions = self.sg_phenotypes_func(
soft_index(plateau.permute(0,3,1,2), agents, scale_by_imsize=True))
masks_pred = masks_from_phenotypes(plateau, occupied_regions, normalize=True) # [BT,N,M]
## have each pair of agents compete.
## If their masks overlap, the winner is the one with higher fitness
if compete:
alive = compete_agents(masks_pred, fitnesses, alive,
mask_thresh=self.mask_thresh,
compete_thresh=self.compete_thresh,
sticky_winners=self.sticky_winners)
alive *= in_bounds.to(alive)
alive_t = torch.transpose(alive, 1, 2)
# print("Num alive masks", alive.sum(), "which ones --> ", np.where(alive[0,:,0].detach().cpu().numpy()))
if not yoke_phenotypes_to_agents:
masks_pred = masks_from_phenotypes(plateau, phenotypes, normalize=True)
## update which parts of the plateau are "unharvested"
unharvested = torch.minimum(self.reduce_func(masks_pred * alive_t, dim=-1, keepdim=True),
torch.tensor(1.0, dtype=torch.float32))
unharvested = 1.0 - unharvested.view(self.BT, self.H, self.W, 1)
## update phenotypes of the winners
if update_pointers:
if self.mask_thresh is not None:
winner_phenotypes = (masks_pred[...,None] > self.mask_thresh).to(plateau)
if self.selection_strength > 0:
winner_phenotypes = winner_phenotypes * plateau.view(self.BT, self.N, 1, self.Q)
winner_phenotypes = self.normalization_func(winner_phenotypes.mean(dim=1)) # [BT,M,Q]
phenotypes += (alive * winner_phenotypes) * self.selection_strength
## reinitialize losing agent positions
alive_mask = (alive > 0.5).to(torch.float32)
loser_agents = sample_coordinates_at_borders(
plateau.permute(0,3,1,2), num_points=self.M,
mask=unharvested.permute(0,3,1,2),
sum_edges=self.sum_edges)
agents = agents * alive_mask + loser_agents * (1.0 - alive_mask)
## reinitialize loser agent phenotypes
loser_phenotypes = self.normalization_func(
compute_distance_weighted_vectors(plateau, agents, mask=unharvested, beta=self.homing_strength))
phenotypes = alive_mask * phenotypes + (1.0 - alive_mask) * loser_phenotypes
phenotypes = self.normalization_func(phenotypes)
## that's it for this round!
# print("round %d" % r, alive.shape, torch.where(alive[0,:,0]))
## run a final competition between the surviving masks
if self.mask_beta is not None:
masks_pred = F.softmax(
self.mask_beta * masks_pred * alive_t - \
self.mask_beta * (1.0 - alive_t), dim=-1)
if self.mask_dead_segments:
masks_pred *= alive_t
masks_pred = masks_pred.view(self.BT,self.H,self.W,self.M)
if self.is_temporal:
masks_pred = self.reshape_batch_time(masks_pred, merge=False)
agents = self.reshape_batch_time(agents, merge=False)
alive = self.reshape_batch_time(alive, merge=False)
phenotypes = self.reshape_batch_time(phenotypes, merge=False)
unharvested = self.reshape_batch_time(unharvested, merge=False)
return (masks_pred, agents, alive, phenotypes, unharvested)
@staticmethod
def masks_to_segments(masks):
return masks.argmax(-1)
@staticmethod
def flatten_plateau_with_masks(plateau, masks, alive, flatten_masks=True):
B,M,_ = alive.shape
Q = plateau.shape[-1]
if flatten_masks:
masks = F.one_hot((alive[...,None,None,:,0] * masks).argmax(-1), num_classes=M).float()
flat_plateau = torch.zeros_like(plateau)
phenotypes = torch.zeros((B,M,Q), device=plateau.device).float()
for b in range(B):
m_inds = torch.where(alive[b,:,0])[0]
masks_b = masks[b,...,m_inds]
num_px = masks_b.sum((0,1)).clamp(min=1)[:,None] # [K,1]
phenos_b = torch.einsum('hwk,hwq->kq', masks_b, plateau[b]) / num_px # [K,Q]
flat_plateau_b = (masks_b[...,None] * phenos_b[None,None]).sum(-2) # [H,W,Q]
phenotypes[b,m_inds,:] = phenos_b
flat_plateau[b] = flat_plateau_b
_norm = lambda x: F.normalize(x, p=2, dim=-1)
return (_norm(flat_plateau), _norm(phenotypes))
@staticmethod
def plot_agents(agents, alive, size=[128,128]):
B,M,_ = alive.shape
agent_map = -1 * torch.ones((B,*size), device=alive.device, dtype=torch.long)
for b in range(B):
inds = torch.where(alive[b,:,0])
for i in inds[0]:
pos = agents[b,i]*0.5 + 0.5
pos = pos * torch.tensor(size, device=pos.device)
hmin, wmin = list(torch.floor(pos).long())
hmax, wmax = list(torch.ceil(pos).long())
agent_map[b,[hmin,hmin,hmax,hmax],[wmin,wmax,wmin,wmax]] = i
return agent_map
if __name__ == '__main__':
Comp = Competition(num_masks=32, num_competition_rounds=5)
left = torch.ones(size=(32,8)).unsqueeze(-1) * torch.tensor([1.,0.2,0.])
middle = torch.ones(size=(32,16)).unsqueeze(-1) * torch.tensor([0.,1.,0.2])
right = torch.ones(size=(32,8)).unsqueeze(-1) * torch.tensor([0.1,0.,1.])
plateau = torch.cat([left, middle, right], dim=-2).unsqueeze(0)
masks, agents, alive, phenotypes, unharvested = Comp(plateau)
mask_inds = np.where(alive[0,:,0].numpy())[0]
print(np.argmax(masks[0,...], axis=-1))
for ind in mask_inds:
print("num pixels in mask %d ---> %d" % (ind, (np.argmax(masks[0], -1) == ind).sum()))