Spaces:
Runtime error
Runtime error
import math | |
from typing import Callable | |
from einops import rearrange, repeat | |
import jax | |
import jax.numpy as jnp | |
from jax import Array as Tensor | |
from flax import nnx | |
from flux.model import Flux | |
from flux.modules.conditioner import HFEmbedder | |
def get_noise( | |
num_samples: int, | |
height: int, | |
width: int, | |
device, | |
dtype: jnp.dtype, | |
seed: int, | |
): | |
# return torch.randn( | |
# num_samples, | |
# 16, | |
# # allow for packing | |
# 2 * math.ceil(height / 16), | |
# 2 * math.ceil(width / 16), | |
# device=device, | |
# dtype=dtype, | |
# generator=torch.Generator(device=device).manual_seed(seed), | |
# ) | |
# rngs = nnx.Rngs(seed) | |
key = jax.random.key(seed) | |
return jax.random.normal( | |
# rngs(), | |
key, | |
( | |
num_samples, | |
2 * math.ceil(height / 16), | |
2 * math.ceil(width / 16), | |
16, | |
), | |
dtype=dtype | |
) | |
def prepare_tokens(t5: HFEmbedder, clip: HFEmbedder, prompt: str | list[str]) -> tuple[Tensor, Tensor]: | |
if isinstance(prompt, str): | |
prompt = [prompt] | |
t5_tokens = t5.tokenize(prompt) | |
clip_tokens = clip.tokenize(prompt) | |
return t5_tokens, clip_tokens | |
# return { | |
# "t5": t5_tokens, | |
# "clip": clip_tokens, | |
# } | |
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, t5_tokens: Tensor, clip_tokens: Tensor) -> dict[str, Tensor]: | |
# bs, c, h, w = img.shape | |
bs, h, w, c = img.shape | |
if bs == 1: | |
bs = t5_tokens.shape[0] | |
# img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
img = rearrange(img, "b (h ph) (w pw) c -> b (h w) (c ph pw)", ph=2, pw=2) | |
if img.shape[0] == 1 and bs > 1: | |
img = repeat(img, "1 ... -> bs ...", bs=bs) | |
# img_ids = torch.zeros(h // 2, w // 2, 3) | |
img_ids = jnp.zeros((h // 2, w // 2, 3)) | |
# img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] | |
# img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] | |
img_ids = img_ids.at[..., 1].set(img_ids[..., 1]+jnp.arange(h // 2)[:, None]) | |
img_ids = img_ids.at[..., 2].set(img_ids[..., 2]+jnp.arange(w // 2)[None, :]) | |
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) | |
# if isinstance(prompt, str): | |
# prompt = [prompt] | |
txt = t5(t5_tokens) | |
if txt.shape[0] == 1 and bs > 1: | |
txt = repeat(txt, "1 ... -> bs ...", bs=bs) | |
# txt_ids = torch.zeros(bs, txt.shape[1], 3) | |
txt_ids = jnp.zeros((bs, txt.shape[1], 3)) | |
vec = clip(clip_tokens) | |
if vec.shape[0] == 1 and bs > 1: | |
vec = repeat(vec, "1 ... -> bs ...", bs=bs) | |
# return { | |
# "img": img, | |
# "img_ids": img_ids.to(img.device), | |
# "txt": txt.to(img.device), | |
# "txt_ids": txt_ids.to(img.device), | |
# "vec": vec.to(img.device), | |
# } | |
return { | |
"img": img, | |
"img_ids": img_ids, | |
"txt": txt, | |
"txt_ids": txt_ids, | |
"vec": vec, | |
} | |
def time_shift(mu: float, sigma: float, t: Tensor): | |
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) | |
# return jnp.exp(mu) / (jnp.exp(mu) + (1 / t - 1) ** sigma) | |
def get_lin_function( | |
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 | |
) -> Callable[[float], float]: | |
m = (y2 - y1) / (x2 - x1) | |
b = y1 - m * x1 | |
return lambda x: m * x + b | |
def get_schedule( | |
num_steps: int, | |
image_seq_len: int, | |
base_shift: float = 0.5, | |
max_shift: float = 1.15, | |
shift: bool = True, | |
) -> Tensor: | |
# extra step for zero | |
# timesteps = torch.linspace(1, 0, num_steps + 1) | |
timesteps = jnp.linspace(1, 0, num_steps + 1) | |
# shifting the schedule to favor high timesteps for higher signal images | |
if shift: | |
# estimate mu based on linear estimation between two points | |
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) | |
timesteps = time_shift(mu, 1.0, timesteps) | |
return timesteps#.tolist() | |
DEBUG=False | |
def denoise_for( | |
model: Flux, | |
# model input | |
img: Tensor, | |
img_ids: Tensor, | |
txt: Tensor, | |
txt_ids: Tensor, | |
vec: Tensor, | |
# sampling parameters | |
timesteps: Tensor, | |
guidance: float = 4.0, | |
): | |
# this is ignored for schnell | |
# guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) | |
guidance_vec = jnp.full((img.shape[0],), guidance, dtype=img.dtype) | |
timesteps = timesteps.tolist() | |
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): | |
# t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) | |
t_vec = jnp.full((img.shape[0],), t_curr, dtype=img.dtype) | |
pred = model( | |
img=img, | |
img_ids=img_ids, | |
txt=txt, | |
txt_ids=txt_ids, | |
y=vec, | |
timesteps=t_vec, | |
guidance=guidance_vec, | |
) | |
img = img + (t_prev - t_curr) * pred | |
return img | |
# @nnx.jit | |
def denoise( | |
model: Flux, | |
# model input | |
img: Tensor, | |
img_ids: Tensor, | |
txt: Tensor, | |
txt_ids: Tensor, | |
vec: Tensor, | |
# sampling parameters | |
timesteps: Tensor, | |
guidance: float = 4.0, | |
): | |
# this is ignored for schnell | |
# guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) | |
guidance_vec = jnp.full((img.shape[0],), guidance, dtype=img.dtype) | |
def scan_func(acc, t_prev): | |
img, t_curr = acc | |
dtype = img.dtype | |
t_vec = jnp.full((img.shape[0],), t_curr, dtype=img.dtype) | |
pred = model( | |
img=img, | |
img_ids=img_ids, | |
txt=txt, | |
txt_ids=txt_ids, | |
y=vec, | |
timesteps=t_vec, | |
guidance=guidance_vec, | |
) | |
img = img + (t_prev - t_curr) * pred | |
return (img.astype(dtype), t_prev), pred | |
acc,pred=scan_func((img, timesteps[0]), timesteps[1:]) | |
return acc[0] | |
def unpack(x: Tensor, height: int, width: int) -> Tensor: | |
# return rearrange( | |
# x, | |
# "b (h w) (c ph pw) -> b c (h ph) (w pw)", | |
# h=math.ceil(height / 16), | |
# w=math.ceil(width / 16), | |
# ph=2, | |
# pw=2, | |
# ) | |
return rearrange( | |
x, | |
"b (h w) (c ph pw) -> b (h ph) (w pw) c", | |
h=math.ceil(height / 16), | |
w=math.ceil(width / 16), | |
ph=2, | |
pw=2, | |
) | |