rizavelioglu commited on
Commit
5558320
·
1 Parent(s): 08198f0

add image size dropdown

Browse files
Files changed (1) hide show
  1. app.py +14 -11
app.py CHANGED
@@ -24,17 +24,17 @@ class PadToSquare:
24
  return transforms.functional.pad(img, padding, padding_mode="edge")
25
 
26
  class VAETester:
27
- def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
28
  self.device = device
29
  self.input_transform = transforms.Compose([
30
  PadToSquare(),
31
- transforms.Resize((512, 512), antialias=True),
32
  transforms.ToDtype(torch.float32, scale=True),
33
  transforms.Normalize(mean=[0.5], std=[0.5]),
34
  ])
35
  self.base_transform = transforms.Compose([
36
  PadToSquare(),
37
- transforms.Resize((512, 512), antialias=True),
38
  transforms.ToDtype(torch.float32, scale=True),
39
  ])
40
  self.output_transform = transforms.Normalize(mean=[-1], std=[2])
@@ -67,8 +67,7 @@ class VAETester:
67
 
68
  with torch.no_grad():
69
  encoded = vae.encode(img_transformed).latent_dist.sample()
70
- encoded_scaled = encoded * vae.config.scaling_factor
71
- decoded = vae.decode(encoded_scaled / vae.config.scaling_factor).sample
72
 
73
  decoded_transformed = self.output_transform(decoded.squeeze(0)).cpu()
74
  reconstructed = decoded_transformed.clip(0, 1)
@@ -92,12 +91,12 @@ class VAETester:
92
  results[name] = (diff_img, recon_img, score)
93
  return results
94
 
95
- # Initialize tester
96
- tester = VAETester()
97
 
98
  @spaces.GPU(duration=5)
99
- def test_all_vaes(image_path: str, tolerance: float):
100
  """Gradio interface function to test all VAEs"""
 
 
101
  try:
102
  img_tensor = read_image(image_path)
103
  results = tester.process_all_models(img_tensor, tolerance)
@@ -110,7 +109,7 @@ def test_all_vaes(image_path: str, tolerance: float):
110
  diff_img, recon_img, score = results[name]
111
  diff_images.append((diff_img, name))
112
  recon_images.append((recon_img, name))
113
- scores.append(f"{name:<25}: {score:.1f}")
114
 
115
  return diff_images, recon_images, "\n".join(scores)
116
 
@@ -125,7 +124,7 @@ with gr.Blocks(title="VAE Performance Tester", css=".monospace-text {font-family
125
  gr.Markdown("# VAE Comparison Tool")
126
  gr.Markdown("""
127
  Upload an image or select an example to compare how different VAEs reconstruct it. Here's what happens:
128
- 1. The image is padded to a square and resized to 512x512 pixels.
129
  2. Each VAE encodes the image into a latent space and decodes it back.
130
  3. The tool then generates:
131
  - **Difference Maps**: Black-and-white images showing where the reconstruction differs from the original (white areas indicate differences above the tolerance threshold).
@@ -145,6 +144,10 @@ with gr.Blocks(title="VAE Performance Tester", css=".monospace-text {font-family
145
  label="Difference Tolerance",
146
  info="Low tolerance (e.g., 0.01): Highly sensitive, flags small deviations. High tolerance (e.g., 0.5): Less sensitive, flags only large deviations, showing fewer differences.",
147
  )
 
 
 
 
148
  submit_btn = gr.Button("Test All VAEs")
149
 
150
  with gr.Column(scale=3):
@@ -163,7 +166,7 @@ with gr.Blocks(title="VAE Performance Tester", css=".monospace-text {font-family
163
 
164
  submit_btn.click(
165
  fn=test_all_vaes,
166
- inputs=[image_input, tolerance_slider],
167
  outputs=[diff_gallery, recon_gallery, scores_output]
168
  )
169
 
 
24
  return transforms.functional.pad(img, padding, padding_mode="edge")
25
 
26
  class VAETester:
27
+ def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu", img_size: int = 512):
28
  self.device = device
29
  self.input_transform = transforms.Compose([
30
  PadToSquare(),
31
+ transforms.Resize((img_size, img_size)),
32
  transforms.ToDtype(torch.float32, scale=True),
33
  transforms.Normalize(mean=[0.5], std=[0.5]),
34
  ])
35
  self.base_transform = transforms.Compose([
36
  PadToSquare(),
37
+ transforms.Resize((img_size, img_size)),
38
  transforms.ToDtype(torch.float32, scale=True),
39
  ])
40
  self.output_transform = transforms.Normalize(mean=[-1], std=[2])
 
67
 
68
  with torch.no_grad():
69
  encoded = vae.encode(img_transformed).latent_dist.sample()
70
+ decoded = vae.decode(encoded).sample
 
71
 
72
  decoded_transformed = self.output_transform(decoded.squeeze(0)).cpu()
73
  reconstructed = decoded_transformed.clip(0, 1)
 
91
  results[name] = (diff_img, recon_img, score)
92
  return results
93
 
 
 
94
 
95
  @spaces.GPU(duration=5)
96
+ def test_all_vaes(image_path: str, tolerance: float, img_size: int):
97
  """Gradio interface function to test all VAEs"""
98
+ # Initialize tester
99
+ tester = VAETester(img_size=img_size)
100
  try:
101
  img_tensor = read_image(image_path)
102
  results = tester.process_all_models(img_tensor, tolerance)
 
109
  diff_img, recon_img, score = results[name]
110
  diff_images.append((diff_img, name))
111
  recon_images.append((recon_img, name))
112
+ scores.append(f"{name:<25}: {score:,.0f}")
113
 
114
  return diff_images, recon_images, "\n".join(scores)
115
 
 
124
  gr.Markdown("# VAE Comparison Tool")
125
  gr.Markdown("""
126
  Upload an image or select an example to compare how different VAEs reconstruct it. Here's what happens:
127
+ 1. The image is padded to a square and resized to `512x512` pixels (can change using `Image Size` dropdown).
128
  2. Each VAE encodes the image into a latent space and decodes it back.
129
  3. The tool then generates:
130
  - **Difference Maps**: Black-and-white images showing where the reconstruction differs from the original (white areas indicate differences above the tolerance threshold).
 
144
  label="Difference Tolerance",
145
  info="Low tolerance (e.g., 0.01): Highly sensitive, flags small deviations. High tolerance (e.g., 0.5): Less sensitive, flags only large deviations, showing fewer differences.",
146
  )
147
+ img_size = gr.Dropdown(
148
+ label="Image Size",
149
+ choices=[512, 1024],
150
+ )
151
  submit_btn = gr.Button("Test All VAEs")
152
 
153
  with gr.Column(scale=3):
 
166
 
167
  submit_btn.click(
168
  fn=test_all_vaes,
169
+ inputs=[image_input, tolerance_slider, img_size],
170
  outputs=[diff_gallery, recon_gallery, scores_output]
171
  )
172