smhh24 commited on
Commit
9d81d6f
·
verified ·
1 Parent(s): 352bba2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -28,6 +28,7 @@ def depth_estimation(image, model_path, encoder='vits'):
28
  config_path = 'configs/config_v2_vits14.json'
29
 
30
  # Ensure model path exists or download if needed
 
31
  if not os.path.exists(model_path):
32
  return "Model checkpoint not found. Please upload a valid model path."
33
 
@@ -36,7 +37,7 @@ def depth_estimation(image, model_path, encoder='vits'):
36
  # Preprocess image
37
  rgb = torch.from_numpy(np.array(image)).permute(2, 0, 1).to(device) # C, H, W
38
  predictions = model.infer(rgb)
39
- depth = predictions["depth"].squeeze().to('cpu').numpy()
40
 
41
  min_depth = depth.min()
42
  max_depth = depth.max()
@@ -87,7 +88,6 @@ def main():
87
  fn=depth_estimation,
88
  inputs=[
89
  gr.Image(type="numpy", label="Input Image"),
90
- gr.Textbox(value='checkpoint/latest.pth', label='Model Path'),
91
  gr.Dropdown(choices=['vits', 'vitb', 'vitl', 'vitg'], value='vits', label='Encoder'),
92
  ],
93
  outputs=[
 
28
  config_path = 'configs/config_v2_vits14.json'
29
 
30
  # Ensure model path exists or download if needed
31
+ model_path="checkpoint/latest.pth"
32
  if not os.path.exists(model_path):
33
  return "Model checkpoint not found. Please upload a valid model path."
34
 
 
37
  # Preprocess image
38
  rgb = torch.from_numpy(np.array(image)).permute(2, 0, 1).to(device) # C, H, W
39
  predictions = model.infer(rgb)
40
+ depth = predictions["depth"].squeeze().to(device).numpy()
41
 
42
  min_depth = depth.min()
43
  max_depth = depth.max()
 
88
  fn=depth_estimation,
89
  inputs=[
90
  gr.Image(type="numpy", label="Input Image"),
 
91
  gr.Dropdown(choices=['vits', 'vitb', 'vitl', 'vitg'], value='vits', label='Encoder'),
92
  ],
93
  outputs=[