Beijia11 commited on
Commit
c4fce07
·
1 Parent(s): 6ded12b

merge demo.py and app.py

Browse files
Files changed (1) hide show
  1. app.py +272 -159
app.py CHANGED
@@ -2,20 +2,30 @@ import os
2
  import sys
3
  import gradio as gr
4
  import torch
5
- import subprocess
6
  import argparse
7
- import glob
8
- import spaces
 
 
 
9
 
10
  project_root = os.path.dirname(os.path.abspath(__file__))
11
  os.environ["GRADIO_TEMP_DIR"] = os.path.join(project_root, "tmp", "gradio")
12
  sys.path.append(project_root)
13
 
 
 
 
 
 
 
14
  HERE_PATH = os.path.normpath(os.path.dirname(__file__))
15
  sys.path.insert(0, HERE_PATH)
16
  from huggingface_hub import hf_hub_download
17
  hf_hub_download(repo_id="EXCAI/Diffusion-As-Shader", filename='spatracker/spaT_final.pth', local_dir=f'{HERE_PATH}/checkpoints/')
18
 
 
 
19
 
20
  # Parse command line arguments
21
  parser = argparse.ArgumentParser(description="Diffusion as Shader Web UI")
@@ -31,21 +41,53 @@ GPU_ID = args.gpu
31
  DEFAULT_MODEL_PATH = args.model_path
32
  OUTPUT_DIR = args.output_dir
33
 
34
- # if 'CUDA_HOME' not in os.environ:
35
- # for cuda_path in ['/usr/local/cuda', '/usr/cuda', '/opt/cuda']:
36
- # if os.path.exists(cuda_path):
37
- # os.environ['CUDA_HOME'] = cuda_path
38
- # print(cuda_path)
39
- # break
40
- # if 'CUDA_HOME' not in os.environ:
41
- # os.environ['CUDA_HOME'] = '/usr/local/cuda'
42
- # print("set default cuda path in: /usr/local/cuda")
43
-
44
  # Create necessary directories
45
  os.makedirs("outputs", exist_ok=True)
46
  # Create project tmp directory instead of using system temp
47
  os.makedirs(os.path.join(project_root, "tmp"), exist_ok=True)
48
  os.makedirs(os.path.join(project_root, "tmp", "gradio"), exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  def save_uploaded_file(file):
51
  if file is None:
@@ -86,59 +128,22 @@ def save_uploaded_file(file):
86
 
87
  return temp_path
88
 
89
- def create_run_command(args):
90
- """Create command based on input parameters"""
91
- cmd = ["python", "demo.py"]
92
-
93
- if "prompt" not in args or args["prompt"] is None or args["prompt"] == "":
94
- args["prompt"] = ""
95
- if "checkpoint_path" not in args or args["checkpoint_path"] is None or args["checkpoint_path"] == "":
96
- args["checkpoint_path"] = DEFAULT_MODEL_PATH
97
-
98
- # 添加调试输出
99
- print(f"DEBUG: Command args: {args}")
100
-
101
- for key, value in args.items():
102
- if value is not None:
103
- # Handle boolean values correctly - for repaint, we need to pass true/false
104
- if isinstance(value, bool):
105
- cmd.append(f"--{key}")
106
- cmd.append(str(value).lower()) # Convert True/False to true/false
107
- else:
108
- cmd.append(f"--{key}")
109
- cmd.append(str(value))
110
-
111
- return cmd
112
-
113
- @spaces.GPU(duration=240)
114
- def run_process(cmd):
115
- """Run command and return output"""
116
- print(f"Running command: {' '.join(cmd)}")
117
- process = subprocess.Popen(
118
- cmd,
119
- stdout=subprocess.PIPE,
120
- stderr=subprocess.PIPE,
121
- universal_newlines=True
122
- )
123
-
124
- output = []
125
- for line in iter(process.stdout.readline, ""):
126
- print(line, end="")
127
- output.append(line)
128
- if not line:
129
- break
130
-
131
- process.stdout.close()
132
- return_code = process.wait()
133
-
134
- if return_code:
135
- stderr = process.stderr.read()
136
- print(f"Error: {stderr}")
137
- raise subprocess.CalledProcessError(return_code, cmd, output="\n".join(output), stderr=stderr)
138
-
139
- return "\n".join(output)
140
 
141
- @spaces.GPU(duration=240)
142
  def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image):
