LinB203
m
a220803
raw
history blame
5.98 kB
# import numpy as np
# import torch
# import torch.nn as nn
# from math import pi
# from einops import rearrange, repeat
#
# #################################################################################
# # Sine/Cosine Positional Embedding Functions #
# #################################################################################
# # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
#
# def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
# """
# grid_size: int of the grid height and width
# return:
# pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
# """
# grid_h = np.arange(grid_size, dtype=np.float32)
# grid_w = np.arange(grid_size, dtype=np.float32)
# grid = np.meshgrid(grid_w, grid_h) # here w goes first
# grid = np.stack(grid, axis=0)
#
# grid = grid.reshape([2, 1, grid_size, grid_size])
# pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
# if cls_token and extra_tokens > 0:
# pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
# return pos_embed
#
#
# def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
# assert embed_dim % 2 == 0
#
# # use half of dimensions to encode grid_h
# emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
# emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
#
# emb = np.concatenate([emb_h, emb_w], axis=1)
# return emb
#
#
# def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
# """
# embed_dim: output dimension for each position
# pos: a list of positions to be encoded: size (M,)
# out: (M, D)
# """
# assert embed_dim % 2 == 0
# omega = np.arange(embed_dim // 2, dtype=np.float64)
# omega /= embed_dim / 2.
# omega = 1. / 10000**omega
#
# pos = pos.reshape(-1)
# out = np.einsum('m,d->md', pos, omega)
#
# emb_sin = np.sin(out)
# emb_cos = np.cos(out)
#
# emb = np.concatenate([emb_sin, emb_cos], axis=1)
# return emb
#
# def broadcat(tensors, dim=-1):
# num_tensors = len(tensors)
# shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
# assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
# shape_len = list(shape_lens)[0]
# dim = (dim + shape_len) if dim < 0 else dim
# dims = list(zip(*map(lambda t: list(t.shape), tensors)))
# expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
# assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
# max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
# expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
# expanded_dims.insert(dim, (dim, dims[dim]))
# expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
# tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
# return torch.cat(tensors, dim=dim)
#
#
# def rotate_half(x):
# x = rearrange(x, '... (d r) -> ... d r', r=2)
# x1, x2 = x.unbind(dim=-1)
# x = torch.stack((-x2, x1), dim=-1)
# return rearrange(x, '... d r -> ... (d r)')
#
# #################################################################################
# # VisionRotary #
# #################################################################################
# # References:
# # EVA: https://github.com/baaivision/EVA
# # Transformer升级之路:2、博采众长的旋转式位置编码: https://spaces.ac.cn/archives/8265
# # Transformer升级之路:4、二维位置的旋转式位置编码: https://spaces.ac.cn/archives/8397
#
# class VisionRotaryEmbeddingFast(nn.Module):
# def __init__(
# self,
# dim,
# pt_hw=(int, int), # (H, W)
# ft_hw=None,
# custom_freqs = None,
# freqs_for = 'lang',
# theta = 10000,
# max_freq = 10,
# num_freqs = 1,
# ):
# super().__init__()
# # Unlike a 1d RoPE, a 2d RoPE requires splitting the dimension into four parts
# # References: https://spaces.ac.cn/archives/8397
#
# if custom_freqs:
# freqs = custom_freqs
# elif freqs_for == 'lang':
# freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
# elif freqs_for == 'pixel':
# freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
# elif freqs_for == 'constant':
# freqs = torch.ones(num_freqs).float()
# else:
# raise ValueError(f'unknown modality {freqs_for}')
#
# if ft_hw is None: ft_hw = pt_hw
# h_t = torch.arange(ft_hw[0]) / ft_hw[0] * pt_hw[0]
# w_t = torch.arange(ft_hw[1]) / ft_hw[1] * pt_hw[1]
#
# h_freqs = torch.einsum('..., f -> ... f', h_t, freqs)
# w_freqs = torch.einsum('..., f -> ... f', w_t, freqs)
#
# h_freqs = repeat(h_freqs, '... n -> ... (n r)', r=2)
# w_freqs = repeat(w_freqs, '... n -> ... (n r)', r=2)
#
# freqs = broadcat((h_freqs[:, None, :], w_freqs[None, :, :]), dim=-1)
# freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
# freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
#
# self.register_buffer("freqs_cos", freqs_cos)
# self.register_buffer("freqs_sin", freqs_sin)
#
# def forward(self, t):
# # 2d RoPE: [[cos(h*theta), -sin(h*theta), 0, 0 ],
# # [sin(h*theta), cos(h*theta), 0, 0 ],
# # [0, 0, cos(w*theta), -sin(w*theta)],
# # [0, 0, sin(w*theta), cos(w*theta) ],]
#
# return t * self.freqs_cos + rotate_half(t) * self.freqs_sin