vae-comparison / app.py
rizavelioglu
add app file
2a90016
raw
history blame
5.21 kB
import gradio as gr
import torch
from diffusers import AutoencoderKL
from PIL import Image
import torchvision.transforms.v2 as transforms
from torchvision.io import read_image
from typing import Tuple, Dict, List
class VAETester:
def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
self.device = device
self.input_transform = transforms.Compose([
transforms.Pad(padding=[128, 0], padding_mode="edge"),
transforms.Resize((512, 512), antialias=True),
transforms.ToDtype(torch.float32, scale=True),
transforms.Normalize(mean=[0.5], std=[0.5]),
])
self.base_transform = transforms.Compose([
transforms.Pad(padding=[128, 0], padding_mode="edge"),
transforms.Resize((512, 512), antialias=True),
transforms.ToDtype(torch.float32, scale=True),
])
self.output_transform = transforms.Normalize(mean=[-1], std=[2])
# Load all VAE models at initialization
self.vae_models = self._load_all_vaes()
def _load_all_vaes(self) -> Dict[str, AutoencoderKL]:
"""Load all available VAE models"""
vae_configs = {
"Stable Diffusion 3 Medium": ("stabilityai/stable-diffusion-3-medium-diffusers", "vae"),
"Stable Diffusion v1-4": ("CompVis/stable-diffusion-v1-4", "vae"),
"SD VAE FT-MSE": ("stabilityai/sd-vae-ft-mse", ""),
"FLUX.1-dev": ("black-forest-labs/FLUX.1-dev", "vae")
}
vae_dict = {}
for name, (path, subfolder) in vae_configs.items():
vae_dict[name] = AutoencoderKL.from_pretrained(path, subfolder=subfolder).to(self.device)
return vae_dict
def process_image(self,
img: torch.Tensor,
vae: AutoencoderKL,
tolerance: float) -> Tuple[Image.Image, Image.Image, float]:
"""Process image through a single VAE"""
img_transformed = self.input_transform(img).to(self.device).unsqueeze(0)
original_base = self.base_transform(img).cpu()
with torch.no_grad():
encoded = vae.encode(img_transformed).latent_dist.sample()
encoded_scaled = encoded * vae.config.scaling_factor
decoded = vae.decode(encoded_scaled / vae.config.scaling_factor).sample
decoded_transformed = self.output_transform(decoded.squeeze(0)).cpu()
reconstructed = decoded_transformed.clip(0, 1)
diff = (original_base - reconstructed).abs()
bw_diff = (diff > tolerance).any(dim=0).float()
diff_image = transforms.ToPILImage()(bw_diff)
recon_image = transforms.ToPILImage()(reconstructed)
diff_score = bw_diff.sum().item()
return diff_image, recon_image, diff_score
def process_all_models(self,
img: torch.Tensor,
tolerance: float) -> Dict[str, Tuple[Image.Image, Image.Image, float]]:
"""Process image through all loaded VAEs"""
results = {}
for name, vae in self.vae_models.items():
diff_img, recon_img, score = self.process_image(img, vae, tolerance)
results[name] = (diff_img, recon_img, score)
return results
# Initialize tester
tester = VAETester()
def test_all_vaes(image_path: str, tolerance: float) -> Tuple[List[Image.Image], List[Image.Image], List[str]]:
"""Gradio interface function to test all VAEs"""
try:
img_tensor = read_image(image_path)
results = tester.process_all_models(img_tensor, tolerance)
diff_images = []
recon_images = []
scores = []
for name in tester.vae_models.keys():
diff_img, recon_img, score = results[name]
diff_images.append(diff_img)
recon_images.append(recon_img)
scores.append(f"{name}: {score:.2f}")
return diff_images, recon_images, scores
except Exception as e:
error_msg = f"Error: {str(e)}"
return [None], [None], [error_msg]
# Gradio interface
with gr.Blocks(title="VAE Performance Tester") as demo:
gr.Markdown("# VAE Performance Testing Tool")
gr.Markdown("Upload an image to compare all VAE models simultaneously")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="filepath", label="Input Image", height=512)
tolerance_slider = gr.Slider(
minimum=0.01,
maximum=0.5,
value=0.1,
step=0.01,
label="Difference Tolerance"
)
submit_btn = gr.Button("Test All VAEs")
with gr.Column(scale=3):
with gr.Row():
diff_gallery = gr.Gallery(label="Difference Maps", columns=4, height=512)
recon_gallery = gr.Gallery(label="Reconstructed Images", columns=4, height=512)
scores_output = gr.Textbox(label="Difference Scores", lines=4)
submit_btn.click(
fn=test_all_vaes,
inputs=[image_input, tolerance_slider],
outputs=[diff_gallery, recon_gallery, scores_output]
)
if __name__ == "__main__":
demo.launch()