jadechoghari commited on
Commit
c56ba76
·
verified ·
1 Parent(s): c577b6d

Create transformer.py

Browse files
Files changed (1) hide show
  1. transformer.py +135 -0
transformer.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+
12
+ class ModLN(nn.Module):
13
+ """
14
+ Modulation with adaLN.
15
+
16
+ References:
17
+ DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L101
18
+ """
19
+ def __init__(self, inner_dim: int, mod_dim: int, eps: float):
20
+ super().__init__()
21
+ self.norm = nn.LayerNorm(inner_dim, eps=eps)
22
+ self.mlp = nn.Sequential(
23
+ nn.SiLU(),
24
+ nn.Linear(mod_dim, inner_dim * 2),
25
+ )
26
+
27
+ @staticmethod
28
+ def modulate(x, shift, scale):
29
+ # x: [N, L, D]
30
+ # shift, scale: [N, D]
31
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
32
+
33
+ def forward(self, x, cond):
34
+ shift, scale = self.mlp(cond).chunk(2, dim=-1) # [N, D]
35
+ return self.modulate(self.norm(x), shift, scale) # [N, L, D]
36
+
37
+
38
+ class ConditionModulationBlock(nn.Module):
39
+ """
40
+ Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
41
+ """
42
+ # use attention from torch.nn.MultiHeadAttention
43
+ # Block contains a cross-attention layer, a self-attention layer, and a MLP
44
+ def __init__(self, inner_dim: int, cond_dim: int, mod_dim: int, num_heads: int, eps: float,
45
+ attn_drop: float = 0., attn_bias: bool = False,
46
+ mlp_ratio: float = 4., mlp_drop: float = 0.):
47
+ super().__init__()
48
+ self.norm1 = ModLN(inner_dim, mod_dim, eps)
49
+ self.cross_attn = nn.MultiheadAttention(
50
+ embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
51
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
52
+ self.norm2 = ModLN(inner_dim, mod_dim, eps)
53
+ self.self_attn = nn.MultiheadAttention(
54
+ embed_dim=inner_dim, num_heads=num_heads,
55
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
56
+ self.norm3 = ModLN(inner_dim, mod_dim, eps)
57
+ self.mlp = nn.Sequential(
58
+ nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
59
+ nn.GELU(),
60
+ nn.Dropout(mlp_drop),
61
+ nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
62
+ nn.Dropout(mlp_drop),
63
+ )
64
+
65
+ def forward(self, x, cond, mod):
66
+ # x: [N, L, D]
67
+ # cond: [N, L_cond, D_cond]
68
+ # mod: [N, D_mod]
69
+ x = x + self.cross_attn(self.norm1(x, mod), cond, cond, need_weights=False)[0]
70
+ before_sa = self.norm2(x, mod)
71
+ x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
72
+ x = x + self.mlp(self.norm3(x, mod))
73
+ return x
74
+
75
+
76
+ class TriplaneTransformer(nn.Module):
77
+ """
78
+ Transformer with condition and modulation that generates a triplane representation.
79
+
80
+ Reference:
81
+ Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486
82
+ """
83
+ def __init__(self, inner_dim: int, image_feat_dim: int, camera_embed_dim: int,
84
+ triplane_low_res: int, triplane_high_res: int, triplane_dim: int,
85
+ num_layers: int, num_heads: int,
86
+ eps: float = 1e-6):
87
+ super().__init__()
88
+
89
+ # attributes
90
+ self.triplane_low_res = triplane_low_res
91
+ self.triplane_high_res = triplane_high_res
92
+ self.triplane_dim = triplane_dim
93
+
94
+ # modules
95
+ # initialize pos_embed with 1/sqrt(dim) * N(0, 1)
96
+ self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5)
97
+ self.layers = nn.ModuleList([
98
+ ConditionModulationBlock(
99
+ inner_dim=inner_dim, cond_dim=image_feat_dim, mod_dim=camera_embed_dim, num_heads=num_heads, eps=eps)
100
+ for _ in range(num_layers)
101
+ ])
102
+ self.norm = nn.LayerNorm(inner_dim, eps=eps)
103
+ self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
104
+
105
+ def forward(self, image_feats, camera_embeddings):
106
+ # image_feats: [N, L_cond, D_cond]
107
+ # camera_embeddings: [N, D_mod]
108
+
109
+ assert image_feats.shape[0] == camera_embeddings.shape[0], \
110
+ f"Mismatched batch size: {image_feats.shape[0]} vs {camera_embeddings.shape[0]}"
111
+
112
+ N = image_feats.shape[0]
113
+ H = W = self.triplane_low_res
114
+ L = 3 * H * W
115
+
116
+ x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
117
+ for layer in self.layers:
118
+ x = layer(x, image_feats, camera_embeddings)
119
+ x = self.norm(x)
120
+
121
+ # separate each plane and apply deconv
122
+ x = x.view(N, 3, H, W, -1)
123
+ x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
124
+ x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
125
+ x = self.deconv(x) # [3*N, D', H', W']
126
+ x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
127
+ x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
128
+ x = x.contiguous()
129
+
130
+ assert self.triplane_high_res == x.shape[-2], \
131
+ f"Output triplane resolution does not match with expected: {x.shape[-2]} vs {self.triplane_high_res}"
132
+ assert self.triplane_dim == x.shape[-3], \
133
+ f"Output triplane dimension does not match with expected: {x.shape[-3]} vs {self.triplane_dim}"
134
+
135
+ return x