JiantaoLin
commited on
Commit
·
52a094c
1
Parent(s):
0213646
new
Browse files
pipeline/kiss3d_wrapper.py
CHANGED
@@ -90,7 +90,8 @@ def init_wrapper_from_config(config_path):
|
|
90 |
|
91 |
# load lora weights
|
92 |
flux_pipe.load_lora_weights(flux_lora_pth)
|
93 |
-
flux_pipe.to(device=flux_device)
|
|
|
94 |
# flux_pipe = None
|
95 |
|
96 |
# load redux model
|
@@ -163,7 +164,7 @@ def init_wrapper_from_config(config_path):
|
|
163 |
logger.info('==> Loading LLM ...')
|
164 |
llm_device = llm_configs.get('device', 'cpu')
|
165 |
llm, llm_tokenizer = load_llm_model(llm_configs['base_model'])
|
166 |
-
llm.to(llm_device)
|
167 |
# logger.warning(f"GPU memory allocated after load llm model on {llm_device}: {torch.cuda.memory_allocated(device=llm_device) / 1024**3} GB")
|
168 |
else:
|
169 |
llm, llm_tokenizer = None, None
|
@@ -267,11 +268,14 @@ class kiss3d_wrapper(object):
|
|
267 |
return caption_text
|
268 |
# @spaces.GPU
|
269 |
def get_detailed_prompt(self, prompt, seed=None):
|
|
|
270 |
if self.llm_model is not None:
|
271 |
detailed_prompt = get_llm_response(self.llm_model, self.llm_tokenizer, prompt, seed=seed)
|
272 |
|
273 |
logger.info(f"LLM refined prompt result: \"{detailed_prompt}\"")
|
274 |
return detailed_prompt
|
|
|
|
|
275 |
return prompt
|
276 |
|
277 |
def del_llm_model(self):
|
|
|
90 |
|
91 |
# load lora weights
|
92 |
flux_pipe.load_lora_weights(flux_lora_pth)
|
93 |
+
# flux_pipe.to(device=flux_device)
|
94 |
+
flux_pipe.enable_model_cpu_offload(device=flux_device)
|
95 |
# flux_pipe = None
|
96 |
|
97 |
# load redux model
|
|
|
164 |
logger.info('==> Loading LLM ...')
|
165 |
llm_device = llm_configs.get('device', 'cpu')
|
166 |
llm, llm_tokenizer = load_llm_model(llm_configs['base_model'])
|
167 |
+
# llm.to(llm_device)
|
168 |
# logger.warning(f"GPU memory allocated after load llm model on {llm_device}: {torch.cuda.memory_allocated(device=llm_device) / 1024**3} GB")
|
169 |
else:
|
170 |
llm, llm_tokenizer = None, None
|
|
|
268 |
return caption_text
|
269 |
# @spaces.GPU
|
270 |
def get_detailed_prompt(self, prompt, seed=None):
|
271 |
+
self.llm_model.to(self.config['llm']['device'])
|
272 |
if self.llm_model is not None:
|
273 |
detailed_prompt = get_llm_response(self.llm_model, self.llm_tokenizer, prompt, seed=seed)
|
274 |
|
275 |
logger.info(f"LLM refined prompt result: \"{detailed_prompt}\"")
|
276 |
return detailed_prompt
|
277 |
+
self.llm_model.to('cpu')
|
278 |
+
torch.cuda.empty_cache()
|
279 |
return prompt
|
280 |
|
281 |
def del_llm_model(self):
|
pipeline/pipeline_config/default.yaml
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
flux:
|
2 |
base_model: "https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"
|
3 |
-
flux_dtype: '
|
4 |
lora: "./checkpoint/flux_lora/rgb_normal_large.safetensors"
|
5 |
controlnet: "InstantX/FLUX.1-dev-Controlnet-Union"
|
6 |
redux: "black-forest-labs/FLUX.1-Redux-dev"
|
|
|
1 |
flux:
|
2 |
base_model: "https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"
|
3 |
+
flux_dtype: 'bf16'
|
4 |
lora: "./checkpoint/flux_lora/rgb_normal_large.safetensors"
|
5 |
controlnet: "InstantX/FLUX.1-dev-Controlnet-Union"
|
6 |
redux: "black-forest-labs/FLUX.1-Redux-dev"
|