Anton Forsman commited on
Commit
06c31c1
·
1 Parent(s): ddf6771

fixed device issues

Browse files
Files changed (1) hide show
  1. inference.py +6 -1
inference.py CHANGED
@@ -9,6 +9,8 @@ import io
9
 
10
  from model import Unet, ConditionalUnet, GaussianDiffusion, DiffusionImageAPI
11
 
 
 
12
  def inference1():
13
  # new image from web page
14
  image = requests.get("https://picsum.photos/120/80").content
@@ -18,7 +20,7 @@ def inference():
18
  model = Unet(
19
  image_channels=3,
20
  )
21
- model.load_state_dict(torch.load("./model_final.pt"))
22
 
23
  diffusion = GaussianDiffusion(
24
  model=model,
@@ -28,6 +30,9 @@ def inference():
28
  image_size=(120, 80),
29
  )
30
 
 
 
 
31
  imageAPI = DiffusionImageAPI(diffusion)
32
 
33
  images, versions = diffusion.sample(1)
 
9
 
10
  from model import Unet, ConditionalUnet, GaussianDiffusion, DiffusionImageAPI
11
 
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
  def inference1():
15
  # new image from web page
16
  image = requests.get("https://picsum.photos/120/80").content
 
20
  model = Unet(
21
  image_channels=3,
22
  )
23
+ model.load_state_dict(torch.load("./model_final.pt"), map_location=device)
24
 
25
  diffusion = GaussianDiffusion(
26
  model=model,
 
30
  image_size=(120, 80),
31
  )
32
 
33
+ model.to(device)
34
+ diffusion.to(device)
35
+
36
  imageAPI = DiffusionImageAPI(diffusion)
37
 
38
  images, versions = diffusion.sample(1)