StarryXL-Demo / app.py
eienmojiki's picture
Update app.py
4de31af verified
import os
import gc
import random
from typing import Callable, Dict, Optional, Tuple
import gradio as gr
import numpy as np
import PIL.Image
import spaces
import torch
from transformers import CLIPTextModel
from diffusers import AutoencoderKL, StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
MODEL = "eienmojiki/Starry-XL-v5.2"
HF_TOKEN = os.getenv("HF_TOKEN")
MIN_IMAGE_SIZE = int(os.getenv("MIN_IMAGE_SIZE", "512"))
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
sampler_list = [
"DPM++ 2M Karras",
"DPM++ SDE Karras",
"DPM++ 2M SDE Karras",
"Euler",
"Euler a",
"DDIM",
]
examples = [
[
"""1girl, midori \(blue archive\), blue archive,
(ningen mame:0.9), ciloranko, sho \(sho lwlw\), (tianliang duohe fangdongye:0.8), ask \(askzy\), wlop,
indoors, plant, hair bow, cake, cat ears, food, smile, animal ear headphones, bare legs, short shorts, drawing \(object\), feet, legs, on back, bed, solo, green eyes, cat, table, window blinds, headphones, nintendo switch, toes, bow, toenails, looking at viewer, chips \(food\), potted plant, halo, calendar \(object\), tray, blonde hair, green halo, lying, barefoot, bare shoulders, blunt bangs, green shorts, picture frame, fake animal ears, closed mouth, shorts, handheld game console, green bow, animal ears, on bed, medium hair, knees up, upshorts, eating, potato chips, pillow, blush, dolphin shorts, ass, character doll, alternate costume,
masterpiece, newest, absurdres""",
"""bad anatomy,blurry,(worst quality:1.8),low quality,hands bad,face bad,(normal quality:1.3),bad hands,mutated hands and fingers,extra legs,extra arms,duplicate,cropped,text,jpeg,artifacts,signature,watermark,username,blurry,artist name,trademark,title,multiple view,Reference sheet,long body,multiple breasts,mutated,bad anatomy,disfigured,bad proportions,duplicate,bad feet,artist name,ugly,text font ui,missing limb,monochrome,""",
1399560451,
896,
1152,
5.0,
26,
"DPM++ 2M SDE Karras",
2
]
]
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
def seed_everything(seed: int) -> torch.Generator:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
generator = torch.Generator()
generator.manual_seed(seed)
return generator
def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]:
scheduler_factory_map = {
"DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config(
scheduler_config, use_karras_sigmas=True
),
"DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config(
scheduler_config, use_karras_sigmas=True
),
"DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config(
scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++"
),
"Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config),
"Euler a": lambda: EulerAncestralDiscreteScheduler.from_config(scheduler_config),
"DDIM": lambda: DDIMScheduler.from_config(scheduler_config),
}
return scheduler_factory_map.get(name, lambda: None)()
def load_pipeline(model_name):
if torch.cuda.is_available():
pipe = StableDiffusionXLPipeline.from_pretrained(
model_name,
torch_dtype=torch.float16,
custom_pipeline="lpw_stable_diffusion_xl",
safety_checker=None,
use_safetensors=True,
add_watermarker=False,
use_auth_token=HF_TOKEN
)
pipe.to(device)
return pipe
def common_upscale(
samples: torch.Tensor,
width: int,
height: int,
upscale_method: str,
) -> torch.Tensor:
return torch.nn.functional.interpolate(
samples, size=(height, width), mode=upscale_method
)
def upscale(
samples: torch.Tensor, upscale_method: str, scale_by: float
) -> torch.Tensor:
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
return common_upscale(samples, width, height, upscale_method)
def free_memory() -> None:
torch.cuda.empty_cache()
gc.collect()
@spaces.GPU(enable_queue=False)
def generate(
prompt: str,
negative_prompt: str = None,
seed: int = 0,
width: int = 1024,
height: int = 1024,
guidance_scale: float = 5.0,
num_inference_steps: int = 26,
sampler: str = "Euler a",
clip_skip: int = 1,
progress=gr.Progress(track_tqdm=True),
):
generator = seed_everything(seed)
pipe.scheduler = get_scheduler(pipe.scheduler.config, sampler)
pipe.text_encoder = CLIPTextModel.from_pretrained(
MODEL,
subfolder = "text_encoder",
num_hidden_layers = 12 - (clip_skip - 1),
torch_dtype = torch.float16
)
pipe.to(device)
try:
gr.Info("Generating image...")
img = pipe(
prompt = prompt,
negative_prompt = negative_prompt,
width = width,
height = height,
guidance_scale = guidance_scale,
num_inference_steps = num_inference_steps,
generator = generator,
num_images_per_prompt=1,
output_type="pil",
).images[0]
return img, seed
except Exception as e:
print(f"An error occurred: {e}")
if torch.cuda.is_available():
pipe = load_pipeline(MODEL)
else:
pipe = None
with gr.Blocks(
theme=gr.themes.Base(
font = [gr.themes.GoogleFont("Teachers"), "Arial", "sans-serif"],
primary_hue="rose",
secondary_hue="pink"
)
) as demo:
gr.HTML(
"""
<style>
.title-container {
display: flex;
justify-content: center;
align-items: center;
height: 80vh; /* Adjust this value to position the title vertically */
}
.title {
font-size: 1.5em;
text-align: center;
color: #333;
font-family: 'Helvetica Neue', sans-serif;
text-transform: uppercase;
letter-spacing: 0.1em;
padding: 0.5em 0;
background: transparent;
}
.title span {
background: -webkit-linear-gradient(45deg, #FFBF00, #F28C28);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
}
</style>
<h1 class="title"><span>Starry XL 5.2</span></h1>
<p>Explore <strong>Starry XL</strong> with this intuitive demo! We always welcome your feedback, so please share any issues you encounter in the Community tab. If you love Starry XL, contribute to its development by creating a pull request.</p>
<p><strong>Model page:</strong> <a href="https://huggingface.co./eienmojiki/Starry-XL-v5.2">eienmojiki/Starry-XL-v5.2</a></p>
"""
)
with gr.Group():
prompt = gr.Text(
info="Your prompt here OwO",
label="Prompt",
placeholder="Tips: Follow the instruction at the model page for better prompt."
)
negative_prompt = gr.Text(
info="Enter your negative prompt here",
label="Negative Prompt",
placeholder="(Optional)"
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
sampler = gr.Dropdown(
label="Sampler",
choices=sampler_list,
interactive=True,
value="Euler a",
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=1,
maximum=20,
step=0.1,
value=5.0,
)
num_inference_steps = gr.Slider(
label="Steps",
minimum=10,
maximum=100,
step=1,
value=25,
)
clip_skip = gr.Slider(
label="Clip Skip",
minimum=1,
maximum=2,
step=1,
value=1
)
run_button = gr.Button("Run")
result = gr.Image(
label="Result",
show_label=False
)
with gr.Group():
used_seed = gr.Number(label="Used Seed", interactive=False)
gr.Examples(
examples=examples,
inputs=[
prompt,
negative_prompt,
seed,
width,
height,
guidance_scale,
num_inference_steps,
sampler,
clip_skip
],
outputs=[result, used_seed],
fn=lambda *args, **kwargs: generate(*args, **kwargs),
cache_examples=CACHE_EXAMPLES,
)
gr.on(
triggers=[
prompt.submit,
negative_prompt.submit,
run_button.click,
],
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=generate,
inputs=[
prompt,
negative_prompt,
seed,
width,
height,
guidance_scale,
num_inference_steps,
sampler,
clip_skip
],
outputs=[result, used_seed],
api_name="run"
)
if __name__ == "__main__":
demo.queue(max_size=20).launch(show_error=True)