vae-comparison / app.py
rizavelioglu
fix spaces-gpu error: import spaces first
e354de8
import spaces
import gradio as gr
import torch
from diffusers import AutoencoderKL
from diffusers.utils.remote_utils import remote_decode
import torchvision.transforms.v2 as transforms
from torchvision.io import read_image
from typing import Dict
import os
from huggingface_hub import login
# Get token from environment variable
hf_token = os.getenv("access_token")
login(token=hf_token)
class PadToSquare:
"""Custom transform to pad an image to square dimensions"""
def __call__(self, img):
_, h, w = img.shape # Get the original dimensions
max_side = max(h, w)
pad_h = (max_side - h) // 2
pad_w = (max_side - w) // 2
padding = (pad_w, pad_h, max_side - w - pad_w, max_side - h - pad_h)
return transforms.functional.pad(img, padding, padding_mode="edge")
class VAETester:
def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu", img_size: int = 512):
self.device = device
self.input_transform = transforms.Compose([
PadToSquare(),
transforms.Resize((img_size, img_size)),
transforms.ToDtype(torch.float32, scale=True),
transforms.Normalize(mean=[0.5], std=[0.5]),
])
self.base_transform = transforms.Compose([
PadToSquare(),
transforms.Resize((img_size, img_size)),
transforms.ToDtype(torch.float32, scale=True),
])
self.output_transform = transforms.Normalize(mean=[-1], std=[2])
self.vae_models = self._load_all_vaes()
def _get_endpoint(self, base_name: str) -> str:
"""Helper method to get the endpoint for a given base model name"""
endpoints = {
"sd-vae-ft-mse": "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud",
"sdxl-vae": "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud",
"FLUX.1-schnell": "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud",
}
return endpoints[base_name]
def _load_all_vaes(self) -> Dict[str, Dict]:
"""Load configurations for local and remote VAE models"""
local_vaes = {
"stable-diffusion-v1-4": AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(self.device),
"sd-vae-ft-mse": AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(self.device),
"sdxl-vae": AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(self.device),
"stable-diffusion-3-medium": AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae").to(self.device),
"FLUX.1-schnell": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="vae").to(self.device),
"FLUX.1-dev": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae").to(self.device),
}
# Define the desired order of models
order = [
"stable-diffusion-v1-4",
"sd-vae-ft-mse",
"sd-vae-ft-mse (remote)",
"sdxl-vae",
"sdxl-vae (remote)",
"stable-diffusion-3-medium",
"FLUX.1-schnell",
"FLUX.1-schnell (remote)",
"FLUX.1-dev",
]
# Construct the vae_models dictionary in the specified order
vae_models = {}
for name in order:
if "(remote)" not in name:
# Local model
vae_models[name] = {"type": "local", "vae": local_vaes[name]}
else:
# Remote model
base_name = name.replace(" (remote)", "")
vae_models[name] = {
"type": "remote",
"local_vae_key": base_name,
"endpoint": self._get_endpoint(base_name),
}
return vae_models
def process_image(self, img: torch.Tensor, model_config: Dict, tolerance: float):
"""Process image through a single VAE (local or remote)"""
img_transformed = self.input_transform(img).to(self.device).unsqueeze(0)
original_base = self.base_transform(img).cpu()
if model_config["type"] == "local":
vae = model_config["vae"]
with torch.no_grad():
encoded = vae.encode(img_transformed).latent_dist.sample()
decoded = vae.decode(encoded).sample
elif model_config["type"] == "remote":
local_vae = self.vae_models[model_config["local_vae_key"]]["vae"]
with torch.no_grad():
encoded = local_vae.encode(img_transformed).latent_dist.sample()
decoded = remote_decode(
endpoint=model_config["endpoint"],
tensor=encoded,
do_scaling=False,
output_type="pt",
return_type="pt",
partial_postprocess=False,
)
decoded_transformed = self.output_transform(decoded.squeeze(0)).cpu()
reconstructed = decoded_transformed.clip(0, 1)
diff = (original_base - reconstructed).abs()
bw_diff = (diff > tolerance).any(dim=0).float()
diff_image = transforms.ToPILImage()(bw_diff)
recon_image = transforms.ToPILImage()(reconstructed)
diff_score = bw_diff.sum().item()
return diff_image, recon_image, diff_score
def process_all_models(self, img: torch.Tensor, tolerance: float):
"""Process image through all configured VAEs"""
results = {}
for name, model_config in self.vae_models.items():
diff_img, recon_img, score = self.process_image(img, model_config, tolerance)
results[name] = (diff_img, recon_img, score)
return results
@spaces.GPU(duration=15)
def test_all_vaes(image_path: str, tolerance: float, img_size: int):
"""Gradio interface function to test all VAEs"""
tester = VAETester(img_size=img_size)
try:
img_tensor = read_image(image_path)
results = tester.process_all_models(img_tensor, tolerance)
diff_images = []
recon_images = []
scores = []
for name in tester.vae_models.keys():
diff_img, recon_img, score = results[name]
diff_images.append((diff_img, name))
recon_images.append((recon_img, name))
scores.append(f"{name:<25}: {score:,.0f}")
return diff_images, recon_images, "\n".join(scores)
except Exception as e:
error_msg = f"Error: {str(e)}"
return [None], [None], error_msg
examples = [f"examples/{img_filename}" for img_filename in sorted(os.listdir("examples/"))]
with gr.Blocks(title="VAE Performance Tester", css=".monospace-text {font-family: 'Courier New', Courier, monospace;}") as demo:
gr.Markdown("# VAE Comparison Tool")
gr.Markdown("""
Upload an image or select an example to compare how different VAEs reconstruct it. Now includes remote VAEs via Hugging Face's remote decoding feature!
1. The image is padded to a square and resized to the selected size (512 or 1024 pixels).
2. Each VAE (local or remote) encodes the image into a latent space and decodes it back.
3. Outputs include:
- **Difference Maps**: Where reconstruction differs from the original (white = difference > tolerance).
- **Reconstructed Images**: Outputs from each VAE.
- **Sum of Differences**: Total pixels exceeding tolerance (lower is better).
Adjust tolerance to change sensitivity.
""")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="filepath", label="Input Image", height=512)
tolerance_slider = gr.Slider(
minimum=0.01,
maximum=0.5,
value=0.1,
step=0.01,
label="Difference Tolerance",
info="Low (0.01): Sensitive to small changes. High (0.5): Only large changes flagged."
)
img_size = gr.Dropdown(label="Image Size", choices=[512, 1024], value=512)
submit_btn = gr.Button("Test All VAEs")
with gr.Column(scale=3):
with gr.Row():
diff_gallery = gr.Gallery(label="Difference Maps", columns=4, height=512)
recon_gallery = gr.Gallery(label="Reconstructed Images", columns=4, height=512)
scores_output = gr.Textbox(label="Sum of differences (lower is better)", lines=9, elem_classes="monospace-text")
if examples:
with gr.Row():
gr.Examples(examples=examples, inputs=image_input, label="Example Images")
submit_btn.click(
fn=test_all_vaes,
inputs=[image_input, tolerance_slider, img_size],
outputs=[diff_gallery, recon_gallery, scores_output]
)
if __name__ == "__main__":
demo.launch()