Spaces:
Paused
Paused
# !pip install diffusers["torch"] transformers | |
import hydra | |
import torch | |
import yaml | |
from diffusers import StableDiffusionPipeline | |
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | |
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel | |
import torch.nn.functional as F | |
from PIL import Image, ImageDraw, ImageFont | |
import matplotlib.pyplot as plt | |
import torch.nn as nn | |
import time | |
from accelerate import Accelerator | |
import torchvision.transforms as transforms | |
from torch.utils.tensorboard import SummaryWriter | |
from omegaconf import DictConfig, OmegaConf | |
from datetime import datetime | |
import logging | |
import itertools | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from diffusers import LMSDiscreteScheduler | |
from diffusers.optimization import get_scheduler | |
from torch import autocast | |
from torch.cuda.amp import GradScaler | |
import pdb | |
import math | |
from my_model import unet_2d_condition | |
from typing import Iterable, Optional | |
import os | |
import json | |
import numpy as np | |
import scipy | |
def freeze_params(params): | |
for param in params: | |
param.requires_grad = False | |
def unfreeze_params(params): | |
for param in params: | |
param.requires_grad = True | |
class EMAModel: | |
""" | |
Exponential Moving Average of models weights | |
""" | |
def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): | |
parameters = list(parameters) | |
print("list parameters") | |
self.shadow_params = [p.clone().detach() for p in parameters] | |
print("finish clone parameters") | |
self.decay = decay | |
self.optimization_step = 0 | |
def get_decay(self, optimization_step): | |
""" | |
Compute the decay factor for the exponential moving average. | |
""" | |
value = (1 + optimization_step) / (10 + optimization_step) | |
return 1 - min(self.decay, value) | |
def step(self, parameters): | |
parameters = list(parameters) | |
self.optimization_step += 1 | |
self.decay = self.get_decay(self.optimization_step) | |
for s_param, param in zip(self.shadow_params, parameters): | |
if param.requires_grad: | |
tmp = self.decay * (s_param - param) | |
s_param.sub_(tmp) | |
else: | |
s_param.copy_(param) | |
torch.cuda.empty_cache() | |
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: | |
""" | |
Copy current averaged parameters into given collection of parameters. | |
Args: | |
parameters: Iterable of `torch.nn.Parameter`; the parameters to be | |
updated with the stored moving averages. If `None`, the | |
parameters with which this `ExponentialMovingAverage` was | |
initialized will be used. | |
""" | |
parameters = list(parameters) | |
for s_param, param in zip(self.shadow_params, parameters): | |
param.data.copy_(s_param.data) | |
def to(self, device=None, dtype=None) -> None: | |
r"""c""" | |
# .to() on the tensors handles None correctly | |
self.shadow_params = [ | |
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) | |
for p in self.shadow_params | |
] | |
def compute_visor_loss(attn_maps_mid, attn_maps_up, obj_a_positions, obj_b_positions, relationship): | |
loss = 0 | |
for attn_map_integrated in attn_maps_mid: | |
attn_map = attn_map_integrated.chunk(2)[1] | |
# | |
b, i, j = attn_map.shape | |
H = W = int(math.sqrt(i)) | |
weight_matrix_x = torch.zeros(size=(H, W)).cuda() | |
weight_matrix_y = torch.zeros(size=(H, W)).cuda() | |
for x_indx in range(W): | |
weight_matrix_x[:, x_indx] = x_indx | |
for y_indx in range(H): | |
weight_matrix_y[y_indx, :] = y_indx | |
# for obj_idx in range(object_number): | |
# | |
# bbox = bboxes[obj_idx] | |
obj_a_avg_x_total = 0 | |
obj_a_avg_y_total = 0 | |
for obj_a_position in obj_a_positions: | |
ca_map_obj = attn_map[:, :, obj_a_position].reshape(b, H, W) | |
# pdb.set_trace() | |
obj_a_avg_x = (ca_map_obj * weight_matrix_x.unsqueeze(0)).reshape(b, -1).sum(-1)/ca_map_obj.reshape(b,-1).sum(-1) | |
obj_a_avg_y = (ca_map_obj * weight_matrix_y.unsqueeze(0)).reshape(b, -1).sum(-1)/ca_map_obj.reshape(b,-1).sum(-1) | |
obj_a_avg_x_total += obj_a_avg_x | |
obj_a_avg_y_total += obj_a_avg_y | |
obj_a_avg_x_total = (obj_a_avg_x_total/len(obj_a_positions)).mean() / W | |
obj_a_avg_y_total = (obj_a_avg_y_total/len(obj_a_positions)).mean() / H | |
print('mid: obj_a_avg_x_total', obj_a_avg_x_total) | |
obj_b_avg_x_total = 0 | |
obj_b_avg_y_total = 0 | |
for obj_b_position in obj_b_positions: | |
ca_map_obj = attn_map[:, :, obj_b_position].reshape(b, H, W) | |
obj_b_avg_x = (ca_map_obj * weight_matrix_x.unsqueeze(0)).reshape(b, -1).sum(-1)/ca_map_obj.reshape(b,-1).sum(-1) | |
obj_b_avg_y = (ca_map_obj * weight_matrix_y.unsqueeze(0)).reshape(b, -1).sum(-1)/ca_map_obj.reshape(b,-1).sum(-1) | |
obj_b_avg_x_total += obj_b_avg_x | |
obj_b_avg_y_total += obj_b_avg_y | |
obj_b_avg_x_total = (obj_b_avg_x_total/len(obj_b_positions)).mean() / W | |
obj_b_avg_y_total = (obj_b_avg_y_total/len(obj_b_positions)).mean() / H | |
print('mid: obj_b_avg_x_total', obj_b_avg_x_total) | |
if relationship == 0: | |
loss += (obj_b_avg_x_total - obj_a_avg_x_total) | |
elif relationship == 1: | |
loss += (obj_a_avg_x_total - obj_b_avg_x_total) | |
elif relationship == 2: | |
loss += (obj_b_avg_y_total - obj_a_avg_y_total) | |
elif relationship == 3: | |
loss += (obj_a_avg_y_total - obj_b_avg_y_total) | |
for attn_map_integrated in attn_maps_up[0]: | |
attn_map = attn_map_integrated.chunk(2)[1] | |
b, i, j = attn_map.shape | |
H = W = int(math.sqrt(i)) | |
weight_matrix_x = torch.zeros(size=(H, W)).cuda() | |
weight_matrix_y = torch.zeros(size=(H, W)).cuda() | |
for x_indx in range(W): | |
weight_matrix_x[:, x_indx] = x_indx | |
for y_indx in range(H): | |
weight_matrix_y[y_indx, :] = y_indx | |
# for obj_idx in range(object_number): | |
# | |
# bbox = bboxes[obj_idx] | |
obj_a_avg_x_total = 0 | |
obj_a_avg_y_total = 0 | |
for obj_a_position in obj_a_positions: | |
ca_map_obj = attn_map[:, :, obj_a_position].reshape(b, H, W) | |
obj_a_avg_x = (ca_map_obj * weight_matrix_x.unsqueeze(0)).reshape(b, -1).sum(-1) / ca_map_obj.reshape(b, -1).sum(-1) | |
obj_a_avg_y = (ca_map_obj * weight_matrix_y.unsqueeze(0)).reshape(b, -1).sum(-1) / ca_map_obj.reshape(b, -1).sum(-1) | |
obj_a_avg_x_total += obj_a_avg_x | |
obj_a_avg_y_total += obj_a_avg_y | |
obj_a_avg_x_total = (obj_a_avg_x_total / len(obj_a_positions)).mean() / W | |
obj_a_avg_y_total = (obj_a_avg_y_total / len(obj_a_positions)).mean() / H | |
print('up: obj_a_avg_x_total', obj_a_avg_x_total) | |
obj_b_avg_x_total = 0 | |
obj_b_avg_y_total = 0 | |
for obj_b_position in obj_b_positions: | |
ca_map_obj = attn_map[:, :, obj_b_position].reshape(b, H, W) | |
obj_b_avg_x = (ca_map_obj * weight_matrix_x.unsqueeze(0)).reshape(b, -1).sum(-1) / ca_map_obj.reshape(b, -1).sum(-1) | |
obj_b_avg_y = (ca_map_obj * weight_matrix_y.unsqueeze(0)).reshape(b, -1).sum(-1) / ca_map_obj.reshape(b, -1).sum(-1) | |
obj_b_avg_x_total += obj_b_avg_x | |
obj_b_avg_y_total += obj_b_avg_y | |
obj_b_avg_x_total = (obj_b_avg_x_total / len(obj_b_positions)).mean() / W | |
obj_b_avg_y_total = (obj_b_avg_y_total / len(obj_b_positions)).mean() / H | |
print('up: obj_b_avg_x_total', obj_b_avg_x_total) | |
if relationship == 0: | |
loss += (obj_a_avg_x_total - obj_b_avg_x_total) | |
elif relationship == 1: | |
loss += (obj_b_avg_x_total - obj_a_avg_x_total) | |
elif relationship == 2: | |
loss += (obj_a_avg_y_total - obj_b_avg_y_total) | |
elif relationship == 3: | |
loss += (obj_b_avg_y_total - obj_a_avg_y_total) | |
loss = loss / (len(attn_maps_up[0]) + len(attn_maps_mid)) | |
return loss | |
def train(cfg: DictConfig): | |
# fix the randomness of torch | |
print(cfg) | |
with open('./conf/unet/origin_config.json') as f: | |
unet_config = json.load(f) | |
unet = unet_2d_condition.UNet2DConditionModel(**unet_config) | |
# ckp = torch.load('/Users/shil5883/Downloads/diffusion_pytorch_model.bin', map_location='cpu') | |
# prev_attn_map = torch.load('./attn_map.ckp', map_location='cpu') | |
ckp = torch.load('/work/minghao/chess_gen/diffusion_pytorch_model.bin', map_location='cpu') | |
prev_attn_map = torch.load('/work/minghao/chess_gen/visual_attn/2023-02-02/15-05-51/epoch_100_sche_constant_lr_1e-06_ac_1/attn_map.ckp', map_location='cpu') | |
# prev_attn_map = torch.load('/work/minghao/chess_gen/visual_attn/2023-01-16/18-58-12/epoch_100_sche_constant_lr_1e-06_ac_1/attn_map.ckp', map_location='cpu') | |
unet.load_state_dict(ckp) | |
unet_original = UNet2DConditionModel(**unet_config) | |
unet_original.load_state_dict(ckp) | |
date_now, time_now = datetime.now().strftime("%Y-%m-%d,%H-%M-%S").split(',') | |
# cfg.general.save_path = os.path.join(cfg.general.save_path, date_now, time_now) | |
# if not os.path.exists(cfg.general.save_path ): | |
# os.makedirs(cfg.general.save_path) | |
# cfg.general.save_path | |
mixed_precision = 'fp16' if torch.cuda.is_available() else 'no' | |
accelerator = Accelerator( | |
gradient_accumulation_steps=cfg.training.accumulate_step, | |
mixed_precision=mixed_precision, | |
log_with="tensorboard", | |
logging_dir='./', | |
) | |
# initialize dataset and dataloader | |
if accelerator.is_main_process: | |
print("Loading the dataset!!!!!") | |
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer") | |
# train_dataset = ICLEVERDataset(cfg.data.data_path, tokenizer, cfg, prefix='train') | |
# val_dataset = ICLEVERDataset(cfg.data.data_path, tokenizer, cfg, prefix='val') | |
# train_loader = DataLoader(train_dataset, batch_size=cfg.training.batch_size, shuffle=True, num_workers=2, pin_memory=False) | |
# val_loader = DataLoader(val_dataset, batch_size=cfg.training.batch_size * 2, shuffle=True, num_workers=2, pin_memory=False) | |
if accelerator.is_main_process: | |
print("Complete loading the dataset!!!!!") | |
if accelerator.is_main_process: | |
print("Complete load the noise scheduler!!!!!") | |
with open("config.yaml", "w") as f: | |
OmegaConf.save(cfg, f) | |
if not os.path.exists(cfg.general.save_path) and accelerator.is_main_process: | |
os.makedirs(cfg.general.save_path) | |
if accelerator.is_main_process: | |
print("saved load the noise scheduler!!!!!") | |
# Move unet to device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# load pretrained models and schedular | |
text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder") | |
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae") | |
# boards_embedder.to(device) | |
if accelerator.is_main_process: | |
print("move the model to device!!!!!") | |
# Keep vae and unet in eval model as we don't train these | |
# Initialize the optimizer | |
cfg.training.lr = ( | |
cfg.training.lr * cfg.training.accumulate_step * cfg.training.batch_size * accelerator.num_processes | |
) | |
# Move vae and unet to device | |
vae.to(device) | |
unet.to(device) | |
text_encoder.to(device) | |
# prev_attn_map.to(device) | |
unet_original.to(device) | |
vae.eval() | |
unet.eval() | |
text_encoder.eval() | |
unet_original.eval() | |
# tokenizer.to(device) | |
# if accelerator.is_main_process: | |
print("prepare the accelerator module at process: {}!!!!!".format(accelerator.process_index)) | |
# unet = accelerator.prepare(unet) | |
print("done the accelerator module at process: {}!!!!!".format(accelerator.process_index)) | |
# Create EMA for the unet. | |
# if cfg.training.use_ema: | |
# ema_unet = EMAModel(unet.parameters()) | |
# ema_encoder = EMAModel(boards_embedder.parameters()) | |
ema_unet = None | |
# print(start_ema) | |
if cfg.training.use_ema: | |
if accelerator.is_main_process: | |
print("Using the EMA model!!!!!") | |
print("start EMA at process: {}!!!!!".format(accelerator.process_index)) | |
ema_unet = EMAModel(unet.parameters()) | |
# ema_encoder = EMAModel(boards_embedder.parameters()) | |
# prompt = 'A traffic light below a sink' | |
templates = ['{} to the left of {}', '{} to the right of {}', '{} above {}', '{} below {}'] | |
bboxes_template = [[0.0, 0.0, 0.5, 1.0], [0.0, 0.0, 1.0, 0.5], [0.5, 0.0, 1.0, 1.0], [0.0, 0.5, 1.0, 1.0]] | |
bboxes_template_list = [[0, 2], [2, 0], [1, 3], [3, 1]] | |
iteration_start = cfg.inference.start_pair | |
iteration_now = iteration_start | |
iteration_interval = cfg.inference.iteration_interval | |
with open('./coco_paris.txt', 'r') as f: | |
image_pairs = f.readlines() | |
for image_pair in tqdm(image_pairs[iteration_start: iteration_start + iteration_interval]): | |
obj_a, obj_b = image_pair.strip().split(',')[0], image_pair.strip().split(',')[1] | |
obj_a = 'A {}'.format(obj_a) if obj_a[0] not in ['a', 'e', 'i', 'o', 'u'] else 'An {}'.format(obj_a) | |
obj_b = 'a {}'.format(obj_b) if obj_b[0] not in ['a', 'e', 'i', 'o', 'u'] else 'an {}'.format(obj_b) | |
for idx, template in enumerate(templates): | |
prompt = template.format(obj_a, obj_b) | |
obj_a_len = len(obj_a.split(' ')) - 1 | |
obj_a_position = [2] if obj_a_len == 1 else [2, 3] | |
obj_b_position = [obj_a_len + 1 + len(template.split(' ')) + i for i in range(len(obj_b.split(' '))-1)] | |
obj_positions = [obj_a_position, obj_b_position] | |
obj_a_boxes = [bboxes_template[bboxes_template_list[idx][0]].copy() for _ in range(len(obj_a.split(' ')) - 1)] | |
obj_b_boxes = [bboxes_template[bboxes_template_list[idx][1]].copy() for _ in range(len(obj_b.split(' ')) - 1)] | |
obj_boxes = [obj_a_boxes, obj_b_boxes] | |
print(prompt, obj_positions, obj_boxes) | |
# for infer_iter in range(1): | |
inference(device, unet, unet_original, vae, tokenizer, text_encoder, prompt, cfg, prev_attn_map, bboxes=obj_boxes, object_positions=obj_positions, infer_iter=cfg.inference.infer_iter, pair_id=iteration_now) | |
obj_b, obj_a = image_pair.strip().split(',')[0], image_pair.strip().split(',')[1] | |
obj_a = 'A {}'.format(obj_a) if obj_a[0] not in ['a', 'e', 'i', 'o', 'u'] else 'An {}'.format(obj_a) | |
obj_b = 'a {}'.format(obj_b) if obj_b[0] not in ['a', 'e', 'i', 'o', 'u'] else 'an {}'.format(obj_b) | |
for idx, template in enumerate(templates): | |
prompt = template.format(obj_a, obj_b) | |
obj_a_len = len(obj_a.split(' ')) - 1 | |
obj_a_position = [2] if obj_a_len == 1 else [2, 3] | |
obj_b_position = [obj_a_len + 1 + len(template.split(' ')) + i for i in range(len(obj_b.split(' '))-1)] | |
obj_positions = [obj_a_position, obj_b_position] | |
obj_a_boxes = [bboxes_template[bboxes_template_list[idx][0]].copy() for _ in range(len(obj_a.split(' ')) - 1)] | |
obj_b_boxes = [bboxes_template[bboxes_template_list[idx][1]].copy() for _ in range(len(obj_b.split(' ')) - 1)] | |
obj_boxes = [obj_a_boxes, obj_b_boxes] | |
print(prompt, obj_positions, obj_boxes) | |
inference(device, unet, unet_original, vae, tokenizer, text_encoder, prompt, cfg, prev_attn_map, bboxes=obj_boxes, object_positions=obj_positions, infer_iter=cfg.inference.infer_iter, pair_id=iteration_now) | |
iteration_now += 1 | |
def compute_ca_loss(attn_maps_mid, attn_maps_up, bboxes, object_positions): | |
loss = 0 | |
object_number = len(bboxes) | |
if object_number == 0: | |
return torch.tensor(0).float().cuda() | |
for attn_map_integrated in attn_maps_mid: | |
attn_map = attn_map_integrated.chunk(2)[1] | |
# | |
b, i, j = attn_map.shape | |
H = W = int(math.sqrt(i)) | |
# pdb.set_trace() | |
for obj_idx in range(object_number): | |
obj_loss = 0 | |
mask = torch.zeros(size=(H, W)).cuda() | |
for obj_box in bboxes[obj_idx]: | |
x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ | |
int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) | |
mask[y_min: y_max, x_min: x_max] = 1 | |
for obj_position in object_positions[obj_idx]: | |
ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W) | |
# ca_map_obj = attn_map[:, :, object_positions[obj_position]].reshape(b, H, W) | |
activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1)/ca_map_obj.reshape(b, -1).sum(dim=-1) | |
obj_loss += torch.mean((1 - activation_value) ** 2) | |
loss += (obj_loss/len(object_positions[obj_idx])) | |
# print("??", obj_idx, obj_loss/len(object_positions[obj_idx])) | |
# compute loss on padding tokens | |
# activation_value = torch.zeros(size=(b, )).cuda() | |
# for obj_idx in range(object_number): | |
# bbox = bboxes[obj_idx] | |
# ca_map_obj = attn_map[:, :, padding_start:].reshape(b, H, W, -1) | |
# activation_value += ca_map_obj[:, int(bbox[0] * H): int(bbox[1] * H), | |
# int(bbox[2] * W): int(bbox[3] * W), :].reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(dim=-1) | |
# | |
# loss += torch.mean((1 - activation_value) ** 2) | |
for attn_map_integrated in attn_maps_up[0]: | |
attn_map = attn_map_integrated.chunk(2)[1] | |
# | |
b, i, j = attn_map.shape | |
H = W = int(math.sqrt(i)) | |
for obj_idx in range(object_number): | |
obj_loss = 0 | |
mask = torch.zeros(size=(H, W)).cuda() | |
for obj_box in bboxes[obj_idx]: | |
x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ | |
int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) | |
mask[y_min: y_max, x_min: x_max] = 1 | |
for obj_position in object_positions[obj_idx]: | |
ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W) | |
# ca_map_obj = attn_map[:, :, object_positions[obj_position]].reshape(b, H, W) | |
activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum( | |
dim=-1) | |
obj_loss += torch.mean((1 - activation_value) ** 2) | |
loss += (obj_loss / len(object_positions[obj_idx])) | |
# compute loss on padding tokens | |
# activation_value = torch.zeros(size=(b, )).cuda() | |
# for obj_idx in range(object_number): | |
# bbox = bboxes[obj_idx] | |
# ca_map_obj = attn_map[:, :,padding_start:].reshape(b, H, W, -1) | |
# activation_value += ca_map_obj[:, int(bbox[0] * H): int(bbox[1] * H), | |
# int(bbox[2] * W): int(bbox[3] * W), :].reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(dim=-1) | |
# | |
# loss += torch.mean((1 - activation_value) ** 2) | |
loss = loss / (object_number * (len(attn_maps_up[0]) + len(attn_maps_mid))) | |
return loss | |
def plt_all_attn_map_in_one(attn_map_integrated_list_down, attn_map_integrated_list_mid, attn_map_integrated_list_up, image, prompt, cfg, t, prefix='all'): | |
prompt_split = prompt.split(' ') | |
prompt_len = len(prompt_split) + 4 | |
total_layers = len(attn_map_integrated_list_down) + len(attn_map_integrated_list_mid) + len(attn_map_integrated_list_up) | |
fig, axs = plt.subplots(nrows=total_layers+1, ncols=prompt_len, figsize=(4 * prompt_len, 4 * total_layers)) | |
fig.suptitle(prompt, fontsize=32) | |
fig.tight_layout() | |
cnt = 1 | |
ax = axs[0][0] | |
ax.imshow(image) | |
for prompt_idx in range(prompt_len): | |
ax = axs[0][prompt_idx] | |
ax.set_axis_off() | |
for layer, attn_map_integrated in enumerate(attn_map_integrated_list_down): | |
attn_map_uncond, attn_map = attn_map_integrated.chunk(2) | |
grid_size = int(math.sqrt(attn_map.shape[1])) | |
for prompt_idx in range(prompt_len): | |
ax = axs[cnt][prompt_idx] | |
if prompt_idx == 0: | |
ax.set_ylabel('down {}'.format(layer), rotation=0, size='large') | |
mask = attn_map.mean(dim=0)[:, prompt_idx].reshape(grid_size, grid_size).detach().cpu().numpy() | |
im = ax.imshow(mask, cmap='YlGn') | |
ax.set_axis_off() | |
cnt += 1 | |
for layer, attn_map_integrated in enumerate(attn_map_integrated_list_mid): | |
attn_map_uncond, attn_map = attn_map_integrated.chunk(2) | |
grid_size = int(math.sqrt(attn_map.shape[1])) | |
for prompt_idx in range(prompt_len): | |
ax = axs[cnt][prompt_idx] | |
if prompt_idx ==0: | |
ax.set_ylabel('mid {}'.format(layer), rotation=0, size='large') | |
mask = attn_map.mean(dim=0)[:, prompt_idx].reshape(grid_size, grid_size).detach().cpu().numpy() | |
im = ax.imshow(mask, cmap='YlGn') | |
ax.set_axis_off() | |
cnt += 1 | |
for layer, attn_map_integrated in enumerate(attn_map_integrated_list_up): | |
attn_map_uncond, attn_map = attn_map_integrated.chunk(2) | |
grid_size = int(math.sqrt(attn_map.shape[1])) | |
for prompt_idx in range(prompt_len): | |
ax = axs[cnt][prompt_idx] | |
if prompt_idx ==0: | |
ax.set_ylabel('up {}'.format(layer), rotation=0, size='large') | |
mask = attn_map.mean(dim=0)[:, prompt_idx].reshape(grid_size, grid_size).detach().cpu().numpy() | |
im = ax.imshow(mask, cmap='YlGn') | |
ax.set_axis_off() | |
cnt += 1 | |
if not os.path.exists(cfg.general.save_path + "/{}".format(prefix)): | |
os.makedirs(cfg.general.save_path + "/{}".format(prefix)) | |
plt.savefig(cfg.general.save_path + "/{}/step_{}.png".format(prefix, str(int(t)).zfill(4))) | |
# generate_video() | |
plt.close() | |
if __name__=="__main__": | |
train() | |