143
  """Process video motion transfer task"""
144
  try:
@@ -150,42 +155,68 @@ def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image)
150
  print(f"DEBUG: Repaint option: {mt_repaint_option}")
151
  print(f"DEBUG: Repaint image: {mt_repaint_image}")
152
 
153
- args = {
154
- "input_path": input_video_path,
155
- "prompt": f"\"{prompt}\"",
156
- "checkpoint_path": DEFAULT_MODEL_PATH,
157
- "output_dir": OUTPUT_DIR,
158
- "gpu": GPU_ID
159
- }
 
160
 
161
- # Priority: Custom Image > Yes > No
162
  if mt_repaint_image is not None:
163
- # Custom image takes precedence if provided
164
  repaint_path = save_uploaded_file(mt_repaint_image)
165
- print(f"DEBUG: Repaint path: {repaint_path}")
166
- args["repaint"] = repaint_path
167
  elif mt_repaint_option == "Yes":
168
- # Otherwise use Yes/No selection
169
- args["repaint"] = "true"
170
-
171
- # Create and run command
172
- cmd = create_run_command(args)
173
- output = run_process(cmd)
174
-
175
- # Find generated video files
176
- output_files = glob.glob(os.path.join(OUTPUT_DIR, "*.mp4"))
177
- if output_files:
178
- # Sort by modification time, return the latest file
179
- latest_file = max(output_files, key=os.path.getmtime)
180
- return latest_file
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  else:
182
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  except Exception as e:
184
  import traceback
185
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
186
  return None
187
 
188
- @spaces.GPU(duration=240)
189
  def process_camera_control(source, prompt, camera_motion, tracking_method):
190
  """Process camera control task"""
191
  try:
@@ -197,36 +228,66 @@ def process_camera_control(source, prompt, camera_motion, tracking_method):
197
  print(f"DEBUG: Camera motion: '{camera_motion}'")
198
  print(f"DEBUG: Tracking method: '{tracking_method}'")
199
 
200
- args = {
201
- "input_path": input_media_path,
202
- "prompt": prompt,
203
- "checkpoint_path": DEFAULT_MODEL_PATH,
204
- "output_dir": OUTPUT_DIR,
205
- "gpu": GPU_ID,
206
- "tracking_method": tracking_method
207
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
- if camera_motion and camera_motion.strip():
210
- args["camera_motion"] = camera_motion
211
 
212
- # Create and run command
213
- cmd = create_run_command(args)
214
- output = run_process(cmd)
 
 
 
 
 
215
 
216
- # Find generated video files
217
- output_files = glob.glob(os.path.join(OUTPUT_DIR, "*.mp4"))
218
- if output_files:
219
- # Sort by modification time, return the latest file
220
- latest_file = max(output_files, key=os.path.getmtime)
221
- return latest_file
222
- else:
223
- return None
224
  except Exception as e:
225
  import traceback
226
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
227
  return None
228
 
229
- @spaces.GPU(duration=240)
230
  def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method):
231
  """Process object manipulation task"""
232
  try:
@@ -236,36 +297,90 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
236
  return None
237
 
238
  object_mask_path = save_uploaded_file(object_mask)
 
 
 
239
 
240
- args = {
241
- "input_path": input_image_path,
242
- "prompt": prompt,
243
- "checkpoint_path": DEFAULT_MODEL_PATH,
244
- "output_dir": OUTPUT_DIR,
245
- "gpu": GPU_ID,
246
- "object_motion": object_motion,
247
- "object_mask": object_mask_path,
248
- "tracking_method": tracking_method
249
- }
250
 
251
- # Create and run command
252
- cmd = create_run_command(args)
253
- output = run_process(cmd)
 
 
 
254
 
255
- # Find generated video files
256
- output_files = glob.glob(os.path.join(OUTPUT_DIR, "*.mp4"))
257
- if output_files:
258
- # Sort by modification time, return the latest file
259
- latest_file = max(output_files, key=os.path.getmtime)
260
- return latest_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  else:
262
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  except Exception as e:
264
  import traceback
265
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
266
  return None
267
 
268
- @spaces.GPU(duration=240)
269
  def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image):
270
  """Process mesh animation task"""
271
  try:
@@ -278,36 +393,34 @@ def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma
278
  if tracking_video_path is None:
279
  return None
280
 
281
- args = {
282
- "input_path": input_video_path,
283
- "prompt": prompt,
284
- "checkpoint_path": DEFAULT_MODEL_PATH,
285
- "output_dir": OUTPUT_DIR,
286
- "gpu": GPU_ID,
287
- "tracking_path": tracking_video_path
288
- }
289
 
290
- # Priority: Custom Image > Yes > No
 
 
 
291
  if ma_repaint_image is not None:
292
- # Custom image takes precedence if provided
293
  repaint_path = save_uploaded_file(ma_repaint_image)
294
- args["repaint"] = repaint_path
 
295
  elif ma_repaint_option == "Yes":
296
- # Otherwise use Yes/No selection
297
- args["repaint"] = "true"
298
-
299
- # Create and run command
300
- cmd = create_run_command(args)
301
- output = run_process(cmd)
 
 
 
 
 
 
 
 
 
 
302
 
303
- # Find generated video files
304
- output_files = glob.glob(os.path.join(OUTPUT_DIR, "*.mp4"))
305
- if output_files:
306
- # Sort by modification time, return the latest file
307
- latest_file = max(output_files, key=os.path.getmtime)
308
- return latest_file
309
- else:
310
- return None
311
  except Exception as e:
312
  import traceback
313
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
 
2
  import sys
3
  import gradio as gr
4
  import torch
 
5
  import argparse
6
+ from PIL import Image
7
+ import numpy as np
8
+ import torchvision.transforms as transforms
9
+ from moviepy.editor import VideoFileClip
10
+ from diffusers.utils import load_image, load_video
11
 
12
  project_root = os.path.dirname(os.path.abspath(__file__))
13
  os.environ["GRADIO_TEMP_DIR"] = os.path.join(project_root, "tmp", "gradio")
14
  sys.path.append(project_root)
15
 
16
+ try:
17
+ sys.path.append(os.path.join(project_root, "submodules/MoGe"))
18
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
+ except:
20
+ print("Warning: MoGe not found, motion transfer will not be applied")
21
+
22
  HERE_PATH = os.path.normpath(os.path.dirname(__file__))
23
  sys.path.insert(0, HERE_PATH)
24
  from huggingface_hub import hf_hub_download
25
  hf_hub_download(repo_id="EXCAI/Diffusion-As-Shader", filename='spatracker/spaT_final.pth', local_dir=f'{HERE_PATH}/checkpoints/')
26
 
27
+ from models.pipelines import DiffusionAsShaderPipeline, FirstFrameRepainter, CameraMotionGenerator, ObjectMotionGenerator
28
+ from submodules.MoGe.moge.model import MoGeModel
29
 
30
  # Parse command line arguments
31
  parser = argparse.ArgumentParser(description="Diffusion as Shader Web UI")
 
41
  DEFAULT_MODEL_PATH = args.model_path
42
  OUTPUT_DIR = args.output_dir
43
 
 
 
 
 
 
 
 
 
 
 
44
  # Create necessary directories
45
  os.makedirs("outputs", exist_ok=True)
46
  # Create project tmp directory instead of using system temp
47
  os.makedirs(os.path.join(project_root, "tmp"), exist_ok=True)
48
  os.makedirs(os.path.join(project_root, "tmp", "gradio"), exist_ok=True)
49
+ def load_media(media_path, max_frames=49, transform=None):
50
+ """Load video or image frames and convert to tensor
51
+
52
+ Args:
53
+ media_path (str): Path to video or image file
54
+ max_frames (int): Maximum number of frames to load
55
+ transform (callable): Transform to apply to frames
56
+
57
+ Returns:
58
+ Tuple[torch.Tensor, float, bool]: Video tensor [T,C,H,W], FPS, and is_video flag
59
+ """
60
+ if transform is None:
61
+ transform = transforms.Compose([
62
+ transforms.Resize((480, 720)),
63
+ transforms.ToTensor()
64
+ ])
65
+
66
+ # Determine if input is video or image based on extension
67
+ ext = os.path.splitext(media_path)[1].lower()
68
+ is_video = ext in ['.mp4', '.avi', '.mov']
69
+
70
+ if is_video:
71
+ frames = load_video(media_path)
72
+ fps = len(frames) / VideoFileClip(media_path).duration
73
+ else:
74
+ # Handle image as single frame
75
+ image = load_image(media_path)
76
+ frames = [image]
77
+ fps = 8 # Default fps for images
78
+
79
+ # Ensure we have exactly max_frames
80
+ if len(frames) > max_frames:
81
+ frames = frames[:max_frames]
82
+ elif len(frames) < max_frames:
83
+ last_frame = frames[-1]
84
+ while len(frames) < max_frames:
85
+ frames.append(last_frame.copy())
86
+
87
+ # Convert frames to tensor
88
+ video_tensor = torch.stack([transform(frame) for frame in frames])
89
+
90
+ return video_tensor, fps, is_video
91
 
92
  def save_uploaded_file(file):
93
  if file is None:
 
128
 
129
  return temp_path
130
 
131
+ das_pipeline = None
132
+ moge_model = None
133
+
134
+ def get_das_pipeline():
135
+ global das_pipeline
136
+ if das_pipeline is None:
137
+ das_pipeline = DiffusionAsShaderPipeline(gpu_id=GPU_ID, output_dir=OUTPUT_DIR)
138
+ return das_pipeline
139
+
140
+ def get_moge_model():
141
+ global moge_model
142
+ if moge_model is None:
143
+ das = get_das_pipeline()
144
+ moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device)
145
+ return moge_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
 
147
  def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image):
148
  """Process video motion transfer task"""
149
  try:
 
155
  print(f"DEBUG: Repaint option: {mt_repaint_option}")
156
  print(f"DEBUG: Repaint image: {mt_repaint_image}")
157
 
158
+
159
+ das = get_das_pipeline()
160
+ video_tensor, fps, is_video = load_media(input_video_path)
161
+ if not is_video:
162
+ tracking_method = "moge"
163
+ print("Image input detected, using MoGe for tracking video generation.")
164
+ else:
165
+ tracking_method = "spatracker"
166
 
167
+ repaint_img_tensor = None
168
  if mt_repaint_image is not None:
 
169
  repaint_path = save_uploaded_file(mt_repaint_image)
170
+ repaint_img_tensor, _, _ = load_media(repaint_path)
171
+ repaint_img_tensor = repaint_img_tensor[0]
172
  elif mt_repaint_option == "Yes":
173
+ repainter = FirstFrameRepainter(gpu_id=GPU_ID, output_dir=OUTPUT_DIR)
174
+ repaint_img_tensor = repainter.repaint(
175
+ video_tensor[0],
176
+ prompt=prompt,
177
+ depth_path=None
178
+ )
179
+ tracking_tensor = None
180
+ if tracking_method == "moge":
181
+ moge = get_moge_model()
182
+ infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1]
183
+ H, W = infer_result["points"].shape[0:2]
184
+ pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3]
185
+ poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1)
186
+
187
+ pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3)
188
+
189
+ cam_motion = CameraMotionGenerator(None)
190
+ cam_motion.set_intr(infer_result["intrinsics"])
191
+
192
+ pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
193
+
194
+ _, tracking_tensor = das.visualize_tracking_moge(
195
+ pred_tracks.cpu().numpy(),
196
+ infer_result["mask"].cpu().numpy()
197
+ )
198
+ print('Export tracking video via MoGe')
199
  else:
