bardofcodes commited on
Commit
d83c313
·
verified ·
1 Parent(s): 6149018

Create analogy_projector.py

Browse files
analogy_projector/analogy_projector.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ADOBE CONFIDENTIAL
3
+ Copyright 2024 Adobe
4
+ All Rights Reserved.
5
+ NOTICE: All information contained herein is, and remains
6
+ the property of Adobe and its suppliers, if any. The intellectual
7
+ and technical concepts contained herein are proprietary to Adobe
8
+ and its suppliers and are protected by all applicable intellectual
9
+ property laws, including trade secret and copyright laws.
10
+ Dissemination of this information or reproduction of this material
11
+ is strictly forbidden unless prior written permission is obtained
12
+ from Adobe.
13
+ """
14
+
15
+ import einops
16
+ import numpy as np
17
+ import torch as th
18
+ import torch.nn as nn
19
+ from diffusers import ModelMixin
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ # REf: https://github.com/tatp22/multidim-positional-encoding/tree/master
22
+
23
+
24
+ OUT_SIZE = 768
25
+ IN_SIZE = 2048
26
+
27
+
28
+ def get_emb(sin_inp):
29
+ """
30
+ Gets a base embedding for one dimension with sin and cos intertwined
31
+ """
32
+ emb = th.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
33
+ return th.flatten(emb, -2, -1)
34
+
35
+
36
+ class PositionalEncoding1D(nn.Module):
37
+ def __init__(self, channels):
38
+ """
39
+ :param channels: The last dimension of the tensor you want to apply pos emb to.
40
+ """
41
+ super(PositionalEncoding1D, self).__init__()
42
+ self.org_channels = channels
43
+ channels = int(np.ceil(channels / 2) * 2)
44
+ self.channels = channels
45
+ inv_freq = 1.0 / (10000 ** (th.arange(0, channels, 2).float() / channels))
46
+ self.register_buffer("inv_freq", inv_freq)
47
+ self.register_buffer("cached_penc", None, persistent=False)
48
+
49
+ def forward(self, tensor):
50
+ """
51
+ :param tensor: A 3d tensor of size (batch_size, x, ch)
52
+ :return: Positional Encoding Matrix of size (batch_size, x, ch)
53
+ """
54
+ if len(tensor.shape) != 3:
55
+ raise RuntimeError("The input tensor has to be 3d!")
56
+
57
+ if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
58
+ return self.cached_penc
59
+
60
+ self.cached_penc = None
61
+ batch_size, x, orig_ch = tensor.shape
62
+ pos_x = th.arange(x, device=tensor.device, dtype=self.inv_freq.dtype)
63
+ sin_inp_x = th.einsum("i,j->ij", pos_x, self.inv_freq)
64
+ emb_x = get_emb(sin_inp_x)
65
+ emb = th.zeros((x, self.channels), device=tensor.device, dtype=tensor.dtype)
66
+ emb[:, : self.channels] = emb_x
67
+
68
+ self.cached_penc = emb[None, :, :orig_ch].repeat(batch_size, 1, 1)
69
+ return self.cached_penc
70
+
71
+
72
+
73
+ class PositionalEncoding3D(nn.Module):
74
+ def __init__(self, channels):
75
+ """
76
+ :param channels: The last dimension of the tensor you want to apply pos emb to.
77
+ """
78
+ super(PositionalEncoding3D, self).__init__()
79
+ self.org_channels = channels
80
+ channels = int(np.ceil(channels / 6) * 2)
81
+ if channels % 2:
82
+ channels += 1
83
+ self.channels = channels
84
+ inv_freq = 1.0 / (10000 ** (th.arange(0, channels, 2).float() / channels))
85
+ self.register_buffer("inv_freq", inv_freq)
86
+ self.register_buffer("cached_penc", None, persistent=False)
87
+
88
+ def forward(self, tensor):
89
+ """
90
+ :param tensor: A 5d tensor of size (batch_size, x, y, z, ch)
91
+ :return: Positional Encoding Matrix of size (batch_size, x, y, z, ch)
92
+ """
93
+ if len(tensor.shape) != 5:
94
+ raise RuntimeError("The input tensor has to be 5d!")
95
+
96
+ if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
97
+ return self.cached_penc
98
+
99
+ self.cached_penc = None
100
+ batch_size, x, y, z, orig_ch = tensor.shape
101
+ pos_x = th.arange(x, device=tensor.device, dtype=self.inv_freq.dtype)
102
+ pos_y = th.arange(y, device=tensor.device, dtype=self.inv_freq.dtype)
103
+ pos_z = th.arange(z, device=tensor.device, dtype=self.inv_freq.dtype)
104
+ sin_inp_x = th.einsum("i,j->ij", pos_x, self.inv_freq)
105
+ sin_inp_y = th.einsum("i,j->ij", pos_y, self.inv_freq)
106
+ sin_inp_z = th.einsum("i,j->ij", pos_z, self.inv_freq)
107
+ emb_x = get_emb(sin_inp_x).unsqueeze(1).unsqueeze(1)
108
+ emb_y = get_emb(sin_inp_y).unsqueeze(1)
109
+ emb_z = get_emb(sin_inp_z)
110
+ emb = th.zeros(
111
+ (x, y, z, self.channels * 3),
112
+ device=tensor.device,
113
+ dtype=tensor.dtype,
114
+ )
115
+ emb[:, :, :, : self.channels] = emb_x
116
+ emb[:, :, :, self.channels : 2 * self.channels] = emb_y
117
+ emb[:, :, :, 2 * self.channels :] = emb_z
118
+
119
+ self.cached_penc = emb[None, :, :, :, :orig_ch].repeat(batch_size, 1, 1, 1, 1)
120
+ return self.cached_penc
121
+
122
+ class AnalogyProjector(ModelMixin, ConfigMixin):
123
+
124
+ @register_to_config
125
+ def __init__(self):
126
+ super(AnalogyProjector, self).__init__()
127
+ self.projector = DinoSiglipMixer()
128
+ self.pos_embd_1D = PositionalEncoding1D(OUT_SIZE)
129
+ self.pos_embd_3D = PositionalEncoding3D(OUT_SIZE)
130
+
131
+
132
+ def forward(self, dino_in, siglip_in, batch_size):
133
+
134
+ image_embeddings = self.projector(dino_in, siglip_in)
135
+
136
+ image_embeddings = einops.rearrange(image_embeddings, '(k b) t d -> b k t d', b=batch_size)
137
+ image_embeddings = self.position_embd(image_embeddings)
138
+ return image_embeddings
139
+
140
+ def position_embd(self, image_embeddings, concat=False):
141
+ canvas_embd = image_embeddings[:, :, 1:, :]
142
+ batch_size = canvas_embd.shape[0]
143
+ type_size = canvas_embd.shape[1]
144
+ xy_size = canvas_embd.shape[2]
145
+
146
+ x_size = int(xy_size ** 0.5)
147
+
148
+ canvas_embd = canvas_embd.reshape(batch_size, type_size, x_size, x_size, -1)
149
+ if concat:
150
+ canvas_embd = th.cat([canvas_embd, self.pos_embd_3D(canvas_embd)], -1)
151
+ else:
152
+ canvas_embd = self.pos_embd_3D(canvas_embd) + canvas_embd
153
+ canvas_embd = canvas_embd.reshape(batch_size, type_size, xy_size, -1)
154
+
155
+ class_embd = image_embeddings[:, :, 0, :]
156
+ if concat:
157
+ class_embd = th.cat([class_embd, self.pos_embd_1D(class_embd)], -1)
158
+ else:
159
+ class_embd = self.pos_embd_1D(class_embd) + class_embd
160
+ all_embd_list = []
161
+ for i in range(type_size):
162
+ all_embd_list.append(class_embd[:, i:i+1])
163
+ all_embd_list.append(canvas_embd[:, i])
164
+ image_embeddings = th.cat(all_embd_list, 1)
165
+ return image_embeddings
166
+
167
+
168
+ class HighLowMixer(th.nn.Module):
169
+ def __init__(self, in_size=IN_SIZE, out_size=OUT_SIZE):
170
+ super().__init__()
171
+ mid_size = (in_size + out_size) // 2
172
+
173
+ self.lower_projector = th.nn.Sequential(
174
+ th.nn.LayerNorm(IN_SIZE//2),
175
+ th.nn.SiLU()
176
+ )
177
+ self.upper_projector = th.nn.Sequential(
178
+ th.nn.LayerNorm(IN_SIZE//2),
179
+ th.nn.SiLU()
180
+ )
181
+ self.projectors = th.nn.ModuleList([
182
+ # add layer norm
183
+ th.nn.Linear(in_size, mid_size),
184
+ th.nn.SiLU(),
185
+ th.nn.Linear(mid_size, out_size)
186
+ ])
187
+ # initialize
188
+ for proj in self.projectors:
189
+ if isinstance(proj, th.nn.Linear):
190
+ th.nn.init.xavier_uniform_(proj.weight)
191
+ th.nn.init.zeros_(proj.bias)
192
+
193
+ def forward(self, lower_in, upper_in, ):
194
+ # ALso format lower_in
195
+ lower_in = self.lower_projector(lower_in)
196
+ upper_in = self.upper_projector(upper_in)
197
+ x = th.cat([lower_in, upper_in], -1)
198
+ for proj in self.projectors:
199
+ x = proj(x)
200
+ return x
201
+
202
+ class DinoSiglipMixer(th.nn.Module):
203
+ def __init__(self, in_size=OUT_SIZE * 2, out_size=OUT_SIZE):
204
+ super().__init__()
205
+ self.dino_projector = HighLowMixer()
206
+ self.siglip_projector = HighLowMixer()
207
+ self.projectors = th.nn.Sequential(
208
+ th.nn.SiLU(),
209
+ th.nn.Linear(in_size, out_size),
210
+ )
211
+ # initialize
212
+ for proj in self.projectors:
213
+ if isinstance(proj, th.nn.Linear):
214
+ th.nn.init.xavier_uniform_(proj.weight)
215
+ th.nn.init.zeros_(proj.bias)
216
+
217
+
218
+ def forward(self, dino_in, siglip_in):
219
+ # ALso format lower_in
220
+ lower, upper = th.chunk(dino_in, 2, -1)
221
+ dino_out = self.dino_projector(lower, upper)
222
+ lower, upper = th.chunk(siglip_in, 2, -1)
223
+ siglip_out = self.siglip_projector(lower, upper)
224
+ x = th.cat([dino_out, siglip_out], -1)
225
+ for proj in self.projectors:
226
+ x = proj(x)
227
+ return x