import os |
import sys |
import json |
import argparse |
import numpy as np |
import math |
from einops import rearrange |
import time |
import random |
import string |
import h5py |
from tqdm import tqdm |
import webdataset as wds |
import gc |
import matplotlib.pyplot as plt |
import torch |
import torch.nn as nn |
from torchvision import transforms |
from accelerate import Accelerator, DeepSpeedPlugin |
torch.backends.cuda.matmul.allow_tf32 = True |
import utils |
local_rank = os.getenv('RANK') |
if local_rank is None: |
local_rank = 0 |
else: |
local_rank = int(local_rank) |
print("LOCAL RANK ", local_rank) |
num_devices = torch.cuda.device_count() |
if num_devices==0: num_devices = 1 |
if num_devices <= 1 and utils.is_interactive(): |
global_batch_size = batch_size = 32 |
print(f"Setting batch_size to {batch_size}") |
os.environ["MASTER_ADDR"] = "localhost" |
os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000) |
os.environ["RANK"] = "0" |
os.environ["LOCAL_RANK"] = "0" |
os.environ["WORLD_SIZE"] = "1" |
os.environ["GLOBAL_BATCH_SIZE"] = str(global_batch_size) |
else: |
global_batch_size = os.environ["GLOBAL_BATCH_SIZE"] |
batch_size = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices |
if local_rank == 0: |
with open('/fsx/proj-fmri/ckadirt/MindEyeV2/src/deepspeed_config_stage2.json', 'r') as file: |
config = json.load(file) |
config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"]) |
config['train_micro_batch_size_per_gpu'] = batch_size |
config['bf16'] = {'enabled': False} |
config['fp16'] = {'enabled': True} |
with open('/fsx/proj-fmri/ckadirt/MindEyeV2/src/deepspeed_config_stage2.json', 'w') as file: |
json.dump(config, file) |
else: |
time.sleep(10) |
deepspeed_plugin = DeepSpeedPlugin("/fsx/proj-fmri/ckadirt/MindEyeV2/src/deepspeed_config_stage2.json") |
accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin) |
print("PID of this process =",os.getpid()) |
device = accelerator.device |
print("device:",device) |
num_workers = num_devices |
print(accelerator.state) |
world_size = accelerator.state.num_processes |
distributed = not accelerator.state.distributed_type == 'NO' |
if accelerator.mixed_precision == "bf16": |
data_type = torch.bfloat16 |
elif accelerator.mixed_precision == "fp16": |
data_type = torch.float16 |
else: |
data_type = torch.float32 |
print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size, "data_type =", data_type) |
print = accelerator.print |
if utils.is_interactive(): |
model_name = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) |
model_name = model_name + "_interactive" |
print("model_name:", model_name) |
jupyter_args = f"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \ |
--model_name={model_name} \ |
--subj=1 --batch_size={batch_size} --no-blurry_recon --no-depth_recon --hidden_dim=4096 \ |
--clip_scale=1. --blur_scale=100. --depth_scale=100. \ |
--max_lr=3e-4 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug --no-ckpt_saving" |
jupyter_args = jupyter_args.split() |
print(jupyter_args) |
from IPython.display import clear_output |
get_ipython().run_line_magic('load_ext', 'autoreload') |
get_ipython().run_line_magic('autoreload', '2') |
parser = argparse.ArgumentParser(description="Model Training Configuration") |
parser.add_argument( |
"--model_name", type=str, default="testing", |
help="name of model, used for ckpt saving and wandb logging (if enabled)", |
) |
parser.add_argument( |
"--data_path", type=str, default="/fsx/proj-fmri/shared/natural-scenes-dataset", |
help="Path to where NSD data is stored / where to download it to", |
) |
parser.add_argument( |
"--subj",type=int, default=1, choices=[1,2,5,7], |
) |
parser.add_argument( |
"--batch_size", type=int, default=32, |
help="Batch size can be increased by 10x if only training v2c and not diffusion diffuser", |
) |
parser.add_argument( |
"--wandb_log",action=argparse.BooleanOptionalAction,default=True, |
help="whether to log to wandb", |
) |
parser.add_argument( |
"--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False, |
help="if not using wandb and want to resume from a ckpt", |
) |
parser.add_argument( |
"--wandb_project",type=str,default="stability", |
help="wandb project name", |
) |
parser.add_argument( |
"--mixup_pct",type=float,default=.33, |
help="proportion of way through training when to switch from BiMixCo to SoftCLIP", |
) |
parser.add_argument( |
"--blurry_recon",action=argparse.BooleanOptionalAction,default=True, |
help="whether to output blurry reconstructions", |
) |
parser.add_argument( |
"--depth_recon",action=argparse.BooleanOptionalAction,default=True, |
help="whether to output depth reconstructions", |
) |
parser.add_argument( |
"--blur_scale",type=float,default=100., |
help="multiply loss from blurry recons by this number", |
) |
parser.add_argument( |
"--depth_scale",type=float,default=100., |
help="multiply loss from depth recons by this number", |
) |
parser.add_argument( |
"--clip_scale",type=float,default=1., |
help="multiply contrastive loss by this number", |
) |
parser.add_argument( |
"--use_image_aug",action=argparse.BooleanOptionalAction,default=True, |
help="whether to use image augmentation", |
) |
parser.add_argument( |
"--num_epochs",type=int,default=120, |
help="number of epochs of training", |
) |
parser.add_argument( |
"--hidden_dim",type=int,default=4096, |
) |
parser.add_argument( |
"--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'], |
) |
parser.add_argument( |
"--ckpt_saving",action=argparse.BooleanOptionalAction,default=True, |
) |
parser.add_argument( |
"--ckpt_interval",type=int,default=5, |
help="save backup ckpt and reconstruct every x epochs", |
) |
parser.add_argument( |
"--seed",type=int,default=42, |
) |
parser.add_argument( |
"--max_lr",type=float,default=3e-4, |
) |
parser.add_argument( |
"--seq_len",type=int,default=2, |
) |
if utils.is_interactive(): |
args = parser.parse_args(jupyter_args) |
else: |
args = parser.parse_args() |
for attribute_name in vars(args).keys(): |
globals()[attribute_name] = getattr(args, attribute_name) |
outdir = os.path.abspath(f'../train_logs/{model_name}') |
if not os.path.exists(outdir) and ckpt_saving: |
os.makedirs(outdir,exist_ok=True) |
if use_image_aug: |
import kornia |
from kornia.augmentation.container import AugmentationSequential |
img_augment = AugmentationSequential( |
kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3), |
kornia.augmentation.Resize((224, 224)), |
kornia.augmentation.RandomHorizontalFlip(p=0.3), |
kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3), |
kornia.augmentation.RandomGrayscale(p=0.3), |
same_on_batch=False, |
data_keys=["input"], |
) |
if subj==1: |
num_train = 24958 |
num_test = 2770 |
test_batch_size = num_test |
def my_split_by_node(urls): return urls |
train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..36}.tar" |
print(train_url) |
train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\ |
.shuffle(750, initial=1500, rng=random.Random(42))\ |
.decode("torch")\ |
.rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\ |
.to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"]) |
train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True) |
test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar" |
print(test_url) |
test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\ |
.shuffle(750, initial=1500, rng=random.Random(42))\ |
.decode("torch")\ |
.rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\ |
.to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"]) |
test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=True, pin_memory=True) |
test_vox_indices = [] |
test_73k_images = [] |
for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl): |
test_vox_indices = np.append(test_vox_indices, behav[:,0,5].cpu().numpy()) |
test_73k_images = np.append(test_73k_images, behav[:,0,0].cpu().numpy()) |
test_vox_indices = test_vox_indices.astype(np.int16) |
print(test_i, (test_i+1) * test_batch_size, len(test_vox_indices)) |
print("---\n") |
train_vox_indices = [] |
train_73k_images = [] |
for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl): |
train_vox_indices = np.append(train_vox_indices, behav[:,0,5].long().cpu().numpy()) |
train_73k_images = np.append(train_73k_images, behav[:,0,0].cpu().numpy()) |
train_vox_indices = train_vox_indices.astype(np.int16) |
print(train_i, (train_i+1) * batch_size, len(train_vox_indices)) |
f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r') |
voxels = f['betas'][:] |
print(f"subj0{subj} betas loaded into memory") |
voxels = torch.Tensor(voxels).to("cpu").to(data_type) |
print("voxels", voxels.shape) |
num_voxels = voxels.shape[-1] |
f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r') |
images = f['images'][:] |
images = torch.Tensor(images).to("cpu").to(data_type) |
print("images", images.shape) |
from models import Clipper |
clip_model = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=True, norm_embs=True) |
clip_seq_dim = 257 |
clip_emb_dim = 768 |
if blurry_recon: |
from diffusers import VQModel |
autoenc = VQModel.from_pretrained("/fsx/proj-fmri/shared/cache/models--microsoft--vq-diffusion-ithq/snapshots/3f796fb49ee559370dc638dea1d8116af131d993/vqvae", torch_dtype=data_type) |
autoenc.eval() |
autoenc.requires_grad_(False) |
autoenc.to(device) |
utils.count_params(autoenc) |
if blurry_recon: |
if utils.is_interactive(): display(utils.torch_to_Image(images[[30]])) |
input_batch = images[[30]].to(device) |
print(input_batch.shape) |
downsampled_image = nn.functional.interpolate(input_batch, size=(8, 8), mode='bilinear', align_corners=False) |
re_upsampled_image = nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest') |
re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215 |
print(re_upsampled_enc.shape) |
if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(re_upsampled_enc/0.18215).sample / 2 + 0.5).clamp(0,1))) |
if depth_recon: |
from controlnet_aux.midas import MidasDetector |
midas_depth = MidasDetector.from_pretrained( |
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large", cache_dir="/fsx/proj-fmri/shared/cache").to(device) |
midas_depth.model.eval() |
midas_depth.model.requires_grad_(False) |
midas_depth.model.to(device) |
pass |
if depth_recon: |
if utils.is_interactive(): display(utils.torch_to_Image(images[[30]])) |
input_batch = images[[30,31]].float().to(device) |
print(input_batch.shape) |
midas_emb = midas_depth.model(input_batch).unsqueeze(1) |
print(midas_emb.shape) |
prediction = utils.resize(midas_emb, 32) |
print(prediction.shape) |
prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half() |
midas_emb_size = prediction.flatten(1).shape[1] |
print("midas_emb", prediction.shape, prediction.min(), prediction.max()) |
print("midas_emb_size", midas_emb_size) |
if utils.is_interactive(): display(utils.torch_to_Image(utils.resize(prediction, 224))) |
if blurry_recon: |
prediction = utils.resize(midas_emb, 128).half().repeat(1,3,1,1) |
prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half() |
prediction_enc = autoenc.encode(2*prediction-1).latents * 0.18215 |
print("vae midas_emb", prediction_enc.shape, prediction_enc.min(), prediction_enc.max()) |
if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(prediction_enc/0.18215).sample / 2 + 0.5).clamp(0,1))) |
class MindEyeModule(nn.Module): |
def __init__(self): |
super(MindEyeModule, self).__init__() |
def forward(self, x): |
return x |
model = MindEyeModule() |
model |
time_embedding_dim = 512 |
class RidgeRegression(torch.nn.Module): |
def __init__(self, input_size, out_features): |
super(RidgeRegression, self).__init__() |
self.out_features = out_features |
self.linear = torch.nn.Linear(input_size, out_features) |
def forward(self, x): |
return self.linear(x) |
model.ridge = RidgeRegression(voxels.shape[1] + time_embedding_dim, out_features=hidden_dim) |
utils.count_params(model.ridge) |
utils.count_params(model) |
b = torch.randn((2,1,voxels.shape[1])) |
time_emb_test = torch.randn((2,1,time_embedding_dim)) |
print(b.shape, model.ridge(torch.cat((b,time_emb_test),dim=-1)).shape) |
num_past_voxels = 15 |
from functools import partial |
from diffusers.models.vae import Decoder |
class BrainNetwork(nn.Module): |
def __init__(self, out_dim=768, in_dim=15724, seq_len=2, h=4096, n_blocks=4, drop=.15, clip_size=768): |
super().__init__() |
self.seq_len = seq_len |
self.h = h |
self.clip_size = clip_size |
self.mixer_blocks1 = nn.ModuleList([ |
self.mixer_block1(h, drop) for _ in range(n_blocks) |
]) |
self.mixer_blocks2 = nn.ModuleList([ |
self.mixer_block2(seq_len, drop) for _ in range(n_blocks) |
]) |
self.clin1 = nn.Linear(h * seq_len, out_dim, bias=True) |
self.clip_proj = nn.Sequential( |
nn.LayerNorm(clip_size), |
nn.GELU(), |
nn.Linear(clip_size, 2048), |
nn.LayerNorm(2048), |
nn.GELU(), |
nn.Linear(2048, 2048), |
nn.LayerNorm(2048), |
nn.GELU(), |
nn.Linear(2048, clip_size) |
) |
if blurry_recon: |
self.blin1 = nn.Linear(h*seq_len, 4096) |
self.bgroupnorm = nn.GroupNorm(1, 256) |
self.bupsampler = Decoder( |
in_channels=256, |
out_channels=128, |
up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"], |
block_out_channels=[32, 64, 128], |
layers_per_block=1, |
) |
if depth_recon: |
self.dlin1 = nn.Linear(h*seq_len, 4096) |
self.dgroupnorm = nn.GroupNorm(1, 256) |
self.dupsampler = Decoder( |
in_channels=256, |
out_channels=1, |
up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"], |
block_out_channels=[32, 64, 128, 256], |
layers_per_block=1, |
) |
def mixer_block1(self, h, drop): |
return nn.Sequential( |
nn.LayerNorm(h), |
self.mlp(h, h, drop), |
) |
def mixer_block2(self, seq_len, drop): |
return nn.Sequential( |
nn.LayerNorm(seq_len), |
self.mlp(seq_len, seq_len, drop) |
) |
def mlp(self, in_dim, out_dim, drop): |
return nn.Sequential( |
nn.Linear(in_dim, out_dim), |
nn.GELU(), |
nn.Dropout(drop), |
nn.Linear(out_dim, out_dim), |
) |
def forward(self, x, idx = None): |
print(idx) |
b,d = torch.Tensor([0.]), torch.Tensor([0.]) |
residual1 = x |
residual2 = x.permute(0,2,1) |
for block1, block2 in zip(self.mixer_blocks1,self.mixer_blocks2): |
x = block1(x) + residual1 |
residual1 = x |
x = x.permute(0,2,1) |
x = block2(x) + residual2 |
residual2 = x |
x = x.permute(0,2,1) |
x = x.reshape(x.size(0), -1) |
c = self.clin1(x) |
c = self.clip_proj(c.reshape(len(c), -1, self.clip_size)) |
if blurry_recon: |
b = self.blin1(x) |
b = b.reshape(len(b), 256, 4, 4) |
b = self.bgroupnorm(b) |
b = self.bupsampler(b) |
if depth_recon: |
d = self.dlin1(x) |
d = d.reshape(len(d), 256, 4, 4) |
d = self.dgroupnorm(d) |
d = self.dupsampler(d) |
return c, b, d |
class TimeEmbedding(nn.Module): |
def __init__(self, embedding_time_dim=512, num_past_voxels=15): |
super().__init__() |
self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim) |
self.num_past_voxels = num_past_voxels |
self.embedding_time_dim = embedding_time_dim |
def forward(self, time): |
time = time.long() |
time = self.embedding_time(time) |
return time |
model.time_embedding = TimeEmbedding(embedding_time_dim=512, num_past_voxels=15) |
model.backbone = BrainNetwork(h=hidden_dim, in_dim=hidden_dim, seq_len=seq_len, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim) |
utils.count_params(model.backbone) |
utils.count_params(model) |
b = torch.randn((1,seq_len,hidden_dim)) |
print("b.shape",b.shape) |
with torch.no_grad(): |
clip_, blur_, depth_ = model.backbone(b) |
print(clip_.shape, blur_.shape, depth_.shape) |
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] |
opt_grouped_parameters = [ |
{'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2}, |
{'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2}, |
{'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, |
] |
optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr) |
if lr_scheduler_type == 'linear': |
lr_scheduler = torch.optim.lr_scheduler.LinearLR( |
optimizer, |
total_iters=int(np.floor(num_epochs*(num_train/num_devices/batch_size))), |
last_epoch=-1 |
) |
elif lr_scheduler_type == 'cycle': |
total_steps=int(np.floor(num_epochs*(num_train/num_devices/batch_size))) |
print("total_steps", total_steps) |
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( |
optimizer, |
max_lr=max_lr, |
total_steps=total_steps, |
final_div_factor=1000, |
last_epoch=-1, pct_start=2/num_epochs |
) |
def save_ckpt(tag): |
ckpt_path = outdir+f'/{tag}.pth' |
print(f'saving {ckpt_path}',flush=True) |
unwrapped_model = accelerator.unwrap_model(model) |
try: |
torch.save({ |
'epoch': epoch, |
'model_state_dict': unwrapped_model.state_dict(), |
'optimizer_state_dict': optimizer.state_dict(), |
'lr_scheduler': lr_scheduler.state_dict(), |
'train_losses': losses, |
'test_losses': test_losses, |
'lrs': lrs, |
}, ckpt_path) |
except: |
print("Couldn't save... moving on to prevent crashing.") |
del unwrapped_model |
print("\nDone with model preparations!") |
utils.count_params(model) |
if local_rank==0 and wandb_log: |
import wandb |
wandb_project = 'mindeyev2' |
wandb_run = model_name |
wandb_notes = '' |
print(f"wandb {wandb_project} run {wandb_run}") |
wandb.login(host='https://stability.wandb.io') |
wandb_config = { |
"model_name": model_name, |
"global_batch_size": global_batch_size, |
"batch_size": batch_size, |
"num_epochs": num_epochs, |
"clip_scale": clip_scale, |
"blur_scale": blur_scale, |
"use_image_aug": use_image_aug, |
"max_lr": max_lr, |
"mixup_pct": mixup_pct, |
"num_train": num_train, |
"num_test": num_test, |
"ckpt_interval": ckpt_interval, |
"ckpt_saving": ckpt_saving, |
"seed": seed, |
"distributed": distributed, |
"num_devices": num_devices, |
"world_size": world_size, |
"train_url": train_url, |
"test_url": test_url, |
} |
print("wandb_config:\n",wandb_config) |
if False: |
print("wandb_id:",model_name) |
wandb.init( |
id = model_name, |
project=wandb_project, |
name=wandb_run, |
config=wandb_config, |
notes=wandb_notes, |
resume="allow", |
) |
else: |
wandb.init( |
project=wandb_project, |
name=wandb_run, |
config=wandb_config, |
notes=wandb_notes, |
) |
else: |
wandb_log = False |
epoch = 0 |
losses, test_losses, lrs = [], [], [] |
best_test_loss = 1e9 |
soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs)) |
if resume_from_ckpt: |
print("\n---resuming from last.pth ckpt---\n") |
try: |
checkpoint = torch.load(outdir+'/last.pth', map_location='cpu') |
except: |
print('last.pth failed... trying last_backup.pth') |
checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu') |
epoch = checkpoint['epoch'] |
print("Epoch",epoch) |
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
model.load_state_dict(checkpoint['model_state_dict']) |
del checkpoint |
elif wandb_log: |
if wandb.run.resumed: |
print("\n---resuming from last.pth ckpt---\n") |
try: |
checkpoint = torch.load(outdir+'/last.pth', map_location='cpu') |
except: |
print('last.pth failed... trying last_backup.pth') |
checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu') |
epoch = checkpoint['epoch'] |
print("Epoch",epoch) |
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
model.load_state_dict(checkpoint['model_state_dict']) |
del checkpoint |
torch.cuda.empty_cache() |
model, optimizer, train_dl, lr_scheduler = accelerator.prepare( |
model, optimizer, train_dl, lr_scheduler |
) |
def add_saturation(image, alpha=2): |
gray_image = 0.2989 * image[:, 0, :, :] + 0.5870 * image[:, 1, :, :] + 0.1140 * image[:, 2, :, :] |
gray_image = gray_image.unsqueeze(1).expand_as(image) |
saturated_image = alpha * image + (1 - alpha) * gray_image |
return torch.clamp(saturated_image, 0, 1) |
print(f"{model_name} starting with epoch {epoch} / {num_epochs}") |
progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0)) |
test_image, test_voxel = None, None |
mse = nn.MSELoss() |
l1 = nn.L1Loss() |
for epoch in progress_bar: |
model.train() |
fwd_percent_correct = 0. |
bwd_percent_correct = 0. |
test_fwd_percent_correct = 0. |
test_bwd_percent_correct = 0. |
loss_clip_total = 0. |
loss_blurry_total = 0. |
loss_depth_total = 0. |
test_loss_clip_total = 0. |
test_loss_blurry_total = 0. |
test_loss_depth_total = 0. |
blurry_pixcorr = 0. |
test_blurry_pixcorr = 0. |
for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl): |
with torch.cuda.amp.autocast(dtype=data_type): |
optimizer.zero_grad() |
voxel = voxels[behav[:,0,5].cpu().long()].to(device) |
image = images[behav[:,0,0].cpu().long()].to(device).float() |
past_15_voxels = voxels[past_behav[:,:seq_len-1,5].cpu().long()].to(device) |
past_15_times = torch.Tensor([i for i in range(seq_len-1)]).to(device) |
if blurry_recon: |
blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215 |
if depth_recon: |
depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32) |
depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half() |
depth_image_enc = depth_images |
if use_image_aug: |
image = img_augment(image) |
clip_target = clip_model.embed_image(image) |
assert not torch.any(torch.isnan(clip_target)) |
if epoch < int(mixup_pct * num_epochs): |
voxel, perm, betas, select = utils.mixco(voxel) |
past_voxel, _, _, _ = utils.mixco(voxel, perm=perm, betas=betas, select=select) |
for p in range(seq_len-1): |
mask = past_behav[:,p,-1] == torch.ones_like(past_behav[:,p,-1]) |
past_15_voxels[mask, p, :] = torch.zeros_like(past_15_voxels[0, p, :]) |
past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1]) |
past_15_times = past_15_times.repeat(voxel.shape[0], 1) |
past_15_times = past_15_times.reshape(-1) |
time_embeddings = model.time_embedding(past_15_times) |
past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1) |
positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device) |
voxel = torch.cat((voxel, positional_current_voxel), dim=-1) |
voxel_ridge = model.ridge(torch.cat((voxel, past_info_full), dim=-2)) |
voxel_ridge = voxel_ridge.view( seq_len, int(voxel_ridge.shape[0]/seq_len), hidden_dim).permute(1,0,2) |
clip_voxels, blurry_image_enc_, depth_image_enc_ = model.backbone(voxel_ridge, idx = train_i) |
clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1) |
clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1) |
if epoch < int(mixup_pct * num_epochs): |
loss_clip = utils.mixco_nce( |
clip_voxels_norm, |
clip_target_norm, |
temp=.006, |
perm=perm, betas=betas, select=select) |
else: |
epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)] |
loss_clip = utils.soft_clip_loss( |
clip_voxels_norm, |
clip_target_norm, |
temp=epoch_temp) |
loss_clip_total += loss_clip.item() |
loss_clip *= clip_scale |
loss = loss_clip |
if blurry_recon: |
downsampled_image = nn.functional.interpolate(image, size=(8, 8), mode='bilinear', align_corners=False) |
re_upsampled_image = add_saturation(nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest')) |
re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215 |
loss_blurry = (l1(blurry_image_enc_, blurry_image_enc) + l1(blurry_image_enc_, re_upsampled_enc)) |
loss_blurry += l1(torch.var(blurry_image_enc), torch.var(blurry_image_enc_)) |
loss_blurry_total += loss_blurry.item() |
loss_blurry *= blur_scale |
loss += loss_blurry |
if depth_recon: |
loss_depth = l1(depth_image_enc_, depth_image_enc) |
loss_depth_total += loss_depth.item() |
loss_depth *= depth_scale |
loss += loss_depth |
labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) |
fwd_percent_correct += utils.topk(torch.abs(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm)), labels, k=1).item() |
bwd_percent_correct += utils.topk(torch.abs(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm)), labels, k=1).item() |
if blurry_recon: |
with torch.no_grad(): |
random_samps = np.random.choice(np.arange(len(voxel)), size=batch_size//5, replace=False) |
blurry_recon_images = (autoenc.decode(blurry_image_enc_[random_samps]/0.18215).sample/ 2 + 0.5).clamp(0,1) |
pixcorr = utils.pixcorr_origsize_nanmean(image[random_samps], blurry_recon_images) |
blurry_pixcorr += pixcorr.item() |
utils.check_loss(loss) |
accelerator.backward(loss) |
optimizer.step() |
losses.append(loss.item()) |
lrs.append(optimizer.param_groups[0]['lr']) |
if lr_scheduler_type is not None: |
lr_scheduler.step() |
model.eval() |
if local_rank==0: |
with torch.no_grad(), torch.cuda.amp.autocast(dtype=data_type): |
for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl): |
assert len(behav) == num_test |
if test_image is None: |
voxel = voxels[behav[:,0,5].cpu().long()] |
image = behav[:,0,0].cpu().long() |
unique_image, sort_indices = torch.unique(image, return_inverse=True) |
for im in unique_image: |
locs = torch.where(im == image)[0] |
if test_image is None: |
test_image = images[im][None] |
test_voxel = torch.mean(voxel[locs],axis=0)[None] |
else: |
test_image = torch.vstack((test_image, images[im][None])) |
test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None])) |
random_indices = torch.arange(len(test_voxel))[:300] |
voxel = test_voxel[random_indices].to(device) |
image = test_image[random_indices].to(device) |
assert len(image) == 300 |
current_past_behav = past_behav[random_indices] |
past_15_voxels = voxels[current_past_behav[:,:seq_len-1,5].cpu().long()].to(device) |
past_15_times = torch.Tensor([i for i in range(seq_len-1)]).to(device) |
if blurry_recon: |
blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215 |
if depth_recon: |
depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32) |
depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half() |
depth_image_enc = depth_images |
clip_target = clip_model.embed_image(image.float()) |
past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1]) |
past_15_times = past_15_times.repeat(voxel.shape[0], 1) |
past_15_times = past_15_times.reshape(-1) |
time_embeddings = model.time_embedding(past_15_times) |
past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1) |
positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device) |
voxel = torch.cat((voxel, positional_current_voxel), dim=-1) |
voxel_ridge = model.ridge(torch.cat((voxel, past_info_full), dim=-2)) |
voxel_ridge = voxel_ridge.view(seq_len, int(voxel_ridge.shape[0]/seq_len), hidden_dim).permute(1,0,2) |
clip_voxels, blurry_image_enc_, depth_image_enc_ = model.backbone(voxel_ridge) |
clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1) |
clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1) |
loss_clip = utils.soft_clip_loss( |
clip_voxels_norm, |
clip_target_norm, |
temp=.006) |
test_loss_clip_total += loss_clip.item() |
loss_clip = loss_clip * clip_scale |
loss = loss_clip |
if blurry_recon: |
downsampled_image = nn.functional.interpolate(image, size=(8, 8), mode='bilinear', align_corners=False) |
re_upsampled_image = add_saturation(nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest')) |
re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215 |
loss_blurry = (l1(blurry_image_enc_, blurry_image_enc) + l1(blurry_image_enc_, re_upsampled_enc)) |
loss_blurry += l1(torch.var(blurry_image_enc), torch.var(blurry_image_enc_)) |
test_loss_blurry_total += loss_blurry.item() |
loss_blurry *= blur_scale |
loss += loss_blurry |
blurry_recon_images = (autoenc.decode(blurry_image_enc_[:len(voxel)//2]/0.18215).sample / 2 + 0.5).clamp(0,1) |
blurry_recon_images = torch.vstack((blurry_recon_images, (autoenc.decode(blurry_image_enc_[len(voxel)//2:]/0.18215).sample / 2 + 0.5).clamp(0,1))) |
pixcorr = utils.pixcorr(image, blurry_recon_images) |
loss += (1 - pixcorr) |
test_blurry_pixcorr += pixcorr.item() |
if depth_recon: |
loss_depth = l1(depth_image_enc_, depth_image_enc) |
test_loss_depth_total += loss_depth.item() |
loss_depth *= depth_scale |
loss += loss_depth |
labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) |
test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1).item() |
test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1).item() |
utils.check_loss(loss) |
test_losses.append(loss.item()) |
print("---") |
assert (test_i+1) == 1 |
logs = {"train/loss": np.mean(losses[-(train_i+1):]), |
"test/loss": np.mean(test_losses[-(test_i+1):]), |
"train/lr": lrs[-1], |
"train/num_steps": len(losses), |
"test/num_steps": len(test_losses), |
"train/fwd_pct_correct": fwd_percent_correct / (train_i + 1), |
"train/bwd_pct_correct": bwd_percent_correct / (train_i + 1), |
"test/test_fwd_pct_correct": test_fwd_percent_correct / (test_i + 1), |
"test/test_bwd_pct_correct": test_bwd_percent_correct / (test_i + 1), |
"train/loss_clip_total": loss_clip_total / (train_i + 1), |
"train/loss_blurry_total": loss_blurry_total / (train_i + 1), |
"test/loss_clip_total": test_loss_clip_total / (test_i + 1), |
"test/loss_blurry_total": test_loss_blurry_total / (test_i + 1), |
"train/blurry_pixcorr": blurry_pixcorr / (train_i + 1), |
"test/blurry_pixcorr": test_blurry_pixcorr / (test_i + 1), |
"train/loss_depth_total": loss_depth_total / (train_i + 1), |
"test/loss_depth_total": test_loss_depth_total / (test_i + 1), |
} |
if blurry_recon: |
fig, axes = plt.subplots(1, 8, figsize=(10, 4)) |
jj=-1 |
for j in [0,1,2,3]: |
jj+=1 |
axes[jj].imshow(utils.torch_to_Image((autoenc.decode(blurry_image_enc[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1))) |
axes[jj].axis('off') |
jj+=1 |
axes[jj].imshow(utils.torch_to_Image((autoenc.decode(blurry_image_enc_[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1))) |
axes[jj].axis('off') |
if wandb_log: |
logs[f"test/recons"] = wandb.Image(fig, caption=f"epoch{epoch:03d}") |
plt.close() |
else: |
plt.show() |
if depth_recon: |
fig, axes = plt.subplots(1, 8, figsize=(10, 4)) |
jj=-1 |
for j in [0,1,2,3]: |
jj+=1 |
axes[jj].imshow(utils.torch_to_Image(utils.resize(depth_image_enc[[j]].view(1,1,32,32).clamp(0,1), 224))) |
axes[jj].axis('off') |
jj+=1 |
axes[jj].imshow(utils.torch_to_Image(utils.resize(depth_image_enc_[[j]].view(1,1,32,32).clamp(0,1), 224))) |
axes[jj].axis('off') |
if wandb_log: |
logs[f"test/depth_recons"] = wandb.Image(fig, caption=f"epoch{epoch:03d}") |
plt.close() |
else: |
plt.show() |
progress_bar.set_postfix(**logs) |
if epoch % ckpt_interval == 0: |
if not utils.is_interactive(): |
save_ckpt(f'last') |
if wandb_log: wandb.log(logs) |
accelerator.wait_for_everyone() |
torch.cuda.empty_cache() |
gc.collect() |
print("\n===Finished!===\n") |
if ckpt_saving: |
save_ckpt(f'last') |
if not utils.is_interactive(): |
sys.exit(0) |
plt.plot(losses) |
plt.show() |
plt.plot(test_losses) |
plt.show() |
annots = np.load("/fsx/proj-fmri/shared/mindeyev2_dataset/COCO_73k_annots_curated.npy") |
ii=2 |
all_indices = np.unique(train_73k_images) |
with torch.no_grad(), torch.cuda.amp.autocast(): |
for batch in tqdm(range(0,len(all_indices),512)): |
if batch==0: |
clip_target = clip_model.embed_image(images[all_indices[batch:batch+512]]).cpu() |
else: |
target = clip_model.embed_image(images[all_indices[batch:batch+512]]).cpu() |
clip_target = torch.vstack((clip_target,target)) |
clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1) |
voxel = test_voxel[[ii]].to(device) |
image = test_image[[ii]].to(device) |
print("Original Image (test set)") |
display(utils.torch_to_Image(image)) |
clip_target = clip_model.embed_image(image).cpu() |
voxel_ridge = model.ridge(voxel).unsqueeze(1) |
clip_voxels, _, _ = model.backbone(voxel_ridge) |
clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1) |
clip_voxels_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1) |
print("clip_voxels_norm", clip_voxels_norm.shape) |
print("clip_target_norm", clip_target_norm.shape) |
sortt = torch.argsort(utils.batchwise_cosine_similarity(clip_voxels_norm.cpu(), |
clip_target_norm).flatten()).flip(0) |
picks = all_indices[sortt[:5]] |
print("\nNearest neighbors in training set") |
for ip,p in enumerate(picks): |
display(utils.torch_to_Image(images[[p]])) |
if ip==0: predicted_caption = utils.select_annotations([annots[int(p)]])[0] |
print("\n=====\npredicted_caption:\n", predicted_caption) |
from diffusers import StableDiffusionXLPipeline |
pipe = StableDiffusionXLPipeline.from_pretrained( |
"/fsx/proj-fmri/shared/cache/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/f898a3e026e802f68796b95e9702464bac78d76f", torch_dtype=torch.float16, variant="fp16", use_safetensors=True |
) |
pipe.to("cuda") |
pass |
prompt = predicted_caption |
recon = pipe(prompt=prompt).images[0] |
print("Seen image") |
display(utils.torch_to_Image(image)) |
print("Reconstruction") |
utils.torch_to_Image(utils.resize(transforms.ToTensor()(recon),224)) |