Spaces:
Runtime error
Runtime error
+ proper use of hparams
Browse files- flowutils.py +28 -16
flowutils.py
CHANGED
@@ -61,18 +61,27 @@ class ConditionalDiagGaussian(BaseDistribution):
|
|
61 |
return log_p
|
62 |
|
63 |
|
64 |
-
|
65 |
def build_flows(
|
66 |
-
latent_size, num_flows=4,
|
67 |
):
|
68 |
# Define flows
|
69 |
|
70 |
flows = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
for i in range(num_flows):
|
72 |
flows += [
|
73 |
nf.flows.CoupledRationalQuadraticSpline(
|
74 |
latent_size,
|
75 |
-
num_blocks=
|
76 |
num_hidden_channels=hidden_units,
|
77 |
num_context_channels=context_size,
|
78 |
)
|
@@ -85,7 +94,7 @@ def build_flows(
|
|
85 |
nn.Linear(context_size, context_size),
|
86 |
nn.SiLU(),
|
87 |
# output mean and scales for K=latent_size dimensions
|
88 |
-
nn.Linear(context_size, latent_size * 2)
|
89 |
)
|
90 |
|
91 |
q0 = ConditionalDiagGaussian(latent_size, context_encoder)
|
@@ -190,7 +199,8 @@ class PatchFlow(torch.nn.Module):
|
|
190 |
input_size,
|
191 |
patch_size=3,
|
192 |
context_embedding_size=128,
|
193 |
-
|
|
|
194 |
hidden_units=128,
|
195 |
):
|
196 |
super().__init__()
|
@@ -198,8 +208,12 @@ class PatchFlow(torch.nn.Module):
|
|
198 |
self.local_pooler = SpatialNormer(
|
199 |
in_channels=num_sigmas, kernel_size=patch_size
|
200 |
)
|
201 |
-
self.
|
202 |
-
latent_size=num_sigmas,
|
|
|
|
|
|
|
|
|
203 |
)
|
204 |
self.position_encoding = PositionalEncoding2D(channels=context_embedding_size)
|
205 |
|
@@ -213,7 +227,7 @@ class PatchFlow(torch.nn.Module):
|
|
213 |
def init_weights(self):
|
214 |
# Initialize weights with Xavier
|
215 |
linear_modules = list(
|
216 |
-
filter(lambda m: isinstance(m, nn.Linear), self.
|
217 |
)
|
218 |
total = len(linear_modules)
|
219 |
|
@@ -252,12 +266,10 @@ class PatchFlow(torch.nn.Module):
|
|
252 |
p = rearrange(p, "n b c -> (n b) c")
|
253 |
|
254 |
# Compute log densities for each patch
|
255 |
-
logpx = self.
|
256 |
logpx = rearrange(logpx, "(n b) -> n b", n=n, b=b)
|
257 |
patch_logpx.append(logpx)
|
258 |
-
# del ctx, p
|
259 |
|
260 |
-
# print(p[:4], ctx[:4], logpx)
|
261 |
# Convert back to image
|
262 |
logpx = torch.cat(patch_logpx, dim=0)
|
263 |
logpx = rearrange(logpx, "(h w) b -> b 1 h w", b=b, h=new_h, w=new_w)
|
@@ -290,12 +302,12 @@ class PatchFlow(torch.nn.Module):
|
|
290 |
# # Concatenate global context to local context
|
291 |
# context_vector = torch.cat([context_vector, gctx], dim=1)
|
292 |
|
293 |
-
z, ldj = flow_model.
|
294 |
-
|
295 |
-
|
296 |
-
)
|
297 |
|
298 |
-
loss =
|
299 |
loss *= n_patches
|
300 |
|
301 |
if train:
|
|
|
61 |
return log_p
|
62 |
|
63 |
|
|
|
64 |
def build_flows(
|
65 |
+
latent_size, num_flows=4, num_blocks_per_flow=2, hidden_units=128, context_size=64
|
66 |
):
|
67 |
# Define flows
|
68 |
|
69 |
flows = []
|
70 |
+
|
71 |
+
flows.append(
|
72 |
+
nf.flows.MaskedAffineAutoregressive(
|
73 |
+
latent_size,
|
74 |
+
hidden_features=hidden_units,
|
75 |
+
num_blocks=num_blocks_per_flow,
|
76 |
+
context_features=context_size,
|
77 |
+
)
|
78 |
+
)
|
79 |
+
|
80 |
for i in range(num_flows):
|
81 |
flows += [
|
82 |
nf.flows.CoupledRationalQuadraticSpline(
|
83 |
latent_size,
|
84 |
+
num_blocks=num_blocks_per_flow,
|
85 |
num_hidden_channels=hidden_units,
|
86 |
num_context_channels=context_size,
|
87 |
)
|
|
|
94 |
nn.Linear(context_size, context_size),
|
95 |
nn.SiLU(),
|
96 |
# output mean and scales for K=latent_size dimensions
|
97 |
+
nn.Linear(context_size, latent_size * 2),
|
98 |
)
|
99 |
|
100 |
q0 = ConditionalDiagGaussian(latent_size, context_encoder)
|
|
|
199 |
input_size,
|
200 |
patch_size=3,
|
201 |
context_embedding_size=128,
|
202 |
+
num_flows=4,
|
203 |
+
num_blocks_per_flow=2,
|
204 |
hidden_units=128,
|
205 |
):
|
206 |
super().__init__()
|
|
|
208 |
self.local_pooler = SpatialNormer(
|
209 |
in_channels=num_sigmas, kernel_size=patch_size
|
210 |
)
|
211 |
+
self.flows = build_flows(
|
212 |
+
latent_size=num_sigmas,
|
213 |
+
context_size=context_embedding_size,
|
214 |
+
num_flows=num_flows,
|
215 |
+
num_blocks_per_flow=num_blocks_per_flow,
|
216 |
+
hidden_units=hidden_units,
|
217 |
)
|
218 |
self.position_encoding = PositionalEncoding2D(channels=context_embedding_size)
|
219 |
|
|
|
227 |
def init_weights(self):
|
228 |
# Initialize weights with Xavier
|
229 |
linear_modules = list(
|
230 |
+
filter(lambda m: isinstance(m, nn.Linear), self.flows.modules())
|
231 |
)
|
232 |
total = len(linear_modules)
|
233 |
|
|
|
266 |
p = rearrange(p, "n b c -> (n b) c")
|
267 |
|
268 |
# Compute log densities for each patch
|
269 |
+
logpx = self.flows.log_prob(p, context=ctx)
|
270 |
logpx = rearrange(logpx, "(n b) -> n b", n=n, b=b)
|
271 |
patch_logpx.append(logpx)
|
|
|
272 |
|
|
|
273 |
# Convert back to image
|
274 |
logpx = torch.cat(patch_logpx, dim=0)
|
275 |
logpx = rearrange(logpx, "(h w) b -> b 1 h w", b=b, h=new_h, w=new_w)
|
|
|
302 |
# # Concatenate global context to local context
|
303 |
# context_vector = torch.cat([context_vector, gctx], dim=1)
|
304 |
|
305 |
+
# z, ldj = flow_model.flows.inverse_and_log_det(
|
306 |
+
# patch_feature,
|
307 |
+
# context=context_vector,
|
308 |
+
# )
|
309 |
|
310 |
+
loss = flow_model.flows.forward_kld(patch_feature, context_vector)
|
311 |
loss *= n_patches
|
312 |
|
313 |
if train:
|