Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,902 Bytes
e354de8 2a90016 46241ec 2a90016 3050d2d 0dc0490 c02cab9 2a90016 0dc0490 84e7576 0dc0490 2a90016 3050d2d 2a90016 5558320 2a90016 3050d2d 5558320 2a90016 3050d2d 5558320 2a90016 46241ec 2a90016 46241ec 2a90016 46241ec 2a90016 46241ec 2a90016 46241ec 2a90016 46241ec 5558320 2a90016 5558320 2a90016 3050d2d 5558320 2a90016 3050d2d 2a90016 3050d2d 2a90016 3050d2d 2a90016 3050d2d 46241ec 3050d2d 2a90016 3050d2d 46241ec 5558320 46241ec 2a90016 46241ec 3050d2d 08198f0 46241ec 2a90016 5558320 2a90016 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import spaces
import gradio as gr
import torch
from diffusers import AutoencoderKL
from diffusers.utils.remote_utils import remote_decode
import torchvision.transforms.v2 as transforms
from torchvision.io import read_image
from typing import Dict
import os
from huggingface_hub import login
# Get token from environment variable
hf_token = os.getenv("access_token")
login(token=hf_token)
class PadToSquare:
"""Custom transform to pad an image to square dimensions"""
def __call__(self, img):
_, h, w = img.shape # Get the original dimensions
max_side = max(h, w)
pad_h = (max_side - h) // 2
pad_w = (max_side - w) // 2
padding = (pad_w, pad_h, max_side - w - pad_w, max_side - h - pad_h)
return transforms.functional.pad(img, padding, padding_mode="edge")
class VAETester:
def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu", img_size: int = 512):
self.device = device
self.input_transform = transforms.Compose([
PadToSquare(),
transforms.Resize((img_size, img_size)),
transforms.ToDtype(torch.float32, scale=True),
transforms.Normalize(mean=[0.5], std=[0.5]),
])
self.base_transform = transforms.Compose([
PadToSquare(),
transforms.Resize((img_size, img_size)),
transforms.ToDtype(torch.float32, scale=True),
])
self.output_transform = transforms.Normalize(mean=[-1], std=[2])
self.vae_models = self._load_all_vaes()
def _get_endpoint(self, base_name: str) -> str:
"""Helper method to get the endpoint for a given base model name"""
endpoints = {
"sd-vae-ft-mse": "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud",
"sdxl-vae": "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud",
"FLUX.1-schnell": "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud",
}
return endpoints[base_name]
def _load_all_vaes(self) -> Dict[str, Dict]:
"""Load configurations for local and remote VAE models"""
local_vaes = {
"stable-diffusion-v1-4": AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(self.device),
"sd-vae-ft-mse": AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(self.device),
"sdxl-vae": AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(self.device),
"stable-diffusion-3-medium": AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae").to(self.device),
"FLUX.1-schnell": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="vae").to(self.device),
"FLUX.1-dev": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae").to(self.device),
}
# Define the desired order of models
order = [
"stable-diffusion-v1-4",
"sd-vae-ft-mse",
"sd-vae-ft-mse (remote)",
"sdxl-vae",
"sdxl-vae (remote)",
"stable-diffusion-3-medium",
"FLUX.1-schnell",
"FLUX.1-schnell (remote)",
"FLUX.1-dev",
]
# Construct the vae_models dictionary in the specified order
vae_models = {}
for name in order:
if "(remote)" not in name:
# Local model
vae_models[name] = {"type": "local", "vae": local_vaes[name]}
else:
# Remote model
base_name = name.replace(" (remote)", "")
vae_models[name] = {
"type": "remote",
"local_vae_key": base_name,
"endpoint": self._get_endpoint(base_name),
}
return vae_models
def process_image(self, img: torch.Tensor, model_config: Dict, tolerance: float):
"""Process image through a single VAE (local or remote)"""
img_transformed = self.input_transform(img).to(self.device).unsqueeze(0)
original_base = self.base_transform(img).cpu()
if model_config["type"] == "local":
vae = model_config["vae"]
with torch.no_grad():
encoded = vae.encode(img_transformed).latent_dist.sample()
decoded = vae.decode(encoded).sample
elif model_config["type"] == "remote":
local_vae = self.vae_models[model_config["local_vae_key"]]["vae"]
with torch.no_grad():
encoded = local_vae.encode(img_transformed).latent_dist.sample()
decoded = remote_decode(
endpoint=model_config["endpoint"],
tensor=encoded,
do_scaling=False,
output_type="pt",
return_type="pt",
partial_postprocess=False,
)
decoded_transformed = self.output_transform(decoded.squeeze(0)).cpu()
reconstructed = decoded_transformed.clip(0, 1)
diff = (original_base - reconstructed).abs()
bw_diff = (diff > tolerance).any(dim=0).float()
diff_image = transforms.ToPILImage()(bw_diff)
recon_image = transforms.ToPILImage()(reconstructed)
diff_score = bw_diff.sum().item()
return diff_image, recon_image, diff_score
def process_all_models(self, img: torch.Tensor, tolerance: float):
"""Process image through all configured VAEs"""
results = {}
for name, model_config in self.vae_models.items():
diff_img, recon_img, score = self.process_image(img, model_config, tolerance)
results[name] = (diff_img, recon_img, score)
return results
@spaces.GPU(duration=15)
def test_all_vaes(image_path: str, tolerance: float, img_size: int):
"""Gradio interface function to test all VAEs"""
tester = VAETester(img_size=img_size)
try:
img_tensor = read_image(image_path)
results = tester.process_all_models(img_tensor, tolerance)
diff_images = []
recon_images = []
scores = []
for name in tester.vae_models.keys():
diff_img, recon_img, score = results[name]
diff_images.append((diff_img, name))
recon_images.append((recon_img, name))
scores.append(f"{name:<25}: {score:,.0f}")
return diff_images, recon_images, "\n".join(scores)
except Exception as e:
error_msg = f"Error: {str(e)}"
return [None], [None], error_msg
examples = [f"examples/{img_filename}" for img_filename in sorted(os.listdir("examples/"))]
with gr.Blocks(title="VAE Performance Tester", css=".monospace-text {font-family: 'Courier New', Courier, monospace;}") as demo:
gr.Markdown("# VAE Comparison Tool")
gr.Markdown("""
Upload an image or select an example to compare how different VAEs reconstruct it. Now includes remote VAEs via Hugging Face's remote decoding feature!
1. The image is padded to a square and resized to the selected size (512 or 1024 pixels).
2. Each VAE (local or remote) encodes the image into a latent space and decodes it back.
3. Outputs include:
- **Difference Maps**: Where reconstruction differs from the original (white = difference > tolerance).
- **Reconstructed Images**: Outputs from each VAE.
- **Sum of Differences**: Total pixels exceeding tolerance (lower is better).
Adjust tolerance to change sensitivity.
""")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="filepath", label="Input Image", height=512)
tolerance_slider = gr.Slider(
minimum=0.01,
maximum=0.5,
value=0.1,
step=0.01,
label="Difference Tolerance",
info="Low (0.01): Sensitive to small changes. High (0.5): Only large changes flagged."
)
img_size = gr.Dropdown(label="Image Size", choices=[512, 1024], value=512)
submit_btn = gr.Button("Test All VAEs")
with gr.Column(scale=3):
with gr.Row():
diff_gallery = gr.Gallery(label="Difference Maps", columns=4, height=512)
recon_gallery = gr.Gallery(label="Reconstructed Images", columns=4, height=512)
scores_output = gr.Textbox(label="Sum of differences (lower is better)", lines=9, elem_classes="monospace-text")
if examples:
with gr.Row():
gr.Examples(examples=examples, inputs=image_input, label="Example Images")
submit_btn.click(
fn=test_all_vaes,
inputs=[image_input, tolerance_slider, img_size],
outputs=[diff_gallery, recon_gallery, scores_output]
)
if __name__ == "__main__":
demo.launch()
|