Sapir's picture
wip
ebaff66
raw
history blame
1.39 kB
from abc import ABC, abstractmethod
from typing import Tuple
import torch
from diffusers.configuration_utils import ConfigMixin
from einops import rearrange
from torch import Tensor
from txt2img.common.torch_utils import append_dims
from txt2img.config.diffusion_parts import PatchifierConfig, PatchifierName
def pixart_alpha_patchify(
latents: Tensor,
patch_size: int,
) -> Tuple[Tensor, Tensor]:
latents = rearrange(
latents,
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
p1=patch_size[0],
p2=patch_size[1],
p3=patch_size[2],
)
return latents
class SymmetricPatchifier(Patchifier):
def patchify(
self,
latents: Tensor,
) -> Tuple[Tensor, Tensor]:
return pixart_alpha_patchify(latents, self._patch_size)
def unpatchify(
self, latents: Tensor, output_height: int, output_width: int, output_num_frames: int, out_channels: int
) -> Tuple[Tensor, Tensor]:
output_height = output_height // self._patch_size[1]
output_width = output_width // self._patch_size[2]
latents = rearrange(
latents,
"b (f h w) (c p q) -> b c f (h p) (w q) ",
f=output_num_frames,
h=output_height,
w=output_width,
p=self._patch_size[1],
q=self._patch_size[2],
)
return latents