flux-latent-preview / README.md
twodgirl's picture
Update README.md
fb5dd14 verified
|
raw
history blame
2.68 kB
metadata
license: apache-2.0
tags:
  - text-to-image
  - flux
datasets:
  - DucHaiten/pony-art
  - jordandavis/fashion_num_people
  - mattmdjaga/human_parsing_dataset
  - Voxel51/Describable-Textures-Dataset
  - twodgirl/vndb

Flux Latent Preview at Half-Size

The decoder provides a preview image; such thing already exists in the wild for the Flux Dev model.

Max supported resolution is between 768 and 1024px.

Retraining the text encoder and the VAE decoder has reduced the checkpoint size by around 10GB. This set the model's capabilities back by two years.

Inference

from diffusers import AutoencoderKL, FluxPipeline
from safetensors.torch import load_model
from tea_model import TeaDecoder
import torch
from torchvision import transforms

def preview_image(latents, pipe):
    latents = FluxPipeline._unpack_latents(latents,
                                           pipe.default_sample_size * pipe.vae_scale_factor,
                                           pipe.default_sample_size * pipe.vae_scale_factor,
                                           pipe.vae_scale_factor)
    tea = TeaDecoder(ch_in=16)
    load_model(tea, './vae_decoder.safetensors')
    tea = tea.to(device='cuda')
    output = tea(latents.to(torch.float32)) / 2.0 + 0.5
    preview = transforms.ToPILImage()(output[0].clamp(0, 1))

    return preview

def full_size_image(latents, pipe):
    latents = FluxPipeline._unpack_latents(latents,
                                           pipe.default_sample_size * pipe.vae_scale_factor,
                                           pipe.default_sample_size * pipe.vae_scale_factor,
                                           pipe.vae_scale_factor)
    latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
    latents = latents.to(dtype=pipe.vae.dtype)
    torch.cuda.empty_cache()
    pipe.vae = pipe.vae.to(device='cuda')
    pixel_values, = pipe.vae.decode(latents, return_dict=False)
    images = pipe.image_processor.postprocess(pixel_values.to('cpu'), output_type='pil')

    return images

if __name__ == '__main__':
    pipe = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev')
    latents = pipe('cat playing piano', num_inference_steps=10, output_type='latent').images
    # Return the upscaled and preview image.
    upscaled = full_size_image(latents, pipe)
    preview = preview_image(latents, pipe)
    preview.save('cat.png')

Disclaimer

Use of this code and the copy of documentation requires citation and attribution to the author via a link to their Hugging Face profile in all resulting work.