ahsanMah commited on
Commit
b55c3d2
·
1 Parent(s): a186356

+ proper use of hparams

Browse files
Files changed (1) hide show
  1. flowutils.py +28 -16
flowutils.py CHANGED
@@ -61,18 +61,27 @@ class ConditionalDiagGaussian(BaseDistribution):
61
  return log_p
62
 
63
 
64
-
65
  def build_flows(
66
- latent_size, num_flows=4, num_blocks=2, hidden_units=128, context_size=64
67
  ):
68
  # Define flows
69
 
70
  flows = []
 
 
 
 
 
 
 
 
 
 
71
  for i in range(num_flows):
72
  flows += [
73
  nf.flows.CoupledRationalQuadraticSpline(
74
  latent_size,
75
- num_blocks=num_blocks,
76
  num_hidden_channels=hidden_units,
77
  num_context_channels=context_size,
78
  )
@@ -85,7 +94,7 @@ def build_flows(
85
  nn.Linear(context_size, context_size),
86
  nn.SiLU(),
87
  # output mean and scales for K=latent_size dimensions
88
- nn.Linear(context_size, latent_size * 2)
89
  )
90
 
91
  q0 = ConditionalDiagGaussian(latent_size, context_encoder)
@@ -190,7 +199,8 @@ class PatchFlow(torch.nn.Module):
190
  input_size,
191
  patch_size=3,
192
  context_embedding_size=128,
193
- num_blocks=2,
 
194
  hidden_units=128,
195
  ):
196
  super().__init__()
@@ -198,8 +208,12 @@ class PatchFlow(torch.nn.Module):
198
  self.local_pooler = SpatialNormer(
199
  in_channels=num_sigmas, kernel_size=patch_size
200
  )
201
- self.flow = build_flows(
202
- latent_size=num_sigmas, context_size=context_embedding_size
 
 
 
 
203
  )
204
  self.position_encoding = PositionalEncoding2D(channels=context_embedding_size)
205
 
@@ -213,7 +227,7 @@ class PatchFlow(torch.nn.Module):
213
  def init_weights(self):
214
  # Initialize weights with Xavier
215
  linear_modules = list(
216
- filter(lambda m: isinstance(m, nn.Linear), self.flow.modules())
217
  )
218
  total = len(linear_modules)
219
 
@@ -252,12 +266,10 @@ class PatchFlow(torch.nn.Module):
252
  p = rearrange(p, "n b c -> (n b) c")
253
 
254
  # Compute log densities for each patch
255
- logpx = self.flow.log_prob(p, context=ctx)
256
  logpx = rearrange(logpx, "(n b) -> n b", n=n, b=b)
257
  patch_logpx.append(logpx)
258
- # del ctx, p
259
 
260
- # print(p[:4], ctx[:4], logpx)
261
  # Convert back to image
262
  logpx = torch.cat(patch_logpx, dim=0)
263
  logpx = rearrange(logpx, "(h w) b -> b 1 h w", b=b, h=new_h, w=new_w)
@@ -290,12 +302,12 @@ class PatchFlow(torch.nn.Module):
290
  # # Concatenate global context to local context
291
  # context_vector = torch.cat([context_vector, gctx], dim=1)
292
 
293
- z, ldj = flow_model.flow.inverse_and_log_det(
294
- patch_feature,
295
- context=context_vector,
296
- )
297
 
298
- loss = -torch.mean(flow_model.flow.q0.log_prob(z, context_vector) + ldj)
299
  loss *= n_patches
300
 
301
  if train:
 
61
  return log_p
62
 
63
 
 
64
  def build_flows(
65
+ latent_size, num_flows=4, num_blocks_per_flow=2, hidden_units=128, context_size=64
66
  ):
67
  # Define flows
68
 
69
  flows = []
70
+
71
+ flows.append(
72
+ nf.flows.MaskedAffineAutoregressive(
73
+ latent_size,
74
+ hidden_features=hidden_units,
75
+ num_blocks=num_blocks_per_flow,
76
+ context_features=context_size,
77
+ )
78
+ )
79
+
80
  for i in range(num_flows):
81
  flows += [
82
  nf.flows.CoupledRationalQuadraticSpline(
83
  latent_size,
84
+ num_blocks=num_blocks_per_flow,
85
  num_hidden_channels=hidden_units,
86
  num_context_channels=context_size,
87
  )
 
94
  nn.Linear(context_size, context_size),
95
  nn.SiLU(),
96
  # output mean and scales for K=latent_size dimensions
97
+ nn.Linear(context_size, latent_size * 2),
98
  )
99
 
100
  q0 = ConditionalDiagGaussian(latent_size, context_encoder)
 
199
  input_size,
200
  patch_size=3,
201
  context_embedding_size=128,
202
+ num_flows=4,
203
+ num_blocks_per_flow=2,
204
  hidden_units=128,
205
  ):
206
  super().__init__()
 
208
  self.local_pooler = SpatialNormer(
209
  in_channels=num_sigmas, kernel_size=patch_size
210
  )
211
+ self.flows = build_flows(
212
+ latent_size=num_sigmas,
213
+ context_size=context_embedding_size,
214
+ num_flows=num_flows,
215
+ num_blocks_per_flow=num_blocks_per_flow,
216
+ hidden_units=hidden_units,
217
  )
218
  self.position_encoding = PositionalEncoding2D(channels=context_embedding_size)
219
 
 
227
  def init_weights(self):
228
  # Initialize weights with Xavier
229
  linear_modules = list(
230
+ filter(lambda m: isinstance(m, nn.Linear), self.flows.modules())
231
  )
232
  total = len(linear_modules)
233
 
 
266
  p = rearrange(p, "n b c -> (n b) c")
267
 
268
  # Compute log densities for each patch
269
+ logpx = self.flows.log_prob(p, context=ctx)
270
  logpx = rearrange(logpx, "(n b) -> n b", n=n, b=b)
271
  patch_logpx.append(logpx)
 
272
 
 
273
  # Convert back to image
274
  logpx = torch.cat(patch_logpx, dim=0)
275
  logpx = rearrange(logpx, "(h w) b -> b 1 h w", b=b, h=new_h, w=new_w)
 
302
  # # Concatenate global context to local context
303
  # context_vector = torch.cat([context_vector, gctx], dim=1)
304
 
305
+ # z, ldj = flow_model.flows.inverse_and_log_det(
306
+ # patch_feature,
307
+ # context=context_vector,
308
+ # )
309
 
310
+ loss = flow_model.flows.forward_kld(patch_feature, context_vector)
311
  loss *= n_patches
312
 
313
  if train: