# import numpy as np # import torch # import torch.nn as nn # from math import pi # from einops import rearrange, repeat # # ################################################################################# # # Sine/Cosine Positional Embedding Functions # # ################################################################################# # # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py # # def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): # """ # grid_size: int of the grid height and width # return: # pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) # """ # grid_h = np.arange(grid_size, dtype=np.float32) # grid_w = np.arange(grid_size, dtype=np.float32) # grid = np.meshgrid(grid_w, grid_h) # here w goes first # grid = np.stack(grid, axis=0) # # grid = grid.reshape([2, 1, grid_size, grid_size]) # pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) # if cls_token and extra_tokens > 0: # pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) # return pos_embed # # # def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): # assert embed_dim % 2 == 0 # # # use half of dimensions to encode grid_h # emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # # emb = np.concatenate([emb_h, emb_w], axis=1) # return emb # # # def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): # """ # embed_dim: output dimension for each position # pos: a list of positions to be encoded: size (M,) # out: (M, D) # """ # assert embed_dim % 2 == 0 # omega = np.arange(embed_dim // 2, dtype=np.float64) # omega /= embed_dim / 2. # omega = 1. / 10000**omega # # pos = pos.reshape(-1) # out = np.einsum('m,d->md', pos, omega) # # emb_sin = np.sin(out) # emb_cos = np.cos(out) # # emb = np.concatenate([emb_sin, emb_cos], axis=1) # return emb # # def broadcat(tensors, dim=-1): # num_tensors = len(tensors) # shape_lens = set(list(map(lambda t: len(t.shape), tensors))) # assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' # shape_len = list(shape_lens)[0] # dim = (dim + shape_len) if dim < 0 else dim # dims = list(zip(*map(lambda t: list(t.shape), tensors))) # expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] # assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' # max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) # expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) # expanded_dims.insert(dim, (dim, dims[dim])) # expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) # tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) # return torch.cat(tensors, dim=dim) # # # def rotate_half(x): # x = rearrange(x, '... (d r) -> ... d r', r=2) # x1, x2 = x.unbind(dim=-1) # x = torch.stack((-x2, x1), dim=-1) # return rearrange(x, '... d r -> ... (d r)') # # ################################################################################# # # VisionRotary # # ################################################################################# # # References: # # EVA: https://github.com/baaivision/EVA # # Transformer升级之路:2、博采众长的旋转式位置编码: https://spaces.ac.cn/archives/8265 # # Transformer升级之路:4、二维位置的旋转式位置编码: https://spaces.ac.cn/archives/8397 # # class VisionRotaryEmbeddingFast(nn.Module): # def __init__( # self, # dim, # pt_hw=(int, int), # (H, W) # ft_hw=None, # custom_freqs = None, # freqs_for = 'lang', # theta = 10000, # max_freq = 10, # num_freqs = 1, # ): # super().__init__() # # Unlike a 1d RoPE, a 2d RoPE requires splitting the dimension into four parts # # References: https://spaces.ac.cn/archives/8397 # # if custom_freqs: # freqs = custom_freqs # elif freqs_for == 'lang': # freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) # elif freqs_for == 'pixel': # freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi # elif freqs_for == 'constant': # freqs = torch.ones(num_freqs).float() # else: # raise ValueError(f'unknown modality {freqs_for}') # # if ft_hw is None: ft_hw = pt_hw # h_t = torch.arange(ft_hw[0]) / ft_hw[0] * pt_hw[0] # w_t = torch.arange(ft_hw[1]) / ft_hw[1] * pt_hw[1] # # h_freqs = torch.einsum('..., f -> ... f', h_t, freqs) # w_freqs = torch.einsum('..., f -> ... f', w_t, freqs) # # h_freqs = repeat(h_freqs, '... n -> ... (n r)', r=2) # w_freqs = repeat(w_freqs, '... n -> ... (n r)', r=2) # # freqs = broadcat((h_freqs[:, None, :], w_freqs[None, :, :]), dim=-1) # freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) # freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) # # self.register_buffer("freqs_cos", freqs_cos) # self.register_buffer("freqs_sin", freqs_sin) # # def forward(self, t): # # 2d RoPE: [[cos(h*theta), -sin(h*theta), 0, 0 ], # # [sin(h*theta), cos(h*theta), 0, 0 ], # # [0, 0, cos(w*theta), -sin(w*theta)], # # [0, 0, sin(w*theta), cos(w*theta) ],] # # return t * self.freqs_cos + rotate_half(t) * self.freqs_sin