flux-dev-flax / flux /sampling.py
lnyan's picture
Update
d4607d7
import math
from typing import Callable
from einops import rearrange, repeat
import jax
import jax.numpy as jnp
from jax import Array as Tensor
from flax import nnx
from flux.model import Flux
from flux.modules.conditioner import HFEmbedder
def get_noise(
num_samples: int,
height: int,
width: int,
device,
dtype: jnp.dtype,
seed: int,
):
# return torch.randn(
# num_samples,
# 16,
# # allow for packing
# 2 * math.ceil(height / 16),
# 2 * math.ceil(width / 16),
# device=device,
# dtype=dtype,
# generator=torch.Generator(device=device).manual_seed(seed),
# )
# rngs = nnx.Rngs(seed)
key = jax.random.key(seed)
return jax.random.normal(
# rngs(),
key,
(
num_samples,
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
16,
),
dtype=dtype
)
def prepare_tokens(t5: HFEmbedder, clip: HFEmbedder, prompt: str | list[str]) -> tuple[Tensor, Tensor]:
if isinstance(prompt, str):
prompt = [prompt]
t5_tokens = t5.tokenize(prompt)
clip_tokens = clip.tokenize(prompt)
return t5_tokens, clip_tokens
# return {
# "t5": t5_tokens,
# "clip": clip_tokens,
# }
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, t5_tokens: Tensor, clip_tokens: Tensor) -> dict[str, Tensor]:
# bs, c, h, w = img.shape
bs, h, w, c = img.shape
if bs == 1:
bs = t5_tokens.shape[0]
# img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
img = rearrange(img, "b (h ph) (w pw) c -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
# img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids = jnp.zeros((h // 2, w // 2, 3))
# img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
# img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = img_ids.at[..., 1].set(img_ids[..., 1]+jnp.arange(h // 2)[:, None])
img_ids = img_ids.at[..., 2].set(img_ids[..., 2]+jnp.arange(w // 2)[None, :])
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
# if isinstance(prompt, str):
# prompt = [prompt]
txt = t5(t5_tokens)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
# txt_ids = torch.zeros(bs, txt.shape[1], 3)
txt_ids = jnp.zeros((bs, txt.shape[1], 3))
vec = clip(clip_tokens)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
# return {
# "img": img,
# "img_ids": img_ids.to(img.device),
# "txt": txt.to(img.device),
# "txt_ids": txt_ids.to(img.device),
# "vec": vec.to(img.device),
# }
return {
"img": img,
"img_ids": img_ids,
"txt": txt,
"txt_ids": txt_ids,
"vec": vec,
}
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
# return jnp.exp(mu) / (jnp.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> Tensor:
# extra step for zero
# timesteps = torch.linspace(1, 0, num_steps + 1)
timesteps = jnp.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# estimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps#.tolist()
DEBUG=False
def denoise_for(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
# sampling parameters
timesteps: Tensor,
guidance: float = 4.0,
):
# this is ignored for schnell
# guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
guidance_vec = jnp.full((img.shape[0],), guidance, dtype=img.dtype)
timesteps = timesteps.tolist()
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
# t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
t_vec = jnp.full((img.shape[0],), t_curr, dtype=img.dtype)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
img = img + (t_prev - t_curr) * pred
return img
# @nnx.jit
def denoise(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
# sampling parameters
timesteps: Tensor,
guidance: float = 4.0,
):
# this is ignored for schnell
# guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
guidance_vec = jnp.full((img.shape[0],), guidance, dtype=img.dtype)
@nnx.scan
def scan_func(acc, t_prev):
img, t_curr = acc
dtype = img.dtype
t_vec = jnp.full((img.shape[0],), t_curr, dtype=img.dtype)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
img = img + (t_prev - t_curr) * pred
return (img.astype(dtype), t_prev), pred
acc,pred=scan_func((img, timesteps[0]), timesteps[1:])
return acc[0]
def unpack(x: Tensor, height: int, width: int) -> Tensor:
# return rearrange(
# x,
# "b (h w) (c ph pw) -> b c (h ph) (w pw)",
# h=math.ceil(height / 16),
# w=math.ceil(width / 16),
# ph=2,
# pw=2,
# )
return rearrange(
x,
"b (h w) (c ph pw) -> b (h ph) (w pw) c",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)