mart9992's picture
m
b793f0c
#!/usr/bin/env python3
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import einops
import numpy as np
import torch
import torch.nn as nn
class Normalize(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
def forward(self, x):
return torch.nn.functional.normalize(x, dim=self.dim, p=2)
class LearnableLogitScaling(nn.Module):
def __init__(
self,
logit_scale_init: float = 1 / 0.07,
learnable: bool = True,
max_logit_scale: float = 100,
) -> None:
super().__init__()
self.max_logit_scale = max_logit_scale
self.logit_scale_init = logit_scale_init
self.learnable = learnable
log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
if learnable:
self.log_logit_scale = nn.Parameter(log_logit_scale)
else:
self.register_buffer("log_logit_scale", log_logit_scale)
def forward(self, x):
return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
def extra_repr(self):
st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}"
return st
class EinOpsRearrange(nn.Module):
def __init__(self, rearrange_expr: str, **kwargs) -> None:
super().__init__()
self.rearrange_expr = rearrange_expr
self.kwargs = kwargs
def forward(self, x):
assert isinstance(x, torch.Tensor)
return einops.rearrange(x, self.rearrange_expr, **self.kwargs)
class VerboseNNModule(nn.Module):
"""
Wrapper around nn.Module that prints registered buffers and parameter names.
"""
@staticmethod
def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str:
st = (
"("
+ name
+ "): "
+ "tensor("
+ str(tuple(tensor[1].shape))
+ ", requires_grad="
+ str(tensor[1].requires_grad)
+ ")\n"
)
return st
def extra_repr(self) -> str:
named_modules = set()
for p in self.named_modules():
named_modules.update([p[0]])
named_modules = list(named_modules)
string_repr = ""
for p in self.named_parameters():
name = p[0].split(".")[0]
if name not in named_modules:
string_repr += self.get_readable_tensor_repr(name, p)
for p in self.named_buffers():
name = p[0].split(".")[0]
string_repr += self.get_readable_tensor_repr(name, p)
return string_repr
def cast_if_src_dtype(
tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype
):
updated = False
if tensor.dtype == src_dtype:
tensor = tensor.to(dtype=tgt_dtype)
updated = True
return tensor, updated
class QuickGELU(nn.Module):
# From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class SelectElement(nn.Module):
def __init__(self, index) -> None:
super().__init__()
self.index = index
def forward(self, x):
assert x.ndim >= 3
return x[:, self.index, ...]
class SelectEOSAndProject(nn.Module):
"""
Text Pooling used in OpenCLIP
"""
def __init__(self, proj: nn.Module) -> None:
super().__init__()
self.proj = proj
def forward(self, x, seq_len):
assert x.ndim == 3
# x is of shape B x L x D
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), seq_len]
x = self.proj(x)
return x