owiedotch commited on
Commit
59593f5
1 Parent(s): 10d6431

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -59,7 +59,10 @@ model = instantiate_from_config(config)
59
  ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu")
60
  load_state_dict(model, ckpt, strict=True)
61
  model.freeze()
62
- model.to("cuda")
 
 
 
63
 
64
  @torch.no_grad()
65
  def process(
@@ -113,7 +116,7 @@ def process(
113
  control_img = np.array(control_img)
114
 
115
  # Convert to tensor (NCHW, [0,1])
116
- control = torch.tensor(control_img[None] / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
117
  control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
118
  height, width = control.size(-2), control.size(-1)
119
  model.control_scales = [strength] * 13
@@ -122,7 +125,7 @@ def process(
122
  preds = []
123
  for _ in tqdm(range(num_samples)):
124
  shape = (1, 4, height // 8, width // 8)
125
- x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
126
 
127
  if not tile_diffusion and not tile_vae:
128
  samples = sampler.sample_ccsr(
 
59
  ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu")
60
  load_state_dict(model, ckpt, strict=True)
61
  model.freeze()
62
+
63
+ # Check if CUDA is available, otherwise use CPU
64
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
+ model.to(device)
66
 
67
  @torch.no_grad()
68
  def process(
 
116
  control_img = np.array(control_img)
117
 
118
  # Convert to tensor (NCHW, [0,1])
119
+ control = torch.tensor(control_img[None] / 255.0, dtype=torch.float32, device=device).clamp_(0, 1)
120
  control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
121
  height, width = control.size(-2), control.size(-1)
122
  model.control_scales = [strength] * 13
 
125
  preds = []
126
  for _ in tqdm(range(num_samples)):
127
  shape = (1, 4, height // 8, width // 8)
128
+ x_T = torch.randn(shape, device=device, dtype=torch.float32)
129
 
130
  if not tile_diffusion and not tile_vae:
131
  samples = sampler.sample_ccsr(