Spaces:
Running
on
Zero
Running
on
Zero
"""Vector quantizer. | |
Copyright (2024) Bytedance Ltd. and/or its affiliates | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
Reference: | |
https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py | |
https://github.com/google-research/magvit/blob/main/videogvt/models/vqvae.py | |
""" | |
from typing import Mapping, Text, Tuple | |
import torch | |
from einops import rearrange | |
from torch.cuda.amp import autocast | |
class VectorQuantizer(torch.nn.Module): | |
def __init__(self, | |
codebook_size: int = 1024, | |
token_size: int = 256, | |
commitment_cost: float = 0.25, | |
use_l2_norm: bool = False, | |
): | |
super().__init__() | |
self.commitment_cost = commitment_cost | |
self.embedding = torch.nn.Embedding(codebook_size, token_size) | |
self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size) | |
self.use_l2_norm = use_l2_norm | |
# Ensure quantization is performed using f32 | |
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: | |
z = z.float() | |
z = rearrange(z, 'b c h w -> b h w c').contiguous() | |
z_flattened = rearrange(z, 'b h w c -> (b h w) c') | |
if self.use_l2_norm: | |
z_flattened = torch.nn.functional.normalize(z_flattened, dim=-1) | |
embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1) | |
else: | |
embedding = self.embedding.weight | |
d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \ | |
torch.sum(embedding**2, dim=1) - 2 * \ | |
torch.einsum('bd,dn->bn', z_flattened, embedding.T) | |
min_encoding_indices = torch.argmin(d, dim=1) # num_ele | |
z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape) | |
if self.use_l2_norm: | |
z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1) | |
z = torch.nn.functional.normalize(z, dim=-1) | |
# compute loss for embedding | |
commitment_loss = self.commitment_cost * torch.mean((z_quantized.detach() - z) **2) | |
codebook_loss = torch.mean((z_quantized - z.detach()) **2) | |
loss = commitment_loss + codebook_loss | |
# preserve gradients | |
z_quantized = z + (z_quantized - z).detach() | |
# reshape back to match original input shape | |
z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous() | |
result_dict = dict( | |
quantizer_loss=loss, | |
commitment_loss=commitment_loss, | |
codebook_loss=codebook_loss, | |
min_encoding_indices=min_encoding_indices.view(z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3]) | |
) | |
return z_quantized, result_dict | |
def get_codebook_entry(self, indices): | |
if len(indices.shape) == 1: | |
z_quantized = self.embedding(indices) | |
elif len(indices.shape) == 2: | |
z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding.weight) | |
else: | |
raise NotImplementedError | |
return z_quantized |