quite slow to load the fp8 model
On Nvidia A6000, using the code below to load the fp8
import torch
from diffusers import FluxTransformer2DModel, FluxPipeline
from transformers import T5EncoderModel
from optimum.quanto import freeze, qfloat8, quantize
import time
import json
# Initialize a dictionary to store stats
stats = {}
# Measure the time taken to load and prepare the model into VRAM
start_time = time.time()
bfl_repo = "black-forest-labs/FLUX.1-schnell"
dtype = torch.bfloat16
transformer = FluxTransformer2DModel.from_single_file(
"https://huggingface.co./Kijai/flux-fp8/blob/main/flux1-schnell-fp8.safetensors",
torch_dtype=dtype
)
quantize(transformer, weights=qfloat8)
freeze(transformer)
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2
pipe.to("cuda")
stats['model_loading_time'] = time.time() - start_time
it took around 293s to load the model, why is this so slow to load?
I'm getting the error when running this code:
transformer = FluxTransformer2DModel.from_single_file(...
AttributeError: type object 'FluxTransformer2DModel' has no attribute 'from_single_file'
Which version of diffuser do you have installed?
Thanks!
It's 0.30.0
I install dependencies from this https://github.com/black-forest-labs/flux/blob/main/pyproject.toml#L11-L24
Thanks, I think they just realeased the .from_single_file
Could make it work installing the latest version from the main (I basically just clicked through all steps listed https://huggingface.co./docs/diffusers/installation#install-from-source
885s with a simple laptop 16GB ram (no GPU support since Windows+AMD-GPU)
The quantizing + freezing takes a lot of time.
why have u quantized it again when it is already in fp8?
You could probably just import it like so:
transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co./Kijai/flux-fp8/blob/main/flux1-schnell-fp8.safetensors", torch_dtype=torch.bfloat16
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder_2=None, torch_dtype=torch.bfloat16)
You could probably just import it like so:
transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co./Kijai/flux-fp8/blob/main/flux1-schnell-fp8.safetensors", torch_dtype=dtype
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder_2=None, torch_dtype=dtype)
What dtype though?
You could probably just import it like so:
transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co./Kijai/flux-fp8/blob/main/flux1-schnell-fp8.safetensors", torch_dtype=dtype
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder_2=None, torch_dtype=dtype)What dtype though?
torch.bfloat16
You could probably just import it like so:
transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co./Kijai/flux-fp8/blob/main/flux1-schnell-fp8.safetensors", torch_dtype=dtype
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder_2=None, torch_dtype=dtype)What dtype though?
torch.bfloat16
But its not 16 its 8. I think what you need is load_in_8bit=True
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder_2=None, load_in_8bit=True)
You could probably just import it like so:
transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co./Kijai/flux-fp8/blob/main/flux1-schnell-fp8.safetensors", torch_dtype=dtype
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder_2=None, torch_dtype=dtype)What dtype though?
torch.bfloat16
But its not 16 its 8. I think what you need is load_in_8bit=True
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder_2=None, load_in_8bit=True)
The Diffusers pipeline suggests to import it and then use optimum quanto to quantize the weights and freeze the transformer in qfloat8
https://huggingface.co./docs/diffusers/main/en/api/pipelines/flux#single-file-loading-for-the-fluxtransformer2dmodel
I am experiencing the same issue, it would be really nice if anyone can help. Using the example provided in official docs, https://huggingface.co./docs/diffusers/main/en/api/pipelines/flux#single-file-loading-for-the-fluxtransformer2dmodel