rizavelioglu commited on
Commit
3050d2d
·
1 Parent(s): 85dc5d9

add explanations, fix img processing, add another vae, examples

Browse files
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  from diffusers import AutoencoderKL
4
  import torchvision.transforms.v2 as transforms
5
  from torchvision.io import read_image
6
- from typing import Tuple, Dict, List
7
  import os
8
  from huggingface_hub import login
9
 
@@ -11,17 +11,27 @@ from huggingface_hub import login
11
  hf_token = os.getenv("access_token")
12
  login(token=hf_token)
13
 
 
 
 
 
 
 
 
 
 
 
14
  class VAETester:
15
  def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
16
  self.device = device
17
  self.input_transform = transforms.Compose([
18
- transforms.Pad(padding=[128, 0], padding_mode="edge"),
19
  transforms.Resize((512, 512), antialias=True),
20
  transforms.ToDtype(torch.float32, scale=True),
21
  transforms.Normalize(mean=[0.5], std=[0.5]),
22
  ])
23
  self.base_transform = transforms.Compose([
24
- transforms.Pad(padding=[128, 0], padding_mode="edge"),
25
  transforms.Resize((512, 512), antialias=True),
26
  transforms.ToDtype(torch.float32, scale=True),
27
  ])
@@ -33,9 +43,10 @@ class VAETester:
33
  def _load_all_vaes(self) -> Dict[str, AutoencoderKL]:
34
  """Load all available VAE models"""
35
  vae_configs = {
36
- "Stable Diffusion 3 Medium": ("stabilityai/stable-diffusion-3-medium-diffusers", "vae"),
37
- "Stable Diffusion v1-4": ("CompVis/stable-diffusion-v1-4", "vae"),
38
- "SD VAE FT-MSE": ("stabilityai/sd-vae-ft-mse", ""),
 
39
  "FLUX.1-dev": ("black-forest-labs/FLUX.1-dev", "vae")
40
  }
41
 
@@ -79,7 +90,6 @@ class VAETester:
79
  results[name] = (diff_img, recon_img, score)
80
  return results
81
 
82
-
83
  # Initialize tester
84
  tester = VAETester()
85
 
@@ -96,21 +106,31 @@ def test_all_vaes(image_path: str, tolerance: float):
96
 
97
  for name in tester.vae_models.keys():
98
  diff_img, recon_img, score = results[name]
99
- diff_images.append(diff_img)
100
- recon_images.append(recon_img)
101
- scores.append(f"{name}: {score:.2f}")
102
 
103
- return diff_images, recon_images, scores
104
 
105
  except Exception as e:
106
  error_msg = f"Error: {str(e)}"
107
- return [None], [None], [error_msg]
108
 
 
109
 
110
  # Gradio interface
111
- with gr.Blocks(title="VAE Performance Tester") as demo:
112
- gr.Markdown("# VAE Performance Testing Tool")
113
- gr.Markdown("Upload an image to compare all VAE models simultaneously")
 
 
 
 
 
 
 
 
 
114
 
115
  with gr.Row():
116
  with gr.Column(scale=1):
@@ -120,7 +140,8 @@ with gr.Blocks(title="VAE Performance Tester") as demo:
120
  maximum=0.5,
121
  value=0.1,
122
  step=0.01,
123
- label="Difference Tolerance"
 
124
  )
125
  submit_btn = gr.Button("Test All VAEs")
126
 
@@ -128,7 +149,15 @@ with gr.Blocks(title="VAE Performance Tester") as demo:
128
  with gr.Row():
129
  diff_gallery = gr.Gallery(label="Difference Maps", columns=4, height=512)
130
  recon_gallery = gr.Gallery(label="Reconstructed Images", columns=4, height=512)
131
- scores_output = gr.Textbox(label="Difference Scores", lines=4)
 
 
 
 
 
 
 
 
132
 
