rizavelioglu commited on
Commit
46241ec
·
1 Parent(s): 8b4895b

add support for remote VAE-decoding

Browse files
Files changed (1) hide show
  1. app.py +85 -59
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  from diffusers import AutoencoderKL
 
4
  import torchvision.transforms.v2 as transforms
5
  from torchvision.io import read_image
6
  from typing import Dict
@@ -38,64 +39,99 @@ class VAETester:
38
  transforms.ToDtype(torch.float32, scale=True),
39
  ])
40
  self.output_transform = transforms.Normalize(mean=[-1], std=[2])
41
-
42
- # Load all VAE models at initialization
43
  self.vae_models = self._load_all_vaes()
44
 
45
- def _load_all_vaes(self) -> Dict[str, AutoencoderKL]:
46
- """Load all available VAE models"""
47
- vae_configs = {
48
- "stable-diffusion-v1-4": ("CompVis/stable-diffusion-v1-4", "vae"),
49
- "sd-vae-ft-mse": ("stabilityai/sd-vae-ft-mse", ""),
50
- "sdxl-vae": ("stabilityai/sdxl-vae", ""),
51
- "stable-diffusion-3-medium": ("stabilityai/stable-diffusion-3-medium-diffusers", "vae"),
52
- "FLUX.1-dev": ("black-forest-labs/FLUX.1-dev", "vae")
53
  }
54
-
55
- vae_dict = {}
56
- for name, (path, subfolder) in vae_configs.items():
57
- vae_dict[name] = AutoencoderKL.from_pretrained(path, subfolder=subfolder).to(self.device)
58
- return vae_dict
59
-
60
- def process_image(self,
61
- img: torch.Tensor,
62
- vae: AutoencoderKL,
63
- tolerance: float):
64
- """Process image through a single VAE"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  img_transformed = self.input_transform(img).to(self.device).unsqueeze(0)
66
  original_base = self.base_transform(img).cpu()
67
 
68
- with torch.no_grad():
69
- encoded = vae.encode(img_transformed).latent_dist.sample()
70
- decoded = vae.decode(encoded).sample
71
-
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  decoded_transformed = self.output_transform(decoded.squeeze(0)).cpu()
73
  reconstructed = decoded_transformed.clip(0, 1)
74
-
75
  diff = (original_base - reconstructed).abs()
76
  bw_diff = (diff > tolerance).any(dim=0).float()
77
-
78
  diff_image = transforms.ToPILImage()(bw_diff)
79
  recon_image = transforms.ToPILImage()(reconstructed)
80
  diff_score = bw_diff.sum().item()
81
-
82
  return diff_image, recon_image, diff_score
83
 
84
- def process_all_models(self,
85
- img: torch.Tensor,
86
- tolerance: float):
87
- """Process image through all loaded VAEs"""
88
  results = {}
89
- for name, vae in self.vae_models.items():
90
- diff_img, recon_img, score = self.process_image(img, vae, tolerance)
91
  results[name] = (diff_img, recon_img, score)
92
  return results
93
 
94
-
95
- @spaces.GPU(duration=10)
96
  def test_all_vaes(image_path: str, tolerance: float, img_size: int):
97
  """Gradio interface function to test all VAEs"""
98
- # Initialize tester
99
  tester = VAETester(img_size=img_size)
100
  try:
101
  img_tensor = read_image(image_path)
@@ -112,25 +148,23 @@ def test_all_vaes(image_path: str, tolerance: float, img_size: int):
112
  scores.append(f"{name:<25}: {score:,.0f}")
113
 
114
  return diff_images, recon_images, "\n".join(scores)
115
-
116
  except Exception as e:
117
  error_msg = f"Error: {str(e)}"
118
  return [None], [None], error_msg
119
 
120
  examples = [f"examples/{img_filename}" for img_filename in sorted(os.listdir("examples/"))]
121
 
122
- # Gradio interface
123
  with gr.Blocks(title="VAE Performance Tester", css=".monospace-text {font-family: 'Courier New', Courier, monospace;}") as demo:
124
  gr.Markdown("# VAE Comparison Tool")
125
  gr.Markdown("""
