cross-image-attention / utils /ddpm_inversion.py
yuvalalaluf's picture
initial commit
82ef366
import abc
import torch
from torch import inference_mode
from tqdm import tqdm
"""
Inversion code taken from:
1. The official implementation of Edit-Friendly DDPM Inversion: https://github.com/inbarhub/DDPM_inversion
2. The LEDITS demo: https://huggingface.co./spaces/editing-images/ledits/tree/main
"""
LOW_RESOURCE = True
def invert(x0, pipe, prompt_src="", num_diffusion_steps=100, cfg_scale_src=3.5, eta=1):
# inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
# based on the code in https://github.com/inbarhub/DDPM_inversion
# returns wt, zs, wts:
# wt - inverted latent
# wts - intermediate inverted latents
# zs - noise maps
pipe.scheduler.set_timesteps(num_diffusion_steps)
with inference_mode():
w0 = (pipe.vae.encode(x0).latent_dist.mode() * 0.18215).float()
wt, zs, wts = inversion_forward_process(pipe, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src,
prog_bar=True, num_inference_steps=num_diffusion_steps)
return zs, wts
def inversion_forward_process(model, x0,
etas=None,
prog_bar=False,
prompt="",
cfg_scale=3.5,
num_inference_steps=50, eps=None
):
if not prompt == "":
text_embeddings = encode_text(model, prompt)
uncond_embedding = encode_text(model, "")
timesteps = model.scheduler.timesteps.to(model.device)
variance_noise_shape = (
num_inference_steps,
model.unet.in_channels,
model.unet.sample_size,
model.unet.sample_size)
if etas is None or (type(etas) in [int, float] and etas == 0):
eta_is_zero = True
zs = None
else:
eta_is_zero = False
if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps
xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps)
alpha_bar = model.scheduler.alphas_cumprod
zs = torch.zeros(size=variance_noise_shape, device=model.device)
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
xt = x0
op = tqdm(reversed(timesteps)) if prog_bar else reversed(timesteps)
for t in op:
idx = t_to_idx[int(t)]
# 1. predict noise residual
if not eta_is_zero:
xt = xts[idx][None]
with torch.no_grad():
out = model.unet.forward(xt, timestep=t, encoder_hidden_states=uncond_embedding)
if not prompt == "":
cond_out = model.unet.forward(xt, timestep=t, encoder_hidden_states=text_embeddings)
if not prompt == "":
## classifier free guidance
noise_pred = out.sample + cfg_scale * (cond_out.sample - out.sample)
else:
noise_pred = out.sample
if eta_is_zero:
# 2. compute more noisy image and set x_t -> x_t+1
xt = forward_step(model, noise_pred, t, xt)
else:
xtm1 = xts[idx + 1][None]
# pred of x0
pred_original_sample = (xt - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5
# direction to xt
prev_timestep = t - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
alpha_prod_t_prev = model.scheduler.alphas_cumprod[
prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
variance = get_variance(model, t)
pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred
mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5)
zs[idx] = z
# correction to avoid error accumulation
xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z
xts[idx + 1] = xtm1
if not zs is None:
zs[-1] = torch.zeros_like(zs[-1])
return xt, zs, xts
def encode_text(model, prompts):
text_input = model.tokenizer(
prompts,
padding="max_length",
max_length=model.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0]
return text_encoding
def sample_xts_from_x0(model, x0, num_inference_steps=50):
"""
Samples from P(x_1:T|x_0)
"""
# torch.manual_seed(43256465436)
alpha_bar = model.scheduler.alphas_cumprod
sqrt_one_minus_alpha_bar = (1 - alpha_bar) ** 0.5
alphas = model.scheduler.alphas
betas = 1 - alphas
variance_noise_shape = (
num_inference_steps,
model.unet.in_channels,
model.unet.sample_size,
model.unet.sample_size)
timesteps = model.scheduler.timesteps.to(model.device)
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
xts = torch.zeros(variance_noise_shape).to(x0.device)
for t in reversed(timesteps):
idx = t_to_idx[int(t)]
xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
xts = torch.cat([xts, x0], dim=0)
return xts
def forward_step(model, model_output, timestep, sample):
next_timestep = min(model.scheduler.config.num_train_timesteps - 2,
timestep + model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps)
# 2. compute alphas, betas
alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
next_sample = model.scheduler.add_noise(pred_original_sample,
model_output,
torch.LongTensor([next_timestep]))
return next_sample
def get_variance(model, timestep):
prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = model.scheduler.alphas_cumprod[
prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
return variance
class AttentionControl(abc.ABC):
def step_callback(self, x_t):
return x_t
def between_steps(self):
return
@property
def num_uncond_att_layers(self):
return self.num_att_layers if LOW_RESOURCE else 0
@abc.abstractmethod
def forward(self, attn, is_cross: bool, place_in_unet: str):
raise NotImplementedError
def __call__(self, attn, is_cross: bool, place_in_unet: str):
if self.cur_att_layer >= self.num_uncond_att_layers:
if LOW_RESOURCE:
attn = self.forward(attn, is_cross, place_in_unet)
else:
h = attn.shape[0]
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
self.cur_att_layer += 1
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
self.cur_att_layer = 0
self.cur_step += 1
self.between_steps()
return attn
def reset(self):
self.cur_step = 0
self.cur_att_layer = 0
def __init__(self):
self.cur_step = 0
self.num_att_layers = -1
self.cur_att_layer = 0
class AttentionStore(AttentionControl):
@staticmethod
def get_empty_store():
return {"down_cross": [], "mid_cross": [], "up_cross": [],
"down_self": [], "mid_self": [], "up_self": []}
def forward(self, attn, is_cross: bool, place_in_unet: str):
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
if attn.shape[1] <= 32 ** 2: # avoid memory overhead
self.step_store[key].append(attn)
return attn
def between_steps(self):
if len(self.attention_store) == 0:
self.attention_store = self.step_store
else:
for key in self.attention_store:
for i in range(len(self.attention_store[key])):
self.attention_store[key][i] += self.step_store[key][i]
self.step_store = self.get_empty_store()
def get_average_attention(self):
average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in
self.attention_store}
return average_attention
def reset(self):
super(AttentionStore, self).reset()
self.step_store = self.get_empty_store()
self.attention_store = {}
def __init__(self):
super(AttentionStore, self).__init__()
self.step_store = self.get_empty_store()
self.attention_store = {}
def register_attention_control(model, controller):
def ca_forward(self, place_in_unet):
to_out = self.to_out
if type(to_out) is torch.nn.modules.container.ModuleList:
to_out = self.to_out[0]
else:
to_out = self.to_out
def forward(x, context=None, mask=None):
batch_size, sequence_length, dim = x.shape
h = self.heads
q = self.to_q(x)
is_cross = context is not None
context = context if is_cross else x
k = self.to_k(context)
v = self.to_v(context)
q = self.reshape_heads_to_batch_dim(q)
k = self.reshape_heads_to_batch_dim(k)
v = self.reshape_heads_to_batch_dim(v)
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if mask is not None:
mask = mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max
mask = mask[:, None, :].repeat(h, 1, 1)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
attn = controller(attn, is_cross, place_in_unet)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.reshape_batch_dim_to_heads(out)
return to_out(out)
return forward
class DummyController:
def __call__(self, *args):
return args[0]
def __init__(self):
self.num_att_layers = 0
if controller is None:
controller = DummyController()
def register_recr(net_, count, place_in_unet):
if net_.__class__.__name__ == 'CrossAttention':
net_.forward = ca_forward(net_, place_in_unet)
return count + 1
elif hasattr(net_, 'children'):
for net__ in net_.children():
count = register_recr(net__, count, place_in_unet)
return count
cross_att_count = 0
sub_nets = model.unet.named_children()
for net in sub_nets:
if "down" in net[0]:
cross_att_count += register_recr(net[1], 0, "down")
elif "up" in net[0]:
cross_att_count += register_recr(net[1], 0, "up")
elif "mid" in net[0]:
cross_att_count += register_recr(net[1], 0, "mid")
controller.num_att_layers = cross_att_count