minerva-generate-docker / diffmodels /diffusion_utils.py
kbora's picture
Upload 51 files
6af7294
raw
history blame
7.43 kB
# Utility class for loading and using diffusers model
import diffusers
import transformers
import torch
from typing import Union
import os
import warnings
import numpy as np
from PIL import Image
import tqdm
from copy import deepcopy
import matplotlib.pyplot as plt
def build_generator(
device : torch.device,
seed : int,
):
"""
Build a torch.Generator with a given seed.
"""
generator = torch.Generator(device).manual_seed(seed)
return generator
def load_stablediffusion_model(
model_id : Union[str, os.PathLike],
device : torch.device,
):
"""
Load a complete diffusion model from a model id.
Returns a tuple of the model and a torch.Generator if seed is not None.
"""
pipe = diffusers.DiffusionPipeline.from_pretrained(
model_id,
revision="fp16",
torch_dtype=torch.float16,
use_auth_token=True,
)
pipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
try:
pipe = pipe.to(device)
except:
warnings.warn(
f'Could not load model to device:{device}. Using CPU instead.'
)
pipe = pipe.to('cpu')
device = 'cpu'
return pipe
def visualize_image_grid(
imgs : np.array,
rows : int,
cols : int):
assert len(imgs) == rows*cols
# create grid
w, h = imgs[0].size # assuming each image is the same size
grid = Image.new('RGB', size=(cols*w, rows*h))
for i,img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
def build_pipeline(
autoencoder : Union[str, os.PathLike] = "CompVis/stable-diffusion-v1-4",
tokenizer : Union[str, os.PathLike] = "openai/clip-vit-large-patch14",
text_encoder : Union[str, os.PathLike] = "openai/clip-vit-large-patch14",
unet : Union[str, os.PathLike] = "CompVis/stable-diffusion-v1-4",
device : torch.device = torch.device('cuda'),
):
"""
Create a pipeline for StableDiffusion by loading the model and component seperetely.
Arguments:
autoencoder: path to model that autoencoder will be loaded from
tokenizer: path to tokenizer
text_encoder: path to text_encoder
unet: path to unet
"""
# Load the VAE for encoding images into the latent space
vae = diffusers.AutoencoderKL.from_pretrained(autoencoder, subfolder = 'vae')
# Load tokenizer & text encoder for encoding text into the latent space
tokenizer = transformers.CLIPTokenizer.from_pretrained(tokenizer)
text_encoder = transformers.CLIPTextModel.from_pretrained(text_encoder)
# Use the UNet model for conditioning the diffusion process
unet = diffusers.UNet2DConditionModel.from_pretrained(unet, subfolder = 'unet')
# Move all the components to device
vae = vae.to(device)
text_encoder = text_encoder.to(device)
unet = unet.to(device)
return vae, tokenizer, text_encoder, unet
#TODO : Add negative prompting
def custom_stablediffusion_inference(
vae,
tokenizer,
text_encoder,
unet,
noise_scheduler,
prompt : list,
device : torch.device,
num_inference_steps = 100,
image_size = (512,512),
guidance_scale = 8,
seed = 42,
return_image_step = 5,
):
# Get the text embeddings that will condition the diffusion process
if isinstance(prompt,str):
prompt = [prompt]
batch_size = len(prompt)
text_input = tokenizer(
prompt,
padding = 'max_length',
truncation = True,
max_length = tokenizer.model_max_length,
return_tensors = 'pt').to(device)
text_embeddings = text_encoder(
text_input.input_ids.to(device)
)[0]
# Get the text embeddings for classifier-free guidance
max_length = text_input.input_ids.shape[-1]
empty = [""] * batch_size
uncond_input = tokenizer(
empty,
padding = 'max_length',
truncation = True,
max_length = max_length,
return_tensors = 'pt').to(device)
uncond_embeddings = text_encoder(
uncond_input.input_ids.to(device)
)[0]
# Concatenate the text embeddings to get the conditioning vector
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# Generate initial noise
latents = torch.randn(
(1, unet.in_channels, image_size[0] // 8, image_size[1] // 8),
generator=torch.manual_seed(seed) if seed is not None else None
)
print(latents.shape)
latents = latents.to(device)
# Initialize scheduler for noise generation
noise_scheduler.set_timesteps(num_inference_steps)
latents = latents * noise_scheduler.init_noise_sigma
noise_scheduler.set_timesteps(num_inference_steps)
for i,t in tqdm.tqdm(enumerate(noise_scheduler.timesteps)):
# If no text embedding is provided (classifier-free guidance), extend the conditioning vector
latent_model_input = torch.cat([latents] * 2)
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
with torch.no_grad():
# Get the noise prediction from the UNet
noise_pred = unet(latent_model_input, t, encoder_hidden_states = text_embeddings).sample
# Perform guidance from the text embeddings
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# Compute the previously noisy sample x_t -> x_t-1
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
# Now that latent is generated from a noise, use unet decoder to generate images
if i % return_image_step == 0:
with torch.no_grad():
latents_copy = deepcopy(latents)
image = vae.decode(1/0.18215 * latents_copy).sample
image = (image / 2 + 0.5).clamp(0,1)
image = image.detach().cpu().permute(0,2,3,1).numpy() # bxhxwxc
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(img) for img in images]
yield pil_images[0]
yield pil_images[0]
if __name__ == "__main__":
device = torch.device("cpu")
model_id = "stabilityai/stable-diffusion-2-1"
tokenizer_id = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
#noise_scheduler = diffusers.LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
noise_scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained(model_id,subfolder="scheduler")
prompt = "A Hyperrealistic photograph of Italian architectural modern home in Italy, lens flares,\
cinematic, hdri, matte painting, concept art, celestial, soft render, highly detailed, octane\
render, architectural HD, HQ, 4k, 8k"
vae, tokenizer, text_encoder, unet = build_pipeline(
autoencoder = model_id,
tokenizer=tokenizer_id,
text_encoder=tokenizer_id,
unet=model_id,
device=device,
)
image_iter = custom_stablediffusion_inference(vae, tokenizer, text_encoder, unet, noise_scheduler, prompt = prompt, device=device, seed = None)
for i, image in enumerate(image_iter):
image.save(f"step_{i}.png")