200
+ pred_tracks, pred_visibility, T_Firsts = das.generate_tracking_spatracker(video_tensor)
201
+
202
+ _, tracking_tensor = das.visualize_tracking_spatracker(video_tensor, pred_tracks, pred_visibility, T_Firsts)
203
+ print('Export tracking video via SpaTracker')
204
+
205
+ output_path = das.apply_tracking(
206
+ video_tensor=video_tensor,
207
+ fps=8,
208
+ tracking_tensor=tracking_tensor,
209
+ img_cond_tensor=repaint_img_tensor,
210
+ prompt=prompt,
211
+ checkpoint_path=DEFAULT_MODEL_PATH
212
+ )
213
+
214
+ return output_path
215
  except Exception as e:
216
  import traceback
217
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
218
  return None
219
 
 
220
  def process_camera_control(source, prompt, camera_motion, tracking_method):
221
  """Process camera control task"""
222
  try:
 
228
  print(f"DEBUG: Camera motion: '{camera_motion}'")
229
  print(f"DEBUG: Tracking method: '{tracking_method}'")
230
 
231
+ das = get_das_pipeline()
232
+
233
+ video_tensor, fps, is_video = load_media(input_media_path)
234
+ if not is_video and tracking_method == "spatracker":
235
+ tracking_method = "moge"
236
+ print("Image input detected with spatracker selected, switching to MoGe")
237
+
238
+ cam_motion = CameraMotionGenerator(camera_motion)
239
+ repaint_img_tensor = None
240
+ tracking_tensor = None
241
+
242
+ if tracking_method == "moge":
243
+ moge = get_moge_model()
244
+
245
+ infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1]
246
+ H, W = infer_result["points"].shape[0:2]
247
+ pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3]
248
+ cam_motion.set_intr(infer_result["intrinsics"])
249
+
250
+ if camera_motion:
251
+ poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
252
+ print("Camera motion applied")
253
+ else:
254
+ poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1)
255
+
256
+ pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3)
257
+ pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
258
+
259
+ _, tracking_tensor = das.visualize_tracking_moge(
260
+ pred_tracks.cpu().numpy(),
261
+ infer_result["mask"].cpu().numpy()
262
+ )
263
+ print('Export tracking video via MoGe')
264
+ else:
265
+
266
+ pred_tracks, pred_visibility, T_Firsts = das.generate_tracking_spatracker(video_tensor)
267
+ if camera_motion:
268
+ poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
269
+ pred_tracks = cam_motion.apply_motion_on_pts(pred_tracks, poses)
270
+ print("Camera motion applied")
271
+
272
+ _, tracking_tensor = das.visualize_tracking_spatracker(video_tensor, pred_tracks, pred_visibility, T_Firsts)
273
+ print('Export tracking video via SpaTracker')
274
 
 
 
275
 
276
+ output_path = das.apply_tracking(
277
+ video_tensor=video_tensor,
278
+ fps=8,
279
+ tracking_tensor=tracking_tensor,
280
+ img_cond_tensor=repaint_img_tensor,
281
+ prompt=prompt,
282
+ checkpoint_path=DEFAULT_MODEL_PATH
283
+ )
284
 
285
+ return output_path
 
 
 
 
 
 
 
286
  except Exception as e:
287
  import traceback
288
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
289
  return None
290
 
 
291
  def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method):
292
  """Process object manipulation task"""
293
  try:
 
297
  return None
298
 
299
  object_mask_path = save_uploaded_file(object_mask)
300
+ if object_mask_path is None:
301
+ print("Object mask not provided")
302
+ return None
303
 
 
 
 
 
 
 
 
 
 
 
304
 
305
+ das = get_das_pipeline()
306
+ video_tensor, fps, is_video = load_media(input_image_path)
307
+ if not is_video and tracking_method == "spatracker":
308
+ tracking_method = "moge"
309
+ print("Image input detected with spatracker selected, switching to MoGe")
310
+
311
 
