Spaces:
Paused
Paused
Anton Forsman
commited on
Commit
·
60e1bdf
1
Parent(s):
5373453
fixes
Browse files- inference.py +1 -2
inference.py
CHANGED
@@ -10,7 +10,6 @@ import io
|
|
10 |
from model import Unet, ConditionalUnet, GaussianDiffusion, DiffusionImageAPI
|
11 |
|
12 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
-
device = "cpu"
|
14 |
|
15 |
def inference1():
|
16 |
# new image from web page
|
@@ -21,7 +20,7 @@ def inference():
|
|
21 |
model = Unet(
|
22 |
image_channels=3,
|
23 |
)
|
24 |
-
model.load_state_dict(torch.load("./model_final.pt"
|
25 |
|
26 |
diffusion = GaussianDiffusion(
|
27 |
model=model,
|
|
|
10 |
from model import Unet, ConditionalUnet, GaussianDiffusion, DiffusionImageAPI
|
11 |
|
12 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
13 |
|
14 |
def inference1():
|
15 |
# new image from web page
|
|
|
20 |
model = Unet(
|
21 |
image_channels=3,
|
22 |
)
|
23 |
+
model.load_state_dict(torch.load("./model_final.pt", map_location=device))
|
24 |
|
25 |
diffusion = GaussianDiffusion(
|
26 |
model=model,
|