franciszzj commited on
Commit
16c2627
·
1 Parent(s): bafa7b2

update gradio app

Browse files
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import numpy as np
2
  from PIL import Image
 
3
  from leffa.transform import LeffaTransform
4
  from leffa.model import LeffaModel
5
  from leffa.inference import LeffaInference
@@ -8,6 +9,9 @@ from utils.densepose_predictor import DensePosePredictor
8
 
9
  import gradio as gr
10
 
 
 
 
11
 
12
  def leffa_predict(src_image_path, ref_image_path, control_type):
13
  assert control_type in [
@@ -20,14 +24,20 @@ def leffa_predict(src_image_path, ref_image_path, control_type):
20
 
21
  # Mask
22
  if control_type == "virtual_tryon":
23
- automasker = AutoMasker()
 
 
 
24
  src_image = src_image.convert("RGB")
25
  mask = automasker(src_image, "upper")["mask"]
26
  elif control_type == "pose_transfer":
27
  mask = Image.fromarray(np.ones_like(src_image_array) * 255)
28
 
29
  # DensePose
30
- densepose_predictor = DensePosePredictor()
 
 
 
31
  src_image_iuv_array = densepose_predictor.predict_iuv(src_image_array)
32
  src_image_seg_array = densepose_predictor.predict_seg(src_image_array)
33
  src_image_iuv = Image.fromarray(src_image_iuv_array)
@@ -72,43 +82,56 @@ if __name__ == "__main__":
72
  # control_type = sys.argv[3]
73
  # leffa_predict(src_image_path, ref_image_path, control_type)
74
 
75
- # Launch a gr.Interface
76
- gr_demo = gr.Interface(
77
- fn=leffa_predict,
78
- inputs=[
79
- gr.Image(sources=["upload", "webcam", "clipboard"],
80
- type="filepath",
81
- label="Source Person Image",
82
- width=768,
83
- height=1024,
84
- ),
85
- gr.Image(sources=["upload", "webcam", "clipboard"],
86
- type="filepath",
87
- label="Reference Image",
88
- width=768,
89
- height=1024,
90
- ),
91
- gr.Radio(["virtual_tryon", "pose_transfer"],
92
- label="Control Type",
93
- default="virtual_tryon",
94
- ),
95
- ],
96
- outputs=[
97
- gr.Image(label="Generated Person Image",
98
- width=768,
99
- height=1024,
100
- )
101
- ],
102
- title="Leffa",
103
- description="Leffa is a unified framework for controllable person image generation that enables precise manipulation of both appearance (i.e., virtual try-on) and pose (i.e., pose transfer).",
104
- article="Controllable person image generation aims to generate a person image conditioned on reference images, allowing precise control over the person’s appearance or pose. However, prior methods often distort fine-grained textural details from the reference image, despite achieving high overall image quality. We attribute these distortions to inadequate attention to corresponding regions in the reference image. To address this, we thereby propose \textbf{learning flow fields in attention} (\textbf{\ours{}}), which explicitly guides the target query to attend to the correct reference key in the attention layer during training. Specifically, it is realized via a regularization loss on top of the attention map within a diffusion-based baseline. Our extensive experiments show that Leffa achieves state-of-the-art performance in controlling appearance (virtual try-on) and pose (pose transfer), significantly reducing fine-grained detail distortion while maintaining high image quality. Additionally, we show that our loss is model-agnostic and can be used to improve the performance of other diffusion models.",
105
- examples=[
106
- ["./examples/14092_00_person.jpg", "./examples/04181_00_garment.jpg", "virtual_tryon"],
107
- ["./examples/14092_00_person.jpg", "./examples/14684_00_person.jpg", "pose_transfer"],
108
- ],
109
- # cache_examples=True,
110
- examples_per_page=10,
111
- allow_flagging=False,
112
- theme=gr.themes.Default(),
113
- )
114
- gr_demo.launch(share=True, server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  from PIL import Image
3
+ from huggingface_hub import snapshot_download
4
  from leffa.transform import LeffaTransform
5
  from leffa.model import LeffaModel
6
  from leffa.inference import LeffaInference
 
9
 
10
  import gradio as gr
11
 
12
+ # Download checkpoints
13
+ snapshot_download(repo_id="franciszzj/Leffa", local_dir="./")
14
+
15
 
16
  def leffa_predict(src_image_path, ref_image_path, control_type):
17
  assert control_type in [
 
24
 
25
  # Mask
26
  if control_type == "virtual_tryon":
27
+ automasker = AutoMasker(
28
+ densepose_path="./ckpts/densepose",
29
+ schp_path="./ckpts/schp",
30
+ )
31
  src_image = src_image.convert("RGB")
32
  mask = automasker(src_image, "upper")["mask"]
33
  elif control_type == "pose_transfer":
34
  mask = Image.fromarray(np.ones_like(src_image_array) * 255)
35
 
36
  # DensePose
37
+ densepose_predictor = DensePosePredictor(
38
+ config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
39
+ weights_path="./ckpts/densepose/model_final_162be9.pkl",
40
+ )
41
  src_image_iuv_array = densepose_predictor.predict_iuv(src_image_array)
42
  src_image_seg_array = densepose_predictor.predict_seg(src_image_array)
43
  src_image_iuv = Image.fromarray(src_image_iuv_array)
 
82
  # control_type = sys.argv[3]
83
  # leffa_predict(src_image_path, ref_image_path, control_type)
84
 
85
+ with gr.Blocks().queue() as demo:
86
+ gr.Markdown(
87
+ "## Leffa: Learning Flow Fields in Attention for Controllable Person Image Generation")
88
+ gr.Markdown("Leffa is a unified framework for controllable person image generation that enables precise manipulation of both appearance (i.e., virtual try-on) and pose (i.e., pose transfer).")
89
+ with gr.Row():
90
+ with gr.Column():
91
+ src_image = gr.Image(
92
+ sources=["upload"],
93
+ type="filepath",
94
+ label="Source Person Image",
95
+ width=384,
96
+ height=512,
97
+ )
98
+ with gr.Row():
99
+ control_type = gr.Dropdown(
100
+ ["virtual_tryon", "pose_transfer"], label="Control Type")
101
+
102
+ example = gr.Examples(
103
+ inputs=src_image,
104
+ examples_per_page=10,
105
+ examples=["./examples/14684_00_person.jpg",
106
+ "./examples/14092_00_person.jpg"],
107
+ )
108
+
109
+ with gr.Column():
110
+ ref_image = gr.Image(
111
+ sources=["upload"],
112
+ type="filepath",
113
+ label="Reference Image",
114
+ width=384,
115
+ height=512,
116
+ )
117
+ with gr.Row():
118
+ gen_button = gr.Button("Generate")
119
+
120
+ example = gr.Examples(
121
+ inputs=ref_image,
122
+ examples_per_page=10,
123
+ examples=["./examples/04181_00_garment.jpg",
124
+ "./examples/14684_00_person.jpg"],
125
+ )
126
+
127
+ with gr.Column():
128
+ gen_image = gr.Image(
129
+ label="Generated Person Image",
130
+ width=384,
131
+ height=512,
132
+ )
133
+
134
+ gen_button.click(fn=leffa_predict, inputs=[
135
+ src_image, ref_image, control_type], outputs=[gen_image])
136
+
137
+ demo.launch(share=True, server_port=7860)
utils/densepose_predictor.py CHANGED
@@ -10,13 +10,15 @@ from detectron2.engine import DefaultPredictor
10
 
11
 
12
  class DensePosePredictor(object):
13
- def __init__(self):
 
 
 
14
  cfg = get_cfg()
15
  add_densepose_config(cfg)
16
  cfg.merge_from_file(
17
- "ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml"
18
- ) # Use the path to the config file from densepose
19
- cfg.MODEL.WEIGHTS = "ckpts/densepose/model_final_162be9.pkl" # Use the path to the pre-trained model weights
20
  cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
  cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # Adjust as needed
22
  self.predictor = DefaultPredictor(cfg)
 
10
 
11
 
12
  class DensePosePredictor(object):
13
+ def __init__(self,
14
+ config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
15
+ weights_path="./ckpts/densepose/model_final_162be9.pkl"
16
+ ):
17
  cfg = get_cfg()
18
  add_densepose_config(cfg)
19
  cfg.merge_from_file(
20
+ config_path) # Use the path to the config file from densepose
21
+ cfg.MODEL.WEIGHTS = weights_path # Use the path to the pre-trained model weights
 
22
  cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
  cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # Adjust as needed
24
  self.predictor = DefaultPredictor(cfg)
utils/garment_agnostic_mask_predictor.py CHANGED
@@ -200,21 +200,21 @@ def hull_mask(mask_area: np.ndarray):
200
  class AutoMasker:
201
  def __init__(
202
  self,
 
 
203
  device="cuda",
204
  ):
205
- densepose_ckpt = "./ckpts/densepose"
206
- schp_ckpt = "./ckpts/schp"
207
  np.random.seed(0)
208
  torch.manual_seed(0)
209
  torch.cuda.manual_seed(0)
210
 
211
- self.densepose_processor = DensePose(densepose_ckpt, device)
212
  self.schp_processor_atr = SCHP(
213
- ckpt_path=os.path.join(schp_ckpt, "exp-schp-201908301523-atr.pth"),
214
  device=device,
215
  )
216
  self.schp_processor_lip = SCHP(
217
- ckpt_path=os.path.join(schp_ckpt, "exp-schp-201908261155-lip.pth"),
218
  device=device,
219
  )
220
 
 
200
  class AutoMasker:
201
  def __init__(
202
  self,
203
+ densepose_path: str = "./ckpts/densepose",
204
+ schp_path: str = "./ckpts/schp",
205
  device="cuda",
206
  ):
 
 
207
  np.random.seed(0)
208
  torch.manual_seed(0)
209
  torch.cuda.manual_seed(0)
210
 
211
+ self.densepose_processor = DensePose(densepose_path, device)
212
  self.schp_processor_atr = SCHP(
213
+ ckpt_path=os.path.join(schp_path, "exp-schp-201908301523-atr.pth"),
214
  device=device,
215
  )
216
  self.schp_processor_lip = SCHP(
217
+ ckpt_path=os.path.join(schp_path, "exp-schp-201908261155-lip.pth"),
218
  device=device,
219
  )
220