rizavelioglu commited on
Commit
2a90016
·
1 Parent(s): 1259fea

add app file

Browse files
Files changed (1) hide show
  1. app.py +136 -0
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import AutoencoderKL
4
+ from PIL import Image
5
+ import torchvision.transforms.v2 as transforms
6
+ from torchvision.io import read_image
7
+ from typing import Tuple, Dict, List
8
+
9
+
10
+ class VAETester:
11
+ def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
12
+ self.device = device
13
+ self.input_transform = transforms.Compose([
14
+ transforms.Pad(padding=[128, 0], padding_mode="edge"),
15
+ transforms.Resize((512, 512), antialias=True),
16
+ transforms.ToDtype(torch.float32, scale=True),
17
+ transforms.Normalize(mean=[0.5], std=[0.5]),
18
+ ])
19
+ self.base_transform = transforms.Compose([
20
+ transforms.Pad(padding=[128, 0], padding_mode="edge"),
21
+ transforms.Resize((512, 512), antialias=True),
22
+ transforms.ToDtype(torch.float32, scale=True),
23
+ ])
24
+ self.output_transform = transforms.Normalize(mean=[-1], std=[2])
25
+
26
+ # Load all VAE models at initialization
27
+ self.vae_models = self._load_all_vaes()
28
+
29
+ def _load_all_vaes(self) -> Dict[str, AutoencoderKL]:
30
+ """Load all available VAE models"""
31
+ vae_configs = {
32
+ "Stable Diffusion 3 Medium": ("stabilityai/stable-diffusion-3-medium-diffusers", "vae"),
33
+ "Stable Diffusion v1-4": ("CompVis/stable-diffusion-v1-4", "vae"),
34
+ "SD VAE FT-MSE": ("stabilityai/sd-vae-ft-mse", ""),
35
+ "FLUX.1-dev": ("black-forest-labs/FLUX.1-dev", "vae")
36
+ }
37
+
38
+ vae_dict = {}
39
+ for name, (path, subfolder) in vae_configs.items():
40
+ vae_dict[name] = AutoencoderKL.from_pretrained(path, subfolder=subfolder).to(self.device)
41
+ return vae_dict
42
+
43
+ def process_image(self,
44
+ img: torch.Tensor,
45
+ vae: AutoencoderKL,
46
+ tolerance: float) -> Tuple[Image.Image, Image.Image, float]:
47
+ """Process image through a single VAE"""
48
+ img_transformed = self.input_transform(img).to(self.device).unsqueeze(0)
49
+ original_base = self.base_transform(img).cpu()
50
+
51
+ with torch.no_grad():
52
+ encoded = vae.encode(img_transformed).latent_dist.sample()
53
+ encoded_scaled = encoded * vae.config.scaling_factor
54
+ decoded = vae.decode(encoded_scaled / vae.config.scaling_factor).sample
55
+
56
+ decoded_transformed = self.output_transform(decoded.squeeze(0)).cpu()
57
+ reconstructed = decoded_transformed.clip(0, 1)
58
+
59
+ diff = (original_base - reconstructed).abs()
60
+ bw_diff = (diff > tolerance).any(dim=0).float()
61
+
62
+ diff_image = transforms.ToPILImage()(bw_diff)
63
+ recon_image = transforms.ToPILImage()(reconstructed)
64
+ diff_score = bw_diff.sum().item()
65
+
66
+ return diff_image, recon_image, diff_score
67
+
68
+ def process_all_models(self,
69
+ img: torch.Tensor,
70
+ tolerance: float) -> Dict[str, Tuple[Image.Image, Image.Image, float]]:
71
+ """Process image through all loaded VAEs"""
72
+ results = {}
73
+ for name, vae in self.vae_models.items():
74
+ diff_img, recon_img, score = self.process_image(img, vae, tolerance)
75
+ results[name] = (diff_img, recon_img, score)
76
+ return results
77
+
78
+
79
+ # Initialize tester
80
+ tester = VAETester()
81
+
82
+
83
+ def test_all_vaes(image_path: str, tolerance: float) -> Tuple[List[Image.Image], List[Image.Image], List[str]]:
84
+ """Gradio interface function to test all VAEs"""
85
+ try:
86
+ img_tensor = read_image(image_path)
87
+ results = tester.process_all_models(img_tensor, tolerance)
88
+
89
+ diff_images = []
90
+ recon_images = []
91
+ scores = []
92
+
93
+ for name in tester.vae_models.keys():
94
+ diff_img, recon_img, score = results[name]
95
+ diff_images.append(diff_img)
96
+ recon_images.append(recon_img)
97
+ scores.append(f"{name}: {score:.2f}")
98
+
99
+ return diff_images, recon_images, scores
100
+
101
+ except Exception as e:
102
+ error_msg = f"Error: {str(e)}"
103
+ return [None], [None], [error_msg]
104
+
105
+
106
+ # Gradio interface
107
+ with gr.Blocks(title="VAE Performance Tester") as demo:
108
+ gr.Markdown("# VAE Performance Testing Tool")
109
+ gr.Markdown("Upload an image to compare all VAE models simultaneously")
110
+
111
+ with gr.Row():
112
+ with gr.Column(scale=1):
113
+ image_input = gr.Image(type="filepath", label="Input Image", height=512)
114
+ tolerance_slider = gr.Slider(
115
+ minimum=0.01,
116
+ maximum=0.5,
117
+ value=0.1,
118
+ step=0.01,
119
+ label="Difference Tolerance"
120
+ )
121
+ submit_btn = gr.Button("Test All VAEs")
122
+
123
+ with gr.Column(scale=3):
124
+ with gr.Row():
125
+ diff_gallery = gr.Gallery(label="Difference Maps", columns=4, height=512)
126
+ recon_gallery = gr.Gallery(label="Reconstructed Images", columns=4, height=512)
127
+ scores_output = gr.Textbox(label="Difference Scores", lines=4)
128
+
129
+ submit_btn.click(
130
+ fn=test_all_vaes,
131
+ inputs=[image_input, tolerance_slider],
132
+ outputs=[diff_gallery, recon_gallery, scores_output]
133
+ )
134
+
135
+ if __name__ == "__main__":
136
+ demo.launch()