File size: 6,943 Bytes
2a90016
 
 
 
 
3050d2d
0dc0490
 
2a90016
0dc0490
84e7576
0dc0490
2a90016
3050d2d
 
 
 
 
 
 
 
 
 
2a90016
 
 
 
3050d2d
2a90016
 
 
 
 
3050d2d
2a90016
 
 
 
 
 
 
 
 
 
 
3050d2d
 
 
 
2a90016
 
 
 
 
 
 
 
 
 
 
0dc0490
2a90016
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0dc0490
2a90016
 
 
 
 
 
 
 
 
 
 
0dc0490
2a90016
 
 
 
 
 
 
 
 
 
 
3050d2d
 
 
2a90016
3050d2d
2a90016
 
 
3050d2d
2a90016
3050d2d
2a90016
 
3050d2d
 
 
 
 
 
 
 
 
 
 
 
2a90016
 
 
 
 
 
 
 
 
3050d2d
 
2a90016
 
 
 
 
 
 
3050d2d
 
 
 
 
 
 
 
 
2a90016
 
 
 
 
 
 
 
 
3050d2d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import gradio as gr
import torch
from diffusers import AutoencoderKL
import torchvision.transforms.v2 as transforms
from torchvision.io import read_image
from typing import Dict
import os
from huggingface_hub import login

# Get token from environment variable
hf_token = os.getenv("access_token")
login(token=hf_token)

class PadToSquare:
    """Custom transform to pad an image to square dimensions"""
    def __call__(self, img):
        _, h, w = img.shape  # Get the original dimensions
        max_side = max(h, w)
        pad_h = (max_side - h) // 2
        pad_w = (max_side - w) // 2
        padding = (pad_w, pad_h, max_side - w - pad_w, max_side - h - pad_h)
        return transforms.functional.pad(img, padding, padding_mode="edge")

class VAETester:
    def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
        self.device = device
        self.input_transform = transforms.Compose([
            PadToSquare(),
            transforms.Resize((512, 512), antialias=True),
            transforms.ToDtype(torch.float32, scale=True),
            transforms.Normalize(mean=[0.5], std=[0.5]),
        ])
        self.base_transform = transforms.Compose([
            PadToSquare(),
            transforms.Resize((512, 512), antialias=True),
            transforms.ToDtype(torch.float32, scale=True),
        ])
        self.output_transform = transforms.Normalize(mean=[-1], std=[2])

        # Load all VAE models at initialization
        self.vae_models = self._load_all_vaes()

    def _load_all_vaes(self) -> Dict[str, AutoencoderKL]:
        """Load all available VAE models"""
        vae_configs = {
            "stable-diffusion-v1-4": ("CompVis/stable-diffusion-v1-4", "vae"),
            "sd-vae-ft-mse": ("stabilityai/sd-vae-ft-mse", ""),
            "sdxl-vae": ("stabilityai/sdxl-vae", ""),
            "stable-diffusion-3-medium": ("stabilityai/stable-diffusion-3-medium-diffusers", "vae"),
            "FLUX.1-dev": ("black-forest-labs/FLUX.1-dev", "vae")
        }

        vae_dict = {}
        for name, (path, subfolder) in vae_configs.items():
            vae_dict[name] = AutoencoderKL.from_pretrained(path, subfolder=subfolder).to(self.device)
        return vae_dict

    def process_image(self,
                      img: torch.Tensor,
                      vae: AutoencoderKL,
                      tolerance: float):
        """Process image through a single VAE"""
        img_transformed = self.input_transform(img).to(self.device).unsqueeze(0)
        original_base = self.base_transform(img).cpu()

        with torch.no_grad():
            encoded = vae.encode(img_transformed).latent_dist.sample()
            encoded_scaled = encoded * vae.config.scaling_factor
            decoded = vae.decode(encoded_scaled / vae.config.scaling_factor).sample

        decoded_transformed = self.output_transform(decoded.squeeze(0)).cpu()
        reconstructed = decoded_transformed.clip(0, 1)

        diff = (original_base - reconstructed).abs()
        bw_diff = (diff > tolerance).any(dim=0).float()

        diff_image = transforms.ToPILImage()(bw_diff)
        recon_image = transforms.ToPILImage()(reconstructed)
        diff_score = bw_diff.sum().item()

        return diff_image, recon_image, diff_score

    def process_all_models(self,
                           img: torch.Tensor,
                           tolerance: float):
        """Process image through all loaded VAEs"""
        results = {}
        for name, vae in self.vae_models.items():
            diff_img, recon_img, score = self.process_image(img, vae, tolerance)
            results[name] = (diff_img, recon_img, score)
        return results

# Initialize tester
tester = VAETester()


def test_all_vaes(image_path: str, tolerance: float):
    """Gradio interface function to test all VAEs"""
    try:
        img_tensor = read_image(image_path)
        results = tester.process_all_models(img_tensor, tolerance)

        diff_images = []
        recon_images = []
        scores = []

        for name in tester.vae_models.keys():
            diff_img, recon_img, score = results[name]
            diff_images.append((diff_img, name))
            recon_images.append((recon_img, name))
            scores.append(f"{name:<25}: {score:.1f}")

        return diff_images, recon_images, "\n".join(scores)

    except Exception as e:
        error_msg = f"Error: {str(e)}"
        return [None], [None], error_msg

examples = [f"examples/{img_filename}" for img_filename in sorted(os.listdir("examples/"))]

# Gradio interface
with gr.Blocks(title="VAE Performance Tester", css=".monospace-text {font-family: 'Courier New', Courier, monospace;}") as demo:
    gr.Markdown("# VAE Comparison Tool")
    gr.Markdown("""
        Upload an image or select an example to compare how different VAEs reconstruct it. Here's what happens:
        1. The image is padded to a square and resized to 512x512 pixels.
        2. Each VAE encodes the image into a latent space and decodes it back.
        3. The tool then generates:
           - **Difference Maps**: Black-and-white images showing where the reconstruction differs from the original (white areas indicate differences above the tolerance threshold).
           - **Reconstructed Images**: The outputs from each VAE.
           - **Sum of Differences**: A numerical score for each VAE, measuring the total difference in pixels exceeding the tolerance.
        Use the tolerance slider to adjust the sensitivity.
    """)

    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(type="filepath", label="Input Image", height=512)
            tolerance_slider = gr.Slider(
                minimum=0.01,
                maximum=0.5,
                value=0.1,
                step=0.01,
                label="Difference Tolerance",
                info="Low tolerance (e.g., 0.01): Highly sensitive, flags small deviations. High tolerance (e.g., 0.5): Less sensitive, flags only large deviations, showing fewer differences.",
            )
            submit_btn = gr.Button("Test All VAEs")

        with gr.Column(scale=3):
            with gr.Row():
                diff_gallery = gr.Gallery(label="Difference Maps", columns=4, height=512)
                recon_gallery = gr.Gallery(label="Reconstructed Images", columns=4, height=512)
            scores_output = gr.Textbox(label="Sum of difference (lower is better reconstruction)", lines=5, elem_classes="monospace-text")

        if examples:
            with gr.Column():
                example_gallery = gr.Examples(
                    examples=examples,
                    inputs=image_input,
                    label="Example Images"
                )

    submit_btn.click(
        fn=test_all_vaes,
        inputs=[image_input, tolerance_slider],
        outputs=[diff_gallery, recon_gallery, scores_output]
    )

if __name__ == "__main__":
    demo.launch()