Update visual.py
Browse files
visual.py
CHANGED
@@ -125,7 +125,7 @@ class Resampler(nn.Module):
|
|
125 |
self.ln_q = norm_layer(embed_dim)
|
126 |
self.ln_kv = norm_layer(embed_dim)
|
127 |
|
128 |
-
self.apply(self._init_weights)
|
129 |
|
130 |
def _init_weights(self, m):
|
131 |
if isinstance(m, nn.Linear):
|
@@ -189,7 +189,7 @@ class VisualAttention(nn.Module):
|
|
189 |
# query/key/value: [sq, b, h]
|
190 |
sq, b, _ = query.size()
|
191 |
|
192 |
-
assert query
|
193 |
sk = sq
|
194 |
mixed_x_layer = self.in_proj(query)
|
195 |
|
|
|
125 |
self.ln_q = norm_layer(embed_dim)
|
126 |
self.ln_kv = norm_layer(embed_dim)
|
127 |
|
128 |
+
# self.apply(self._init_weights)
|
129 |
|
130 |
def _init_weights(self, m):
|
131 |
if isinstance(m, nn.Linear):
|
|
|
189 |
# query/key/value: [sq, b, h]
|
190 |
sq, b, _ = query.size()
|
191 |
|
192 |
+
assert torch.allclose(query, key), 'Only Support Self-Attention Currently'
|
193 |
sk = sq
|
194 |
mixed_x_layer = self.in_proj(query)
|
195 |
|