jbilcke-hf's picture
jbilcke-hf HF staff
up
69f3483
raw
history blame
No virus
4.96 kB
from collections import deque
from typing import Tuple, Callable
from einops import rearrange
import torch
import torch.nn.functional as F
def get_nn_feats(x, y, threshold=0.9):
if type(x) is deque:
x = torch.cat(list(x), dim=1)
if type(y) is deque:
y = torch.cat(list(y), dim=1)
x_norm = F.normalize(x, p=2, dim=-1)
y_norm = F.normalize(y, p=2, dim=-1)
cosine_similarity = torch.matmul(x_norm, y_norm.transpose(1, 2))
max_cosine_values, nearest_neighbors_indices = torch.max(cosine_similarity, dim=-1)
mask = max_cosine_values < threshold
# print('mask ratio', torch.sum(mask)/x.shape[0]/x.shape[1])
indices_expanded = nearest_neighbors_indices.unsqueeze(-1).expand(-1, -1, x_norm.size(-1))
nearest_neighbor_tensor = torch.gather(y, 1, indices_expanded)
selected_tensor = torch.where(mask.unsqueeze(-1), x, nearest_neighbor_tensor)
return selected_tensor
def get_nn_latent(x, y, threshold=0.9):
assert len(x.shape) == 4
_, c, h, w = x.shape
x_ = rearrange(x, 'n c h w -> n (h w) c')
y_ = []
for i in range(len(y)):
y_.append(rearrange(y[i], 'n c h w -> n (h w) c'))
y_ = torch.cat(y_, dim=1)
x_norm = F.normalize(x_, p=2, dim=-1)
y_norm = F.normalize(y_, p=2, dim=-1)
cosine_similarity = torch.matmul(x_norm, y_norm.transpose(1, 2))
max_cosine_values, nearest_neighbors_indices = torch.max(cosine_similarity, dim=-1)
mask = max_cosine_values < threshold
indices_expanded = nearest_neighbors_indices.unsqueeze(-1).expand(-1, -1, x_norm.size(-1))
nearest_neighbor_tensor = torch.gather(y_, 1, indices_expanded)
# Use values from x where the cosine similarity is below the threshold
x_expanded = x_.expand_as(nearest_neighbor_tensor)
selected_tensor = torch.where(mask.unsqueeze(-1), x_expanded, nearest_neighbor_tensor)
selected_tensor = rearrange(selected_tensor, 'n (h w) c -> n c h w', h=h, w=w, c=c)
return selected_tensor
def random_bipartite_soft_matching(
metric: torch.Tensor, use_grid: bool = False, ratio: float = 0.5
) -> Tuple[Callable, Callable]:
"""
Applies ToMe with the two sets as (r chosen randomly, the rest).
Input size is [batch, tokens, channels].
This will reduce the number of tokens by a ratio of ratio/2.
"""
with torch.no_grad():
B, N, _ = metric.shape
if use_grid:
assert ratio == 0.5
sample = torch.randint(2, size=(B, N//2, 1), device=metric.device)
sample_alternate = 1 - sample
grid = torch.arange(0, N, 2).view(1, N//2, 1).to(device=metric.device)
grid = grid.repeat(4, 1, 1)
rand_idx = torch.cat([sample + grid, sample_alternate + grid], dim = 1)
else:
rand_idx = torch.rand(B, N, 1, device=metric.device).argsort(dim=1)
r = int(ratio * N)
a_idx = rand_idx[:, :r, :]
b_idx = rand_idx[:, r:, :]
def split(x):
C = x.shape[-1]
a = x.gather(dim=1, index=a_idx.expand(B, r, C))
b = x.gather(dim=1, index=b_idx.expand(B, N - r, C))
return a, b
metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = split(metric)
scores = a @ b.transpose(-1, -2)
_, dst_idx = scores.max(dim=-1)
dst_idx = dst_idx[..., None]
def merge_kv_out(keys: torch.Tensor, values: torch.Tensor, outputs: torch.Tensor, mode="mean") -> torch.Tensor:
src_keys, dst_keys = split(keys)
C_keys = src_keys.shape[-1]
dst_keys = dst_keys.scatter_reduce(-2, dst_idx.expand(B, r, C_keys), src_keys, reduce=mode)
src_values, dst_values = split(values)
C_values = src_values.shape[-1]
dst_values = dst_values.scatter_reduce(-2, dst_idx.expand(B, r, C_values), src_values, reduce=mode)
src_outputs, dst_outputs = split(outputs)
C_outputs = src_outputs.shape[-1]
dst_outputs = dst_outputs.scatter_reduce(-2, dst_idx.expand(B, r, C_outputs), src_outputs, reduce=mode)
return dst_keys, dst_values, dst_outputs
def merge_kv(keys: torch.Tensor, values: torch.Tensor, mode="mean") -> torch.Tensor:
src_keys, dst_keys = split(keys)
C_keys = src_keys.shape[-1]
dst_keys = dst_keys.scatter_reduce(-2, dst_idx.expand(B, r, C_keys), src_keys, reduce=mode)
src_values, dst_values = split(values)
C_values = src_values.shape[-1]
dst_values = dst_values.scatter_reduce(-2, dst_idx.expand(B, r, C_values), src_values, reduce=mode)
return dst_keys, dst_values
def merge_out(outputs: torch.Tensor, mode="mean") -> torch.Tensor:
src_outputs, dst_outputs = split(outputs)
C_outputs = src_outputs.shape[-1]
dst_outputs = dst_outputs.scatter_reduce(-2, dst_idx.expand(B, r, C_outputs), src_outputs, reduce=mode)
return dst_outputs
return merge_kv_out, merge_kv, merge_out