bardofcodes commited on
Commit
6149018
·
verified ·
1 Parent(s): 199fa3a

Delete analogy_projector.py

Browse files
Files changed (1) hide show
  1. analogy_projector.py +0 -227
analogy_projector.py DELETED
@@ -1,227 +0,0 @@
1
- """
2
- ADOBE CONFIDENTIAL
3
- Copyright 2024 Adobe
4
- All Rights Reserved.
5
- NOTICE: All information contained herein is, and remains
6
- the property of Adobe and its suppliers, if any. The intellectual
7
- and technical concepts contained herein are proprietary to Adobe
8
- and its suppliers and are protected by all applicable intellectual
9
- property laws, including trade secret and copyright laws.
10
- Dissemination of this information or reproduction of this material
11
- is strictly forbidden unless prior written permission is obtained
12
- from Adobe.
13
- """
14
-
15
- import einops
16
- import numpy as np
17
- import torch as th
18
- import torch.nn as nn
19
- from diffusers import ModelMixin
20
- from diffusers.configuration_utils import ConfigMixin, register_to_config
21
- # REf: https://github.com/tatp22/multidim-positional-encoding/tree/master
22
-
23
-
24
- OUT_SIZE = 768
25
- IN_SIZE = 2048
26
-
27
-
28
- def get_emb(sin_inp):
29
- """
30
- Gets a base embedding for one dimension with sin and cos intertwined
31
- """
32
- emb = th.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
33
- return th.flatten(emb, -2, -1)
34
-
35
-
36
- class PositionalEncoding1D(nn.Module):
37
- def __init__(self, channels):
38
- """
39
- :param channels: The last dimension of the tensor you want to apply pos emb to.
40
- """
41
- super(PositionalEncoding1D, self).__init__()
42
- self.org_channels = channels
43
- channels = int(np.ceil(channels / 2) * 2)
44
- self.channels = channels
45
- inv_freq = 1.0 / (10000 ** (th.arange(0, channels, 2).float() / channels))
46
- self.register_buffer("inv_freq", inv_freq)
47
- self.register_buffer("cached_penc", None, persistent=False)
48
-
49
- def forward(self, tensor):
50
- """
51
- :param tensor: A 3d tensor of size (batch_size, x, ch)
52
- :return: Positional Encoding Matrix of size (batch_size, x, ch)
53
- """
54
- if len(tensor.shape) != 3:
55
- raise RuntimeError("The input tensor has to be 3d!")
56
-
57
- if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
58
- return self.cached_penc
59
-
60
- self.cached_penc = None
61
- batch_size, x, orig_ch = tensor.shape
62
- pos_x = th.arange(x, device=tensor.device, dtype=self.inv_freq.dtype)
63
- sin_inp_x = th.einsum("i,j->ij", pos_x, self.inv_freq)
64
- emb_x = get_emb(sin_inp_x)
65
- emb = th.zeros((x, self.channels), device=tensor.device, dtype=tensor.dtype)
66
- emb[:, : self.channels] = emb_x
67
-
68
- self.cached_penc = emb[None, :, :orig_ch].repeat(batch_size, 1, 1)
69
- return self.cached_penc
70
-
71
-
72
-
73
- class PositionalEncoding3D(nn.Module):
74
- def __init__(self, channels):
75
- """
76
- :param channels: The last dimension of the tensor you want to apply pos emb to.
77
- """
78
- super(PositionalEncoding3D, self).__init__()
79
- self.org_channels = channels
80
- channels = int(np.ceil(channels / 6) * 2)
81
- if channels % 2:
82
- channels += 1
83
- self.channels = channels
84
- inv_freq = 1.0 / (10000 ** (th.arange(0, channels, 2).float() / channels))
85
- self.register_buffer("inv_freq", inv_freq)
86
- self.register_buffer("cached_penc", None, persistent=False)
87
-
88
- def forward(self, tensor):
89
- """
90
- :param tensor: A 5d tensor of size (batch_size, x, y, z, ch)
91
- :return: Positional Encoding Matrix of size (batch_size, x, y, z, ch)
92
- """
93
- if len(tensor.shape) != 5:
94
- raise RuntimeError("The input tensor has to be 5d!")
95
-
96
- if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
97
- return self.cached_penc
98
-
99
- self.cached_penc = None
100
- batch_size, x, y, z, orig_ch = tensor.shape
101
- pos_x = th.arange(x, device=tensor.device, dtype=self.inv_freq.dtype)
102
- pos_y = th.arange(y, device=tensor.device, dtype=self.inv_freq.dtype)
103
- pos_z = th.arange(z, device=tensor.device, dtype=self.inv_freq.dtype)
104
- sin_inp_x = th.einsum("i,j->ij", pos_x, self.inv_freq)
105
- sin_inp_y = th.einsum("i,j->ij", pos_y, self.inv_freq)
106
- sin_inp_z = th.einsum("i,j->ij", pos_z, self.inv_freq)
107
- emb_x = get_emb(sin_inp_x).unsqueeze(1).unsqueeze(1)
108
- emb_y = get_emb(sin_inp_y).unsqueeze(1)
109
- emb_z = get_emb(sin_inp_z)
110
- emb = th.zeros(
111
- (x, y, z, self.channels * 3),
112
- device=tensor.device,
113
- dtype=tensor.dtype,
114
- )
115
- emb[:, :, :, : self.channels] = emb_x
116
- emb[:, :, :, self.channels : 2 * self.channels] = emb_y
117
- emb[:, :, :, 2 * self.channels :] = emb_z
118
-
119
- self.cached_penc = emb[None, :, :, :, :orig_ch].repeat(batch_size, 1, 1, 1, 1)
120
- return self.cached_penc
121
-
122
- class AnalogyProjector(ModelMixin, ConfigMixin):
123
-
124
- @register_to_config
125
- def __init__(self):
126
- super(AnalogyProjector, self).__init__()
127
- self.projector = DinoSiglipMixer()
128
- self.pos_embd_1D = PositionalEncoding1D(OUT_SIZE)
129
- self.pos_embd_3D = PositionalEncoding3D(OUT_SIZE)
130
-
131
-
132
- def forward(self, dino_in, siglip_in, batch_size):
133
-
134
- image_embeddings = self.projector(dino_in, siglip_in)
135
-
136
- image_embeddings = einops.rearrange(image_embeddings, '(k b) t d -> b k t d', b=batch_size)
137
- image_embeddings = self.position_embd(image_embeddings)
138
- return image_embeddings
139
-
140
- def position_embd(self, image_embeddings, concat=False):
141
- canvas_embd = image_embeddings[:, :, 1:, :]
142
- batch_size = canvas_embd.shape[0]
143
- type_size = canvas_embd.shape[1]
144
- xy_size = canvas_embd.shape[2]
145
-
146
- x_size = int(xy_size ** 0.5)
147
-
148
- canvas_embd = canvas_embd.reshape(batch_size, type_size, x_size, x_size, -1)
149
- if concat:
150
- canvas_embd = th.cat([canvas_embd, self.pos_embd_3D(canvas_embd)], -1)
151
- else:
152
- canvas_embd = self.pos_embd_3D(canvas_embd) + canvas_embd
153
- canvas_embd = canvas_embd.reshape(batch_size, type_size, xy_size, -1)
154
-
155
- class_embd = image_embeddings[:, :, 0, :]
156
- if concat:
157
- class_embd = th.cat([class_embd, self.pos_embd_1D(class_embd)], -1)
158
- else:
159
- class_embd = self.pos_embd_1D(class_embd) + class_embd
160
- all_embd_list = []
161
- for i in range(type_size):
162
- all_embd_list.append(class_embd[:, i:i+1])
163
- all_embd_list.append(canvas_embd[:, i])
164
- image_embeddings = th.cat(all_embd_list, 1)
165
- return image_embeddings
166
-
167
-
168
- class HighLowMixer(th.nn.Module):
169
- def __init__(self, in_size=IN_SIZE, out_size=OUT_SIZE):
170
- super().__init__()
171
- mid_size = (in_size + out_size) // 2
172
-
173
- self.lower_projector = th.nn.Sequential(
174
- th.nn.LayerNorm(IN_SIZE//2),
175
- th.nn.SiLU()
176
- )
177
- self.upper_projector = th.nn.Sequential(
178
- th.nn.LayerNorm(IN_SIZE//2),
179
- th.nn.SiLU()
180
- )
181
- self.projectors = th.nn.ModuleList([
182
- # add layer norm
183
- th.nn.Linear(in_size, mid_size),
184
- th.nn.SiLU(),
185
- th.nn.Linear(mid_size, out_size)
186
- ])
187
- # initialize
188
- for proj in self.projectors:
189
- if isinstance(proj, th.nn.Linear):
190
- th.nn.init.xavier_uniform_(proj.weight)
191
- th.nn.init.zeros_(proj.bias)
192
-
193
- def forward(self, lower_in, upper_in, ):
194
- # ALso format lower_in
195
- lower_in = self.lower_projector(lower_in)
196
- upper_in = self.upper_projector(upper_in)
197
- x = th.cat([lower_in, upper_in], -1)
198
- for proj in self.projectors:
199
- x = proj(x)
200
- return x
201
-
202
- class DinoSiglipMixer(th.nn.Module):
203
- def __init__(self, in_size=OUT_SIZE * 2, out_size=OUT_SIZE):
204
- super().__init__()
205
- self.dino_projector = HighLowMixer()
206
- self.siglip_projector = HighLowMixer()
207
- self.projectors = th.nn.Sequential(
208
- th.nn.SiLU(),
209
- th.nn.Linear(in_size, out_size),
210
- )
211
- # initialize
212
- for proj in self.projectors:
213
- if isinstance(proj, th.nn.Linear):
214
- th.nn.init.xavier_uniform_(proj.weight)
215
- th.nn.init.zeros_(proj.bias)
216
-
217
-
218
- def forward(self, dino_in, siglip_in):
219
- # ALso format lower_in
220
- lower, upper = th.chunk(dino_in, 2, -1)
221
- dino_out = self.dino_projector(lower, upper)
222
- lower, upper = th.chunk(siglip_in, 2, -1)
223
- siglip_out = self.siglip_projector(lower, upper)
224
- x = th.cat([dino_out, siglip_out], -1)
225
- for proj in self.projectors:
226
- x = proj(x)
227
- return x