Spaces:
Paused
Paused
Anton Forsman
commited on
Commit
·
06c31c1
1
Parent(s):
ddf6771
fixed device issues
Browse files- inference.py +6 -1
inference.py
CHANGED
@@ -9,6 +9,8 @@ import io
|
|
9 |
|
10 |
from model import Unet, ConditionalUnet, GaussianDiffusion, DiffusionImageAPI
|
11 |
|
|
|
|
|
12 |
def inference1():
|
13 |
# new image from web page
|
14 |
image = requests.get("https://picsum.photos/120/80").content
|
@@ -18,7 +20,7 @@ def inference():
|
|
18 |
model = Unet(
|
19 |
image_channels=3,
|
20 |
)
|
21 |
-
model.load_state_dict(torch.load("./model_final.pt"))
|
22 |
|
23 |
diffusion = GaussianDiffusion(
|
24 |
model=model,
|
@@ -28,6 +30,9 @@ def inference():
|
|
28 |
image_size=(120, 80),
|
29 |
)
|
30 |
|
|
|
|
|
|
|
31 |
imageAPI = DiffusionImageAPI(diffusion)
|
32 |
|
33 |
images, versions = diffusion.sample(1)
|
|
|
9 |
|
10 |
from model import Unet, ConditionalUnet, GaussianDiffusion, DiffusionImageAPI
|
11 |
|
12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
|
14 |
def inference1():
|
15 |
# new image from web page
|
16 |
image = requests.get("https://picsum.photos/120/80").content
|
|
|
20 |
model = Unet(
|
21 |
image_channels=3,
|
22 |
)
|
23 |
+
model.load_state_dict(torch.load("./model_final.pt"), map_location=device)
|
24 |
|
25 |
diffusion = GaussianDiffusion(
|
26 |
model=model,
|
|
|
30 |
image_size=(120, 80),
|
31 |
)
|
32 |
|
33 |
+
model.to(device)
|
34 |
+
diffusion.to(device)
|
35 |
+
|
36 |
imageAPI = DiffusionImageAPI(diffusion)
|
37 |
|
38 |
images, versions = diffusion.sample(1)
|