JiantaoLin commited on
Commit
30d56f8
·
1 Parent(s): 52a094c
Files changed (1) hide show
  1. pipeline/kiss3d_wrapper.py +3 -2
pipeline/kiss3d_wrapper.py CHANGED
@@ -91,7 +91,7 @@ def init_wrapper_from_config(config_path):
91
  # load lora weights
92
  flux_pipe.load_lora_weights(flux_lora_pth)
93
  # flux_pipe.to(device=flux_device)
94
- flux_pipe.enable_model_cpu_offload(device=flux_device)
95
  # flux_pipe = None
96
 
97
  # load redux model
@@ -490,6 +490,7 @@ class kiss3d_wrapper(object):
490
  """
491
  return: gen_3d_bundle_image, torch.Tensor of shape (3, 1024, 2048), range [0., 1.]
492
  """
 
493
  print(f"==> generate_3d_bundle_image_text: {prompt}")
494
  if isinstance(self.flux_pipeline, FluxImg2ImgPipeline):
495
  flux_pipeline = self.flux_pipeline
@@ -548,7 +549,7 @@ class kiss3d_wrapper(object):
548
  torchvision.utils.save_image(gen_3d_bundle_image_, save_path)
549
  logger.info(f"Save generated 3D bundle image to {save_path}")
550
  return gen_3d_bundle_image_, save_path
551
-
552
  return gen_3d_bundle_image_
553
 
554
  def reconstruct_3d_bundle_image(self,
 
91
  # load lora weights
92
  flux_pipe.load_lora_weights(flux_lora_pth)
93
  # flux_pipe.to(device=flux_device)
94
+ # flux_pipe.enable_model_cpu_offload(device=flux_device)
95
  # flux_pipe = None
96
 
97
  # load redux model
 
490
  """
491
  return: gen_3d_bundle_image, torch.Tensor of shape (3, 1024, 2048), range [0., 1.]
492
  """
493
+ self.flux_pipeline.to(self.config['flux'].get('device', 'cpu'))
494
  print(f"==> generate_3d_bundle_image_text: {prompt}")
495
  if isinstance(self.flux_pipeline, FluxImg2ImgPipeline):
496
  flux_pipeline = self.flux_pipeline
 
549
  torchvision.utils.save_image(gen_3d_bundle_image_, save_path)
550
  logger.info(f"Save generated 3D bundle image to {save_path}")
551
  return gen_3d_bundle_image_, save_path
552
+ self.flux_pipeline.to('cpu')
553
  return gen_3d_bundle_image_
554
 
555
  def reconstruct_3d_bundle_image(self,