Spaces:
Paused
Paused
Anton Forsman
commited on
Commit
•
06ebfb2
1
Parent(s):
94539f6
fix device issue
Browse files
unet.py
CHANGED
@@ -411,6 +411,10 @@ class ConditionalUnet(nn.Module):
|
|
411 |
|
412 |
self.class_embedding = nn.Embedding(num_classes + 1, unet.starting_channels, padding_idx=0)
|
413 |
|
|
|
|
|
|
|
|
|
414 |
def forward(self, x, t, cond=None):
|
415 |
cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
416 |
cond = cond.unsqueeze(0)
|
|
|
411 |
|
412 |
self.class_embedding = nn.Embedding(num_classes + 1, unet.starting_channels, padding_idx=0)
|
413 |
|
414 |
+
def to(self, device):
|
415 |
+
self.device = device
|
416 |
+
return super().to(device)
|
417 |
+
|
418 |
def forward(self, x, t, cond=None):
|
419 |
cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
420 |
cond = cond.unsqueeze(0)
|