126
- Upload an image or select an example to compare how different VAEs reconstruct it. Here's what happens:
127
- 1. The image is padded to a square and resized to `512x512` pixels (can change using `Image Size` dropdown).
128
- 2. Each VAE encodes the image into a latent space and decodes it back.
129
- 3. The tool then generates:
130
- - **Difference Maps**: Black-and-white images showing where the reconstruction differs from the original (white areas indicate differences above the tolerance threshold).
131
- - **Reconstructed Images**: The outputs from each VAE.
132
- - **Sum of Differences**: A numerical score for each VAE, measuring the total difference in pixels exceeding the tolerance.
133
- Use the tolerance slider to adjust the sensitivity.
134
  """)
135
 
136
  with gr.Row():
@@ -142,27 +176,20 @@ with gr.Blocks(title="VAE Performance Tester", css=".monospace-text {font-family
142
  value=0.1,
143
  step=0.01,
144
  label="Difference Tolerance",
145
- info="Low tolerance (e.g., 0.01): Highly sensitive, flags small deviations. High tolerance (e.g., 0.5): Less sensitive, flags only large deviations, showing fewer differences.",
146
- )
147
- img_size = gr.Dropdown(
148
- label="Image Size",
149
- choices=[512, 1024],
150
  )
 
151
  submit_btn = gr.Button("Test All VAEs")
152
 
153
  with gr.Column(scale=3):
154
  with gr.Row():
155
  diff_gallery = gr.Gallery(label="Difference Maps", columns=4, height=512)
156
  recon_gallery = gr.Gallery(label="Reconstructed Images", columns=4, height=512)
157
- scores_output = gr.Textbox(label="Sum of difference (lower is better reconstruction)", lines=5, elem_classes="monospace-text")
158
 
159
  if examples:
160
  with gr.Row():
161
- example_gallery = gr.Examples(
162
- examples=examples,
163
- inputs=image_input,
164
- label="Example Images"
165
- )
166
 
167
  submit_btn.click(
168
  fn=test_all_vaes,
@@ -172,4 +199,3 @@ with gr.Blocks(title="VAE Performance Tester", css=".monospace-text {font-family
172
 
173
  if __name__ == "__main__":
174
  demo.launch()
175
-
 
1
  import gradio as gr
2
  import torch
3
  from diffusers import AutoencoderKL
4
+ from diffusers.utils.remote_utils import remote_decode
5
  import torchvision.transforms.v2 as transforms
6
  from torchvision.io import read_image
7
  from typing import Dict
 
39
  transforms.ToDtype(torch.float32, scale=True),
40
  ])
41
  self.output_transform = transforms.Normalize(mean=[-1], std=[2])
 
 
42
  self.vae_models = self._load_all_vaes()
43
 
44
+ def _get_endpoint(self, base_name: str) -> str:
45
+ """Helper method to get the endpoint for a given base model name"""
46
+ endpoints = {
47
+ "sd-vae-ft-mse": "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud",
48
+ "sdxl-vae": "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud",
49
+ "FLUX.1-schnell": "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud",
 
 
50
  }
51
+ return endpoints[base_name]
52
+
53
+ def _load_all_vaes(self) -> Dict[str, Dict]:
54
+ """Load configurations for local and remote VAE models"""
55
+ local_vaes = {
56
+ "stable-diffusion-v1-4": AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(self.device),
57
+ "sd-vae-ft-mse": AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(self.device),
58
+ "sdxl-vae": AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(self.device),
59
+ "stable-diffusion-3-medium": AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae").to(self.device),
60
+ "FLUX.1-schnell": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="vae").to(self.device),
61
+ "FLUX.1-dev": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae").to(self.device),
62
+ }
63
+ # Define the desired order of models
64
+ order = [
65
+ "stable-diffusion-v1-4",
66
+ "sd-vae-ft-mse",
67
+ "sd-vae-ft-mse (remote)",
68
+ "sdxl-vae",
69
+ "sdxl-vae (remote)",
70
+ "stable-diffusion-3-medium",
71
+ "FLUX.1-schnell",
72
+ "FLUX.1-schnell (remote)",
73
+ "FLUX.1-dev",
74
+ ]
75
+
76
+ # Construct the vae_models dictionary in the specified order
77
+ vae_models = {}
78
+ for name in order:
79
+ if "(remote)" not in name:
80
+ # Local model
81
+ vae_models[name] = {"type": "local", "vae": local_vaes[name]}
82
+ else:
83
+ # Remote model
84
+ base_name = name.replace(" (remote)", "")
85
+ vae_models[name] = {
86
+ "type": "remote",
87
+ "local_vae_key": base_name,
88
+ "endpoint": self._get_endpoint(base_name),
89
+ }
90
+
91
+ return vae_models
92
+
93
+ def process_image(self, img: torch.Tensor, model_config: Dict, tolerance: float):
94
+ """Process image through a single VAE (local or remote)"""
95
  img_transformed = self.input_transform(img).to(self.device).unsqueeze(0)
96
  original_base = self.base_transform(img).cpu()
97
 
98
+ if model_config["type"] == "local":
99
+ vae = model_config["vae"]
100
+ with torch.no_grad():
101
+ encoded = vae.encode(img_transformed).latent_dist.sample()
102
+ decoded = vae.decode(encoded).sample
103
+ elif model_config["type"] == "remote":
104
+ local_vae = self.vae_models[model_config["local_vae_key"]]["vae"]
105
+ with torch.no_grad():
106
+ encoded = local_vae.encode(img_transformed).latent_dist.sample()
107
+ decoded = remote_decode(
108
+ endpoint=model_config["endpoint"],
109
+ tensor=encoded,
110
+ do_scaling=False,
111
+ output_type="pt",
112
+ return_type="pt",
113
+ partial_postprocess=False,
114
+ )
115
  decoded_transformed = self.output_transform(decoded.squeeze(0)).cpu()
116
  reconstructed = decoded_transformed.clip(0, 1)
 
117
  diff = (original_base - reconstructed).abs()
118
  bw_diff = (diff > tolerance).any(dim=0).float()
 
119
  diff_image = transforms.ToPILImage()(bw_diff)
120
  recon_image = transforms.ToPILImage()(reconstructed)
121
  diff_score = bw_diff.sum().item()
 
122
  return diff_image, recon_image, diff_score
123
 
124
+ def process_all_models(self, img: torch.Tensor, tolerance: float):
125
+ """Process image through all configured VAEs"""
 
 
126
  results = {}
127
+ for name, model_config in self.vae_models.items():
128
+ diff_img, recon_img, score = self.process_image(img, model_config, tolerance)
129
  results[name] = (diff_img, recon_img, score)
130
  return results
131
 
132
+ @spaces.GPU(duration=15)
 
133
  def test_all_vaes(image_path: str, tolerance: float, img_size: int):
134
  """Gradio interface function to test all VAEs"""
 
135
  tester = VAETester(img_size=img_size)
136
  try:
137
  img_tensor = read_image(image_path)
 
148
  scores.append(f"{name:<25}: {score:,.0f}")
149
 
150
  return diff_images, recon_images, "\n".join(scores)
 
151
  except Exception as e:
152
  error_msg = f"Error: {str(e)}"
153
  return [None], [None], error_msg
154
 
155
  examples = [f"examples/{img_filename}" for img_filename in sorted(os.listdir("examples/"))]
156
 
 
157
  with gr.Blocks(title="VAE Performance Tester", css=".monospace-text {font-family: 'Courier New', Courier, monospace;}") as demo:
158
  gr.Markdown("# VAE Comparison Tool")
159
  gr.Markdown("""
