guardiancc commited on
Commit
46da058
·
verified ·
1 Parent(s): c63eb09

Update mimicmotion/pipelines/pipeline_mimicmotion.py

Browse files
mimicmotion/pipelines/pipeline_mimicmotion.py CHANGED
@@ -16,11 +16,12 @@ from diffusers.schedulers import EulerDiscreteScheduler
16
  from diffusers.utils import BaseOutput, logging
17
  from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
18
  from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
 
19
 
20
  from ..modules.pose_net import PoseNet
21
 
22
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
23
-
24
 
25
  def _append_dims(x, target_dims):
26
  """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
@@ -221,29 +222,37 @@ class MimicMotionPipeline(DiffusionPipeline):
221
  decode_chunk_size: int = 8):
222
  # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
223
  latents = latents.flatten(0, 1)
224
-
225
  latents = 1 / self.vae.config.scaling_factor * latents
226
-
227
  forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
228
  accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
229
-
230
- # decode decode_chunk_size frames at a time to avoid OOM
231
- frames = []
232
- for i in range(0, latents.shape[0], decode_chunk_size):
233
- num_frames_in = latents[i: i + decode_chunk_size].shape[0]
234
  decode_kwargs = {}
235
  if accepts_num_frames:
236
- # we only pass num_frames_in if it's expected
237
- decode_kwargs["num_frames"] = num_frames_in
238
-
239
- frame = self.vae.decode(latents[i: i + decode_chunk_size], **decode_kwargs).sample
240
- frames.append(frame.cpu())
241
- frames = torch.cat(frames, dim=0)
242
-
 
 
 
 
 
 
 
 
 
 
243
  # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
 
244
  frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
245
-
246
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
247
  frames = frames.float()
248
  return frames
249
 
 
16
  from diffusers.utils import BaseOutput, logging
17
  from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
18
  from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
19
+ import threading
20
 
21
  from ..modules.pose_net import PoseNet
22
 
23
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
+ import concurrent.futures
25
 
26
  def _append_dims(x, target_dims):
27
  """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
 
222
  decode_chunk_size: int = 8):
223
  # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
224
  latents = latents.flatten(0, 1)
 
225
  latents = 1 / self.vae.config.scaling_factor * latents
226
+
227
  forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
228
  accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
229
+
230
+ # Função auxiliar para processar um chunk de frames
231
+ def process_chunk(start, end, frames_list):
 
 
232
  decode_kwargs = {}
233
  if accepts_num_frames:
234
+ decode_kwargs["num_frames"] = end - start
235
+ frame = self.vae.decode(latents[start:end], **decode_kwargs).sample
236
+ frames_list.append(frame.cpu())
237
+
238
+ threads = []
239
+ frames = []
240
+
241
+ # Dividindo o trabalho em chunks e criando threads para processá-los
242
+ for i in range(0, latents.shape[0], decode_chunk_size):
243
+ t = threading.Thread(target=process_chunk, args=(i, i + decode_chunk_size, frames))
244
+ threads.append(t)
245
+ t.start()
246
+
247
+ # Aguardando todas as threads terminarem
248
+ for t in threads:
249
+ t.join()
250
+
251
  # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
252
+ frames = torch.cat(frames, dim=0)
253
  frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
254
+
255
+ # Cast para float32 para compatibilidade com bfloat16
256
  frames = frames.float()
257
  return frames
258