File size: 5,176 Bytes
2a90016
 
 
 
 
 
0dc0490
 
2a90016
0dc0490
84e7576
0dc0490
2a90016
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0dc0490
2a90016
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0dc0490
2a90016
 
 
 
 
 
 
 
 
 
 
 
0dc0490
2a90016
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import gradio as gr
import torch
from diffusers import AutoencoderKL
import torchvision.transforms.v2 as transforms
from torchvision.io import read_image
from typing import Tuple, Dict, List
import os
from huggingface_hub import login

# Get token from environment variable
hf_token = os.getenv("access_token")
login(token=hf_token)

class VAETester:
    def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
        self.device = device
        self.input_transform = transforms.Compose([
            transforms.Pad(padding=[128, 0], padding_mode="edge"),
            transforms.Resize((512, 512), antialias=True),
            transforms.ToDtype(torch.float32, scale=True),
            transforms.Normalize(mean=[0.5], std=[0.5]),
        ])
        self.base_transform = transforms.Compose([
            transforms.Pad(padding=[128, 0], padding_mode="edge"),
            transforms.Resize((512, 512), antialias=True),
            transforms.ToDtype(torch.float32, scale=True),
        ])
        self.output_transform = transforms.Normalize(mean=[-1], std=[2])

        # Load all VAE models at initialization
        self.vae_models = self._load_all_vaes()

    def _load_all_vaes(self) -> Dict[str, AutoencoderKL]:
        """Load all available VAE models"""
        vae_configs = {
            "Stable Diffusion 3 Medium": ("stabilityai/stable-diffusion-3-medium-diffusers", "vae"),
            "Stable Diffusion v1-4": ("CompVis/stable-diffusion-v1-4", "vae"),
            "SD VAE FT-MSE": ("stabilityai/sd-vae-ft-mse", ""),
            "FLUX.1-dev": ("black-forest-labs/FLUX.1-dev", "vae")
        }

        vae_dict = {}
        for name, (path, subfolder) in vae_configs.items():
            vae_dict[name] = AutoencoderKL.from_pretrained(path, subfolder=subfolder).to(self.device)
        return vae_dict

    def process_image(self,
                      img: torch.Tensor,
                      vae: AutoencoderKL,
                      tolerance: float):
        """Process image through a single VAE"""
        img_transformed = self.input_transform(img).to(self.device).unsqueeze(0)
        original_base = self.base_transform(img).cpu()

        with torch.no_grad():
            encoded = vae.encode(img_transformed).latent_dist.sample()
            encoded_scaled = encoded * vae.config.scaling_factor
            decoded = vae.decode(encoded_scaled / vae.config.scaling_factor).sample

        decoded_transformed = self.output_transform(decoded.squeeze(0)).cpu()
        reconstructed = decoded_transformed.clip(0, 1)

        diff = (original_base - reconstructed).abs()
        bw_diff = (diff > tolerance).any(dim=0).float()

        diff_image = transforms.ToPILImage()(bw_diff)
        recon_image = transforms.ToPILImage()(reconstructed)
        diff_score = bw_diff.sum().item()

        return diff_image, recon_image, diff_score

    def process_all_models(self,
                           img: torch.Tensor,
                           tolerance: float):
        """Process image through all loaded VAEs"""
        results = {}
        for name, vae in self.vae_models.items():
            diff_img, recon_img, score = self.process_image(img, vae, tolerance)
            results[name] = (diff_img, recon_img, score)
        return results


# Initialize tester
tester = VAETester()


def test_all_vaes(image_path: str, tolerance: float):
    """Gradio interface function to test all VAEs"""
    try:
        img_tensor = read_image(image_path)
        results = tester.process_all_models(img_tensor, tolerance)

        diff_images = []
        recon_images = []
        scores = []

        for name in tester.vae_models.keys():
            diff_img, recon_img, score = results[name]
            diff_images.append(diff_img)
            recon_images.append(recon_img)
            scores.append(f"{name}: {score:.2f}")

        return diff_images, recon_images, scores

    except Exception as e:
        error_msg = f"Error: {str(e)}"
        return [None], [None], [error_msg]


# Gradio interface
with gr.Blocks(title="VAE Performance Tester") as demo:
    gr.Markdown("# VAE Performance Testing Tool")
    gr.Markdown("Upload an image to compare all VAE models simultaneously")

    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(type="filepath", label="Input Image", height=512)
            tolerance_slider = gr.Slider(
                minimum=0.01,
                maximum=0.5,
                value=0.1,
                step=0.01,
                label="Difference Tolerance"
            )
            submit_btn = gr.Button("Test All VAEs")

        with gr.Column(scale=3):
            with gr.Row():
                diff_gallery = gr.Gallery(label="Difference Maps", columns=4, height=512)
                recon_gallery = gr.Gallery(label="Reconstructed Images", columns=4, height=512)
            scores_output = gr.Textbox(label="Difference Scores", lines=4)

    submit_btn.click(
        fn=test_all_vaes,
        inputs=[image_input, tolerance_slider],
        outputs=[diff_gallery, recon_gallery, scores_output]
    )

if __name__ == "__main__":
    demo.launch()