rahulvenkk commited on
Commit
110d56f
1 Parent(s): a45652e

modified app.py gradio cuda

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -26,7 +26,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
 
27
 
28
  # Load CWM 3-frame model (automatically download pre-trained checkpoint)
29
- model = model_factory.load_model('vitb_8x8patch_3frames').to(device)
30
 
31
  model.requires_grad_(False)
32
  model.eval()
@@ -91,7 +91,7 @@ import os
91
  # print("Preloaded images:", preloaded_images)
92
  @spaces.GPU
93
  def get_c(x, points):
94
- x = utils.imagenet_normalize(x).to(device)
95
  with torch.no_grad():
96
  counterfactual = model.get_counterfactual(x, points)
97
  return counterfactual
 
26
 
27
 
28
  # Load CWM 3-frame model (automatically download pre-trained checkpoint)
29
+ model = model_factory.load_model('vitb_8x8patch_3frames')#.to(device)
30
 
31
  model.requires_grad_(False)
32
  model.eval()
 
91
  # print("Preloaded images:", preloaded_images)
92
  @spaces.GPU
93
  def get_c(x, points):
94
+ x = utils.imagenet_normalize(x)#.to(device)
95
  with torch.no_grad():
96
  counterfactual = model.get_counterfactual(x, points)
97
  return counterfactual