JiantaoLin
commited on
Commit
·
30d56f8
1
Parent(s):
52a094c
new
Browse files
pipeline/kiss3d_wrapper.py
CHANGED
@@ -91,7 +91,7 @@ def init_wrapper_from_config(config_path):
|
|
91 |
# load lora weights
|
92 |
flux_pipe.load_lora_weights(flux_lora_pth)
|
93 |
# flux_pipe.to(device=flux_device)
|
94 |
-
flux_pipe.enable_model_cpu_offload(device=flux_device)
|
95 |
# flux_pipe = None
|
96 |
|
97 |
# load redux model
|
@@ -490,6 +490,7 @@ class kiss3d_wrapper(object):
|
|
490 |
"""
|
491 |
return: gen_3d_bundle_image, torch.Tensor of shape (3, 1024, 2048), range [0., 1.]
|
492 |
"""
|
|
|
493 |
print(f"==> generate_3d_bundle_image_text: {prompt}")
|
494 |
if isinstance(self.flux_pipeline, FluxImg2ImgPipeline):
|
495 |
flux_pipeline = self.flux_pipeline
|
@@ -548,7 +549,7 @@ class kiss3d_wrapper(object):
|
|
548 |
torchvision.utils.save_image(gen_3d_bundle_image_, save_path)
|
549 |
logger.info(f"Save generated 3D bundle image to {save_path}")
|
550 |
return gen_3d_bundle_image_, save_path
|
551 |
-
|
552 |
return gen_3d_bundle_image_
|
553 |
|
554 |
def reconstruct_3d_bundle_image(self,
|
|
|
91 |
# load lora weights
|
92 |
flux_pipe.load_lora_weights(flux_lora_pth)
|
93 |
# flux_pipe.to(device=flux_device)
|
94 |
+
# flux_pipe.enable_model_cpu_offload(device=flux_device)
|
95 |
# flux_pipe = None
|
96 |
|
97 |
# load redux model
|
|
|
490 |
"""
|
491 |
return: gen_3d_bundle_image, torch.Tensor of shape (3, 1024, 2048), range [0., 1.]
|
492 |
"""
|
493 |
+
self.flux_pipeline.to(self.config['flux'].get('device', 'cpu'))
|
494 |
print(f"==> generate_3d_bundle_image_text: {prompt}")
|
495 |
if isinstance(self.flux_pipeline, FluxImg2ImgPipeline):
|
496 |
flux_pipeline = self.flux_pipeline
|
|
|
549 |
torchvision.utils.save_image(gen_3d_bundle_image_, save_path)
|
550 |
logger.info(f"Save generated 3D bundle image to {save_path}")
|
551 |
return gen_3d_bundle_image_, save_path
|
552 |
+
self.flux_pipeline.to('cpu')
|
553 |
return gen_3d_bundle_image_
|
554 |
|
555 |
def reconstruct_3d_bundle_image(self,
|