Spaces:
Runtime error
Runtime error
import abc | |
import torch | |
from torch import inference_mode | |
from tqdm import tqdm | |
""" | |
Inversion code taken from: | |
1. The official implementation of Edit-Friendly DDPM Inversion: https://github.com/inbarhub/DDPM_inversion | |
2. The LEDITS demo: https://huggingface.co./spaces/editing-images/ledits/tree/main | |
""" | |
LOW_RESOURCE = True | |
def invert(x0, pipe, prompt_src="", num_diffusion_steps=100, cfg_scale_src=3.5, eta=1): | |
# inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf, | |
# based on the code in https://github.com/inbarhub/DDPM_inversion | |
# returns wt, zs, wts: | |
# wt - inverted latent | |
# wts - intermediate inverted latents | |
# zs - noise maps | |
pipe.scheduler.set_timesteps(num_diffusion_steps) | |
with inference_mode(): | |
w0 = (pipe.vae.encode(x0).latent_dist.mode() * 0.18215).float() | |
wt, zs, wts = inversion_forward_process(pipe, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src, | |
prog_bar=True, num_inference_steps=num_diffusion_steps) | |
return zs, wts | |
def inversion_forward_process(model, x0, | |
etas=None, | |
prog_bar=False, | |
prompt="", | |
cfg_scale=3.5, | |
num_inference_steps=50, eps=None | |
): | |
if not prompt == "": | |
text_embeddings = encode_text(model, prompt) | |
uncond_embedding = encode_text(model, "") | |
timesteps = model.scheduler.timesteps.to(model.device) | |
variance_noise_shape = ( | |
num_inference_steps, | |
model.unet.in_channels, | |
model.unet.sample_size, | |
model.unet.sample_size) | |
if etas is None or (type(etas) in [int, float] and etas == 0): | |
eta_is_zero = True | |
zs = None | |
else: | |
eta_is_zero = False | |
if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps | |
xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps) | |
alpha_bar = model.scheduler.alphas_cumprod | |
zs = torch.zeros(size=variance_noise_shape, device=model.device) | |
t_to_idx = {int(v): k for k, v in enumerate(timesteps)} | |
xt = x0 | |
op = tqdm(reversed(timesteps)) if prog_bar else reversed(timesteps) | |
for t in op: | |
idx = t_to_idx[int(t)] | |
# 1. predict noise residual | |
if not eta_is_zero: | |
xt = xts[idx][None] | |
with torch.no_grad(): | |
out = model.unet.forward(xt, timestep=t, encoder_hidden_states=uncond_embedding) | |
if not prompt == "": | |
cond_out = model.unet.forward(xt, timestep=t, encoder_hidden_states=text_embeddings) | |
if not prompt == "": | |
## classifier free guidance | |
noise_pred = out.sample + cfg_scale * (cond_out.sample - out.sample) | |
else: | |
noise_pred = out.sample | |
if eta_is_zero: | |
# 2. compute more noisy image and set x_t -> x_t+1 | |
xt = forward_step(model, noise_pred, t, xt) | |
else: | |
xtm1 = xts[idx + 1][None] | |
# pred of x0 | |
pred_original_sample = (xt - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5 | |
# direction to xt | |
prev_timestep = t - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps | |
alpha_prod_t_prev = model.scheduler.alphas_cumprod[ | |
prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod | |
variance = get_variance(model, t) | |
pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred | |
mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction | |
z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5) | |
zs[idx] = z | |
# correction to avoid error accumulation | |
xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z | |
xts[idx + 1] = xtm1 | |
if not zs is None: | |
zs[-1] = torch.zeros_like(zs[-1]) | |
return xt, zs, xts | |
def encode_text(model, prompts): | |
text_input = model.tokenizer( | |
prompts, | |
padding="max_length", | |
max_length=model.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
with torch.no_grad(): | |
text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0] | |
return text_encoding | |
def sample_xts_from_x0(model, x0, num_inference_steps=50): | |
""" | |
Samples from P(x_1:T|x_0) | |
""" | |
# torch.manual_seed(43256465436) | |
alpha_bar = model.scheduler.alphas_cumprod | |
sqrt_one_minus_alpha_bar = (1 - alpha_bar) ** 0.5 | |
alphas = model.scheduler.alphas | |
betas = 1 - alphas | |
variance_noise_shape = ( | |
num_inference_steps, | |
model.unet.in_channels, | |
model.unet.sample_size, | |
model.unet.sample_size) | |
timesteps = model.scheduler.timesteps.to(model.device) | |
t_to_idx = {int(v): k for k, v in enumerate(timesteps)} | |
xts = torch.zeros(variance_noise_shape).to(x0.device) | |
for t in reversed(timesteps): | |
idx = t_to_idx[int(t)] | |
xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t] | |
xts = torch.cat([xts, x0], dim=0) | |
return xts | |
def forward_step(model, model_output, timestep, sample): | |
next_timestep = min(model.scheduler.config.num_train_timesteps - 2, | |
timestep + model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps) | |
# 2. compute alphas, betas | |
alpha_prod_t = model.scheduler.alphas_cumprod[timestep] | |
beta_prod_t = 1 - alpha_prod_t | |
# 3. compute predicted original sample from predicted noise also called | |
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | |
next_sample = model.scheduler.add_noise(pred_original_sample, | |
model_output, | |
torch.LongTensor([next_timestep])) | |
return next_sample | |
def get_variance(model, timestep): | |
prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps | |
alpha_prod_t = model.scheduler.alphas_cumprod[timestep] | |
alpha_prod_t_prev = model.scheduler.alphas_cumprod[ | |
prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod | |
beta_prod_t = 1 - alpha_prod_t | |
beta_prod_t_prev = 1 - alpha_prod_t_prev | |
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) | |
return variance | |
class AttentionControl(abc.ABC): | |
def step_callback(self, x_t): | |
return x_t | |
def between_steps(self): | |
return | |
def num_uncond_att_layers(self): | |
return self.num_att_layers if LOW_RESOURCE else 0 | |
def forward(self, attn, is_cross: bool, place_in_unet: str): | |
raise NotImplementedError | |
def __call__(self, attn, is_cross: bool, place_in_unet: str): | |
if self.cur_att_layer >= self.num_uncond_att_layers: | |
if LOW_RESOURCE: | |
attn = self.forward(attn, is_cross, place_in_unet) | |
else: | |
h = attn.shape[0] | |
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) | |
self.cur_att_layer += 1 | |
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: | |
self.cur_att_layer = 0 | |
self.cur_step += 1 | |
self.between_steps() | |
return attn | |
def reset(self): | |
self.cur_step = 0 | |
self.cur_att_layer = 0 | |
def __init__(self): | |
self.cur_step = 0 | |
self.num_att_layers = -1 | |
self.cur_att_layer = 0 | |
class AttentionStore(AttentionControl): | |
def get_empty_store(): | |
return {"down_cross": [], "mid_cross": [], "up_cross": [], | |
"down_self": [], "mid_self": [], "up_self": []} | |
def forward(self, attn, is_cross: bool, place_in_unet: str): | |
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" | |
if attn.shape[1] <= 32 ** 2: # avoid memory overhead | |
self.step_store[key].append(attn) | |
return attn | |
def between_steps(self): | |
if len(self.attention_store) == 0: | |
self.attention_store = self.step_store | |
else: | |
for key in self.attention_store: | |
for i in range(len(self.attention_store[key])): | |
self.attention_store[key][i] += self.step_store[key][i] | |
self.step_store = self.get_empty_store() | |
def get_average_attention(self): | |
average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in | |
self.attention_store} | |
return average_attention | |
def reset(self): | |
super(AttentionStore, self).reset() | |
self.step_store = self.get_empty_store() | |
self.attention_store = {} | |
def __init__(self): | |
super(AttentionStore, self).__init__() | |
self.step_store = self.get_empty_store() | |
self.attention_store = {} | |
def register_attention_control(model, controller): | |
def ca_forward(self, place_in_unet): | |
to_out = self.to_out | |
if type(to_out) is torch.nn.modules.container.ModuleList: | |
to_out = self.to_out[0] | |
else: | |
to_out = self.to_out | |
def forward(x, context=None, mask=None): | |
batch_size, sequence_length, dim = x.shape | |
h = self.heads | |
q = self.to_q(x) | |
is_cross = context is not None | |
context = context if is_cross else x | |
k = self.to_k(context) | |
v = self.to_v(context) | |
q = self.reshape_heads_to_batch_dim(q) | |
k = self.reshape_heads_to_batch_dim(k) | |
v = self.reshape_heads_to_batch_dim(v) | |
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale | |
if mask is not None: | |
mask = mask.reshape(batch_size, -1) | |
max_neg_value = -torch.finfo(sim.dtype).max | |
mask = mask[:, None, :].repeat(h, 1, 1) | |
sim.masked_fill_(~mask, max_neg_value) | |
# attention, what we cannot get enough of | |
attn = sim.softmax(dim=-1) | |
attn = controller(attn, is_cross, place_in_unet) | |
out = torch.einsum("b i j, b j d -> b i d", attn, v) | |
out = self.reshape_batch_dim_to_heads(out) | |
return to_out(out) | |
return forward | |
class DummyController: | |
def __call__(self, *args): | |
return args[0] | |
def __init__(self): | |
self.num_att_layers = 0 | |
if controller is None: | |
controller = DummyController() | |
def register_recr(net_, count, place_in_unet): | |
if net_.__class__.__name__ == 'CrossAttention': | |
net_.forward = ca_forward(net_, place_in_unet) | |
return count + 1 | |
elif hasattr(net_, 'children'): | |
for net__ in net_.children(): | |
count = register_recr(net__, count, place_in_unet) | |
return count | |
cross_att_count = 0 | |
sub_nets = model.unet.named_children() | |
for net in sub_nets: | |
if "down" in net[0]: | |
cross_att_count += register_recr(net[1], 0, "down") | |
elif "up" in net[0]: | |
cross_att_count += register_recr(net[1], 0, "up") | |
elif "mid" in net[0]: | |
cross_att_count += register_recr(net[1], 0, "mid") | |
controller.num_att_layers = cross_att_count | |