Beijia11 commited on
Commit
a5e3686
·
1 Parent(s): c4fce07

add gpu for zerogpu

Browse files
Files changed (2) hide show
  1. app.py +7 -0
  2. models/pipelines.py +10 -1
app.py CHANGED
@@ -8,6 +8,7 @@ import numpy as np
8
  import torchvision.transforms as transforms
9
  from moviepy.editor import VideoFileClip
10
  from diffusers.utils import load_image, load_video
 
11
 
12
  project_root = os.path.dirname(os.path.abspath(__file__))
13
  os.environ["GRADIO_TEMP_DIR"] = os.path.join(project_root, "tmp", "gradio")
@@ -131,12 +132,14 @@ def save_uploaded_file(file):
131
  das_pipeline = None
132
  moge_model = None
133
 
 
134
  def get_das_pipeline():
135
  global das_pipeline
136
  if das_pipeline is None:
137
  das_pipeline = DiffusionAsShaderPipeline(gpu_id=GPU_ID, output_dir=OUTPUT_DIR)
138
  return das_pipeline
139
 
 
140
  def get_moge_model():
141
  global moge_model
142
  if moge_model is None:
@@ -144,6 +147,7 @@ def get_moge_model():
144
  moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device)
145
  return moge_model
146
 
 
147
  def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image):
148
  """Process video motion transfer task"""
149
  try:
@@ -217,6 +221,7 @@ def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image)
217
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
218
  return None
219
 
 
220
  def process_camera_control(source, prompt, camera_motion, tracking_method):
221
  """Process camera control task"""
222
  try:
@@ -288,6 +293,7 @@ def process_camera_control(source, prompt, camera_motion, tracking_method):
288
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
289
  return None
290
 
 
291
  def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method):
292
  """Process object manipulation task"""
293
  try:
@@ -381,6 +387,7 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
381
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
382
  return None
383
 
 
384
  def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image):
385
  """Process mesh animation task"""
386
  try:
 
8
  import torchvision.transforms as transforms
9
  from moviepy.editor import VideoFileClip
10
  from diffusers.utils import load_image, load_video
11
+ import spaces
12
 
13
  project_root = os.path.dirname(os.path.abspath(__file__))
14
  os.environ["GRADIO_TEMP_DIR"] = os.path.join(project_root, "tmp", "gradio")
 
132
  das_pipeline = None
133
  moge_model = None
134
 
135
+ @spaces.GPU
136
  def get_das_pipeline():
137
  global das_pipeline
138
  if das_pipeline is None:
139
  das_pipeline = DiffusionAsShaderPipeline(gpu_id=GPU_ID, output_dir=OUTPUT_DIR)
140
  return das_pipeline
141
 
142
+ @spaces.GPU
143
  def get_moge_model():
144
  global moge_model
145
  if moge_model is None:
 
147
  moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device)
148
  return moge_model
149
 
150
+
151
  def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image):
152
  """Process video motion transfer task"""
153
  try:
 
221
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
222
  return None
223
 
224
+
225
  def process_camera_control(source, prompt, camera_motion, tracking_method):
226
  """Process camera control task"""
227
  try:
 
293
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
294
  return None
295
 
296
+
297
  def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method):
298
  """Process object manipulation task"""
299
  try:
 
387
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
388
  return None
389
 
390
+
391
  def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image):
392
  """Process mesh animation task"""
393
  try:
models/pipelines.py CHANGED
@@ -24,6 +24,7 @@ from models.cogvideox_tracking import CogVideoXImageToVideoPipelineTracking
24
  from submodules.MoGe.moge.model import MoGeModel
25
  from image_gen_aux import DepthPreprocessor
26
  from moviepy.editor import ImageSequenceClip
 
27
 
28
  class DiffusionAsShaderPipeline:
29
  def __init__(self, gpu_id=0, output_dir='outputs'):
@@ -55,6 +56,7 @@ class DiffusionAsShaderPipeline:
55
  transforms.ToTensor()
56
  ])
57
 
 
58
  @torch.no_grad()
