File size: 3,449 Bytes
c7b13d5
 
 
 
90da0e7
f03bfaf
 
 
 
 
 
 
90da0e7
 
 
 
f03bfaf
90da0e7
f03bfaf
90da0e7
 
 
 
 
f03bfaf
 
 
 
 
 
 
 
 
331f30a
29f3401
dff35d4
 
 
 
 
f03bfaf
 
dff35d4
f03bfaf
 
 
 
dff35d4
f03bfaf
dff35d4
f03bfaf
 
c259892
 
 
 
 
 
 
 
 
 
f03bfaf
 
 
 
 
0f4b6ed
 
 
dff35d4
 
c259892
f03bfaf
dff35d4
 
0f4b6ed
f03bfaf
0f4b6ed
f03bfaf
dff35d4
 
f03bfaf
c259892
 
94b54b7
f03bfaf
 
90da0e7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import spaces
import os
import torch
import random
from huggingface_hub import snapshot_download
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import StableDiffusionXLPipeline
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from diffusers import UNet2DConditionModel, AutoencoderKL
from diffusers import EulerDiscreteScheduler
import gradio as gr

# Download the model files
ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")

# Load the models
text_encoder = ChatGLMModel.from_pretrained(
    os.path.join(ckpt_dir, 'text_encoder'),
    torch_dtype=torch.float16).half()
tokenizer = ChatGLMTokenizer.from_pretrained(os.path.join(ckpt_dir, 'text_encoder'))
vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), revision=None).half()
scheduler = EulerDiscreteScheduler.from_pretrained(os.path.join(ckpt_dir, "scheduler"))
unet = UNet2DConditionModel.from_pretrained(os.path.join(ckpt_dir, "unet"), revision=None).half()

pipe = StableDiffusionXLPipeline(
        vae=vae,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        unet=unet,
        scheduler=scheduler,
        force_zeros_for_empty_prompt=False)
pipe = pipe.to("cuda")

@spaces.GPU(duration=200)
def generate_image(prompt, negative_prompt, height, width, num_inference_steps, guidance_scale, num_images_per_prompt, use_random_seed, seed, progress=gr.Progress(track_tqdm=True)):
    if use_random_seed:
        seed = random.randint(0, 2**32 - 1)
    else:
        seed = int(seed)  # Ensure seed is an integer
    
    image = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        num_images_per_prompt=num_images_per_prompt,
        generator=torch.Generator(pipe.device).manual_seed(seed)
    ).images
    return image, seed

description = """
<p align="center">Effective Training of Diffusion Model for Photorealistic Text-to-Image Synthesis</p>
<p><center>
<a href="https://kolors.kuaishou.com/" target="_blank">[Official Website]</a>
<a href="https://github.com/Kwai-Kolors/Kolors/blob/master/imgs/Kolors_paper.pdf" target="_blank">[Tech Report]</a>
<a href="https://huggingface.co./Kwai-Kolors/Kolors" target="_blank">[Model Page]</a>
<a href="https://github.com/Kwai-Kolors/Kolors" target="_blank">[Github]</a>
</center></p>
"""

# Gradio interface
iface = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Textbox(label="Prompt"),
        gr.Textbox(label="Negative Prompt")
    ],
    additional_inputs=[
        gr.Slider(512, 2048, 1024, step=64, label="Height"),
        gr.Slider(512, 2048, 1024, step=64, label="Width"),
        gr.Slider(20, 50, 20, step=1, label="Number of Inference Steps"),
        gr.Slider(1, 20, 5, step=0.5, label="Guidance Scale"),
        gr.Slider(1, 4, 1, step=1, label="Number of images per prompt"),
        gr.Checkbox(label="Use Random Seed", value=True),
        gr.Number(label="Seed", value=0, precision=0)
    ],
    additional_inputs_accordion=gr.Accordion(label="Advanced settings", open=False),
    outputs=[
        gr.Gallery(label="Result", elem_id="gallery", show_label=False),
        gr.Number(label="Seed Used")
    ],
    title="Kolors",
    description=description,
    theme='bethecloud/storj_theme',
)

iface.launch(debug=True)