133
  submit_btn.click(
134
  fn=test_all_vaes,
@@ -138,3 +167,4 @@ with gr.Blocks(title="VAE Performance Tester") as demo:
138
 
139
  if __name__ == "__main__":
140
  demo.launch()
 
 
3
  from diffusers import AutoencoderKL
4
  import torchvision.transforms.v2 as transforms
5
  from torchvision.io import read_image
6
+ from typing import Dict
7
  import os
8
  from huggingface_hub import login
9
 
 
11
  hf_token = os.getenv("access_token")
12
  login(token=hf_token)
13
 
14
+ class PadToSquare:
15
+ """Custom transform to pad an image to square dimensions"""
16
+ def __call__(self, img):
17
+ _, h, w = img.shape # Get the original dimensions
18
+ max_side = max(h, w)
19
+ pad_h = (max_side - h) // 2
20
+ pad_w = (max_side - w) // 2
21
+ padding = (pad_w, pad_h, max_side - w - pad_w, max_side - h - pad_h)
22
+ return transforms.functional.pad(img, padding, padding_mode="edge")
23
+
24
  class VAETester:
25
  def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
26
  self.device = device
27
  self.input_transform = transforms.Compose([
28
+ PadToSquare(),
29
  transforms.Resize((512, 512), antialias=True),
30
  transforms.ToDtype(torch.float32, scale=True),
31
  transforms.Normalize(mean=[0.5], std=[0.5]),
32
  ])
33
  self.base_transform = transforms.Compose([
34
+ PadToSquare(),
35
  transforms.Resize((512, 512), antialias=True),
36
  transforms.ToDtype(torch.float32, scale=True),
37
  ])
 
43
  def _load_all_vaes(self) -> Dict[str, AutoencoderKL]:
44
  """Load all available VAE models"""
45
  vae_configs = {
46
+ "stable-diffusion-v1-4": ("CompVis/stable-diffusion-v1-4", "vae"),
47
+ "sd-vae-ft-mse": ("stabilityai/sd-vae-ft-mse", ""),
48
+ "sdxl-vae": ("stabilityai/sdxl-vae", ""),
49
+ "stable-diffusion-3-medium": ("stabilityai/stable-diffusion-3-medium-diffusers", "vae"),
50
  "FLUX.1-dev": ("black-forest-labs/FLUX.1-dev", "vae")
51
  }
52
 
 
90
  results[name] = (diff_img, recon_img, score)
91
  return results
92
 
 
93
  # Initialize tester
94
  tester = VAETester()
95
 
 
106
 
107
  for name in tester.vae_models.keys():
108
  diff_img, recon_img, score = results[name]
109
+ diff_images.append((diff_img, name))
110
+ recon_images.append((recon_img, name))
111
+ scores.append(f"{name:<25}: {score:.1f}")
112
 
113
+ return diff_images, recon_images, "\n".join(scores)
114
 
115
  except Exception as e:
116
  error_msg = f"Error: {str(e)}"
117
+ return [None], [None], error_msg
118
 
119
+ examples = [f"examples/{img_filename}" for img_filename in sorted(os.listdir("examples/"))]
120
 
121
  # Gradio interface
122
+ with gr.Blocks(title="VAE Performance Tester", css=".monospace-text {font-family: 'Courier New', Courier, monospace;}") as demo:
123
+ gr.Markdown("# VAE Comparison Tool")
124
+ gr.Markdown("""
125
+ Upload an image or select an example to compare how different VAEs reconstruct it. Here's what happens:
126
+ 1. The image is padded to a square and resized to 512x512 pixels.
127
+ 2. Each VAE encodes the image into a latent space and decodes it back.
128
+ 3. The tool then generates:
129
+ - **Difference Maps**: Black-and-white images showing where the reconstruction differs from the original (white areas indicate differences above the tolerance threshold).
130
+ - **Reconstructed Images**: The outputs from each VAE.
131
+ - **Sum of Differences**: A numerical score for each VAE, measuring the total difference in pixels exceeding the tolerance.
132
+ Use the tolerance slider to adjust the sensitivity.
133
+ """)
134
 
135
  with gr.Row():
136
  with gr.Column(scale=1):
 
140
  maximum=0.5,
141
  value=0.1,
142
  step=0.01,
143
+ label="Difference Tolerance",
144
+ info="Low tolerance (e.g., 0.01): Highly sensitive, flags small deviations. High tolerance (e.g., 0.5): Less sensitive, flags only large deviations, showing fewer differences.",
145
  )
146
  submit_btn = gr.Button("Test All VAEs")
147
 
 
149
  with gr.Row():
150
  diff_gallery = gr.Gallery(label="Difference Maps", columns=4, height=512)
151
  recon_gallery = gr.Gallery(label="Reconstructed Images", columns=4, height=512)
152
+ scores_output = gr.Textbox(label="Sum of difference (lower is better reconstruction)", lines=5, elem_classes="monospace-text")
153
+
154
+ if examples:
155
+ with gr.Column():
156
+ example_gallery = gr.Examples(
157
+ examples=examples,
158
+ inputs=image_input,
159
+ label="Example Images"
160
+ )
161
 
162
  submit_btn.click(
163
  fn=test_all_vaes,
 
167
 
168
  if __name__ == "__main__":
169
  demo.launch()
170
+
examples/01967_00.jpg ADDED
examples/03032_00.jpg ADDED
examples/048395_0.jpg ADDED
examples/048399_0.jpg ADDED
examples/048400_0.jpg ADDED
examples/048410_0.jpg ADDED
examples/048436_0.jpg ADDED
examples/051807_0.jpg ADDED
examples/051808_0.jpg ADDED
examples/051836_0.jpg ADDED
examples/053055_0.jpg ADDED
examples/053114_0.jpg ADDED
examples/053137_0.jpg ADDED
examples/07089_00.jpg ADDED
examples/13136_00.jpg ADDED
examples/13331_00.jpg ADDED
examples/13988_00.jpg ADDED
examples/14009_00.jpg ADDED
examples/14022_00.jpg ADDED
examples/14533_00.jpg ADDED