Spaces:
Sleeping
Sleeping
# We use the same API as https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/timm/models/layers/patch_embed.py | |
# But we use nn.Linear instead of Conv2d and it's about 8x faster. | |
from functools import partial | |
import torch.nn as nn | |
from einops import rearrange | |
from torch import _assert | |
from torch.nn.modules.utils import _pair | |
try: | |
from flash_attn.ops.fused_dense import FusedDense | |
except ImportError: | |
FusedDense = None | |
class PatchEmbed(nn.Module): | |
"""2D Image to Patch Embedding""" | |
def __init__( | |
self, | |
img_size=224, | |
patch_size=16, | |
in_chans=3, | |
embed_dim=768, | |
norm_layer=None, | |
flatten=True, | |
bias=True, | |
fused_bias_fc=False, | |
): | |
super().__init__() | |
img_size = _pair(img_size) | |
patch_size = _pair(patch_size) | |
self.img_size = img_size | |
self.patch_size = patch_size | |
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) | |
self.num_patches = self.grid_size[0] * self.grid_size[1] | |
self.flatten = flatten | |
if fused_bias_fc and FusedDense is None: | |
raise ImportError("fused_dense is not installed") | |
linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense | |
self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias) | |
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() | |
def forward(self, x): | |
_, _, H, W = x.shape | |
_assert( | |
H == self.img_size[0], | |
f"Input image height ({H}) doesn't match model ({self.img_size[0]}).", | |
) | |
_assert( | |
W == self.img_size[1], | |
f"Input image width ({W}) doesn't match model ({self.img_size[1]}).", | |
) | |
x = self.proj( | |
rearrange( | |
x, | |
"b c (h p1) (w p2) -> b h w (c p1 p2)", | |
p1=self.patch_size[0], | |
p2=self.patch_size[1], | |
) | |
) | |
if self.flatten: | |
x = rearrange(x, "b h w c -> b (h w) c") | |
x = self.norm(x) | |
return x | |