Spaces:
Paused
Paused
Anton Forsman
commited on
Commit
·
475978d
1
Parent(s):
8f59309
fix
Browse files- inference.py +2 -1
- unet.py +3 -1
inference.py
CHANGED
@@ -21,6 +21,7 @@ def inference1():
|
|
21 |
def inference():
|
22 |
model = Unet(
|
23 |
image_channels=3,
|
|
|
24 |
)
|
25 |
model = ConditionalUnet(
|
26 |
unet=model,
|
@@ -33,7 +34,7 @@ def inference():
|
|
33 |
noise_steps=1000,
|
34 |
beta_0=1e-4,
|
35 |
beta_T=0.02,
|
36 |
-
image_size=(
|
37 |
)
|
38 |
|
39 |
model.to(device)
|
|
|
21 |
def inference():
|
22 |
model = Unet(
|
23 |
image_channels=3,
|
24 |
+
dropout=0.1,
|
25 |
)
|
26 |
model = ConditionalUnet(
|
27 |
unet=model,
|
|
|
34 |
noise_steps=1000,
|
35 |
beta_0=1e-4,
|
36 |
beta_T=0.02,
|
37 |
+
image_size=(192, 128),
|
38 |
)
|
39 |
|
40 |
model.to(device)
|
unet.py
CHANGED
@@ -411,7 +411,9 @@ class ConditionalUnet(nn.Module):
|
|
411 |
|
412 |
self.class_embedding = nn.Embedding(num_classes + 1, unet.starting_channels, padding_idx=0)
|
413 |
|
414 |
-
def forward(self, x, t, cond=
|
|
|
|
|
415 |
# cond: (batch_size, n), where n is the number of classes that we are conditioning on
|
416 |
t = self.unet.time_encoding(t)
|
417 |
|
|
|
411 |
|
412 |
self.class_embedding = nn.Embedding(num_classes + 1, unet.starting_channels, padding_idx=0)
|
413 |
|
414 |
+
def forward(self, x, t, cond=None):
|
415 |
+
cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
416 |
+
cond = cond.unsqueeze(0)
|
417 |
# cond: (batch_size, n), where n is the number of classes that we are conditioning on
|
418 |
t = self.unet.time_encoding(t)
|
419 |
|