mishig HF staff commited on
Commit
3806189
1 Parent(s): ecd427f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -63
app.py CHANGED
@@ -1,16 +1,15 @@
1
- from controlnet_aux import OpenposeDetector
2
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
3
  from diffusers import UniPCMultistepScheduler
4
  import gradio as gr
5
  import torch
6
  import base64
7
  from io import BytesIO
8
- from PIL import Image
9
- # live conditioning
10
- canvas_html = "<pose-canvas id='canvas-root' style='display:flex;max-width: 500px;margin: 0 auto;'></pose-canvas>"
11
  load_js = """
12
  async () => {
13
- const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/pose-gradio.js"
14
  fetch(url)
15
  .then(res => res.text())
16
  .then(text => {
@@ -21,22 +20,18 @@ async () => {
21
  });
22
  }
23
  """
 
24
  get_js_image = """
25
- async (image_in_img, prompt, image_file_live_opt, live_conditioning) => {
26
- const canvasEl = document.getElementById("canvas-root");
27
- const data = canvasEl? canvasEl._data : null;
28
- return [image_in_img, prompt, image_file_live_opt, data]
29
  }
30
  """
31
 
32
- # Constants
33
- low_threshold = 100
34
- high_threshold = 200
35
-
36
  # Models
37
- pose_model = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
38
  controlnet = ControlNetModel.from_pretrained(
39
- "lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16
40
  )
41
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
42
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
@@ -54,62 +49,73 @@ pipe.enable_xformers_memory_efficient_attention()
54
  generator = torch.manual_seed(0)
55
 
56
 
57
- def get_pose(image):
58
- return pose_model(image)
59
-
60
-
61
- def generate_images(image, prompt, image_file_live_opt='file', live_conditioning=None):
62
- if image is None and 'image' not in live_conditioning:
63
- raise gr.Error("Please provide an image")
64
  try:
65
- if image_file_live_opt == 'file':
66
- pose = get_pose(image)
67
- elif image_file_live_opt == 'webcam':
68
- base64_img = live_conditioning['image']
69
- image_data = base64.b64decode(base64_img.split(',')[1])
70
- pose = Image.open(BytesIO(image_data)).convert(
71
- 'RGB').resize((512, 512))
72
  output = pipe(
73
  prompt,
74
- pose,
75
  generator=generator,
76
- num_images_per_prompt=3,
77
  num_inference_steps=20,
78
  )
79
  all_outputs = []
80
- all_outputs.append(pose)
81
  for image in output.images:
82
  all_outputs.append(image)
83
  return all_outputs
84
  except Exception as e:
85
  raise gr.Error(str(e))
86
 
 
 
87
 
88
- def toggle(choice):
89
- if choice == "file":
90
- return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
91
- elif choice == "webcam":
92
- return gr.update(visible=False, value=None), gr.update(visible=True, value=canvas_html)
 
93
 
 
 
 
 
 
 
94
 
95
  with gr.Blocks() as blocks:
