Spaces:
Runtime error
Runtime error
File size: 4,308 Bytes
fd1c028 631c9e2 fd1c028 092fcaa fd1c028 621ea3a 631c9e2 092fcaa 6d9201f 092fcaa d639c7d 631c9e2 fd1c028 6d9201f 092fcaa 631c9e2 fd1c028 092fcaa 6d9201f 631c9e2 6d9201f fd1c028 6d9201f fd1c028 6d9201f d639c7d 6d9201f d639c7d 9901ecd d639c7d 6d9201f d639c7d 9901ecd 631c9e2 6d9201f 631c9e2 6d9201f 9901ecd 631c9e2 9901ecd 6d9201f 9901ecd 6d9201f 9901ecd 631c9e2 d639c7d 6d9201f 9901ecd 6d9201f 9901ecd fd1c028 6d9201f 9901ecd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import gradio as gr
import torch
from diffusers import StableDiffusionPanoramaPipeline, DDIMScheduler
import sa_handler
import pipeline_calls
# init models
model_ckpt = "stabilityai/stable-diffusion-2-base"
scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
pipeline = StableDiffusionPanoramaPipeline.from_pretrained(
model_ckpt, scheduler=scheduler, torch_dtype=torch.float16
).to("cuda")
# Configure the pipeline for CPU offloading and VAE slicing
pipeline.enable_model_cpu_offload()
pipeline.enable_vae_slicing()
sa_args = sa_handler.StyleAlignedArgs(share_group_norm=True,
share_layer_norm=True,
share_attention=True,
adain_queries=True,
adain_keys=True,
adain_values=False,
)
# Initialize the style-aligned handler
handler = sa_handler.Handler(pipeline)
handler.register(sa_args)
# Define the function to run MultiDiffusion with StyleAligned
def style_aligned_multidiff(ref_style_prompt, img_generation_prompt):
try:
view_batch_size = 25 # adjust according to VRAM size
reference_latent = torch.randn(1, 4, 64, 64,)
images = pipeline_calls.panorama_call(pipeline,
[ref_style_prompt, img_generation_prompt],
reference_latent=reference_latent,
view_batch_size=view_batch_size)
return images, gr.Image(value=images[0], visible=True)
except Exception as e:
raise gr.Error(f"Error in generating images:{e}")
# Create a Gradio UI
with gr.Blocks() as demo:
gr.HTML('<h1 style="text-align: center;">Style-aligned with MultiDiffusion</h1>')
with gr.Row():
with gr.Column(variant='panel'):
# Textbox for reference style prompt
ref_style_prompt = gr.Textbox(
label='Reference style prompt',
info='Enter a Prompt to generate the reference image',
placeholder='A poster in a papercut art style.'
)
# Image display for the reference style image
ref_style_image = gr.Image(visible=False, label='Reference style image')
with gr.Column(variant='panel'):
# Textbox for prompt for MultiDiffusion panoramas
img_generation_prompt = gr.Textbox(
label='MultiDiffusion Prompt',
info='Enter a Prompt to generate panoramic images using Style-aligned combined with MultiDiffusion',
placeholder= 'A village in a papercut art style.'
)
# Button to trigger image generation
btn = gr.Button('Style-aligned MultiDiffusion - Generate', size='sm')
# Gallery to display generated style image and the panorama
gallery = gr.Gallery(label='Style-Aligned ControlNet - Generated images',
elem_id='gallery',
columns=5,
rows=1,
object_fit='contain',
height='auto',
allow_preview=True,
preview=True,
)
# Button click event
btn.click(fn=style_aligned_multidiff,
inputs=[ref_style_prompt, img_generation_prompt],
outputs=[gallery, ref_style_image],
api_name='style_aligned_multidiffusion')
# Example inputs for the Gradio demo
gr.Examples(
examples=[
['A poster in a papercut art style.', 'A village in a papercut art style.'],
['A poster in a papercut art style.', 'Futuristic cityscape in a papercut art style.'],
['A poster in a papercut art style.', 'A jungle in a papercut art style.'],
['A poster in a flat design style.', 'Girrafes in a flat design style.'],
['A poster in a flat design style.', 'Houses in a flat design style.'],
['A poster in a flat design style.', 'Mountains in a flat design style.'],
],
inputs=[ref_style_prompt, img_generation_prompt],
outputs=[gallery, ref_style_image],
fn=style_aligned_multidiff,
)
# Launch the Gradio demo
demo.launch() |