JiantaoLin commited on
Commit
52a094c
·
1 Parent(s): 0213646
pipeline/kiss3d_wrapper.py CHANGED
@@ -90,7 +90,8 @@ def init_wrapper_from_config(config_path):
90
 
91
  # load lora weights
92
  flux_pipe.load_lora_weights(flux_lora_pth)
93
- flux_pipe.to(device=flux_device)
 
94
  # flux_pipe = None
95
 
96
  # load redux model
@@ -163,7 +164,7 @@ def init_wrapper_from_config(config_path):
163
  logger.info('==> Loading LLM ...')
164
  llm_device = llm_configs.get('device', 'cpu')
165
  llm, llm_tokenizer = load_llm_model(llm_configs['base_model'])
166
- llm.to(llm_device)
167
  # logger.warning(f"GPU memory allocated after load llm model on {llm_device}: {torch.cuda.memory_allocated(device=llm_device) / 1024**3} GB")
168
  else:
169
  llm, llm_tokenizer = None, None
@@ -267,11 +268,14 @@ class kiss3d_wrapper(object):
267
  return caption_text
268
  # @spaces.GPU
269
  def get_detailed_prompt(self, prompt, seed=None):
 
270
  if self.llm_model is not None:
271
  detailed_prompt = get_llm_response(self.llm_model, self.llm_tokenizer, prompt, seed=seed)
272
 
273
  logger.info(f"LLM refined prompt result: \"{detailed_prompt}\"")
274
  return detailed_prompt
 
 
275
  return prompt
276
 
277
  def del_llm_model(self):
 
90
 
91
  # load lora weights
92
  flux_pipe.load_lora_weights(flux_lora_pth)
93
+ # flux_pipe.to(device=flux_device)
94
+ flux_pipe.enable_model_cpu_offload(device=flux_device)
95
  # flux_pipe = None
96
 
97
  # load redux model
 
164
  logger.info('==> Loading LLM ...')
165
  llm_device = llm_configs.get('device', 'cpu')
166
  llm, llm_tokenizer = load_llm_model(llm_configs['base_model'])
167
+ # llm.to(llm_device)
168
  # logger.warning(f"GPU memory allocated after load llm model on {llm_device}: {torch.cuda.memory_allocated(device=llm_device) / 1024**3} GB")
169
  else:
170
  llm, llm_tokenizer = None, None
 
268
  return caption_text
269
  # @spaces.GPU
270
  def get_detailed_prompt(self, prompt, seed=None):
271
+ self.llm_model.to(self.config['llm']['device'])
272
  if self.llm_model is not None:
273
  detailed_prompt = get_llm_response(self.llm_model, self.llm_tokenizer, prompt, seed=seed)
274
 
275
  logger.info(f"LLM refined prompt result: \"{detailed_prompt}\"")
276
  return detailed_prompt
277
+ self.llm_model.to('cpu')
278
+ torch.cuda.empty_cache()
279
  return prompt
280
 
281
  def del_llm_model(self):
pipeline/pipeline_config/default.yaml CHANGED
@@ -1,6 +1,6 @@
1
  flux:
2
  base_model: "https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"
3
- flux_dtype: 'fp8'
4
  lora: "./checkpoint/flux_lora/rgb_normal_large.safetensors"
5
  controlnet: "InstantX/FLUX.1-dev-Controlnet-Union"
6
  redux: "black-forest-labs/FLUX.1-Redux-dev"
 
1
  flux:
2
  base_model: "https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"
3
+ flux_dtype: 'bf16'
4
  lora: "./checkpoint/flux_lora/rgb_normal_large.safetensors"
5
  controlnet: "InstantX/FLUX.1-dev-Controlnet-Union"
6
  redux: "black-forest-labs/FLUX.1-Redux-dev"