Anton Forsman commited on
Commit
475978d
·
1 Parent(s): 8f59309
Files changed (2) hide show
  1. inference.py +2 -1
  2. unet.py +3 -1
inference.py CHANGED
@@ -21,6 +21,7 @@ def inference1():
21
  def inference():
22
  model = Unet(
23
  image_channels=3,
 
24
  )
25
  model = ConditionalUnet(
26
  unet=model,
@@ -33,7 +34,7 @@ def inference():
33
  noise_steps=1000,
34
  beta_0=1e-4,
35
  beta_T=0.02,
36
- image_size=(120, 80),
37
  )
38
 
39
  model.to(device)
 
21
  def inference():
22
  model = Unet(
23
  image_channels=3,
24
+ dropout=0.1,
25
  )
26
  model = ConditionalUnet(
27
  unet=model,
 
34
  noise_steps=1000,
35
  beta_0=1e-4,
36
  beta_T=0.02,
37
+ image_size=(192, 128),
38
  )
39
 
40
  model.to(device)
unet.py CHANGED
@@ -411,7 +411,9 @@ class ConditionalUnet(nn.Module):
411
 
412
  self.class_embedding = nn.Embedding(num_classes + 1, unet.starting_channels, padding_idx=0)
413
 
414
- def forward(self, x, t, cond=torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])):
 
 
415
  # cond: (batch_size, n), where n is the number of classes that we are conditioning on
416
  t = self.unet.time_encoding(t)
417
 
 
411
 
412
  self.class_embedding = nn.Embedding(num_classes + 1, unet.starting_channels, padding_idx=0)
413
 
414
+ def forward(self, x, t, cond=None):
415
+ cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
416
+ cond = cond.unsqueeze(0)
417
  # cond: (batch_size, n), where n is the number of classes that we are conditioning on
418
  t = self.unet.time_encoding(t)
419