Anton Forsman commited on
Commit
06ebfb2
1 Parent(s): 94539f6

fix device issue

Browse files
Files changed (1) hide show
  1. unet.py +4 -0
unet.py CHANGED
@@ -411,6 +411,10 @@ class ConditionalUnet(nn.Module):
411
 
412
  self.class_embedding = nn.Embedding(num_classes + 1, unet.starting_channels, padding_idx=0)
413
 
 
 
 
 
414
  def forward(self, x, t, cond=None):
415
  cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
416
  cond = cond.unsqueeze(0)
 
411
 
412
  self.class_embedding = nn.Embedding(num_classes + 1, unet.starting_channels, padding_idx=0)
413
 
414
+ def to(self, device):
415
+ self.device = device
416
+ return super().to(device)
417
+
418
  def forward(self, x, t, cond=None):
419
  cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
420
  cond = cond.unsqueeze(0)