Beijia11 commited on
Commit
939de0d
·
1 Parent(s): a5e3686
Files changed (1) hide show
  1. models/spatracker/predictor.py +4 -0
models/spatracker/predictor.py CHANGED
@@ -17,6 +17,7 @@ from models.spatracker.models.build_spatracker import (
17
  from models.spatracker.models.core.model_utils import (
18
  meshgrid2d, bilinear_sample2d, smart_cat
19
  )
 
20
 
21
 
22
  class SpaTrackerPredictor(torch.nn.Module):
@@ -33,6 +34,7 @@ class SpaTrackerPredictor(torch.nn.Module):
33
  self.model = model
34
  self.model.eval()
35
 
 
36
  @torch.no_grad()
37
  def forward(
38
  self,
@@ -78,6 +80,7 @@ class SpaTrackerPredictor(torch.nn.Module):
78
 
79
  return tracks, visibilities, T_Firsts
80
 
 
81
  def _compute_dense_tracks(
82
  self, video, grid_query_frame, grid_size=30, backward_tracking=False,
83
  depth_predictor=None, video_depth=None, wind_length=8
@@ -113,6 +116,7 @@ class SpaTrackerPredictor(torch.nn.Module):
113
 
114
  return tracks, visibilities, T_Firsts
115
 
 
116
  def _compute_sparse_tracks(
117
  self,
118
  video,
 
17
  from models.spatracker.models.core.model_utils import (
18
  meshgrid2d, bilinear_sample2d, smart_cat
19
  )
20
+ import spaces
21
 
22
 
23
  class SpaTrackerPredictor(torch.nn.Module):
 
34
  self.model = model
35
  self.model.eval()
36
 
37
+ @spaces.GPU
38
  @torch.no_grad()
39
  def forward(
40
  self,
 
80
 
81
  return tracks, visibilities, T_Firsts
82
 
83
+ @spaces.GPU
84
  def _compute_dense_tracks(
85
  self, video, grid_query_frame, grid_size=30, backward_tracking=False,
86
  depth_predictor=None, video_depth=None, wind_length=8
 
116
 
117
  return tracks, visibilities, T_Firsts
118
 
119
+ @spaces.GPU
120
  def _compute_sparse_tracks(
121
  self,
122
  video,