59
  def _infer(
60
  self,
@@ -181,6 +183,7 @@ class DiffusionAsShaderPipeline:
181
 
182
  return intr
183
 
 
184
  def _apply_poses(self, pts, intr, poses):
185
  """
186
  Args:
@@ -217,6 +220,7 @@ class DiffusionAsShaderPipeline:
217
  return tracking_pts
218
 
219
  ##============= SpatialTracker =============##
 
220
  def generate_tracking_spatracker(self, video_tensor, density=70):
221
  """Generate tracking video
222
 
@@ -271,6 +275,7 @@ class DiffusionAsShaderPipeline:
271
  del tracker, self.depth_preprocessor
272
  torch.cuda.empty_cache()
273
 
 
274
  def visualize_tracking_spatracker(self, video, pred_tracks, pred_visibility, T_Firsts, save_tracking=True):
275
  video = video.unsqueeze(0).to(self.device)
276
  vis = Visualizer(save_dir=self.output_dir, grayscale=False, fps=24, pad_value=0)
@@ -360,6 +365,7 @@ class DiffusionAsShaderPipeline:
360
  outline=tuple(color),
361
  )
362
 
 
363
  def visualize_tracking_moge(self, points, mask, save_tracking=True):
364
  """Visualize tracking results from MoGe model
365
 
@@ -446,6 +452,7 @@ class DiffusionAsShaderPipeline:
446
 
447
  return tracking_path, tracking_video
448
 
 
449
  def apply_tracking(self, video_tensor, fps=8, tracking_tensor=None, img_cond_tensor=None, prompt=None, checkpoint_path=None):
450
  """Generate final video with motion transfer
451
 
@@ -486,6 +493,7 @@ class DiffusionAsShaderPipeline:
486
  """
487
  self.object_motion = motion_type
488
 
 
489
  class FirstFrameRepainter:
490
  def __init__(self, gpu_id=0, output_dir='outputs'):
491
  """Initialize FirstFrameRepainter
@@ -498,7 +506,8 @@ class FirstFrameRepainter:
498
  self.output_dir = output_dir
499
  self.max_depth = 65.0
500
  os.makedirs(output_dir, exist_ok=True)
501
-
 
502
  def repaint(self, image_tensor, prompt, depth_path=None, method="dav"):
503
  """Repaint first frame using Flux
504
 
 
24
  from submodules.MoGe.moge.model import MoGeModel
25
  from image_gen_aux import DepthPreprocessor
26
  from moviepy.editor import ImageSequenceClip
27
+ import spaces
28
 
29
  class DiffusionAsShaderPipeline:
30
  def __init__(self, gpu_id=0, output_dir='outputs'):
 
56
  transforms.ToTensor()
57
  ])
58
 
59
+ @spaces.GPU(duration=240)
60
  @torch.no_grad()
61
  def _infer(
62
  self,
 
183
 
184
  return intr
185
 
186
+ @spaces.GPU
187
  def _apply_poses(self, pts, intr, poses):
188
  """
189
  Args:
 
220
  return tracking_pts
221
 
222
  ##============= SpatialTracker =============##
223
+ @spaces.GPU
224
  def generate_tracking_spatracker(self, video_tensor, density=70):
225
  """Generate tracking video
226
 
 
275
  del tracker, self.depth_preprocessor
276
  torch.cuda.empty_cache()
277
 
278
+ @spaces.GPU
279
  def visualize_tracking_spatracker(self, video, pred_tracks, pred_visibility, T_Firsts, save_tracking=True):
280
  video = video.unsqueeze(0).to(self.device)
281
  vis = Visualizer(save_dir=self.output_dir, grayscale=False, fps=24, pad_value=0)
 
365
  outline=tuple(color),
366
  )
367
 
368
+ @spaces.GPU
369
  def visualize_tracking_moge(self, points, mask, save_tracking=True):
370
  """Visualize tracking results from MoGe model
371
 
 
452
 
453
  return tracking_path, tracking_video
454
 
455
+ @spaces.GPU(duration=240)
456
  def apply_tracking(self, video_tensor, fps=8, tracking_tensor=None, img_cond_tensor=None, prompt=None, checkpoint_path=None):
457
  """Generate final video with motion transfer
458
 
 
493
  """
494
  self.object_motion = motion_type
495
 
496
+ @spaces.GPU(duration=120)
497
  class FirstFrameRepainter:
498
  def __init__(self, gpu_id=0, output_dir='outputs'):
499
  """Initialize FirstFrameRepainter
 
506
  self.output_dir = output_dir
507
  self.max_depth = 65.0
508
  os.makedirs(output_dir, exist_ok=True)
509
+
510
+ @spaces.GPU(duration=120)
511
  def repaint(self, image_tensor, prompt, depth_path=None, method="dav"):
512
  """Repaint first frame using Flux
513