96
- gr.Markdown("""
97
- ## Generate controlled outputs with ControlNet and Stable Diffusion
98
- This Space uses pose estimated lines as the additional conditioning
99
- [Check out our blog to see how this was done (and train your own controlnet)](https://huggingface.co/blog/train-your-controlnet)
100
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  with gr.Row():
102
- live_conditioning = gr.JSON(value={}, visible=False)
103
  with gr.Column():
104
- image_file_live_opt = gr.Radio(["file", "webcam"], value="file",
105
- label="How would you like to upload your image?")
106
- image_in_img = gr.Image(source="upload", visible=True, type="pil")
107
- canvas = gr.HTML(None, elem_id="canvas_html", visible=False)
108
-
109
- image_file_live_opt.change(fn=toggle,
110
- inputs=[image_file_live_opt],
111
- outputs=[image_in_img, canvas],
112
- queue=False)
113
  prompt = gr.Textbox(
114
  label="Enter your prompt",
115
  max_lines=1,
@@ -118,20 +124,20 @@ with gr.Blocks() as blocks:
118
  run_button = gr.Button("Generate")
119
  with gr.Column():
120
  gallery = gr.Gallery().style(grid=[2], height="auto")
 
 
 
 
 
 
 
 
 
 
121
  run_button.click(fn=generate_images,
122
- inputs=[image_in_img, prompt,
123
- image_file_live_opt, live_conditioning],
124
  outputs=[gallery],
125
  _js=get_js_image)
126
  blocks.load(None, None, None, _js=load_js)
127
 
128
- gr.Examples(fn=generate_images,
129
- examples=[
130
- ["./yoga1.jpeg",
131
- "best quality, extremely detailed"]
132
- ],
133
- inputs=[image_in_img, prompt],
134
- outputs=[gallery],
135
- cache_examples=True)
136
-
137
  blocks.launch(debug=True)
 
 
1
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
2
  from diffusers import UniPCMultistepScheduler
3
  import gradio as gr
4
  import torch
5
  import base64
6
  from io import BytesIO
7
+ from PIL import Image, ImageFilter
8
+
9
+ canvas_html = '<pose-maker/>'
10
  load_js = """
11
  async () => {
12
+ const url = "https://huggingface.co/datasets/mishig/gradio-components/raw/main/mannequinAll.js"
13
  fetch(url)
14
  .then(res => res.text())
15
  .then(text => {
 
20
  });
21
  }
22
  """
23
+
24
  get_js_image = """
25
+ async (canvas, prompt) => {
26
+ const poseMakerEl = document.querySelector("pose-maker");
27
+ const imgBase64 = poseMakerEl.captureScreenshot();
28
+ return [imgBase64, prompt]
29
  }
30
  """
31
 
 
 
 
 
32
  # Models
 
33
  controlnet = ControlNetModel.from_pretrained(
34
+ "lllyasviel/sd-controlnet-depth", torch_dtype=torch.float16
35
  )
36
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
37
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
 
49
  generator = torch.manual_seed(0)
50
 
51
 
52
+ def generate_images(canvas, prompt):
 
 
 
 
 
 
53
  try:
54
+ base64_img = canvas
55
+ image_data = base64.b64decode(base64_img.split(',')[1])
56
+ input_img = Image.open(BytesIO(image_data)).convert(
57
+ 'RGB').resize((512, 512))
58
+ input_img = input_img.filter(ImageFilter.GaussianBlur(radius=5))
 
 
59
  output = pipe(
60
  prompt,
61
+ input_img,
62
  generator=generator,
63
+ num_images_per_prompt=2,
64
  num_inference_steps=20,
65
  )
66
  all_outputs = []
 
67
  for image in output.images:
68
  all_outputs.append(image)
69
  return all_outputs
70
  except Exception as e:
71
  raise gr.Error(str(e))
72
 
73
+ def placeholder_fn(axis):
74
+ pass
75
 
76
+ js_change_rotation_axis = """
77
+ async (axis) => {
78
+ const poseMakerEl = document.querySelector("pose-maker");
79
+ poseMakerEl.changeRotationAxis(axis);
80
+ }
81
+ """
82
 
83
+ js_pose_template = """
84
+ async (pose) => {
85
+ const poseMakerEl = document.querySelector("pose-maker");
86
+ poseMakerEl.setPose(pose);
87
+ }
88
+ """
89
 
90
  with gr.Blocks() as blocks:
91
+ gr.HTML(
92
+ """
93
+ <div style="text-align: center; margin: 0 auto;">
94
+ <div
95
+ style="
96
+ display: inline-flex;
97
+ align-items: center;
98
+ gap: 0.8rem;
99
+ font-size: 1.75rem;
100
+ "
101
+ >
102
+ <h1 style="font-weight: 900; margin-bottom: 7px;margin-top:5px">
103
+ Pose in 3D & Render with ControlNet (SD-1.5)
104
+ </h1>
105
+ </div>
106
+ <p style="margin-bottom: 10px; font-size: 94%; line-height: 23px;">
107
+ Using <a href="https://github.com/lllyasviel/ControlNet">ControlNet</a> and <a href="https://boytchev.github.io/mannequin.js/">three.js/mannequin.js</a>
108
+ </p>
109
+ <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>
110
+ </div>
111
+ """
112
+ )
113
  with gr.Row():
 
114
  with gr.Column():
115
+ canvas = gr.HTML(canvas_html, elem_id="canvas_html", visible=True)
116
+ with gr.Row():
117
+ rotation_axis = gr.Radio(["x", "y", "z"], value="x", label="Joint rotation axis")
118
+ pose_template = gr.Radio(["regular", "ballet", "handstand", "split", "kick", "chilling"], value="regular", label="Pose template")
 
 
 
 
 
119
  prompt = gr.Textbox(
120
  label="Enter your prompt",
121
  max_lines=1,
 
124
  run_button = gr.Button("Generate")
125
  with gr.Column():
126
  gallery = gr.Gallery().style(grid=[2], height="auto")
127
+ rotation_axis.change(fn=placeholder_fn,
128
+ inputs=[rotation_axis],
129
+ outputs=[],
130
+ queue=False,
131
+ _js=js_change_rotation_axis)
132
+ pose_template.change(fn=placeholder_fn,
133
+ inputs=[pose_template],
134
+ outputs=[],
135
+ queue=False,
136
+ _js=js_pose_template)
137
  run_button.click(fn=generate_images,
138
+ inputs=[canvas, prompt],
 
139
  outputs=[gallery],
140
  _js=get_js_image)
141
  blocks.load(None, None, None, _js=load_js)
142
 
 
 
 
 
 
 
 
 
 
143
  blocks.launch(debug=True)