ysharma HF staff commited on
Commit
631c9e2
·
1 Parent(s): d639c7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -65
app.py CHANGED
@@ -1,106 +1,77 @@
1
  import gradio as gr
2
- from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
3
- from diffusers.utils import load_image
4
- from transformers import DPTImageProcessor, DPTForDepthEstimation
5
  import torch
 
6
  import mediapy
7
  import sa_handler
8
  import pipeline_calls
9
 
10
 
11
-
12
  # init models
13
-
14
- depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
15
- feature_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
16
-
17
- controlnet = ControlNetModel.from_pretrained(
18
- "diffusers/controlnet-depth-sdxl-1.0",
19
- variant="fp16",
20
- use_safetensors=True,
21
- torch_dtype=torch.float16,
22
- ).to("cuda")
23
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda")
24
- pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
25
- "stabilityai/stable-diffusion-xl-base-1.0",
26
- controlnet=controlnet,
27
- vae=vae,
28
- variant="fp16",
29
- use_safetensors=True,
30
- torch_dtype=torch.float16,
31
  ).to("cuda")
32
  pipeline.enable_model_cpu_offload()
33
  pipeline.enable_vae_slicing()
34
-
35
- sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
36
- share_layer_norm=False,
37
  share_attention=True,
38
  adain_queries=True,
39
  adain_keys=True,
40
  adain_values=False,
41
  )
42
  handler = sa_handler.Handler(pipeline)
43
- handler.register(sa_args, )
44
-
45
 
46
 
47
 
48
- # run ControlNet depth with StyleAligned
49
- def style_aligned_controlnet(ref_style_prompt, depth_map, ref_image, img_generation_prompt):
50
- if depth_map == True:
51
- image = load_image(ref_image)
52
- depth_image = pipeline_calls.get_depth_map(image, feature_processor, depth_estimator)
53
- else:
54
- depth_image = load_image(ref_image).resize((1024, 1024))
55
- #reference_prompt = ref_style_prompt #"a poster in minimalist origami style"
56
- #target_prompts = img_generation_prompt #["mona lisa"] #, "gal gadot"]
57
- controlnet_conditioning_scale = 0.8
58
- num_images_per_prompt = 3 # adjust according to VRAM size
59
- latents = torch.randn(1 + num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)
60
- latents[1:] = torch.randn(num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)
61
- images = pipeline_calls.controlnet_call(pipeline, [ref_style_prompt, img_generation_prompt],
62
- image=depth_image,
63
- num_inference_steps=50,
64
- controlnet_conditioning_scale=controlnet_conditioning_scale,
65
- num_images_per_prompt=num_images_per_prompt,
66
- latents=latents)
67
- #mediapy.show_images([images[0], depth_image2] + images[1:], titles=["reference", "depth"] + [f'result {i}' for i in range(1, len(images))])
68
- return [images[0], depth_image] + images[1:], gr.Image(value=images[0], visible=True)
69
 
70
 
71
  with gr.Blocks() as demo:
72
-
73
  with gr.Row():
74
-
75
  with gr.Column(variant='panel'):
76
  ref_style_prompt = gr.Textbox(
77
  label='Reference style prompt',
78
- info="Enter a Prompt to generate the reference image", placeholder='a poster in <style name> style'
 
79
  )
80
- depth_map = gr.Checkbox(label='Depth-map',)
81
  ref_style_image = gr.Image(visible=False, label='Reference style image')
82
-
83
- with gr.Column(variant='panel'):
84
- ref_image = gr.Image(label="Upload the reference image",
85
- type='filepath' )
86
- img_generation_prompt = gr.Textbox(
87
- label='ControlNet Prompt',
88
- info="Enter a Prompt to generate images using ControlNet and Style-aligned",
89
- )
90
 
91
- btn = gr.Button("Generate", size='sm')
 
 
 
 
 
 
 
 
92
  gallery = gr.Gallery(label="Style-Aligned ControlNet - Generated images",
93
  elem_id="gallery",
94
  columns=5,
95
  rows=1,
96
  object_fit="contain",
97
  height="auto",
 
 
98
  )
99
 
100
- btn.click(fn=style_aligned_controlnet,
101
- inputs=[ref_style_prompt, depth_map, ref_image, img_generation_prompt],
102
- outputs=[gallery, ref_style_image],
103
- api_name="style_aligned_controlnet")
104
 
105
 
106
  demo.launch()
 
1
  import gradio as gr
 
 
 
2
  import torch
3
+ from diffusers import StableDiffusionPanoramaPipeline, DDIMScheduler
4
  import mediapy
5
  import sa_handler
6
  import pipeline_calls
7
 
8
 
 
9
  # init models
10
+ model_ckpt = "stability/stable-diffusion-2-base"
11
+ scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
12
+ pipeline = StableDiffusionPanoramaPipeline.from_pretrained(
13
+ model_ckpt, scheduler=scheduler, torch_dtype=torch.float16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  ).to("cuda")
15
  pipeline.enable_model_cpu_offload()
16
  pipeline.enable_vae_slicing()
17
+ sa_args = sa_handler.StyleAlignedArgs(share_group_norm=True,
18
+ share_layer_norm=True,
 
19
  share_attention=True,
20
  adain_queries=True,
21
  adain_keys=True,
22
  adain_values=False,
23
  )
24
  handler = sa_handler.Handler(pipeline)
25
+ handler.register(sa_args)
 
26
 
27
 
28
 
29
+ # run MultiDiffusion with StyleAligned
30
+ def style_aligned_multidiff(ref_style_prompt, img_generation_prompt):
31
+ view_batch_size = 25 # adjust according to VRAM size
32
+ reference_latent = torch.randn(1, 4, 64, 64,)
33
+ for target_prompt in target_prompts:
34
+ images = pipeline_calls.panorama_call(pipeline,
35
+ [ref_style_prompt, img_generation_prompt],
36
+ reference_latent=reference_latent,
37
+ view_batch_size=view_batch_size)
38
+
39
+ return images, gr.Image(value=images[0], visible=True)
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42
  with gr.Blocks() as demo:
 
43
  with gr.Row():
 
44
  with gr.Column(variant='panel'):
45
  ref_style_prompt = gr.Textbox(
46
  label='Reference style prompt',
47
+ info="Enter a Prompt to generate the reference image",
48
+ placeholder='a beautiful papercut art design'
49
  )
 
50
  ref_style_image = gr.Image(visible=False, label='Reference style image')
 
 
 
 
 
 
 
 
51
 
52
+ with gr.Column(variant='panel'):
53
+ img_generation_prompt = gr.Textbox(
54
+ label='MultiDiffusion Prompt',
55
+ info="Enter a Prompt to generate panaromic images using Style-aligned combined with MultiDiffusion",
56
+ )
57
+
58
+
59
+ btn = gr.Button("Style-aligned MultiDiffusion - Generate", size='sm')
60
+
61
  gallery = gr.Gallery(label="Style-Aligned ControlNet - Generated images",
62
  elem_id="gallery",
63
  columns=5,
64
  rows=1,
65
  object_fit="contain",
66
  height="auto",
67
+ allow_preview=True,
68
+ preview=True,
69
  )
70
 
71
+ btn.click(fn=style_aligned_multidiff,
72
+ inputs=[ref_style_prompt, img_generation_prompt],
73
+ outputs=gallery,
74
+ api_name="style_aligned_multidiffusion")
75
 
76
 
77
  demo.launch()