primecai
init
3f03890
raw
history blame
5.79 kB
import gradio as gr
import torch
from PIL import Image
from diffusers.utils import load_image
from pipeline import FluxConditionalPipeline
from transformer import FluxTransformer2DConditionalModel
import os
pipe = None
CHECKPOINT = "primecai/dsd_model"
def init_pipeline():
global pipe
transformer = FluxTransformer2DConditionalModel.from_pretrained(
os.path.join(CHECKPOINT, "transformer"),
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=False,
ignore_mismatched_sizes=True,
)
pipe = FluxConditionalPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
)
pipe.load_lora_weights(os.path.join(CHECKPOINT, "pytorch_lora_weights.safetensors"))
pipe.to("cuda")
def process_image_and_text(image, text, gemini_prompt, guidance, i_guidance, t_guidance):
w, h, min_size = image.size[0], image.size[1], min(image.size)
image = image.crop(
((w - min_size) // 2, (h - min_size) // 2, (w + min_size) // 2, (h + min_size) // 2)
).resize((512, 512))
if pipe is None:
init_pipeline()
control_image = load_image(image)
result_image = pipe(
prompt=text.strip(),
negative_prompt="",
num_inference_steps=28,
height=512,
width=1024,
guidance_scale=guidance,
image=control_image,
guidance_scale_real_i=i_guidance,
guidance_scale_real_t=t_guidance,
gemini_prompt=gemini_prompt,
).images[0]
return result_image
def get_samples():
sample_list = [
{
"image": "assets/wanrong_character.png",
"text": "A chibi-style girl with pink hair, green eyes, wearing a black and gold ornate dress, dancing gracefully in a flower garden, anime art style with clean and detailed lines.",
},
{
"image": "assets/ben_character_squared.png",
"text": "A confident green-eye young woman with platinum blonde hair in a high ponytail, wearing an oversized orange jacket and black pants, is striking a dynamic pose, anime-style with sharp details and vibrant colors.",
},
{
"image": "assets/seededit_example.png",
"text": "an adorable small creature with big round orange eyes, fluffy brown fur, wearing a blue scarf with a golden charm, sitting atop a towering stack of colorful books in the middle of a vibrant futuristic city street with towering buildings and glowing neon signs, soft daylight illuminating the scene, detailed and whimsical 3D style.",
},
{
"image": "assets/action_hero_figure.jpeg",
"text": "A cartoonish muscular action hero figure with long blue hair and red headband sits on a crowded sidewalk on a Christmas evening, covered in snow and wearing a Christmas hat, holding a sign that reads 'DSD!', dramatic cinematic lighting, close-up view, 3D-rendered in a stylized, vibrant art style.",
},
{
"image": "assets/anime_soldier.jpeg",
"text": "An adorable cartoon goat soldier sits under a beach umbrella with 'DSD!' written on it, bright teal background with soft lighting, 3D-rendered in a playful and vibrant art style.",
},
{
"image": "assets/goat_logo.jpeg",
"text": "A shirt with this logo on it.",
},
{
"image": "assets/cartoon_cat.png",
"text": "A cheerful cartoon orange cat sits under a beach umbrella with 'DSD!' written on it under a sunny sky, simplistic and humorous comic art style.",
},
]
return [[Image.open(sample["image"]), sample["text"]] for sample in sample_list]
demo = gr.Blocks()
with demo:
gr.Markdown(
f"""
<div align="center">
## Diffusion Self-Distillation (beta)
<a href="https://primecai.github.io/dsd/" target="_blank"><img src="https://img.shields.io/badge/Project-Website-blue" style="display:inline-block;"></a>
<a href="https://github.com/primecai/diffusion-self-distillation" target="_blank"><img src="https://img.shields.io/github/stars/primecai/diffusion-self-distillation?label=GitHub%20%E2%98%85&logo=github&color=C8C" style="display:inline-block;"></a>
<a href="https://huggingface.co./papers/2411.18616" target="_blank"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face%20-Space-yellow" style="display:inline-block;"></a>
<a href="https://x.com/prime_cai?lang=en" target="_blank"><img src="https://shields.io/twitter/follow/:?label=Subscribe%20for%20updates!" style="display:inline-block;"></a>
</div>
"""
)
iface = gr.Interface(
fn=process_image_and_text,
inputs=[
gr.Image(type="pil"),
gr.Textbox(lines=2, label="text", info="Could be something as simple as 'this character playing soccer'."),
gr.Checkbox(label="Gemini prompt", value=True, info="Use Gemini to enhance the prompt. This is recommended for most cases, unless you have a specific prompt similar to the examples in mind."),
gr.Slider(minimum=1.0, maximum=6.0, step=0.5, value=3.5, label="guidance scale (tip: start with 3.5, then gradually increase if the consistency is consistently off)"),
gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.0, label="real guidance scale for image (tip: increase if the image is not consistent)"),
gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.0, label="real guidance scale for prompt (tip: increase if the prompt is not consistent)"),
],
outputs=gr.Image(type="pil"),
examples=get_samples(),
)
if __name__ == "__main__":
init_pipeline()
demo.launch(debug=False, share=True)