Spaces:
Runtime error
Runtime error
import open_clip | |
import gradio as gr | |
import numpy as np | |
import torch | |
import torchvision | |
from tqdm.auto import tqdm | |
from PIL import Image, ImageColor | |
from torchvision import transforms | |
from diffusers import DDIMScheduler, DDPMPipeline | |
device = ( | |
"mps" | |
if torch.backends.mps.is_available() | |
else "cuda" | |
if torch.cuda.is_available() | |
else "cpu" | |
) | |
# Load the pretrained pipeline | |
pipeline_name = "alkzar90/sd-class-ukiyo-e-256" | |
image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device) | |
# Sample some images with a DDIM Scheduler over 40 steps | |
scheduler = DDIMScheduler.from_pretrained(pipeline_name) | |
scheduler.set_timesteps(num_inference_steps=40) | |
# Color guidance | |
#------------------------------------------------------------------------------- | |
# Color guidance function | |
def color_loss(images, target_color=(0.1, 0.9, 0.5)): | |
"""Given a target color (R, G, B) return a loss for how far away on average | |
the images' pixels are from that color. Defaults to a light teal: (0.1, 0.9, 0.5)""" | |
target = ( | |
torch.tensor(target_color).to(images.device) * 2 - 1 | |
) # Map target color to (-1, 1) | |
target = target[ | |
None, :, None, None | |
] # Get shape right to work with the images (b, c, h, w) | |
error = torch.abs( | |
images - target | |
).mean() # Mean absolute difference between the image pixels and the target color | |
return error | |
# CLIP guidance | |
#------------------------------------------------------------------------------- | |
clip_model, _, preprocess = open_clip.create_model_and_transforms( | |
"ViT-B-32", pretrained="openai" | |
) | |
clip_model.to(device) | |
# Transforms to resize and augment an image + normalize to match CLIP's training data | |
tfms = transforms.Compose( | |
[ | |
transforms.RandomResizedCrop(224), # Random CROP each time | |
transforms.RandomAffine( | |
5 | |
), # One possible random augmentation: skews the image | |
transforms.RandomHorizontalFlip(), # You can add additional augmentations if you like | |
transforms.Normalize( | |
mean=(0.48145466, 0.4578275, 0.40821073), | |
std=(0.26862954, 0.26130258, 0.27577711), | |
), | |
] | |
) | |
# CLIP guidance function | |
def clip_loss(image, text_features): | |
image_features = clip_model.encode_image( | |
tfms(image) | |
) # Note: applies the above transforms | |
input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2) | |
embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2) | |
dists = ( | |
input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) | |
) # Squared Great Circle Distance | |
return dists.mean() | |
# Sample generator loop | |
#------------------------------------------------------------------------------- | |
def generate(color, | |
color_loss_scale, | |
num_examples=4, | |
seed=None, | |
prompt=None, | |
prompt_loss_scale=None, | |
prompt_n_cuts=None, | |
inference_steps=50, | |
): | |
scheduler.set_timesteps(num_inference_steps=inference_steps) | |
if seed: | |
torch.manual_seed(seed) | |
if prompt: | |
text = open_clip.tokenize([prompt]).to(device) | |
with torch.no_grad(), torch.cuda.amp.autocast(): | |
text_features = clip_model.encode_text(text) | |
target_color = ImageColor.getcolor(color, "RGB") # Target color as RGB | |
target_color = [a / 255 for a in target_color] # Rescale from (0, 255) to (0, 1) | |
x = torch.randn(num_examples, 3, 256, 256).to(device) | |
for i, t in tqdm(enumerate(scheduler.timesteps)): | |
model_input = scheduler.scale_model_input(x, t) | |
with torch.no_grad(): | |
noise_pred = image_pipe.unet(model_input, t)["sample"] | |
x = x.detach().requires_grad_() | |
x0 = scheduler.step(noise_pred, t, x).pred_original_sample | |
# color loss | |
loss = color_loss(x0, target_color) * color_loss_scale | |
cond_color_grad = -torch.autograd.grad(loss, x)[0] | |
# Modify x based solely on the color gradient -> x_cond | |
x_cond = x.detach() + cond_color_grad | |
# prompt loss (modify x_cond with cond_prompt_grad) based on | |
# the original x (not modifified previously with cond_color_grad) | |
if prompt: | |
cond_prompt_grad = 0 | |
for cut in range(prompt_n_cuts): | |
# Set requires grad on x | |
x = x.detach().requires_grad_() | |
# Get the predicted x0: | |
x0 = scheduler.step(noise_pred, t, x).pred_original_sample | |
# Calculate loss | |
prompt_loss = clip_loss(x0, text_features) * prompt_loss_scale | |
# Get gradient (scale by n_cuts since we want the average) | |
cond_prompt_grad -= torch.autograd.grad(prompt_loss, x, retain_graph=True)[0] / prompt_n_cuts | |
# Modify x based on this gradient | |
alpha_bar = scheduler.alphas_cumprod[i] | |
x_cond = ( | |
x_cond + cond_prompt_grad * alpha_bar.sqrt() | |
) # Note the additional scaling factor here! | |
x = scheduler.step(noise_pred, t, x_cond).prev_sample | |
grid = torchvision.utils.make_grid(x, nrow=4) | |
im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5 | |
im = Image.fromarray(np.array(im * 255).astype(np.uint8)) | |
#im.save("test.jpeg") | |
return im | |
# GRADIO Interface | |
#------------------------------------------------------------------------------- | |
TITLE="Ukiyo-e postal generator service 🎴!" | |
DESCRIPTION="This model is a diffusion model for unconditional image generation of Ukiyo-e images ✍ 🎨. \nThe model was train using fine-tuning with the google/ddpm-celebahq-256 pretrain-model and the dataset: https://huggingface.co./datasets/huggan/ukiyoe2photo" | |
CSS = ".output-image, .input-image, .image-preview {height: 250px !important}" | |
# See the gradio docs for the types of inputs and outputs available | |
inputs = [ | |
gr.ColorPicker(label="color (click on the square to pick the color)", value="#DF5C16"), # Add any inputs you need here | |
gr.Slider(label="color_guidance_scale (how strong to blend the color)", minimum=0, maximum=30, value=6.7), | |
gr.Slider(label="num_examples (# images generated)", minimum=2, maximum=12, value=2, step=4), | |
gr.Number(label="seed (reproducibility and experimentation)", value=666), | |
gr.Text(label="Text prompt (optional)", value=None), | |
gr.Slider(label="prompt_guidance_scale (...)", minimum=0, maximum=1000, value=10), | |
gr.Slider(label="prompt_n_cuts", minimum=4, maximum=12, step=4), | |
gr.Slider(label="Number of inference steps (+ steps -> + guidance effect)", minimum=40, maximum=60, value=40, step=1), | |
] | |
outputs = gr.Image(label="result") | |
# And the minimal interface | |
demo = gr.Interface( | |
fn=generate, | |
inputs=inputs, | |
outputs=outputs, | |
css=CSS, | |
#examples=[ | |
#["#DF5C16", 6.7, 12, 666, None, None, None, 40], | |
#["#C01660", 13.5, 12, 1990, None, None, None, 40], | |
#["#44CCAA", 8.9, 12, 1512, None, None, None, 40], | |
#["#39A291", 5.0, 2, 666, "A sakura tree", 60, 4, 40], | |
#["#0E0907", 0.0, 12, 666, "A big whale in the ocean", 60, 8, 52], | |
#["#19A617", 4.6, 12, 666, "An island with sunset at background", 140, 4, 47], | |
#], | |
title=TITLE, | |
description=DESCRIPTION, | |
) | |
if __name__ == "__main__": | |
demo.launch(enable_queue=True) | |