160
+ Upload an image or select an example to compare how different VAEs reconstruct it. Now includes remote VAEs via Hugging Face's remote decoding feature!
161
+ 1. The image is padded to a square and resized to the selected size (512 or 1024 pixels).
162
+ 2. Each VAE (local or remote) encodes the image into a latent space and decodes it back.
163
+ 3. Outputs include:
164
+ - **Difference Maps**: Where reconstruction differs from the original (white = difference > tolerance).
165
+ - **Reconstructed Images**: Outputs from each VAE.
166
+ - **Sum of Differences**: Total pixels exceeding tolerance (lower is better).
167
+ Adjust tolerance to change sensitivity.
168
  """)
169
 
170
  with gr.Row():
 
176
  value=0.1,
177
  step=0.01,
178
  label="Difference Tolerance",
179
+ info="Low (0.01): Sensitive to small changes. High (0.5): Only large changes flagged."
 
 
 
 
180
  )
181
+ img_size = gr.Dropdown(label="Image Size", choices=[512, 1024], value=512)
182
  submit_btn = gr.Button("Test All VAEs")
183
 
184
  with gr.Column(scale=3):
185
  with gr.Row():
186
  diff_gallery = gr.Gallery(label="Difference Maps", columns=4, height=512)
187
  recon_gallery = gr.Gallery(label="Reconstructed Images", columns=4, height=512)
188
+ scores_output = gr.Textbox(label="Sum of differences (lower is better)", lines=9, elem_classes="monospace-text")
189
 
190
  if examples:
191
  with gr.Row():
192
+ gr.Examples(examples=examples, inputs=image_input, label="Example Images")
 
 
 
 
193
 
194
  submit_btn.click(
195
  fn=test_all_vaes,
 
199
 
200
  if __name__ == "__main__":
201
  demo.launch()