File size: 4,308 Bytes
fd1c028
 
631c9e2
fd1c028
092fcaa
 
 
fd1c028
621ea3a
631c9e2
 
 
092fcaa
6d9201f
092fcaa
d639c7d
631c9e2
 
fd1c028
 
 
 
 
6d9201f
092fcaa
631c9e2
fd1c028
092fcaa
6d9201f
631c9e2
6d9201f
 
 
 
 
 
 
 
 
 
 
fd1c028
6d9201f
fd1c028
6d9201f
d639c7d
 
6d9201f
d639c7d
 
9901ecd
 
d639c7d
6d9201f
d639c7d
9901ecd
631c9e2
6d9201f
631c9e2
 
6d9201f
9901ecd
631c9e2
9901ecd
6d9201f
9901ecd
6d9201f
9901ecd
 
 
 
 
 
 
631c9e2
d639c7d
6d9201f
9901ecd
 
 
 
 
6d9201f
9901ecd
 
 
 
 
 
 
 
 
 
 
 
 
fd1c028
6d9201f
9901ecd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import torch
from diffusers import StableDiffusionPanoramaPipeline, DDIMScheduler
import sa_handler
import pipeline_calls


# init models
model_ckpt = "stabilityai/stable-diffusion-2-base"
scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
pipeline = StableDiffusionPanoramaPipeline.from_pretrained(
     model_ckpt, scheduler=scheduler, torch_dtype=torch.float16
).to("cuda")
# Configure the pipeline for CPU offloading and VAE slicing
pipeline.enable_model_cpu_offload()
pipeline.enable_vae_slicing()
sa_args = sa_handler.StyleAlignedArgs(share_group_norm=True,
                                      share_layer_norm=True,
                                      share_attention=True,
                                      adain_queries=True,
                                      adain_keys=True,
                                      adain_values=False,
                                     )
# Initialize the style-aligned handler
handler = sa_handler.Handler(pipeline)
handler.register(sa_args)


# Define the function to run MultiDiffusion with StyleAligned
def style_aligned_multidiff(ref_style_prompt, img_generation_prompt):
    try:
        view_batch_size = 25  # adjust according to VRAM size
        reference_latent = torch.randn(1, 4, 64, 64,)
        images = pipeline_calls.panorama_call(pipeline,
                                              [ref_style_prompt, img_generation_prompt],
                                              reference_latent=reference_latent,
                                              view_batch_size=view_batch_size)
    
        return images, gr.Image(value=images[0], visible=True)
    except Exception as e:
        raise gr.Error(f"Error in generating images:{e}")

# Create a Gradio UI
with gr.Blocks() as demo:
    gr.HTML('<h1 style="text-align: center;">Style-aligned with MultiDiffusion</h1>')
    with gr.Row():
      with gr.Column(variant='panel'):
        # Textbox for reference style prompt
        ref_style_prompt = gr.Textbox(
          label='Reference style prompt',
          info='Enter a Prompt to generate the reference image',
          placeholder='A poster in a papercut art style.'
        )
        # Image display for the reference style image
        ref_style_image = gr.Image(visible=False, label='Reference style image')

      with gr.Column(variant='panel'):
        # Textbox for prompt for MultiDiffusion panoramas
        img_generation_prompt = gr.Textbox(
          label='MultiDiffusion Prompt',
          info='Enter a Prompt to generate panoramic images using Style-aligned combined with MultiDiffusion',
          placeholder= 'A village in a papercut art style.'
          )

    # Button to trigger image generation
    btn = gr.Button('Style-aligned MultiDiffusion - Generate', size='sm')
    # Gallery to display generated style image and the panorama
    gallery = gr.Gallery(label='Style-Aligned ControlNet - Generated images',
                           elem_id='gallery',
                           columns=5,
                           rows=1,
                           object_fit='contain',
                           height='auto',
                           allow_preview=True,
                           preview=True,
                          )
    # Button click event
    btn.click(fn=style_aligned_multidiff,
              inputs=[ref_style_prompt, img_generation_prompt],
              outputs=[gallery, ref_style_image],
              api_name='style_aligned_multidiffusion')

    # Example inputs for the Gradio demo
    gr.Examples(
      examples=[
        ['A poster in a papercut art style.', 'A village in a papercut art style.'],
        ['A poster in a papercut art style.', 'Futuristic cityscape in a papercut art style.'],
        ['A poster in a papercut art style.', 'A jungle in a papercut art style.'],
        ['A poster in a flat design style.', 'Girrafes in a flat design style.'],
        ['A poster in a flat design style.', 'Houses in a flat design style.'],
        ['A poster in a flat design style.', 'Mountains in a flat design style.'],
      ],
      inputs=[ref_style_prompt, img_generation_prompt],
      outputs=[gallery, ref_style_image],
      fn=style_aligned_multidiff,
      )

# Launch the Gradio demo
demo.launch()