flux-dev-flax / flux /util.py
lnyan's picture
Update
d4607d7
import os
from dataclasses import dataclass
import numpy as np
import jax
from jax import Array as Tensor
import jax.numpy as jnp
from flax import nnx
import torch
from einops import rearrange
from huggingface_hub import hf_hub_download
from imwatermark import WatermarkEncoder
from safetensors.torch import load_file as load_sft
from flux.model import Flux, FluxParams
from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
from flux.modules.conditioner import HFEmbedder
@dataclass
class ModelSpec:
params: FluxParams
ae_params: AutoEncoderParams
ckpt_path: str | None
ae_path: str | None
repo_id: str | None
repo_flow: str | None
repo_ae: str | None
configs = {
"flux-dev": ModelSpec(
repo_id="black-forest-labs/FLUX.1-dev",
repo_flow="flux1-dev.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_DEV"),
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
"flux-schnell": ModelSpec(
repo_id="black-forest-labs/FLUX.1-schnell",
repo_flow="flux1-schnell.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_SCHNELL"),
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=False,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
}
try:
import ml_dtypes
from_torch_bf16 = lambda x: jnp.asarray(x.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16))
except:
from_torch_bf16 = lambda x: jnp.asarray(x.float().numpy()).astype(jnp.bfloat16)
def load_from_torch(graph, state, state_dict:dict):
cnt=0
torch_cnt=0
flax_cnt=0
val_cnt=0
print(f"Torch states: #{len(state_dict)}; Flax states: #{len(state.flat_state())}")
def convert_to_jax(tensor):
if tensor.dtype==torch.bfloat16:
return from_torch_bf16(tensor)
else:
return jnp.asarray(tensor.numpy())
for key in sorted(state_dict.keys()):
ptr=state
node=graph
torch_cnt+=1
# print(key)
try:
for loc in key.split(".")[:-1]:
if loc.isnumeric():
if "layers" in ptr:
ptr=ptr["layers"]
node=node.subgraphs["layers"]
loc=int(loc)
ptr=ptr[loc]
node=node.subgraphs[loc]
last=key.split(".")[-1]
if last not in ptr._mapping.keys():
ptr_keys=list(ptr._mapping.keys())
ptr_keys=list(filter(lambda x:x!="bias", ptr_keys))
if len(ptr_keys)==1:
ptr_key=ptr_keys[0]
elif last=="weight" and "kernel" in ptr_keys:
ptr_key="kernel"
else:
cnt+=1
raise Exception(f"Mismatched: {key}: {ptr_keys} ")
val=ptr[ptr_key].value
# assert state_dict[key].shape==val.shape, f"[{node.type}]mismatched {state_dict[key].shape} {val.shape}"
else:
if isinstance(ptr[last], jax.Array):
val=ptr[last]
else:
val=ptr[last].value
ptr_key=last
assert state_dict[key].shape==val.shape, f"{key} mismatched"
if isinstance(ptr[ptr_key], jax.Array):
assert state_dict[key].shape==val.shape, f"Array: [{node.type}]mismatched {state_dict[key].shape} {val.shape}"
kernel=convert_to_jax(state_dict[key])
val_cnt+=1
continue
elif ptr_key=="bias":
assert state_dict[key].shape==val.shape, f"Bias: [{node.type}]mismatched {state_dict[key].shape} {val.shape}"
kernel=nnx.Param(convert_to_jax(state_dict[key])).to_state()
else:
# print(node.type,node.attributes, )
# print(type(ptr._mapping[ptr_key]))
if 'kernel_size' in node.attributes:
kernel=convert_to_jax(state_dict[key])
# print(len(kernel.shape))
# print(kernel.shape)
if len(kernel.shape)==3:
kernel=jnp.transpose(kernel, (2, 1, 0))
elif len(kernel.shape)==4:
kernel=jnp.transpose(kernel, (2, 3, 1, 0))
elif len(kernel.shape)==5:
kernel=jnp.transpose(kernel, (2, 3, 4, 1, 0))
elif 'dot_general' in node.attributes:
kernel=convert_to_jax(state_dict[key])
kernel=jnp.transpose(kernel, (1, 0))
else:
# val=ptr[ptr_key].value
kernel=convert_to_jax(state_dict[key])
assert val.shape==kernel.shape, f"[{node.type}]mismatched {val.shape} {kernel.shape}"
kernel=nnx.Param(kernel).to_state()
# print("new", len(kernel.value.shape), type(kernel))
ptr._mapping[ptr_key]=kernel
flax_cnt+=1
except Exception as e:
print(e, f"{key}")
print(cnt, torch_cnt, flax_cnt, val_cnt)
# print(len(state.flat_state()))
return state
def load_state_dict(model, state_dict):
graph,state=nnx.split(model)
state=load_from_torch(graph, state, state_dict)
nnx.update(model, state)
return model
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
if len(missing) > 0 and len(unexpected) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
print("\n" + "-" * 79 + "\n")
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
elif len(missing) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
elif len(unexpected) > 0:
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
def patch_dtype(model,dtype,patch_param=False):
for path, module in model.iter_modules():
if hasattr(module, "dtype") and (module.dtype is None or jnp.issubdtype(module.dtype, jnp.floating)):
module.dtype=dtype
if patch_param:
if hasattr(module, "param_dtype") and jnp.issubdtype(module.param_dtype, jnp.floating):
module.param_dtype=dtype
if not patch_param:
return model
for path, parent in nnx.iter_graph(model):
if isinstance(parent, nnx.Module):
for name, value in vars(parent).items():
if isinstance(value, nnx.Variable) and value.value is None:
pass
# print(name)
elif isinstance(value, nnx.Variable):
if jnp.issubdtype(value.value.dtype, jnp.floating):
value.value = value.value.astype(dtype)
# print(name,value.value.dtype,value.dtype)
elif isinstance(value,jax.Array):
# print(name,value.dtype)
# print(parent.__getattribute__(name).dtype)
if jnp.issubdtype(value.dtype, jnp.floating):
parent.__setattr__(name,value.astype(dtype))
return model
def load_flow_model(name: str, device: str = "none", hf_download: bool = True):
# Loading Flux
print("Init model")
ckpt_path = configs[name].ckpt_path
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_flow is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
# with torch.device("meta" if ckpt_path is not None else device):
model = Flux(configs[name].params, dtype=jnp.bfloat16, rngs=nnx.Rngs(0))
model = patch_dtype(model, jnp.bfloat16)
if ckpt_path is not None:
print("Loading checkpoint")
# load_sft doesn't support torch.device
sd = load_sft(ckpt_path, device="cpu")
# TODO: loading state_dict
model = load_state_dict(model, sd)
# missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
# print_load_warning(missing, unexpected)
return model
def load_t5(device: str = "none", max_length: int = 512) -> HFEmbedder:
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
return HFEmbedder("lnyan/t5-v1_1-xxl-encoder", max_length=max_length, dtype=jnp.bfloat16)
def load_clip(device: str = "none") -> HFEmbedder:
return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, dtype=jnp.bfloat16)
def load_ae(name: str, device: str = "none", hf_download: bool = True) -> AutoEncoder:
ckpt_path = configs[name].ae_path
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_ae is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
# Loading the autoencoder
print("Init AE")
# with torch.device("meta" if ckpt_path is not None else device):
ae = AutoEncoder(configs[name].ae_params, dtype=jnp.bfloat16, rngs=nnx.Rngs(0))
ae = patch_dtype(ae, jnp.bfloat16)
if ckpt_path is not None:
sd = load_sft(ckpt_path, device="cpu")
# TODO: loading state_dict
ae = load_state_dict(ae, sd)
# missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
# print_load_warning(missing, unexpected)
return ae
class WatermarkEmbedder:
def __init__(self, watermark):
self.watermark = watermark
self.num_bits = len(WATERMARK_BITS)
self.encoder = WatermarkEncoder()
self.encoder.set_watermark("bits", self.watermark)
def __call__(self, image: Tensor) -> Tensor:
"""
Adds a predefined watermark to the input image
Args:
image: ([N,] B, RGB, H, W) in range [-1, 1]
Returns:
same as input but watermarked
"""
image = 0.5 * image + 0.5
squeeze = len(image.shape) == 4
if squeeze:
image = image[None, ...]
n = image.shape[0]
# image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
image_np = np.array(rearrange((255 * image), "n b h w c -> (n b) h w c"))[:, :, :, ::-1]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
# watermarking libary expects input as cv2 BGR format
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
# image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
# image.device
# )
image = jnp.asarray(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b h w c", n=n))
# image = torch.clamp(image / 255, min=0.0, max=1.0)
image = jnp.clip(image / 255, min=0.0, max=1.0)
if squeeze:
image = image[0]
image = 2 * image - 1
return image
# A fixed 48-bit message that was chosen at random
WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)