312
+ mask_image = Image.open(object_mask_path).convert('L')
313
+ mask_image = transforms.Resize((480, 720))(mask_image)
314
+ mask = torch.from_numpy(np.array(mask_image) > 127)
315
+
316
+ motion_generator = ObjectMotionGenerator(device=das.device)
317
+ repaint_img_tensor = None
318
+ tracking_tensor = None
319
+ if tracking_method == "moge":
320
+ moge = get_moge_model()
321
+
322
+
323
+ infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1]
324
+ H, W = infer_result["points"].shape[0:2]
325
+ pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3]
326
+
327
+ pred_tracks = motion_generator.apply_motion(
328
+ pred_tracks=pred_tracks,
329
+ mask=mask,
330
+ motion_type=object_motion,
331
+ distance=50,
332
+ num_frames=49,
333
+ tracking_method="moge"
334
+ )
335
+ print(f"Object motion '{object_motion}' applied using provided mask")
336
+ poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1)
337
+ pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3)
338
+
339
+
340
+ cam_motion = CameraMotionGenerator(None)
341
+ cam_motion.set_intr(infer_result["intrinsics"])
342
+ pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
343
+
344
+ _, tracking_tensor = das.visualize_tracking_moge(
345
+ pred_tracks.cpu().numpy(),
346
+ infer_result["mask"].cpu().numpy()
347
+ )
348
+ print('Export tracking video via MoGe')
349
  else:
350
+
351
+ pred_tracks, pred_visibility, T_Firsts = das.generate_tracking_spatracker(video_tensor)
352
+
353
+
354
+ pred_tracks = motion_generator.apply_motion(
355
+ pred_tracks=pred_tracks.squeeze(),
356
+ mask=mask,
357
+ motion_type=object_motion,
358
+ distance=50,
359
+ num_frames=49,
360
+ tracking_method="spatracker"
361
+ ).unsqueeze(0)
362
+ print(f"Object motion '{object_motion}' applied using provided mask")
363
+
364
+
365
+ _, tracking_tensor = das.visualize_tracking_spatracker(video_tensor, pred_tracks, pred_visibility, T_Firsts)
366
+ print('Export tracking video via SpaTracker')
367
+
368
+
369
+ output_path = das.apply_tracking(
370
+ video_tensor=video_tensor,
371
+ fps=8,
372
+ tracking_tensor=tracking_tensor,
373
+ img_cond_tensor=repaint_img_tensor,
374
+ prompt=prompt,
375
+ checkpoint_path=DEFAULT_MODEL_PATH
376
+ )
377
+
378
+ return output_path
379
  except Exception as e:
380
  import traceback
381
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
382
  return None
383
 
 
384
  def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image):
385
  """Process mesh animation task"""
386
  try:
 
393
  if tracking_video_path is None:
394
  return None
395
 
 
 
 
 
 
 
 
 
396
 
397
+ das = get_das_pipeline()
398
+ video_tensor, fps, is_video = load_media(input_video_path)
399
+ tracking_tensor, tracking_fps, _ = load_media(tracking_video_path)
400
+ repaint_img_tensor = None
401
  if ma_repaint_image is not None:
 
402
  repaint_path = save_uploaded_file(ma_repaint_image)
403
+ repaint_img_tensor, _, _ = load_media(repaint_path)
404
+ repaint_img_tensor = repaint_img_tensor[0] # 获取第一帧
405
  elif ma_repaint_option == "Yes":
406
+
407
+ repainter = FirstFrameRepainter(gpu_id=GPU_ID, output_dir=OUTPUT_DIR)
408
+ repaint_img_tensor = repainter.repaint(
409
+ video_tensor[0],
410
+ prompt=prompt,
411
+ depth_path=None
412
+ )
413
+
414
+ output_path = das.apply_tracking(
415
+ video_tensor=video_tensor,
416
+ fps=8,
417
+ tracking_tensor=tracking_tensor,
418
+ img_cond_tensor=repaint_img_tensor,
419
+ prompt=prompt,
420
+ checkpoint_path=DEFAULT_MODEL_PATH
421
+ )
422
 
423
+ return output_path
 
 
 
 
 
 
 
424
  except Exception as e:
425
  import traceback
426
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")