File size: 7,141 Bytes
cd6dcce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
#!/usr/bin/env python3
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Sequence, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.utils import optional_import
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
class PatchEmbeddingBlock(nn.Module):
"""
A patch embedding block, based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
"""
def __init__(
self,
in_channels: int,
img_size: Tuple[int, int, int],
patch_size: Tuple[int, int, int],
hidden_size: int,
num_heads: int,
pos_embed: str,
dropout_rate: float = 0.0,
) -> None:
"""
Args:
in_channels: dimension of input channels.
img_size: dimension of input image.
patch_size: dimension of patch size.
hidden_size: dimension of hidden layer.
num_heads: number of attention heads.
pos_embed: position embedding layer type.
dropout_rate: faction of the input units to drop.
"""
super().__init__()
if not (0 <= dropout_rate <= 1):
raise AssertionError("dropout_rate should be between 0 and 1.")
if hidden_size % num_heads != 0:
raise AssertionError("hidden size should be divisible by num_heads.")
for m, p in zip(img_size, patch_size):
if m < p:
raise AssertionError("patch_size should be smaller than img_size.")
if pos_embed not in ["conv", "perceptron"]:
raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")
if pos_embed == "perceptron":
if img_size[0] % patch_size[0] != 0:
raise AssertionError("img_size should be divisible by patch_size for perceptron patch embedding.")
self.n_patches = (
(img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) * (img_size[2] // patch_size[2])
)
self.patch_dim = in_channels * patch_size[0] * patch_size[1] * patch_size[2]
self.pos_embed = pos_embed
self.patch_embeddings: Union[nn.Conv3d, nn.Sequential]
if self.pos_embed == "conv":
self.patch_embeddings = nn.Conv3d(
in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size
)
elif self.pos_embed == "perceptron":
self.patch_embeddings = nn.Sequential(
Rearrange(
"b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)",
p1=patch_size[0],
p2=patch_size[1],
p3=patch_size[2],
),
nn.Linear(self.patch_dim, hidden_size),
)
self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
self.dropout = nn.Dropout(dropout_rate)
self.trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
self.trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def trunc_normal_(self, tensor, mean, std, a, b):
# From PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
with torch.no_grad():
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
tensor.uniform_(2 * l - 1, 2 * u - 1)
tensor.erfinv_()
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
tensor.clamp_(min=a, max=b)
return tensor
def forward(self, x):
if self.pos_embed == "conv":
x = self.patch_embeddings(x)
x = x.flatten(2)
x = x.transpose(-1, -2)
elif self.pos_embed == "perceptron":
x = self.patch_embeddings(x)
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class PatchEmbed3D(nn.Module):
"""Video to Patch Embedding.
Args:
patch_size (int): Patch token size. Default: (2,4,4).
in_chans (int): Number of input video channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self,
img_size: Sequence[int] = (96, 96, 96),
patch_size=(4, 4, 4),
in_chans: int = 1,
embed_dim: int = 96,
norm_layer=None,
):
super().__init__()
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1], img_size[2] // patch_size[2])
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, d, h, w = x.size()
if w % self.patch_size[2] != 0:
x = F.pad(x, (0, self.patch_size[2] - w % self.patch_size[2]))
if h % self.patch_size[1] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[1] - h % self.patch_size[1]))
if d % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - d % self.patch_size[0]))
x = self.proj(x) # B C D Wh Ww
if self.norm is not None:
d, wh, ww = x.size(2), x.size(3), x.size(4)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, d, wh, ww)
# pdb.set_trace()
return x
|