File size: 7,432 Bytes
6af7294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
# Utility class for loading and using diffusers model
import diffusers
import transformers

import torch 
from typing import Union
import os
import warnings
import numpy as np
from PIL import Image
import tqdm
from copy import deepcopy
import matplotlib.pyplot as plt

def build_generator(
        device : torch.device,
        seed : int,
):
    """
    Build a torch.Generator with a given seed.
    """
    generator = torch.Generator(device).manual_seed(seed)
    return generator

def load_stablediffusion_model(
        model_id : Union[str, os.PathLike],
        device : torch.device,
        ):
    """
    Load a complete diffusion model from a model id.
    Returns a tuple of the model and a torch.Generator if seed is not None.

    """
    pipe = diffusers.DiffusionPipeline.from_pretrained(
        model_id,
        revision="fp16", 
        torch_dtype=torch.float16,
        use_auth_token=True,
    )
    pipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    try:
        pipe = pipe.to(device)
    except:
        warnings.warn(
            f'Could not load model to device:{device}. Using CPU instead.'
        )
        pipe = pipe.to('cpu')
        device = 'cpu'

    return pipe


def visualize_image_grid(
        imgs : np.array,
        rows : int,
        cols : int):
    
    assert len(imgs) == rows*cols

    # create grid
    w, h = imgs[0].size # assuming each image is the same size

    grid = Image.new('RGB', size=(cols*w, rows*h))

    for i,img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid


def build_pipeline(
        autoencoder : Union[str, os.PathLike] = "CompVis/stable-diffusion-v1-4",
        tokenizer : Union[str, os.PathLike] = "openai/clip-vit-large-patch14",
        text_encoder : Union[str, os.PathLike] = "openai/clip-vit-large-patch14",
        unet : Union[str, os.PathLike] = "CompVis/stable-diffusion-v1-4",
        device : torch.device = torch.device('cuda'),
        ):
    """
    Create a pipeline for StableDiffusion by loading the model and component seperetely.
    Arguments:
        autoencoder: path to model that autoencoder will be loaded from
        tokenizer: path to tokenizer
        text_encoder: path to text_encoder
        unet: path to unet
    """
    # Load the VAE for encoding images into the latent space
    vae = diffusers.AutoencoderKL.from_pretrained(autoencoder, subfolder = 'vae')

    # Load tokenizer & text encoder for encoding text into the latent space
    tokenizer = transformers.CLIPTokenizer.from_pretrained(tokenizer)
    text_encoder = transformers.CLIPTextModel.from_pretrained(text_encoder)

    # Use the UNet model for conditioning the diffusion process
    unet = diffusers.UNet2DConditionModel.from_pretrained(unet, subfolder = 'unet')

    # Move all the components to device
    vae = vae.to(device)
    text_encoder = text_encoder.to(device)
    unet = unet.to(device)

    return vae, tokenizer, text_encoder, unet

#TODO : Add negative prompting
def custom_stablediffusion_inference(
        vae,
        tokenizer,
        text_encoder,
        unet,
        noise_scheduler,
        prompt : list,
        device : torch.device,
        num_inference_steps = 100,
        image_size = (512,512),
        guidance_scale = 8,
        seed = 42,
        return_image_step = 5,
    ):
    # Get the text embeddings that will condition the diffusion process
    if isinstance(prompt,str):
        prompt = [prompt]

    batch_size = len(prompt)
    text_input = tokenizer(
        prompt,
        padding = 'max_length',
        truncation = True,
        max_length = tokenizer.model_max_length,
        return_tensors = 'pt').to(device)
    
    text_embeddings = text_encoder(
        text_input.input_ids.to(device)
    )[0]

    # Get the text embeddings for classifier-free guidance
    max_length = text_input.input_ids.shape[-1]
    empty = [""] * batch_size
    uncond_input = tokenizer(
        empty,
        padding = 'max_length',
        truncation = True,
        max_length = max_length,
        return_tensors = 'pt').to(device)
    
    uncond_embeddings = text_encoder(
        uncond_input.input_ids.to(device)
    )[0]

    # Concatenate the text embeddings to get the conditioning vector
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

    # Generate initial noise
    latents = torch.randn(
        (1, unet.in_channels, image_size[0] // 8, image_size[1] // 8),
        generator=torch.manual_seed(seed) if seed is not None else None
    )
    print(latents.shape)

    latents = latents.to(device)

    # Initialize scheduler for noise generation
    noise_scheduler.set_timesteps(num_inference_steps)

    latents = latents * noise_scheduler.init_noise_sigma 

    noise_scheduler.set_timesteps(num_inference_steps)
    for i,t in tqdm.tqdm(enumerate(noise_scheduler.timesteps)):
        # If no text embedding is provided (classifier-free guidance), extend the conditioning vector
        latent_model_input = torch.cat([latents] * 2)

        latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)

        with torch.no_grad():
            # Get the noise prediction from the UNet
            noise_pred = unet(latent_model_input, t, encoder_hidden_states = text_embeddings).sample 

        # Perform guidance from the text embeddings
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # Compute the previously noisy sample x_t -> x_t-1
        latents = noise_scheduler.step(noise_pred, t, latents).prev_sample

        # Now that latent is generated from a noise, use unet decoder to generate images
        if i % return_image_step == 0:
            with torch.no_grad():
                latents_copy = deepcopy(latents)
                image = vae.decode(1/0.18215 * latents_copy).sample

            image = (image / 2 + 0.5).clamp(0,1)
            image = image.detach().cpu().permute(0,2,3,1).numpy() # bxhxwxc
            images = (image * 255).round().astype("uint8")

            pil_images = [Image.fromarray(img) for img in images]

            yield pil_images[0]

    yield pil_images[0]

if __name__ == "__main__":
    device = torch.device("cpu")
    model_id = "stabilityai/stable-diffusion-2-1"
    tokenizer_id = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
    #noise_scheduler = diffusers.LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
    noise_scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained(model_id,subfolder="scheduler")
    prompt = "A Hyperrealistic photograph of Italian architectural modern home in Italy, lens flares,\
            cinematic, hdri, matte painting, concept art, celestial, soft render, highly detailed, octane\
            render, architectural HD, HQ, 4k, 8k"
    
    vae, tokenizer, text_encoder, unet = build_pipeline(
        autoencoder = model_id,
        tokenizer=tokenizer_id,
        text_encoder=tokenizer_id,
        unet=model_id,
        device=device,
        )
    image_iter = custom_stablediffusion_inference(vae, tokenizer, text_encoder, unet, noise_scheduler, prompt = prompt, device=device, seed = None)
    for i, image in enumerate(image_iter):
        image.save(f"step_{i}.png")