Spaces:
Sleeping
Sleeping
File size: 7,432 Bytes
6af7294 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
# Utility class for loading and using diffusers model
import diffusers
import transformers
import torch
from typing import Union
import os
import warnings
import numpy as np
from PIL import Image
import tqdm
from copy import deepcopy
import matplotlib.pyplot as plt
def build_generator(
device : torch.device,
seed : int,
):
"""
Build a torch.Generator with a given seed.
"""
generator = torch.Generator(device).manual_seed(seed)
return generator
def load_stablediffusion_model(
model_id : Union[str, os.PathLike],
device : torch.device,
):
"""
Load a complete diffusion model from a model id.
Returns a tuple of the model and a torch.Generator if seed is not None.
"""
pipe = diffusers.DiffusionPipeline.from_pretrained(
model_id,
revision="fp16",
torch_dtype=torch.float16,
use_auth_token=True,
)
pipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
try:
pipe = pipe.to(device)
except:
warnings.warn(
f'Could not load model to device:{device}. Using CPU instead.'
)
pipe = pipe.to('cpu')
device = 'cpu'
return pipe
def visualize_image_grid(
imgs : np.array,
rows : int,
cols : int):
assert len(imgs) == rows*cols
# create grid
w, h = imgs[0].size # assuming each image is the same size
grid = Image.new('RGB', size=(cols*w, rows*h))
for i,img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
def build_pipeline(
autoencoder : Union[str, os.PathLike] = "CompVis/stable-diffusion-v1-4",
tokenizer : Union[str, os.PathLike] = "openai/clip-vit-large-patch14",
text_encoder : Union[str, os.PathLike] = "openai/clip-vit-large-patch14",
unet : Union[str, os.PathLike] = "CompVis/stable-diffusion-v1-4",
device : torch.device = torch.device('cuda'),
):
"""
Create a pipeline for StableDiffusion by loading the model and component seperetely.
Arguments:
autoencoder: path to model that autoencoder will be loaded from
tokenizer: path to tokenizer
text_encoder: path to text_encoder
unet: path to unet
"""
# Load the VAE for encoding images into the latent space
vae = diffusers.AutoencoderKL.from_pretrained(autoencoder, subfolder = 'vae')
# Load tokenizer & text encoder for encoding text into the latent space
tokenizer = transformers.CLIPTokenizer.from_pretrained(tokenizer)
text_encoder = transformers.CLIPTextModel.from_pretrained(text_encoder)
# Use the UNet model for conditioning the diffusion process
unet = diffusers.UNet2DConditionModel.from_pretrained(unet, subfolder = 'unet')
# Move all the components to device
vae = vae.to(device)
text_encoder = text_encoder.to(device)
unet = unet.to(device)
return vae, tokenizer, text_encoder, unet
#TODO : Add negative prompting
def custom_stablediffusion_inference(
vae,
tokenizer,
text_encoder,
unet,
noise_scheduler,
prompt : list,
device : torch.device,
num_inference_steps = 100,
image_size = (512,512),
guidance_scale = 8,
seed = 42,
return_image_step = 5,
):
# Get the text embeddings that will condition the diffusion process
if isinstance(prompt,str):
prompt = [prompt]
batch_size = len(prompt)
text_input = tokenizer(
prompt,
padding = 'max_length',
truncation = True,
max_length = tokenizer.model_max_length,
return_tensors = 'pt').to(device)
text_embeddings = text_encoder(
text_input.input_ids.to(device)
)[0]
# Get the text embeddings for classifier-free guidance
max_length = text_input.input_ids.shape[-1]
empty = [""] * batch_size
uncond_input = tokenizer(
empty,
padding = 'max_length',
truncation = True,
max_length = max_length,
return_tensors = 'pt').to(device)
uncond_embeddings = text_encoder(
uncond_input.input_ids.to(device)
)[0]
# Concatenate the text embeddings to get the conditioning vector
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# Generate initial noise
latents = torch.randn(
(1, unet.in_channels, image_size[0] // 8, image_size[1] // 8),
generator=torch.manual_seed(seed) if seed is not None else None
)
print(latents.shape)
latents = latents.to(device)
# Initialize scheduler for noise generation
noise_scheduler.set_timesteps(num_inference_steps)
latents = latents * noise_scheduler.init_noise_sigma
noise_scheduler.set_timesteps(num_inference_steps)
for i,t in tqdm.tqdm(enumerate(noise_scheduler.timesteps)):
# If no text embedding is provided (classifier-free guidance), extend the conditioning vector
latent_model_input = torch.cat([latents] * 2)
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
with torch.no_grad():
# Get the noise prediction from the UNet
noise_pred = unet(latent_model_input, t, encoder_hidden_states = text_embeddings).sample
# Perform guidance from the text embeddings
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# Compute the previously noisy sample x_t -> x_t-1
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
# Now that latent is generated from a noise, use unet decoder to generate images
if i % return_image_step == 0:
with torch.no_grad():
latents_copy = deepcopy(latents)
image = vae.decode(1/0.18215 * latents_copy).sample
image = (image / 2 + 0.5).clamp(0,1)
image = image.detach().cpu().permute(0,2,3,1).numpy() # bxhxwxc
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(img) for img in images]
yield pil_images[0]
yield pil_images[0]
if __name__ == "__main__":
device = torch.device("cpu")
model_id = "stabilityai/stable-diffusion-2-1"
tokenizer_id = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
#noise_scheduler = diffusers.LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
noise_scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained(model_id,subfolder="scheduler")
prompt = "A Hyperrealistic photograph of Italian architectural modern home in Italy, lens flares,\
cinematic, hdri, matte painting, concept art, celestial, soft render, highly detailed, octane\
render, architectural HD, HQ, 4k, 8k"
vae, tokenizer, text_encoder, unet = build_pipeline(
autoencoder = model_id,
tokenizer=tokenizer_id,
text_encoder=tokenizer_id,
unet=model_id,
device=device,
)
image_iter = custom_stablediffusion_inference(vae, tokenizer, text_encoder, unet, noise_scheduler, prompt = prompt, device=device, seed = None)
for i, image in enumerate(image_iter):
image.save(f"step_{i}.png")
|