Spaces:
Running
on
A100
Running
on
A100
from abc import ABC, abstractmethod | |
from typing import Tuple | |
import torch | |
from diffusers.configuration_utils import ConfigMixin | |
from einops import rearrange | |
from torch import Tensor | |
from txt2img.common.torch_utils import append_dims | |
from txt2img.config.diffusion_parts import PatchifierConfig, PatchifierName | |
def pixart_alpha_patchify( | |
latents: Tensor, | |
patch_size: int, | |
) -> Tuple[Tensor, Tensor]: | |
latents = rearrange( | |
latents, | |
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", | |
p1=patch_size[0], | |
p2=patch_size[1], | |
p3=patch_size[2], | |
) | |
return latents | |
class SymmetricPatchifier(Patchifier): | |
def patchify( | |
self, | |
latents: Tensor, | |
) -> Tuple[Tensor, Tensor]: | |
return pixart_alpha_patchify(latents, self._patch_size) | |
def unpatchify( | |
self, latents: Tensor, output_height: int, output_width: int, output_num_frames: int, out_channels: int | |
) -> Tuple[Tensor, Tensor]: | |
output_height = output_height // self._patch_size[1] | |
output_width = output_width // self._patch_size[2] | |
latents = rearrange( | |
latents, | |
"b (f h w) (c p q) -> b c f (h p) (w q) ", | |
f=output_num_frames, | |
h=output_height, | |
w=output_width, | |
p=self._patch_size[1], | |
q=self._patch_size[2], | |
) | |
return latents | |