Wuvin commited on
Commit
8bfc447
1 Parent(s): 2fc8dce
gradio_app/custom_models/mvimg_prediction.py CHANGED
@@ -13,9 +13,9 @@ training_config = "gradio_app/custom_models/image2mvimage.yaml"
13
  checkpoint_path = "ckpt/img2mvimg/unet_state_dict.pth"
14
 
15
  trainer, pipeline = load_pipeline(training_config, checkpoint_path)
16
- pipeline.enable_model_cpu_offload()
17
 
18
  def predict(img_list: List[Image.Image], guidance_scale=2., **kwargs):
 
19
  if isinstance(img_list, Image.Image):
20
  img_list = [img_list]
21
  img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list]
 
13
  checkpoint_path = "ckpt/img2mvimg/unet_state_dict.pth"
14
 
15
  trainer, pipeline = load_pipeline(training_config, checkpoint_path)
 
16
 
17
  def predict(img_list: List[Image.Image], guidance_scale=2., **kwargs):
18
+ pipeline.enable_model_cpu_offload()
19
  if isinstance(img_list, Image.Image):
20
  img_list = [img_list]
21
  img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list]