{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "f4d95fac-ac1d-473c-ab96-650f76e6aaf5", "metadata": { "tags": [] }, "outputs": [], "source": [ "# # Code to convert this notebook to .py if you want to run it via command line or with Slurm\n", "# from subprocess import call\n", "# command = \"jupyter nbconvert Train_with_autoencoder_MLPMixer.ipynb --to python\"\n", "# call(command,shell=True)" ] }, { "cell_type": "markdown", "id": "b0f0f4f3", "metadata": {}, "source": [ "# Import packages & functions" ] }, { "cell_type": "code", "execution_count": 2, "id": "5bad764b-45c1-45ce-a716-8d055e09821a", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2023-10-28 20:46:20,021] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import os\n", "import sys\n", "import json\n", "import argparse\n", "import numpy as np\n", "import math\n", "from einops import rearrange\n", "import time\n", "import random\n", "import string\n", "import h5py\n", "from tqdm import tqdm\n", "\n", "import webdataset as wds\n", "import gc\n", "\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", "from torchvision import transforms\n", "\n", "from accelerate import Accelerator, DeepSpeedPlugin\n", "\n", "# tf32 data type is faster than standard float32\n", "torch.backends.cuda.matmul.allow_tf32 = True\n", "\n", "# custom functions #\n", "import utils" ] }, { "cell_type": "code", "execution_count": 3, "id": "c0267850-3785-4be6-b134-b2a52bf55113", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "utils.mixco" ] }, { "cell_type": "code", "execution_count": 4, "id": "cc5d2e32-6027-4a19-bef4-5ca068db35bb", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LOCAL RANK 0\n", "Setting batch_size to 128\n", "[2023-10-28 20:46:28,070] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented\n", "[2023-10-28 20:46:28,071] [INFO] [comm.py:594:init_distributed] cdb=None\n", "[2023-10-28 20:46:28,071] [INFO] [comm.py:625:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl\n" ] } ], "source": [ "### Multi-GPU config ###\n", "local_rank = os.getenv('RANK')\n", "if local_rank is None: \n", " local_rank = 0\n", "else:\n", " local_rank = int(local_rank)\n", "print(\"LOCAL RANK \", local_rank) \n", "\n", "num_devices = torch.cuda.device_count()\n", "if num_devices==0: num_devices = 1\n", "\n", "# ## UNCOMMENT BELOW SECTION AND COMMENT OUT DEEPSPEED SECTION TO AVOID USING DEEPSPEED ###\n", "# accelerator = Accelerator(split_batches=False, mixed_precision=\"fp16\")\n", "# global_batch_size = batch_size = 128\n", "# data_type = torch.float16 # change depending on your mixed_precision\n", "\n", "### DEEPSPEED INITIALIZATION ###\n", "if num_devices <= 1 and utils.is_interactive():\n", " global_batch_size = batch_size = 128\n", " print(f\"Setting batch_size to {batch_size}\")\n", " # can emulate a distributed environment for deepspeed to work in jupyter notebook\n", " os.environ[\"MASTER_ADDR\"] = \"localhost\"\n", " os.environ[\"MASTER_PORT\"] = str(np.random.randint(10000)+9000)\n", " os.environ[\"RANK\"] = \"0\"\n", " os.environ[\"LOCAL_RANK\"] = \"0\"\n", " os.environ[\"WORLD_SIZE\"] = \"1\"\n", " os.environ[\"GLOBAL_BATCH_SIZE\"] = str(global_batch_size) # set this to your batch size!\n", "else:\n", " global_batch_size = os.environ[\"GLOBAL_BATCH_SIZE\"] \n", " batch_size = int(os.environ[\"GLOBAL_BATCH_SIZE\"]) // num_devices\n", "\n", "# alter the deepspeed config according to your global and local batch size\n", "if local_rank == 0:\n", " with open('deepspeed_config_stage2.json', 'r') as file:\n", " config = json.load(file)\n", " config['train_batch_size'] = int(os.environ[\"GLOBAL_BATCH_SIZE\"])\n", " config['train_micro_batch_size_per_gpu'] = batch_size\n", " config['bf16'] = {'enabled': False}\n", " config['fp16'] = {'enabled': True}\n", " with open('deepspeed_config_stage2.json', 'w') as file:\n", " json.dump(config, file)\n", "else:\n", " # give some time for the local_rank=0 gpu to prep new deepspeed config file\n", " time.sleep(10)\n", "deepspeed_plugin = DeepSpeedPlugin(\"deepspeed_config_stage2.json\")\n", "accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)" ] }, { "cell_type": "code", "execution_count": 5, "id": "b767ab6f-d4a9-47a5-b3bf-f56bf6760c0c", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PID of this process = 1896724\n", "device: cuda:0\n", "Distributed environment: DEEPSPEED Backend: nccl\n", "Num processes: 1\n", "Process index: 0\n", "Local process index: 0\n", "Device: cuda:0\n", "\n", "Mixed precision type: fp16\n", "ds_config: {'bf16': {'enabled': False}, 'fp16': {'enabled': True}, 'zero_optimization': {'stage': 2, 'contiguous_gradients': True, 'stage3_gather_16bit_weights_on_model_save': True, 'stage3_max_live_parameters': 1000000000.0, 'stage3_max_reuse_distance': 1000000000.0, 'stage3_prefetch_bucket_size': 10000000.0, 'stage3_param_persistence_threshold': 100000.0, 'reduce_bucket_size': 10000000.0, 'sub_group_size': 1000000000.0, 'offload_optimizer': {'device': 'none', 'nvme_path': '/scratch', 'pin_memory': True}, 'offload_param': {'device': 'none', 'nvme_path': '/scratch', 'buffer_size': 4000000000.0, 'pin_memory': True}}, 'aio': {'block_size': 26214400, 'queue_depth': 64, 'thread_count': 1, 'single_submit': False, 'overlap_events': True}, 'gradient_accumulation_steps': 1, 'gradient_clipping': 1.0, 'steps_per_print': inf, 'train_batch_size': 128, 'train_micro_batch_size_per_gpu': 128, 'wall_clock_breakdown': False, 'zero_allow_untested_optimizer': True}\n", "\n", "distributed = True num_devices = 1 local rank = 0 world size = 1 data_type = torch.float16\n" ] } ], "source": [ "print(\"PID of this process =\",os.getpid())\n", "device = accelerator.device\n", "print(\"device:\",device)\n", "num_workers = num_devices\n", "print(accelerator.state)\n", "world_size = accelerator.state.num_processes\n", "distributed = not accelerator.state.distributed_type == 'NO'\n", "\n", "# set data_type to match your mixed precision (automatically set based on deepspeed config)\n", "if accelerator.mixed_precision == \"bf16\":\n", " data_type = torch.bfloat16\n", "elif accelerator.mixed_precision == \"fp16\":\n", " data_type = torch.float16\n", "else:\n", " data_type = torch.float32\n", "\n", "print(\"distributed =\",distributed, \"num_devices =\", num_devices, \"local rank =\", local_rank, \"world size =\", world_size, \"data_type =\", data_type)\n", "print = accelerator.print # only print if local_rank=0" ] }, { "cell_type": "markdown", "id": "9018b82b-c054-4463-9527-4b0c2a75bda6", "metadata": { "tags": [] }, "source": [ "# Configurations" ] }, { "cell_type": "code", "execution_count": 6, "id": "2b61fec7-72a0-4b67-86da-1375f1d9fbd3", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "model_name: 0qiAxQoaKN_interactive_bsl\n", "['--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset', '--model_name=0qiAxQoaKN_interactive_bsl', '--subj=1', '--batch_size=128', '--no-blurry_recon', '--no-depth_recon', '--hidden_dim=4096', '--clip_scale=1.', '--blur_scale=100.', '--depth_scale=100.', '--max_lr=3e-4', '--mixup_pct=.66', '--num_epochs=12', '--ckpt_interval=999', '--no-use_image_aug', '--no-ckpt_saving']\n" ] } ], "source": [ "# if running this interactively, can specify jupyter_args here for argparser to use\n", "if utils.is_interactive():\n", " # create random model_name\n", " model_name = ''.join(random.choices(string.ascii_letters + string.digits, k=10))\n", " model_name = model_name + \"_interactive_bsl\"\n", " print(\"model_name:\", model_name)\n", "\n", " # global_batch_size and batch_size should already be defined in the above cells\n", " # other variables can be specified in the following string:\n", " jupyter_args = f\"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \\\n", " --model_name={model_name} \\\n", " --subj=1 --batch_size={batch_size} --no-blurry_recon --no-depth_recon --hidden_dim=4096 \\\n", " --clip_scale=1. --blur_scale=100. --depth_scale=100. \\\n", " --max_lr=3e-4 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug --no-ckpt_saving\"\n", "\n", " jupyter_args = jupyter_args.split()\n", " print(jupyter_args)\n", " \n", " from IPython.display import clear_output # function to clear print outputs in cell\n", " %load_ext autoreload \n", " # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions\n", " %autoreload 2 " ] }, { "cell_type": "code", "execution_count": 7, "id": "2028bdf0-2f41-46d9-b6e7-86b870dbf16c", "metadata": { "tags": [] }, "outputs": [], "source": [ "parser = argparse.ArgumentParser(description=\"Model Training Configuration\")\n", "parser.add_argument(\n", " \"--model_name\", type=str, default=\"testing\",\n", " help=\"name of model, used for ckpt saving and wandb logging (if enabled)\",\n", ")\n", "parser.add_argument(\n", " \"--data_path\", type=str, default=\"/fsx/proj-fmri/shared/natural-scenes-dataset\",\n", " help=\"Path to where NSD data is stored / where to download it to\",\n", ")\n", "parser.add_argument(\n", " \"--subj\",type=int, default=1, choices=[1,2,5,7],\n", ")\n", "parser.add_argument(\n", " \"--batch_size\", type=int, default=32,\n", " help=\"Batch size can be increased by 10x if only training v2c and not diffusion diffuser\",\n", ")\n", "parser.add_argument(\n", " \"--wandb_log\",action=argparse.BooleanOptionalAction,default=True,\n", " help=\"whether to log to wandb\",\n", ")\n", "parser.add_argument(\n", " \"--resume_from_ckpt\",action=argparse.BooleanOptionalAction,default=False,\n", " help=\"if not using wandb and want to resume from a ckpt\",\n", ")\n", "parser.add_argument(\n", " \"--wandb_project\",type=str,default=\"stability\",\n", " help=\"wandb project name\",\n", ")\n", "parser.add_argument(\n", " \"--mixup_pct\",type=float,default=.33,\n", " help=\"proportion of way through training when to switch from BiMixCo to SoftCLIP\",\n", ")\n", "parser.add_argument(\n", " \"--blurry_recon\",action=argparse.BooleanOptionalAction,default=True,\n", " help=\"whether to output blurry reconstructions\",\n", ")\n", "parser.add_argument(\n", " \"--depth_recon\",action=argparse.BooleanOptionalAction,default=True,\n", " help=\"whether to output depth reconstructions\",\n", ")\n", "parser.add_argument(\n", " \"--blur_scale\",type=float,default=100.,\n", " help=\"multiply loss from blurry recons by this number\",\n", ")\n", "parser.add_argument(\n", " \"--depth_scale\",type=float,default=100.,\n", " help=\"multiply loss from depth recons by this number\",\n", ")\n", "parser.add_argument(\n", " \"--clip_scale\",type=float,default=1.,\n", " help=\"multiply contrastive loss by this number\",\n", ")\n", "parser.add_argument(\n", " \"--use_image_aug\",action=argparse.BooleanOptionalAction,default=True,\n", " help=\"whether to use image augmentation\",\n", ")\n", "parser.add_argument(\n", " \"--num_epochs\",type=int,default=120,\n", " help=\"number of epochs of training\",\n", ")\n", "parser.add_argument(\n", " \"--hidden_dim\",type=int,default=4096,\n", ")\n", "parser.add_argument(\n", " \"--lr_scheduler_type\",type=str,default='cycle',choices=['cycle','linear'],\n", ")\n", "parser.add_argument(\n", " \"--ckpt_saving\",action=argparse.BooleanOptionalAction,default=True,\n", ")\n", "parser.add_argument(\n", " \"--ckpt_interval\",type=int,default=5,\n", " help=\"save backup ckpt and reconstruct every x epochs\",\n", ")\n", "parser.add_argument(\n", " \"--seed\",type=int,default=42,\n", ")\n", "parser.add_argument(\n", " \"--max_lr\",type=float,default=3e-4,\n", ")\n", "\n", "if utils.is_interactive():\n", " args = parser.parse_args(jupyter_args)\n", "else:\n", " args = parser.parse_args()\n", "\n", "# create global variables without the args prefix\n", "for attribute_name in vars(args).keys():\n", " globals()[attribute_name] = getattr(args, attribute_name)" ] }, { "cell_type": "code", "execution_count": 8, "id": "60cd7f2c-37fd-426b-a0c6-633e51bc4c4d", "metadata": { "tags": [] }, "outputs": [], "source": [ "outdir = os.path.abspath(f'../train_logs/{model_name}')\n", "if not os.path.exists(outdir) and ckpt_saving:\n", " os.makedirs(outdir,exist_ok=True)\n", "if use_image_aug:\n", " import kornia\n", " from kornia.augmentation.container import AugmentationSequential\n", " img_augment = AugmentationSequential(\n", " kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),\n", " kornia.augmentation.Resize((224, 224)),\n", " kornia.augmentation.RandomHorizontalFlip(p=0.3),\n", " kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),\n", " kornia.augmentation.RandomGrayscale(p=0.3),\n", " same_on_batch=False,\n", " data_keys=[\"input\"],\n", " )" ] }, { "cell_type": "markdown", "id": "42d13c25-1369-4c49-81d4-83d713586096", "metadata": { "tags": [] }, "source": [ "# Prep data, models, and dataloaders" ] }, { "cell_type": "markdown", "id": "1c023f24-5233-4a15-a2f5-78487b3a8546", "metadata": {}, "source": [ "## Dataloader" ] }, { "cell_type": "code", "execution_count": 9, "id": "81084834-035f-4465-ad59-59e6b806a2f5", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/{0..36}.tar\n", "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/test/0.tar\n" ] } ], "source": [ "if subj==1:\n", " num_train = 24958\n", " num_test = 2770\n", "test_batch_size = num_test\n", "\n", "def my_split_by_node(urls): return urls\n", " \n", "train_url = f\"{data_path}/wds/subj0{subj}/train/\" + \"{0..36}.tar\"\n", "# train_url = f\"{data_path}/wds/subj0{subj}/train/\" + \"{0..1}.tar\"\n", "print(train_url)\n", "\n", "train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\\\n", " .shuffle(750, initial=1500, rng=random.Random(42))\\\n", " .decode(\"torch\")\\\n", " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n", " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n", "train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)\n", "\n", "test_url = f\"{data_path}/wds/subj0{subj}/test/\" + \"0.tar\"\n", "print(test_url)\n", "\n", "test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\\\n", " .shuffle(750, initial=1500, rng=random.Random(42))\\\n", " .decode(\"torch\")\\\n", " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n", " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n", "test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=True, pin_memory=True)" ] }, { "cell_type": "markdown", "id": "203b060a-2dd2-4c35-929b-c576be82eb52", "metadata": {}, "source": [ "### check dataloaders are working" ] }, { "cell_type": "code", "execution_count": 10, "id": "e7a9c68c-c3c9-4080-bd99-067c4486dc37", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 2770 2770\n", "---\n", "\n", "194 24960 24960\n" ] } ], "source": [ "test_vox_indices = []\n", "test_73k_images = []\n", "for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):\n", " test_vox_indices = np.append(test_vox_indices, behav[:,0,5].cpu().numpy())\n", " test_73k_images = np.append(test_73k_images, behav[:,0,0].cpu().numpy())\n", "test_vox_indices = test_vox_indices.astype(np.int16)\n", "print(test_i, (test_i+1) * test_batch_size, len(test_vox_indices))\n", "print(\"---\\n\")\n", "\n", "train_vox_indices = []\n", "train_73k_images = []\n", "for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\n", " train_vox_indices = np.append(train_vox_indices, behav[:,0,5].long().cpu().numpy())\n", " train_73k_images = np.append(train_73k_images, behav[:,0,0].cpu().numpy())\n", "train_vox_indices = train_vox_indices.astype(np.int16)\n", "print(train_i, (train_i+1) * batch_size, len(train_vox_indices))" ] }, { "cell_type": "markdown", "id": "45fad12c-f9fb-4408-8fd4-9bca324ad634", "metadata": {}, "source": [ "## Load data and images" ] }, { "cell_type": "code", "execution_count": 11, "id": "039dd330-7339-4f88-8f00-45f95e47baa0", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "subj01 betas loaded into memory\n", "voxels torch.Size([27750, 15724])\n", "images torch.Size([73000, 3, 224, 224])\n" ] } ], "source": [ "# load betas\n", "f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')\n", "# f = h5py.File(f'{data_path}/betas_subj0{subj}_thresholded_wholebrain.hdf5', 'r')\n", "\n", "voxels = f['betas'][:]\n", "print(f\"subj0{subj} betas loaded into memory\")\n", "voxels = torch.Tensor(voxels).to(\"cpu\").to(data_type)\n", "print(\"voxels\", voxels.shape)\n", "num_voxels = voxels.shape[-1]\n", "\n", "# load orig images\n", "f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')\n", "images = f['images'][:]\n", "images = torch.Tensor(images).to(\"cpu\").to(data_type)\n", "print(\"images\", images.shape)" ] }, { "cell_type": "markdown", "id": "10ec4517-dbdf-4ece-98f6-4714d5de4e15", "metadata": {}, "source": [ "## Load models" ] }, { "cell_type": "markdown", "id": "48d6160e-1ee8-4da7-a755-9dbb452a6fa5", "metadata": {}, "source": [ "### CLIP image embeddings model" ] }, { "cell_type": "code", "execution_count": 12, "id": "b0420dc0-199e-4c1a-857d-b1747058b467", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ViT-L/14 cuda:0\n" ] } ], "source": [ "from models import Clipper\n", "clip_model = Clipper(\"ViT-L/14\", device=torch.device(f\"cuda:{local_rank}\"), hidden_state=True, norm_embs=True)\n", "clip_seq_dim = 257\n", "clip_emb_dim = 768 #1024\n", "# hidden_dim = 4096\n", "seq_len = 1 #2 #32 " ] }, { "cell_type": "markdown", "id": "5b79bd38-6990-4504-8d45-4a68d57d8885", "metadata": {}, "source": [ "### SD VAE" ] }, { "cell_type": "code", "execution_count": 13, "id": "01baff79-8114-482b-b115-6f05aa8ad691", "metadata": { "tags": [] }, "outputs": [], "source": [ "# if blurry_recon:\n", "# from diffusers import AutoencoderKL\n", "# autoenc = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16, cache_dir=\"/fsx/proj-fmri/shared/cache\")\n", "# # autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')[\"model_state_dict\"])\n", "# autoenc.eval()\n", "# autoenc.requires_grad_(False)\n", "# autoenc.to(device)\n", "# utils.count_params(autoenc)\n", "\n", "if blurry_recon:# or depth_recon:\n", " from diffusers import VQModel\n", " autoenc = VQModel.from_pretrained(\"/fsx/proj-fmri/shared/cache/models--microsoft--vq-diffusion-ithq/snapshots/3f796fb49ee559370dc638dea1d8116af131d993/vqvae\", torch_dtype=data_type)\n", " autoenc.eval()\n", " autoenc.requires_grad_(False)\n", " autoenc.to(device)\n", " utils.count_params(autoenc)" ] }, { "cell_type": "markdown", "id": "120c8eee-9834-437d-bb60-b38faef50138", "metadata": {}, "source": [ "#### downsampled images" ] }, { "cell_type": "code", "execution_count": 14, "id": "6d1ba8dd-64c2-4ac9-947e-725b7f2e3e50", "metadata": { "tags": [] }, "outputs": [], "source": [ "if blurry_recon:\n", " if utils.is_interactive(): display(utils.torch_to_Image(images[[30]]))\n", "\n", " input_batch = images[[30]].to(device)\n", " print(input_batch.shape)\n", "\n", " downsampled_image = nn.functional.interpolate(input_batch, size=(8, 8), mode='bilinear', align_corners=False)\n", " re_upsampled_image = nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest')\n", " re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215\n", " print(re_upsampled_enc.shape)\n", " \n", " if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(re_upsampled_enc/0.18215).sample / 2 + 0.5).clamp(0,1)))" ] }, { "cell_type": "markdown", "id": "6390a3a8-2bef-4e81-9b82-e154d26a1e1d", "metadata": {}, "source": [ "#### MiDaS depth" ] }, { "cell_type": "code", "execution_count": 15, "id": "f35573e2-95bf-463d-8937-68ad4c2c3c20", "metadata": { "tags": [] }, "outputs": [], "source": [ "if depth_recon:\n", " from controlnet_aux.midas import MidasDetector\n", " \n", " midas_depth = MidasDetector.from_pretrained(\n", " \"valhalla/t2iadapter-aux-models\", filename=\"dpt_large_384.pt\", model_type=\"dpt_large\", cache_dir=\"/fsx/proj-fmri/shared/cache\").to(device)\n", " midas_depth.model.eval()\n", " midas_depth.model.requires_grad_(False)\n", " midas_depth.model.to(device)\n", " pass" ] }, { "cell_type": "code", "execution_count": 16, "id": "ba3f9207-b98e-45da-baa6-5cfcfb2ae958", "metadata": { "tags": [] }, "outputs": [], "source": [ "if depth_recon:\n", " if utils.is_interactive(): display(utils.torch_to_Image(images[[30]]))\n", "\n", " input_batch = images[[30,31]].float().to(device)\n", " print(input_batch.shape)\n", " \n", " midas_emb = midas_depth.model(input_batch).unsqueeze(1)\n", " print(midas_emb.shape)\n", "\n", " prediction = utils.resize(midas_emb, 32) #/30).clamp(0,1).half() # 30 is roughly prediction.max()\n", " print(prediction.shape)\n", " \n", " prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half()\n", " midas_emb_size = prediction.flatten(1).shape[1]\n", " print(\"midas_emb\", prediction.shape, prediction.min(), prediction.max())\n", " print(\"midas_emb_size\", midas_emb_size)\n", " \n", " if utils.is_interactive(): display(utils.torch_to_Image(utils.resize(prediction, 224))) \n", "\n", " if blurry_recon:\n", " prediction = utils.resize(midas_emb, 128).half().repeat(1,3,1,1)\n", " prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half()\n", " prediction_enc = autoenc.encode(2*prediction-1).latents * 0.18215\n", " print(\"vae midas_emb\", prediction_enc.shape, prediction_enc.min(), prediction_enc.max())\n", " \n", " if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(prediction_enc/0.18215).sample / 2 + 0.5).clamp(0,1)))" ] }, { "cell_type": "markdown", "id": "260e5e4a-f697-4b2c-88fc-01f6a54886c0", "metadata": {}, "source": [ "### MindEye modules" ] }, { "cell_type": "code", "execution_count": 17, "id": "c44c271b-173f-472e-b059-a2eda0f4c4c5", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "MindEyeModule()" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class MindEyeModule(nn.Module):\n", " def __init__(self):\n", " super(MindEyeModule, self).__init__()\n", " def forward(self, x):\n", " return x\n", " \n", "model = MindEyeModule()\n", "model" ] }, { "cell_type": "code", "execution_count": 18, "id": "038a5d61-4769-40b9-a004-f4e7b5b38bb0", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "param counts:\n", "64,409,600 total\n", "64,409,600 trainable\n", "param counts:\n", "64,409,600 total\n", "64,409,600 trainable\n", "torch.Size([2, 1, 15724]) torch.Size([2, 1, 4096])\n" ] } ], "source": [ "class RidgeRegression(torch.nn.Module):\n", " # make sure to add weight_decay when initializing optimizer\n", " def __init__(self, input_size, out_features): \n", " super(RidgeRegression, self).__init__()\n", " self.out_features = out_features\n", " self.linear = torch.nn.Linear(input_size, out_features)\n", " def forward(self, x):\n", " return self.linear(x)\n", " \n", "model.ridge = RidgeRegression(voxels.shape[1], out_features=hidden_dim)\n", "utils.count_params(model.ridge)\n", "utils.count_params(model)\n", "\n", "b = torch.randn((2,1,voxels.shape[1]))\n", "print(b.shape, model.ridge(b).shape)" ] }, { "cell_type": "code", "execution_count": 19, "id": "3602c333-d029-465c-8fb4-c3ccffdba6fd", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "param counts:\n", "950,287,384 total\n", "950,287,384 trainable\n", "param counts:\n", "1,014,696,984 total\n", "1,014,696,984 trainable\n", "b.shape torch.Size([2, 1, 4096])\n", "torch.Size([2, 257, 768]) torch.Size([1]) torch.Size([1])\n" ] } ], "source": [ "from functools import partial\n", "from diffusers.models.vae import Decoder\n", "class BrainNetwork(nn.Module):\n", " def __init__(self, out_dim=768, in_dim=15724, seq_len=2, h=4096, n_blocks=4, drop=.15, clip_size=768):\n", " super().__init__()\n", " self.seq_len = seq_len\n", " self.h = h\n", " self.clip_size = clip_size\n", " \n", " # Initial linear layer to match the input dimensions to hidden dimensions\n", " # self.lin0 = nn.Linear(in_dim, seq_len * h)\n", " \n", " # Mixer Blocks\n", " self.mixer_blocks1 = nn.ModuleList([\n", " self.mixer_block1(h, drop) for _ in range(n_blocks)\n", " ])\n", " self.mixer_blocks2 = nn.ModuleList([\n", " self.mixer_block2(seq_len, drop) for _ in range(n_blocks)\n", " ])\n", " \n", " # Output linear layer\n", " self.clin1 = nn.Linear(h * seq_len, out_dim, bias=True)\n", "\n", " # low-rank matrices\n", " # self.rank = 500\n", " # self.U = nn.Parameter(torch.randn(self.rank, out_dim))\n", " # self.V = nn.Parameter(torch.randn(h * seq_len, self.rank))\n", " # self.S = nn.Parameter(torch.randn(out_dim))\n", "\n", " self.clip_proj = nn.Sequential(\n", " nn.LayerNorm(clip_size),\n", " nn.GELU(),\n", " nn.Linear(clip_size, 2048),\n", " nn.LayerNorm(2048),\n", " nn.GELU(),\n", " nn.Linear(2048, 2048),\n", " nn.LayerNorm(2048),\n", " nn.GELU(),\n", " nn.Linear(2048, clip_size)\n", " )\n", "\n", " if blurry_recon:\n", " # self.blin1 = nn.Sequential(\n", " # nn.Linear(out_dim, 4096, bias=True),\n", " # nn.LayerNorm(4096),\n", " # nn.GELU(),\n", " # nn.Linear(4096, 4096))\n", " self.blin1 = nn.Linear(h*seq_len, 4096)\n", " self.bgroupnorm = nn.GroupNorm(1, 256)\n", " self.bupsampler = Decoder(\n", " in_channels=256,\n", " out_channels=128,\n", " up_block_types=[\"UpDecoderBlock2D\",\"UpDecoderBlock2D\",\"UpDecoderBlock2D\"],\n", " block_out_channels=[32, 64, 128],\n", " layers_per_block=1,\n", " )\n", "\n", " if depth_recon:\n", " # self.dlin1 = nn.Sequential(\n", " # nn.Linear(h, midas_emb_size),\n", " # nn.Sigmoid(),\n", " # )\n", " self.dlin1 = nn.Linear(h*seq_len, 4096)\n", " self.dgroupnorm = nn.GroupNorm(1, 256)\n", " self.dupsampler = Decoder(\n", " in_channels=256,\n", " out_channels=1,#128,\n", " up_block_types=[\"UpDecoderBlock2D\",\"UpDecoderBlock2D\",\"UpDecoderBlock2D\",\"UpDecoderBlock2D\"],\n", " block_out_channels=[32, 64, 128, 256],\n", " layers_per_block=1,\n", " )\n", " \n", " def mixer_block1(self, h, drop):\n", " return nn.Sequential(\n", " nn.LayerNorm(h),\n", " self.mlp(h, h, drop), # Token mixing\n", " )\n", "\n", " def mixer_block2(self, seq_len, drop):\n", " return nn.Sequential(\n", " nn.LayerNorm(seq_len),\n", " self.mlp(seq_len, seq_len, drop) # Channel mixing\n", " )\n", " \n", " def mlp(self, in_dim, out_dim, drop):\n", " return nn.Sequential(\n", " nn.Linear(in_dim, out_dim),\n", " nn.GELU(),\n", " nn.Dropout(drop),\n", " nn.Linear(out_dim, out_dim),\n", " )\n", " \n", " def forward(self, x):\n", " # make empty tensors for blur and depth outputs\n", " b,d = torch.Tensor([0.]), torch.Tensor([0.])\n", " \n", " # Initial linear layer\n", " # x = self.lin0(x)\n", " \n", " # Reshape to seq_len by dim\n", " # x = x.reshape(-1, self.seq_len, self.h)\n", " \n", " # Mixer blocks\n", " residual1 = x\n", " residual2 = x.permute(0,2,1)\n", " for block1, block2 in zip(self.mixer_blocks1,self.mixer_blocks2):\n", " x = block1(x) + residual1\n", " residual1 = x\n", " x = x.permute(0,2,1)\n", " \n", " x = block2(x) + residual2\n", " residual2 = x\n", " x = x.permute(0,2,1)\n", " \n", " # Flatten\n", " x = x.reshape(x.size(0), -1)\n", " \n", " c = self.clin1(x)\n", "\n", " # low rank linear to out dim cuts # params by nearly half compared to full linear mapping\n", " # c = (x @ (self.V/100) @ (self.U/100)) + self.S\n", " \n", " c = self.clip_proj(c.reshape(len(c), -1, self.clip_size))\n", "\n", " if blurry_recon:\n", " b = self.blin1(x)\n", " b = b.reshape(len(b), 256, 4, 4)\n", " b = self.bgroupnorm(b)\n", " b = self.bupsampler(b)\n", " \n", " if depth_recon:\n", " d = self.dlin1(x)#.reshape(len(x), 1, 32, 32)\n", " d = d.reshape(len(d), 256, 4, 4)\n", " d = self.dgroupnorm(d)\n", " d = self.dupsampler(d)\n", " \n", " return c, b, d\n", "\n", "model.backbone = BrainNetwork(h=hidden_dim, in_dim=hidden_dim, seq_len=seq_len, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim) \n", "utils.count_params(model.backbone)\n", "utils.count_params(model)\n", "\n", "# test that the model works on some fake data\n", "b = torch.randn((2,seq_len,hidden_dim))\n", "print(\"b.shape\",b.shape)\n", "with torch.no_grad():\n", " clip_, blur_, depth_ = model.backbone(b)\n", "print(clip_.shape, blur_.shape, depth_.shape)" ] }, { "cell_type": "code", "execution_count": 20, "id": "e14d0482-dc42-43b9-9ce1-953c32f2c9c1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "total_steps 2339\n", "\n", "Done with model preparations!\n", "param counts:\n", "1,014,696,984 total\n", "1,014,696,984 trainable\n" ] } ], "source": [ "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n", "opt_grouped_parameters = [\n", " {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},\n", " {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},\n", " {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},\n", "]\n", "\n", "optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)\n", "\n", "if lr_scheduler_type == 'linear':\n", " lr_scheduler = torch.optim.lr_scheduler.LinearLR(\n", " optimizer,\n", " total_iters=int(np.floor(num_epochs*(num_train/num_devices/batch_size))),\n", " last_epoch=-1\n", " )\n", "elif lr_scheduler_type == 'cycle':\n", " total_steps=int(np.floor(num_epochs*(num_train/num_devices/batch_size)))\n", " print(\"total_steps\", total_steps)\n", " lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(\n", " optimizer, \n", " max_lr=max_lr,\n", " total_steps=total_steps,\n", " final_div_factor=1000,\n", " last_epoch=-1, pct_start=2/num_epochs\n", " )\n", " \n", "def save_ckpt(tag): \n", " ckpt_path = outdir+f'/{tag}.pth'\n", " print(f'saving {ckpt_path}',flush=True)\n", " unwrapped_model = accelerator.unwrap_model(model)\n", " try:\n", " torch.save({\n", " 'epoch': epoch,\n", " 'model_state_dict': unwrapped_model.state_dict(),\n", " 'optimizer_state_dict': optimizer.state_dict(),\n", " 'lr_scheduler': lr_scheduler.state_dict(),\n", " 'train_losses': losses,\n", " 'test_losses': test_losses,\n", " 'lrs': lrs,\n", " }, ckpt_path)\n", " except:\n", " print(\"Couldn't save... moving on to prevent crashing.\")\n", " del unwrapped_model\n", " \n", "print(\"\\nDone with model preparations!\")\n", "utils.count_params(model)" ] }, { "cell_type": "markdown", "id": "983f458b-35b8-49f2-b6db-80296cece730", "metadata": {}, "source": [ "# Weights and Biases" ] }, { "cell_type": "code", "execution_count": 21, "id": "0a25a662-daa8-4de9-9233-8364800fcb6b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "wandb mindeyev2 run 0qiAxQoaKN_interactive_bsl\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mckadirt\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "wandb_config:\n", " {'model_name': '0qiAxQoaKN_interactive_bsl', 'global_batch_size': 128, 'batch_size': 128, 'num_epochs': 12, 'clip_scale': 1.0, 'blur_scale': 100.0, 'use_image_aug': False, 'max_lr': 0.0003, 'mixup_pct': 0.66, 'num_train': 24958, 'num_test': 2770, 'ckpt_interval': 999, 'ckpt_saving': False, 'seed': 42, 'distributed': True, 'num_devices': 1, 'world_size': 1, 'train_url': '/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/{0..36}.tar', 'test_url': '/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/test/0.tar'}\n", "wandb_id: 0qiAxQoaKN_interactive_bsl\n" ] }, { "data": { "text/html": [ "wandb version 0.15.12 is available! To upgrade, please run:\n", " $ pip install wandb --upgrade" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Tracking run with wandb version 0.15.5" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /fsx/proj-fmri/ckadirt/MindEyeV2/src/wandb/run-20231028_204841-0qiAxQoaKN_interactive_bsl" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run 0qiAxQoaKN_interactive_bsl to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://stability.wandb.io/ckadirt/mindeyev2" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://stability.wandb.io/ckadirt/mindeyev2/runs/0qiAxQoaKN_interactive_bsl" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "wandb_log = True\n", "if local_rank==0 and wandb_log: # only use main process for wandb logging\n", " import wandb\n", " wandb_project = 'mindeyev2'\n", " wandb_run = model_name\n", " wandb_notes = ''\n", " \n", " print(f\"wandb {wandb_project} run {wandb_run}\")\n", " wandb.login(host='https://stability.wandb.io')#, relogin=True)\n", " wandb_config = {\n", " \"model_name\": model_name,\n", " \"global_batch_size\": global_batch_size,\n", " \"batch_size\": batch_size,\n", " \"num_epochs\": num_epochs,\n", " \"clip_scale\": clip_scale,\n", " \"blur_scale\": blur_scale,\n", " \"use_image_aug\": use_image_aug,\n", " \"max_lr\": max_lr,\n", " \"mixup_pct\": mixup_pct,\n", " \"num_train\": num_train,\n", " \"num_test\": num_test,\n", " \"ckpt_interval\": ckpt_interval,\n", " \"ckpt_saving\": ckpt_saving,\n", " \"seed\": seed,\n", " \"distributed\": distributed,\n", " \"num_devices\": num_devices,\n", " \"world_size\": world_size,\n", " \"train_url\": train_url,\n", " \"test_url\": test_url,\n", " }\n", " print(\"wandb_config:\\n\",wandb_config)\n", " if True: # wandb_auto_resume\n", " print(\"wandb_id:\",model_name)\n", " wandb.init(\n", " id = model_name,\n", " project=wandb_project,\n", " name=wandb_run,\n", " config=wandb_config,\n", " notes=wandb_notes,\n", " resume=\"allow\",\n", " )\n", " else:\n", " wandb.init(\n", " project=wandb_project,\n", " name=wandb_run,\n", " config=wandb_config,\n", " notes=wandb_notes,\n", " )\n", "else:\n", " wandb_log = False" ] }, { "cell_type": "markdown", "id": "d5690151-2131-4918-b750-e869cbd1a8a8", "metadata": {}, "source": [ "# Main" ] }, { "cell_type": "code", "execution_count": 22, "id": "12de6387-6e18-4e4b-b5ce-a847d625330a", "metadata": {}, "outputs": [], "source": [ "epoch = 0\n", "losses, test_losses, lrs = [], [], []\n", "best_test_loss = 1e9\n", "soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))\n", "\n", "# Optionally resume from checkpoint #\n", "if resume_from_ckpt:\n", " print(\"\\n---resuming from last.pth ckpt---\\n\")\n", " try:\n", " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n", " except:\n", " print('last.pth failed... trying last_backup.pth')\n", " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n", " epoch = checkpoint['epoch']\n", " print(\"Epoch\",epoch)\n", " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n", " model.load_state_dict(checkpoint['model_state_dict'])\n", " del checkpoint\n", "elif wandb_log:\n", " if wandb.run.resumed:\n", " print(\"\\n---resuming from last.pth ckpt---\\n\")\n", " try:\n", " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n", " except:\n", " print('last.pth failed... trying last_backup.pth')\n", " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n", " epoch = checkpoint['epoch']\n", " print(\"Epoch\",epoch)\n", " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n", " model.load_state_dict(checkpoint['model_state_dict'])\n", " del checkpoint\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "code", "execution_count": 23, "id": "99f09f76-4481-4133-b09a-a22b10dbc0c4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2023-10-28 20:48:51,902] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed info: version=0.9.5, git-hash=unknown, git-branch=unknown\n", "[2023-10-28 20:48:53,263] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Flops Profiler Enabled: False\n", "[2023-10-28 20:48:53,265] [INFO] [logging.py:96:log_dist] [Rank 0] Removing param_group that has no 'params' in the client Optimizer\n", "[2023-10-28 20:48:53,266] [INFO] [logging.py:96:log_dist] [Rank 0] Using client Optimizer as basic optimizer\n", "[2023-10-28 20:48:53,267] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Basic Optimizer = AdamW\n", "[2023-10-28 20:48:53,268] [INFO] [utils.py:54:is_zero_supported_optimizer] Checking ZeRO support for optimizer=AdamW type=\n", "[2023-10-28 20:48:53,269] [INFO] [logging.py:96:log_dist] [Rank 0] Creating torch.float16 ZeRO stage 2 optimizer\n", "[2023-10-28 20:48:53,269] [INFO] [stage_1_and_2.py:133:__init__] Reduce bucket size 10000000\n", "[2023-10-28 20:48:53,270] [INFO] [stage_1_and_2.py:134:__init__] Allgather bucket size 500,000,000\n", "[2023-10-28 20:48:53,271] [INFO] [stage_1_and_2.py:135:__init__] CPU Offload: False\n", "[2023-10-28 20:48:53,271] [INFO] [stage_1_and_2.py:136:__init__] Round robin gradient partitioning: False\n", "Rank: 0 partition count [1, 1, 1] and sizes[(64409600, False), (950031116, False), (256268, False)] \n", "[2023-10-28 20:48:55,761] [INFO] [utils.py:785:see_memory_usage] Before initializing optimizer states\n", "[2023-10-28 20:48:55,763] [INFO] [utils.py:786:see_memory_usage] MA 7.68 GB Max_MA 7.68 GB CA 7.71 GB Max_CA 8 GB \n", "[2023-10-28 20:48:55,764] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 72.0 GB, percent = 6.4%\n", "[2023-10-28 20:48:55,940] [INFO] [utils.py:785:see_memory_usage] After initializing optimizer states\n", "[2023-10-28 20:48:55,941] [INFO] [utils.py:786:see_memory_usage] MA 15.24 GB Max_MA 26.1 GB CA 26.62 GB Max_CA 27 GB \n", "[2023-10-28 20:48:55,942] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 72.0 GB, percent = 6.4%\n", "[2023-10-28 20:48:55,943] [INFO] [stage_1_and_2.py:488:__init__] optimizer state initialized\n", "[2023-10-28 20:48:56,073] [INFO] [utils.py:785:see_memory_usage] After initializing ZeRO optimizer\n", "[2023-10-28 20:48:56,074] [INFO] [utils.py:786:see_memory_usage] MA 15.24 GB Max_MA 15.24 GB CA 26.62 GB Max_CA 27 GB \n", "[2023-10-28 20:48:56,075] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 71.99 GB, percent = 6.4%\n", "[2023-10-28 20:48:56,078] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Final Optimizer = AdamW\n", "[2023-10-28 20:48:56,078] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed using client LR scheduler\n", "[2023-10-28 20:48:56,079] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed LR Scheduler = None\n", "[2023-10-28 20:48:56,080] [INFO] [logging.py:96:log_dist] [Rank 0] step=0, skipped=0, lr=[1.200000000000002e-05, 1.200000000000002e-05, 1.200000000000002e-05], mom=[(0.95, 0.999), (0.95, 0.999), (0.95, 0.999)]\n", "[2023-10-28 20:48:56,081] [INFO] [config.py:960:print] DeepSpeedEngine configuration:\n", "[2023-10-28 20:48:56,082] [INFO] [config.py:964:print] activation_checkpointing_config {\n", " \"partition_activations\": false, \n", " \"contiguous_memory_optimization\": false, \n", " \"cpu_checkpointing\": false, \n", " \"number_checkpoints\": null, \n", " \"synchronize_checkpoint_boundary\": false, \n", " \"profile\": false\n", "}\n", "[2023-10-28 20:48:56,082] [INFO] [config.py:964:print] aio_config ................... {'block_size': 26214400, 'queue_depth': 64, 'thread_count': 1, 'single_submit': False, 'overlap_events': True}\n", "[2023-10-28 20:48:56,083] [INFO] [config.py:964:print] amp_enabled .................. False\n", "[2023-10-28 20:48:56,084] [INFO] [config.py:964:print] amp_params ................... False\n", "[2023-10-28 20:48:56,085] [INFO] [config.py:964:print] autotuning_config ............ {\n", " \"enabled\": false, \n", " \"start_step\": null, \n", " \"end_step\": null, \n", " \"metric_path\": null, \n", " \"arg_mappings\": null, \n", " \"metric\": \"throughput\", \n", " \"model_info\": null, \n", " \"results_dir\": \"autotuning_results\", \n", " \"exps_dir\": \"autotuning_exps\", \n", " \"overwrite\": true, \n", " \"fast\": true, \n", " \"start_profile_step\": 3, \n", " \"end_profile_step\": 5, \n", " \"tuner_type\": \"gridsearch\", \n", " \"tuner_early_stopping\": 5, \n", " \"tuner_num_trials\": 50, \n", " \"model_info_path\": null, \n", " \"mp_size\": 1, \n", " \"max_train_batch_size\": null, \n", " \"min_train_batch_size\": 1, \n", " \"max_train_micro_batch_size_per_gpu\": 1.024000e+03, \n", " \"min_train_micro_batch_size_per_gpu\": 1, \n", " \"num_tuning_micro_batch_sizes\": 3\n", "}\n", "[2023-10-28 20:48:56,085] [INFO] [config.py:964:print] bfloat16_enabled ............. False\n", "[2023-10-28 20:48:56,086] [INFO] [config.py:964:print] checkpoint_parallel_write_pipeline False\n", "[2023-10-28 20:48:56,087] [INFO] [config.py:964:print] checkpoint_tag_validation_enabled True\n", "[2023-10-28 20:48:56,087] [INFO] [config.py:964:print] checkpoint_tag_validation_fail False\n", "[2023-10-28 20:48:56,088] [INFO] [config.py:964:print] comms_config ................. \n", "[2023-10-28 20:48:56,088] [INFO] [config.py:964:print] communication_data_type ...... None\n", "[2023-10-28 20:48:56,089] [INFO] [config.py:964:print] compression_config ........... {'weight_quantization': {'shared_parameters': {'enabled': False, 'quantizer_kernel': False, 'schedule_offset': 0, 'quantize_groups': 1, 'quantize_verbose': False, 'quantization_type': 'symmetric', 'quantize_weight_in_forward': False, 'rounding': 'nearest', 'fp16_mixed_quantize': False, 'quantize_change_ratio': 0.001}, 'different_groups': {}}, 'activation_quantization': {'shared_parameters': {'enabled': False, 'quantization_type': 'symmetric', 'range_calibration': 'dynamic', 'schedule_offset': 1000}, 'different_groups': {}}, 'sparse_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'row_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'head_pruning': {'shared_parameters': {'enabled': False, 'method': 'topk', 'schedule_offset': 1000}, 'different_groups': {}}, 'channel_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'layer_reduction': {'enabled': False}}\n", "[2023-10-28 20:48:56,090] [INFO] [config.py:964:print] curriculum_enabled_legacy .... False\n", "[2023-10-28 20:48:56,091] [INFO] [config.py:964:print] curriculum_params_legacy ..... False\n", "[2023-10-28 20:48:56,091] [INFO] [config.py:964:print] data_efficiency_config ....... {'enabled': False, 'seed': 1234, 'data_sampling': {'enabled': False, 'num_epochs': 1000, 'num_workers': 0, 'curriculum_learning': {'enabled': False}}, 'data_routing': {'enabled': False, 'random_ltd': {'enabled': False, 'layer_token_lr_schedule': {'enabled': False}}}}\n", "[2023-10-28 20:48:56,092] [INFO] [config.py:964:print] data_efficiency_enabled ...... False\n", "[2023-10-28 20:48:56,093] [INFO] [config.py:964:print] dataloader_drop_last ......... False\n", "[2023-10-28 20:48:56,093] [INFO] [config.py:964:print] disable_allgather ............ False\n", "[2023-10-28 20:48:56,094] [INFO] [config.py:964:print] dump_state ................... False\n", "[2023-10-28 20:48:56,095] [INFO] [config.py:964:print] dynamic_loss_scale_args ...... None\n", "[2023-10-28 20:48:56,095] [INFO] [config.py:964:print] eigenvalue_enabled ........... False\n", "[2023-10-28 20:48:56,096] [INFO] [config.py:964:print] eigenvalue_gas_boundary_resolution 1\n", "[2023-10-28 20:48:56,097] [INFO] [config.py:964:print] eigenvalue_layer_name ........ bert.encoder.layer\n", "[2023-10-28 20:48:56,097] [INFO] [config.py:964:print] eigenvalue_layer_num ......... 0\n", "[2023-10-28 20:48:56,098] [INFO] [config.py:964:print] eigenvalue_max_iter .......... 100\n", "[2023-10-28 20:48:56,099] [INFO] [config.py:964:print] eigenvalue_stability ......... 1e-06\n", "[2023-10-28 20:48:56,099] [INFO] [config.py:964:print] eigenvalue_tol ............... 0.01\n", "[2023-10-28 20:48:56,100] [INFO] [config.py:964:print] eigenvalue_verbose ........... False\n", "[2023-10-28 20:48:56,100] [INFO] [config.py:964:print] elasticity_enabled ........... False\n", "[2023-10-28 20:48:56,101] [INFO] [config.py:964:print] flops_profiler_config ........ {\n", " \"enabled\": false, \n", " \"recompute_fwd_factor\": 0.0, \n", " \"profile_step\": 1, \n", " \"module_depth\": -1, \n", " \"top_modules\": 1, \n", " \"detailed\": true, \n", " \"output_file\": null\n", "}\n", "[2023-10-28 20:48:56,102] [INFO] [config.py:964:print] fp16_auto_cast ............... False\n", "[2023-10-28 20:48:56,103] [INFO] [config.py:964:print] fp16_enabled ................. True\n", "[2023-10-28 20:48:56,103] [INFO] [config.py:964:print] fp16_master_weights_and_gradients False\n", "[2023-10-28 20:48:56,104] [INFO] [config.py:964:print] global_rank .................. 0\n", "[2023-10-28 20:48:56,105] [INFO] [config.py:964:print] grad_accum_dtype ............. None\n", "[2023-10-28 20:48:56,105] [INFO] [config.py:964:print] gradient_accumulation_steps .. 1\n", "[2023-10-28 20:48:56,106] [INFO] [config.py:964:print] gradient_clipping ............ 1.0\n", "[2023-10-28 20:48:56,107] [INFO] [config.py:964:print] gradient_predivide_factor .... 1.0\n", "[2023-10-28 20:48:56,107] [INFO] [config.py:964:print] hybrid_engine ................ enabled=False max_out_tokens=512 inference_tp_size=1 release_inference_cache=False pin_parameters=True tp_gather_partition_size=8\n", "[2023-10-28 20:48:56,108] [INFO] [config.py:964:print] initial_dynamic_scale ........ 65536\n", "[2023-10-28 20:48:56,109] [INFO] [config.py:964:print] load_universal_checkpoint .... False\n", "[2023-10-28 20:48:56,109] [INFO] [config.py:964:print] loss_scale ................... 0\n", "[2023-10-28 20:48:56,110] [INFO] [config.py:964:print] memory_breakdown ............. False\n", "[2023-10-28 20:48:56,111] [INFO] [config.py:964:print] mics_hierarchial_params_gather False\n", "[2023-10-28 20:48:56,111] [INFO] [config.py:964:print] mics_shard_size .............. -1\n", "[2023-10-28 20:48:56,112] [INFO] [config.py:964:print] monitor_config ............... tensorboard=TensorBoardConfig(enabled=False, output_path='', job_name='DeepSpeedJobName') wandb=WandbConfig(enabled=False, group=None, team=None, project='deepspeed') csv_monitor=CSVConfig(enabled=False, output_path='', job_name='DeepSpeedJobName') enabled=False\n", "[2023-10-28 20:48:56,113] [INFO] [config.py:964:print] nebula_config ................ {\n", " \"enabled\": false, \n", " \"persistent_storage_path\": null, \n", " \"persistent_time_interval\": 100, \n", " \"num_of_version_in_retention\": 2, \n", " \"enable_nebula_load\": true, \n", " \"load_path\": null\n", "}\n", "[2023-10-28 20:48:56,113] [INFO] [config.py:964:print] optimizer_legacy_fusion ...... False\n", "[2023-10-28 20:48:56,114] [INFO] [config.py:964:print] optimizer_name ............... None\n", "[2023-10-28 20:48:56,115] [INFO] [config.py:964:print] optimizer_params ............. None\n", "[2023-10-28 20:48:56,115] [INFO] [config.py:964:print] pipeline ..................... {'stages': 'auto', 'partition': 'best', 'seed_layers': False, 'activation_checkpoint_interval': 0}\n", "[2023-10-28 20:48:56,116] [INFO] [config.py:964:print] pld_enabled .................. False\n", "[2023-10-28 20:48:56,117] [INFO] [config.py:964:print] pld_params ................... False\n", "[2023-10-28 20:48:56,117] [INFO] [config.py:964:print] prescale_gradients ........... False\n", "[2023-10-28 20:48:56,118] [INFO] [config.py:964:print] scheduler_name ............... None\n", "[2023-10-28 20:48:56,119] [INFO] [config.py:964:print] scheduler_params ............. None\n", "[2023-10-28 20:48:56,119] [INFO] [config.py:964:print] sparse_attention ............. None\n", "[2023-10-28 20:48:56,120] [INFO] [config.py:964:print] sparse_gradients_enabled ..... False\n", "[2023-10-28 20:48:56,121] [INFO] [config.py:964:print] steps_per_print .............. inf\n", "[2023-10-28 20:48:56,121] [INFO] [config.py:964:print] train_batch_size ............. 128\n", "[2023-10-28 20:48:56,122] [INFO] [config.py:964:print] train_micro_batch_size_per_gpu 128\n", "[2023-10-28 20:48:56,123] [INFO] [config.py:964:print] use_node_local_storage ....... False\n", "[2023-10-28 20:48:56,123] [INFO] [config.py:964:print] wall_clock_breakdown ......... False\n", "[2023-10-28 20:48:56,124] [INFO] [config.py:964:print] world_size ................... 1\n", "[2023-10-28 20:48:56,125] [INFO] [config.py:964:print] zero_allow_untested_optimizer True\n", "[2023-10-28 20:48:56,125] [INFO] [config.py:964:print] zero_config .................. stage=2 contiguous_gradients=True reduce_scatter=True reduce_bucket_size=10000000 allgather_partitions=True allgather_bucket_size=500,000,000 overlap_comm=False load_from_fp32_weights=True elastic_checkpoint=False offload_param=DeepSpeedZeroOffloadParamConfig(device='none', nvme_path=PosixPath('/scratch'), buffer_count=5, buffer_size=4000000000, max_in_cpu=1,000,000,000, pin_memory=True) offload_optimizer=DeepSpeedZeroOffloadOptimizerConfig(device='none', nvme_path=PosixPath('/scratch'), buffer_count=4, pin_memory=True, pipeline=False, pipeline_read=False, pipeline_write=False, fast_init=False) sub_group_size=1000000000 cpu_offload_param=None cpu_offload_use_pin_memory=None cpu_offload=None prefetch_bucket_size=10000000 param_persistence_threshold=100000 model_persistence_threshold=sys.maxsize max_live_parameters=1000000000 max_reuse_distance=1000000000 gather_16bit_weights_on_model_save=True stage3_gather_fp16_weights_on_model_save=False ignore_unused_parameters=True legacy_stage1=False round_robin_gradients=False mics_shard_size=-1 mics_hierarchical_params_gather=False memory_efficient_linear=True\n", "[2023-10-28 20:48:56,126] [INFO] [config.py:964:print] zero_enabled ................. True\n", "[2023-10-28 20:48:56,127] [INFO] [config.py:964:print] zero_force_ds_cpu_optimizer .. True\n", "[2023-10-28 20:48:56,127] [INFO] [config.py:964:print] zero_optimization_stage ...... 2\n", "[2023-10-28 20:48:56,128] [INFO] [config.py:950:print_user_config] json = {\n", " \"bf16\": {\n", " \"enabled\": false\n", " }, \n", " \"fp16\": {\n", " \"enabled\": true\n", " }, \n", " \"zero_optimization\": {\n", " \"stage\": 2, \n", " \"contiguous_gradients\": true, \n", " \"stage3_gather_16bit_weights_on_model_save\": true, \n", " \"stage3_max_live_parameters\": 1.000000e+09, \n", " \"stage3_max_reuse_distance\": 1.000000e+09, \n", " \"stage3_prefetch_bucket_size\": 1.000000e+07, \n", " \"stage3_param_persistence_threshold\": 1.000000e+05, \n", " \"reduce_bucket_size\": 1.000000e+07, \n", " \"sub_group_size\": 1.000000e+09, \n", " \"offload_optimizer\": {\n", " \"device\": \"none\", \n", " \"nvme_path\": \"/scratch\", \n", " \"pin_memory\": true\n", " }, \n", " \"offload_param\": {\n", " \"device\": \"none\", \n", " \"nvme_path\": \"/scratch\", \n", " \"buffer_size\": 4.000000e+09, \n", " \"pin_memory\": true\n", " }\n", " }, \n", " \"aio\": {\n", " \"block_size\": 2.621440e+07, \n", " \"queue_depth\": 64, \n", " \"thread_count\": 1, \n", " \"single_submit\": false, \n", " \"overlap_events\": true\n", " }, \n", " \"gradient_accumulation_steps\": 1, \n", " \"gradient_clipping\": 1.0, \n", " \"steps_per_print\": inf, \n", " \"train_batch_size\": 128, \n", " \"train_micro_batch_size_per_gpu\": 128, \n", " \"wall_clock_breakdown\": false, \n", " \"zero_allow_untested_optimizer\": true\n", "}\n" ] } ], "source": [ "model, optimizer, train_dl, lr_scheduler = accelerator.prepare(\n", "model, optimizer, train_dl, lr_scheduler\n", ")\n", "# leaving out test_dl since we will only have local_rank 0 device do evals" ] }, { "cell_type": "code", "execution_count": 24, "id": "469e6313-425f-45ed-875a-ecd5df343e31", "metadata": {}, "outputs": [], "source": [ "def add_saturation(image, alpha=2):\n", " gray_image = 0.2989 * image[:, 0, :, :] + 0.5870 * image[:, 1, :, :] + 0.1140 * image[:, 2, :, :]\n", " gray_image = gray_image.unsqueeze(1).expand_as(image)\n", " saturated_image = alpha * image + (1 - alpha) * gray_image\n", " return torch.clamp(saturated_image, 0, 1)" ] }, { "cell_type": "code", "execution_count": 25, "id": "60be0d5f-3e94-4612-9373-61b53d836393", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0qiAxQoaKN_interactive_bsl starting with epoch 0 / 12\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/12 [00:00╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n", " in <module>:119 \n", " \n", " 116 │ │ │ │ │ # utils.check_loss(pixcorr) \n", " 117 │ │ │ \n", " 118 │ │ │ utils.check_loss(loss) \n", " 119 │ │ │ accelerator.backward(loss) \n", " 120 │ │ │ optimizer.step() \n", " 121 │ │ │ \n", " 122 │ │ │ losses.append(loss.item()) \n", " \n", " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/accelerate/accelerator. \n", " py:1815 in backward \n", " \n", " 1812 │ │ │ # deepspeed handles loss scaling by gradient_accumulation_steps in its `back \n", " 1813 │ │ │ loss = loss / self.gradient_accumulation_steps \n", " 1814 │ │ if self.distributed_type == DistributedType.DEEPSPEED: \n", " 1815 │ │ │ self.deepspeed_engine_wrapped.backward(loss, **kwargs) \n", " 1816 │ │ elif self.distributed_type == DistributedType.MEGATRON_LM: \n", " 1817 │ │ │ return \n", " 1818 │ │ elif self.scaler is not None: \n", " \n", " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/accelerate/utils/deepsp \n", " eed.py:176 in backward \n", " \n", " 173 │ │ # - zero grad \n", " 174 │ │ # - checking overflow \n", " 175 │ │ # - lr_scheduler step (only if engine.lr_scheduler is not None) \n", " 176 │ │ self.engine.step() \n", " 177 │ │ # and this plugin overrides the above calls with no-ops when Accelerate runs und \n", " 178 │ │ # Deepspeed, but allows normal functionality for non-Deepspeed cases thus enabli \n", " 179 │ │ # training loop that works transparently under many training regimes. \n", " \n", " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/deepspeed/runtime/engin \n", " e.py:2053 in step \n", " \n", " 2050 │ │ │ │ │ and self.quantizer.any_precision_switch()): \n", " 2051 │ │ │ │ self._take_model_step(lr_kwargs, self.block_eigenvalue) \n", " 2052 │ │ │ else: \n", " 2053 │ │ │ │ self._take_model_step(lr_kwargs) \n", " 2054 │ │ │ \n", " 2055 │ │ │ report_progress = self.global_rank == 0 if self.global_rank else True \n", " 2056 \n", " \n", " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/deepspeed/runtime/engin \n", " e.py:1960 in _take_model_step \n", " \n", " 1957 │ │ │ │ # https://nvidia.github.io/apex/advanced.html#gradient-clipping \n", " 1958 │ │ │ │ master_params = amp.master_params(self.optimizer) \n", " 1959 │ │ │ │ clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clippin \n", " 1960 │ │ self.optimizer.step() \n", " 1961 │ │ \n", " 1962 │ │ if hasattr(self.optimizer, '_global_grad_norm'): \n", " 1963 │ │ │ self._global_grad_norm = self.optimizer._global_grad_norm \n", " \n", " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/deepspeed/runtime/zero/ \n", " stage_1_and_2.py:1733 in step \n", " \n", " 1730 │ │ │ │ \n", " 1731 │ │ │ │ # Step 3:- run the optimizer if no offloading \n", " 1732 │ │ │ │ self.start_timers([OPTIMIZER_STEP]) \n", " 1733 │ │ │ │ self._optimizer_step(i) \n", " 1734 │ │ │ │ # Step 4:- get rid of the fp32 gradients. Not needed anymore \n", " 1735 │ │ │ │ self.single_partition_of_fp32_groups[i].grad = None \n", " 1736 │ │ │ │ del single_grad_partition \n", " \n", " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/deepspeed/runtime/zero/ \n", " stage_1_and_2.py:1638 in _optimizer_step \n", " \n", " 1635 │ │ # self.optimizer.step(fp16_param_groups=[self.get_bit16_param_group(group_no) \n", " 1636 │ │ #else: \n", " 1637 │ │ # self.optimizer.step() \n", " 1638 │ │ self.optimizer.step() \n", " 1639 │ │ self.optimizer.param_groups = original_param_groups \n", " 1640 \n", " 1641 def step(self, closure=None): \n", " \n", " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/lr_schedule \n", " r.py:69 in wrapper \n", " \n", " 66 │ │ │ │ instance = instance_ref() \n", " 67 │ │ │ │ instance._step_count += 1 \n", " 68 │ │ │ │ wrapped = func.__get__(instance, cls) \n", " 69 │ │ │ │ return wrapped(*args, **kwargs) \n", " 70 │ │ │ \n", " 71 │ │ │ # Note that the returned function here is no longer a bound method, \n", " 72 │ │ │ # so attributes like `__func__` and `__self__` no longer exist. \n", " \n", " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/optimizer.p \n", " y:280 in wrapper \n", " \n", " 277 │ │ │ │ │ │ │ raise RuntimeError(f\"{func} must return None or a tuple of ( \n", " 278 │ │ │ │ │ │ │ │ │ │ │ f\"but got {result}.\") \n", " 279 │ │ │ │ \n", " 280 │ │ │ │ out = func(*args, **kwargs) \n", " 281 │ │ │ │ self._optimizer_step_code() \n", " 282 │ │ │ │ \n", " 283 │ │ │ │ # call optimizer step post hooks \n", " \n", " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/optimizer.p \n", " y:33 in _use_grad \n", " \n", " 30 │ │ prev_grad = torch.is_grad_enabled() \n", " 31 │ │ try: \n", " 32 │ │ │ torch.set_grad_enabled(self.defaults['differentiable']) \n", " 33 │ │ │ ret = func(self, *args, **kwargs) \n", " 34 │ │ finally: \n", " 35 │ │ │ torch.set_grad_enabled(prev_grad) \n", " 36 │ │ return ret \n", " \n", " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/adamw.py:17 \n", " 1 in step \n", " \n", " 168 │ │ │ │ state_steps, \n", " 169 │ │ │ ) \n", " 170 │ │ │ \n", " 171 │ │ │ adamw( \n", " 172 │ │ │ │ params_with_grad, \n", " 173 │ │ │ │ grads, \n", " 174 │ │ │ │ exp_avgs, \n", " \n", " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/adamw.py:32 \n", " 1 in adamw \n", " \n", " 318 else: \n", " 319 │ │ func = _single_tensor_adamw \n", " 320 \n", " 321 func( \n", " 322 │ │ params, \n", " 323 │ │ grads, \n", " 324 │ │ exp_avgs, \n", " \n", " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/adamw.py:56 \n", " 6 in _multi_tensor_adamw \n", " \n", " 563 │ │ │ else: \n", " 564 │ │ │ │ exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) \n", " 565 │ │ │ │ torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) \n", " 566 │ │ │ │ denom = torch._foreach_add(exp_avg_sq_sqrt, eps) \n", " 567 │ │ │ \n", " 568 │ │ │ torch._foreach_addcdiv_(device_params, device_exp_avgs, denom, step_size) \n", " 569 \n", "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n", "OutOfMemoryError: CUDA out of memory. Tried to allocate 3.54 GiB (GPU 0; 39.56 GiB total capacity; 26.52 GiB \n", "already allocated; 2.00 GiB free; 35.94 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory\n", "try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and \n", "PYTORCH_CUDA_ALLOC_CONF\n", "\n" ], "text/plain": [ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m119\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m116 \u001b[0m\u001b[2m│ │ │ │ │ \u001b[0m\u001b[2m# utils.check_loss(pixcorr)\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m117 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m118 \u001b[0m\u001b[2m│ │ │ \u001b[0mutils.check_loss(loss) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m119 \u001b[2m│ │ │ \u001b[0maccelerator.backward(loss) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m120 \u001b[0m\u001b[2m│ │ │ \u001b[0moptimizer.step() \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m121 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m122 \u001b[0m\u001b[2m│ │ │ \u001b[0mlosses.append(loss.item()) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/accelerate/\u001b[0m\u001b[1;33maccelerator.\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[1;33mpy\u001b[0m:\u001b[94m1815\u001b[0m in \u001b[92mbackward\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1812 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[2m# deepspeed handles loss scaling by gradient_accumulation_steps in its `back\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1813 \u001b[0m\u001b[2m│ │ │ \u001b[0mloss = loss / \u001b[96mself\u001b[0m.gradient_accumulation_steps \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1814 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mself\u001b[0m.distributed_type == DistributedType.DEEPSPEED: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1815 \u001b[2m│ │ │ \u001b[0m\u001b[96mself\u001b[0m.deepspeed_engine_wrapped.backward(loss, **kwargs) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1816 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94melif\u001b[0m \u001b[96mself\u001b[0m.distributed_type == DistributedType.MEGATRON_LM: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1817 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1818 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94melif\u001b[0m \u001b[96mself\u001b[0m.scaler \u001b[95mis\u001b[0m \u001b[95mnot\u001b[0m \u001b[94mNone\u001b[0m: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/accelerate/utils/\u001b[0m\u001b[1;33mdeepsp\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[1;33meed.py\u001b[0m:\u001b[94m176\u001b[0m in \u001b[92mbackward\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m173 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# - zero grad\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m174 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# - checking overflow\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m175 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# - lr_scheduler step (only if engine.lr_scheduler is not None)\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m176 \u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.engine.step() \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m177 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# and this plugin overrides the above calls with no-ops when Accelerate runs und\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m178 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# Deepspeed, but allows normal functionality for non-Deepspeed cases thus enabli\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m179 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# training loop that works transparently under many training regimes.\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/deepspeed/runtime/\u001b[0m\u001b[1;33mengin\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[1;33me.py\u001b[0m:\u001b[94m2053\u001b[0m in \u001b[92mstep\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m2050 \u001b[0m\u001b[2m│ │ │ │ │ \u001b[0m\u001b[95mand\u001b[0m \u001b[96mself\u001b[0m.quantizer.any_precision_switch()): \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m2051 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[96mself\u001b[0m._take_model_step(lr_kwargs, \u001b[96mself\u001b[0m.block_eigenvalue) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m2052 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94melse\u001b[0m: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m2053 \u001b[2m│ │ │ │ \u001b[0m\u001b[96mself\u001b[0m._take_model_step(lr_kwargs) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m2054 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m2055 \u001b[0m\u001b[2m│ │ │ \u001b[0mreport_progress = \u001b[96mself\u001b[0m.global_rank == \u001b[94m0\u001b[0m \u001b[94mif\u001b[0m \u001b[96mself\u001b[0m.global_rank \u001b[94melse\u001b[0m \u001b[94mTrue\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m2056 \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/deepspeed/runtime/\u001b[0m\u001b[1;33mengin\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[1;33me.py\u001b[0m:\u001b[94m1960\u001b[0m in \u001b[92m_take_model_step\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1957 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[2m# https://nvidia.github.io/apex/advanced.html#gradient-clipping\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1958 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mmaster_params = amp.master_params(\u001b[96mself\u001b[0m.optimizer) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1959 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mclip_grad_norm_(parameters=master_params, max_norm=\u001b[96mself\u001b[0m.gradient_clippin \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1960 \u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.optimizer.step() \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1961 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1962 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mhasattr\u001b[0m(\u001b[96mself\u001b[0m.optimizer, \u001b[33m'\u001b[0m\u001b[33m_global_grad_norm\u001b[0m\u001b[33m'\u001b[0m): \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1963 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[96mself\u001b[0m._global_grad_norm = \u001b[96mself\u001b[0m.optimizer._global_grad_norm \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/deepspeed/runtime/zero/\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[1;33mstage_1_and_2.py\u001b[0m:\u001b[94m1733\u001b[0m in \u001b[92mstep\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1730 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1731 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[2m# Step 3:- run the optimizer if no offloading\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1732 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[96mself\u001b[0m.start_timers([OPTIMIZER_STEP]) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1733 \u001b[2m│ │ │ │ \u001b[0m\u001b[96mself\u001b[0m._optimizer_step(i) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1734 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[2m# Step 4:- get rid of the fp32 gradients. Not needed anymore\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1735 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[96mself\u001b[0m.single_partition_of_fp32_groups[i].grad = \u001b[94mNone\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1736 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[94mdel\u001b[0m single_grad_partition \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/deepspeed/runtime/zero/\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[1;33mstage_1_and_2.py\u001b[0m:\u001b[94m1638\u001b[0m in \u001b[92m_optimizer_step\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1635 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# self.optimizer.step(fp16_param_groups=[self.get_bit16_param_group(group_no)\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1636 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m#else:\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1637 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# self.optimizer.step()\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1638 \u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.optimizer.step() \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1639 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.optimizer.param_groups = original_param_groups \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1640 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1641 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mstep\u001b[0m(\u001b[96mself\u001b[0m, closure=\u001b[94mNone\u001b[0m): \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/\u001b[0m\u001b[1;33mlr_schedule\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[1;33mr.py\u001b[0m:\u001b[94m69\u001b[0m in \u001b[92mwrapper\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 66 \u001b[0m\u001b[2m│ │ │ │ \u001b[0minstance = instance_ref() \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 67 \u001b[0m\u001b[2m│ │ │ │ \u001b[0minstance._step_count += \u001b[94m1\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 68 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mwrapped = func.\u001b[92m__get__\u001b[0m(instance, \u001b[96mcls\u001b[0m) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 69 \u001b[2m│ │ │ │ \u001b[0m\u001b[94mreturn\u001b[0m wrapped(*args, **kwargs) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 70 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 71 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[2m# Note that the returned function here is no longer a bound method,\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 72 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[2m# so attributes like `__func__` and `__self__` no longer exist.\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/\u001b[0m\u001b[1;33moptimizer.p\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[1;33my\u001b[0m:\u001b[94m280\u001b[0m in \u001b[92mwrapper\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m277 \u001b[0m\u001b[2m│ │ │ │ │ │ │ \u001b[0m\u001b[94mraise\u001b[0m \u001b[96mRuntimeError\u001b[0m(\u001b[33mf\u001b[0m\u001b[33m\"\u001b[0m\u001b[33m{\u001b[0mfunc\u001b[33m}\u001b[0m\u001b[33m must return None or a tuple of (\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m278 \u001b[0m\u001b[2m│ │ │ │ │ │ │ │ │ │ │ \u001b[0m\u001b[33mf\u001b[0m\u001b[33m\"\u001b[0m\u001b[33mbut got \u001b[0m\u001b[33m{\u001b[0mresult\u001b[33m}\u001b[0m\u001b[33m.\u001b[0m\u001b[33m\"\u001b[0m) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m279 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m280 \u001b[2m│ │ │ │ \u001b[0mout = func(*args, **kwargs) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m281 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[96mself\u001b[0m._optimizer_step_code() \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m282 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m283 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[2m# call optimizer step post hooks\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/\u001b[0m\u001b[1;33moptimizer.p\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[1;33my\u001b[0m:\u001b[94m33\u001b[0m in \u001b[92m_use_grad\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 30 \u001b[0m\u001b[2m│ │ \u001b[0mprev_grad = torch.is_grad_enabled() \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 31 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mtry\u001b[0m: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 32 \u001b[0m\u001b[2m│ │ │ \u001b[0mtorch.set_grad_enabled(\u001b[96mself\u001b[0m.defaults[\u001b[33m'\u001b[0m\u001b[33mdifferentiable\u001b[0m\u001b[33m'\u001b[0m]) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 33 \u001b[2m│ │ │ \u001b[0mret = func(\u001b[96mself\u001b[0m, *args, **kwargs) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 34 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mfinally\u001b[0m: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 35 \u001b[0m\u001b[2m│ │ │ \u001b[0mtorch.set_grad_enabled(prev_grad) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 36 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m ret \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/\u001b[0m\u001b[1;33madamw.py\u001b[0m:\u001b[94m17\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[94m1\u001b[0m in \u001b[92mstep\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m168 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mstate_steps, \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m169 \u001b[0m\u001b[2m│ │ │ \u001b[0m) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m170 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m171 \u001b[2m│ │ │ \u001b[0madamw( \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m172 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mparams_with_grad, \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m173 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mgrads, \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m174 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mexp_avgs, \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/\u001b[0m\u001b[1;33madamw.py\u001b[0m:\u001b[94m32\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[94m1\u001b[0m in \u001b[92madamw\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m318 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94melse\u001b[0m: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m319 \u001b[0m\u001b[2m│ │ \u001b[0mfunc = _single_tensor_adamw \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m320 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m321 \u001b[2m│ \u001b[0mfunc( \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m322 \u001b[0m\u001b[2m│ │ \u001b[0mparams, \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m323 \u001b[0m\u001b[2m│ │ \u001b[0mgrads, \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m324 \u001b[0m\u001b[2m│ │ \u001b[0mexp_avgs, \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/\u001b[0m\u001b[1;33madamw.py\u001b[0m:\u001b[94m56\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[94m6\u001b[0m in \u001b[92m_multi_tensor_adamw\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m563 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94melse\u001b[0m: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m564 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mexp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m565 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mtorch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m566 \u001b[2m│ │ │ │ \u001b[0mdenom = torch._foreach_add(exp_avg_sq_sqrt, eps) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m567 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m568 \u001b[0m\u001b[2m│ │ │ \u001b[0mtorch._foreach_addcdiv_(device_params, device_exp_avgs, denom, step_size) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m569 \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", "\u001b[1;91mOutOfMemoryError: \u001b[0mCUDA out of memory. Tried to allocate \u001b[1;36m3.54\u001b[0m GiB \u001b[1m(\u001b[0mGPU \u001b[1;36m0\u001b[0m; \u001b[1;36m39.56\u001b[0m GiB total capacity; \u001b[1;36m26.52\u001b[0m GiB \n", "already allocated; \u001b[1;36m2.00\u001b[0m GiB free; \u001b[1;36m35.94\u001b[0m GiB reserved in total by PyTorch\u001b[1m)\u001b[0m If reserved memory is >> allocated memory\n", "try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and \n", "PYTORCH_CUDA_ALLOC_CONF\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print(f\"{model_name} starting with epoch {epoch} / {num_epochs}\")\n", "progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))\n", "test_image, test_voxel = None, None\n", "mse = nn.MSELoss()\n", "l1 = nn.L1Loss()\n", "\n", "for epoch in progress_bar:\n", " model.train()\n", " \n", " fwd_percent_correct = 0.\n", " bwd_percent_correct = 0.\n", " test_fwd_percent_correct = 0.\n", " test_bwd_percent_correct = 0.\n", "\n", " loss_clip_total = 0.\n", " loss_blurry_total = 0.\n", " loss_depth_total = 0.\n", " test_loss_clip_total = 0.\n", " test_loss_blurry_total = 0.\n", " test_loss_depth_total = 0.\n", "\n", " blurry_pixcorr = 0.\n", " test_blurry_pixcorr = 0. # needs >.456 to beat low-level subj01 results in mindeye v1\n", " \n", " for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\n", " with torch.cuda.amp.autocast(dtype=data_type):\n", " optimizer.zero_grad()\n", " \n", " voxel = voxels[behav[:,0,5].cpu().long()].to(device)\n", " image = images[behav[:,0,0].cpu().long()].to(device).float()\n", "\n", " for past in range(1):\n", " past_voxel = voxels[past_behav[:,past,5].cpu().long()].to(device)\n", " \n", " if blurry_recon:\n", " # blurry_image_enc = autoenc.encode(2*utils.resize(image,128)-1).latent_dist.mode() * 0.18215\n", " blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215\n", "\n", " if depth_recon:\n", " # depth_images = utils.resize(midas_depth.model(image).unsqueeze(1).repeat(1,3,1,1), 128)\n", " depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32)\n", " depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half()\n", " depth_image_enc = depth_images # autoenc.encode(2*depth_images-1).latents * 0.18215\n", " \n", " if use_image_aug: \n", " image = img_augment(image)\n", " \n", " clip_target = clip_model.embed_image(image)\n", " assert not torch.any(torch.isnan(clip_target))\n", " \n", " if epoch < int(mixup_pct * num_epochs):\n", " voxel, perm, betas, select = utils.mixco(voxel)\n", " past_voxel, _, _, _ = utils.mixco(voxel, perm=perm, betas=betas, select=select)\n", " \n", " voxel_ridge = model.ridge(voxel).unsqueeze(1)\n", " \n", " # past_voxel_ridge = model.ridge(past_voxel)\n", " # voxel_ridge = torch.cat((voxel_ridge.unsqueeze(1), past_voxel_ridge.unsqueeze(1)), axis=1)\n", " \n", " clip_voxels, blurry_image_enc_, depth_image_enc_ = model.backbone(voxel_ridge)\n", " \n", " clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)\n", " clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n", " \n", " if epoch < int(mixup_pct * num_epochs): \n", " loss_clip = utils.mixco_nce(\n", " clip_voxels_norm,\n", " clip_target_norm,\n", " temp=.006, \n", " perm=perm, betas=betas, select=select)\n", " else:\n", " epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]\n", " loss_clip = utils.soft_clip_loss(\n", " clip_voxels_norm,\n", " clip_target_norm,\n", " temp=epoch_temp)\n", "\n", " loss_clip_total += loss_clip.item()\n", " loss_clip *= clip_scale\n", " loss = loss_clip\n", " \n", " if blurry_recon:\n", " downsampled_image = nn.functional.interpolate(image, size=(8, 8), mode='bilinear', align_corners=False)\n", " re_upsampled_image = add_saturation(nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest'))\n", " re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215\n", " \n", " loss_blurry = (l1(blurry_image_enc_, blurry_image_enc) + l1(blurry_image_enc_, re_upsampled_enc))\n", " loss_blurry += l1(torch.var(blurry_image_enc), torch.var(blurry_image_enc_))\n", " loss_blurry_total += loss_blurry.item()\n", " loss_blurry *= blur_scale\n", " loss += loss_blurry\n", "\n", " if depth_recon:\n", " loss_depth = l1(depth_image_enc_, depth_image_enc)\n", " # loss_depth += l1(torch.var(depth_image_enc_), torch.var(depth_image_enc))\n", " loss_depth_total += loss_depth.item()\n", " loss_depth *= depth_scale\n", " loss += loss_depth\n", " \n", " # forward and backward top 1 accuracy \n", " labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) \n", " fwd_percent_correct += utils.topk(torch.abs(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm)), labels, k=1).item()\n", " bwd_percent_correct += utils.topk(torch.abs(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm)), labels, k=1).item()\n", " \n", " if blurry_recon:\n", " with torch.no_grad():\n", " # only doing pixcorr eval on a subset of the samples per batch because its costly & slow to compute autoenc.decode()\n", " random_samps = np.random.choice(np.arange(len(voxel)), size=batch_size//5, replace=False)\n", " # random_samps = np.arange(batch_size//5)\n", " blurry_recon_images = (autoenc.decode(blurry_image_enc_[random_samps]/0.18215).sample/ 2 + 0.5).clamp(0,1)\n", " # pixcorr_origsize_nanmean is computationally less intense than utils.pixcorr and uses nanmean instead of mean\n", " pixcorr = utils.pixcorr_origsize_nanmean(image[random_samps], blurry_recon_images)\n", " # pixcorr = utils.pixcorr(image[random_samps], blurry_recon_images)\n", " # loss += (1 - pixcorr)\n", " blurry_pixcorr += pixcorr.item()\n", " # utils.check_loss(pixcorr)\n", "\n", " utils.check_loss(loss)\n", " accelerator.backward(loss)\n", " optimizer.step()\n", " \n", " losses.append(loss.item())\n", " lrs.append(optimizer.param_groups[0]['lr'])\n", " \n", " if lr_scheduler_type is not None:\n", " lr_scheduler.step()\n", "\n", " model.eval()\n", " if local_rank==0:\n", " with torch.no_grad(), torch.cuda.amp.autocast(dtype=data_type): \n", " for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl): \n", " # all test samples should be loaded per batch such that test_i should never exceed 0\n", " assert len(behav) == num_test\n", " \n", " ## Average same-image repeats ##\n", " if test_image is None:\n", " voxel = voxels[behav[:,0,5].cpu().long()]\n", " image = behav[:,0,0].cpu().long()\n", " \n", " unique_image, sort_indices = torch.unique(image, return_inverse=True)\n", " for im in unique_image:\n", " locs = torch.where(im == image)[0]\n", " if test_image is None:\n", " test_image = images[im][None]\n", " test_voxel = torch.mean(voxel[locs],axis=0)[None]\n", " else:\n", " test_image = torch.vstack((test_image, images[im][None]))\n", " test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))\n", " \n", " # random sample of 300\n", " random_indices = torch.arange(len(test_voxel))[:300]\n", " voxel = test_voxel[random_indices].to(device)\n", " image = test_image[random_indices].to(device)\n", " assert len(image) == 300\n", "\n", " if blurry_recon:\n", " # blurry_image_enc = autoenc.encode(2*utils.resize(image,128)-1).latent_dist.mode() * 0.18215\n", " blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215\n", "\n", " if depth_recon:\n", " # depth_images = utils.resize(midas_depth.model(image).unsqueeze(1).repeat(1,3,1,1), 128)\n", " depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32)\n", " depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half()\n", " depth_image_enc = depth_images # autoenc.encode(2*depth_images-1).latents * 0.18215\n", " \n", " clip_target = clip_model.embed_image(image.float())\n", " \n", " voxel_ridge = model.ridge(voxel).unsqueeze(1)\n", "\n", " # voxel_ridge = torch.cat((voxel_ridge.unsqueeze(1),voxel_ridge.unsqueeze(1)),axis=1)\n", " \n", " clip_voxels, blurry_image_enc_, depth_image_enc_ = model.backbone(voxel_ridge)\n", " \n", " clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)\n", " clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n", " \n", " loss_clip = utils.soft_clip_loss(\n", " clip_voxels_norm,\n", " clip_target_norm,\n", " temp=.006)\n", " test_loss_clip_total += loss_clip.item()\n", " loss_clip = loss_clip * clip_scale\n", " loss = loss_clip\n", "\n", " if blurry_recon:\n", " downsampled_image = nn.functional.interpolate(image, size=(8, 8), mode='bilinear', align_corners=False)\n", " re_upsampled_image = add_saturation(nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest'))\n", " re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215\n", " \n", " loss_blurry = (l1(blurry_image_enc_, blurry_image_enc) + l1(blurry_image_enc_, re_upsampled_enc))\n", " loss_blurry += l1(torch.var(blurry_image_enc), torch.var(blurry_image_enc_))\n", " test_loss_blurry_total += loss_blurry.item()\n", " loss_blurry *= blur_scale\n", " loss += loss_blurry\n", " \n", " # halving the batch size because the decoder is computationally heavy\n", " blurry_recon_images = (autoenc.decode(blurry_image_enc_[:len(voxel)//2]/0.18215).sample / 2 + 0.5).clamp(0,1)\n", " blurry_recon_images = torch.vstack((blurry_recon_images, (autoenc.decode(blurry_image_enc_[len(voxel)//2:]/0.18215).sample / 2 + 0.5).clamp(0,1)))\n", " pixcorr = utils.pixcorr(image, blurry_recon_images)\n", " loss += (1 - pixcorr)\n", " test_blurry_pixcorr += pixcorr.item()\n", "\n", " if depth_recon:\n", " loss_depth = l1(depth_image_enc_, depth_image_enc)\n", " # loss_depth += l1(torch.var(depth_image_enc_), torch.var(depth_image_enc))\n", " test_loss_depth_total += loss_depth.item()\n", " loss_depth *= depth_scale\n", " loss += loss_depth\n", " \n", " # forward and backward top 1 accuracy \n", " labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) \n", " test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1).item()\n", " test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1).item()\n", "\n", " utils.check_loss(loss) \n", " test_losses.append(loss.item())\n", "\n", " # if utils.is_interactive(): clear_output(wait=True)\n", " print(\"---\")\n", " \n", " assert (test_i+1) == 1\n", " logs = {\"train/loss\": np.mean(losses[-(train_i+1):]),\n", " \"test/loss\": np.mean(test_losses[-(test_i+1):]),\n", " \"train/lr\": lrs[-1],\n", " \"train/num_steps\": len(losses),\n", " \"test/num_steps\": len(test_losses),\n", " \"train/fwd_pct_correct\": fwd_percent_correct / (train_i + 1),\n", " \"train/bwd_pct_correct\": bwd_percent_correct / (train_i + 1),\n", " \"test/test_fwd_pct_correct\": test_fwd_percent_correct / (test_i + 1),\n", " \"test/test_bwd_pct_correct\": test_bwd_percent_correct / (test_i + 1),\n", " \"train/loss_clip_total\": loss_clip_total / (train_i + 1),\n", " \"train/loss_blurry_total\": loss_blurry_total / (train_i + 1),\n", " \"test/loss_clip_total\": test_loss_clip_total / (test_i + 1),\n", " \"test/loss_blurry_total\": test_loss_blurry_total / (test_i + 1),\n", " \"train/blurry_pixcorr\": blurry_pixcorr / (train_i + 1),\n", " \"test/blurry_pixcorr\": test_blurry_pixcorr / (test_i + 1),\n", " \"train/loss_depth_total\": loss_depth_total / (train_i + 1),\n", " \"test/loss_depth_total\": test_loss_depth_total / (test_i + 1),\n", " }\n", " \n", " if blurry_recon: \n", " # transform blurry recon latents to images and plot it\n", " fig, axes = plt.subplots(1, 8, figsize=(10, 4))\n", " jj=-1\n", " for j in [0,1,2,3]:\n", " jj+=1\n", " axes[jj].imshow(utils.torch_to_Image((autoenc.decode(blurry_image_enc[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1)))\n", " axes[jj].axis('off')\n", " jj+=1\n", " axes[jj].imshow(utils.torch_to_Image((autoenc.decode(blurry_image_enc_[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1)))\n", " axes[jj].axis('off')\n", " \n", " if wandb_log:\n", " logs[f\"test/recons\"] = wandb.Image(fig, caption=f\"epoch{epoch:03d}\")\n", " plt.close()\n", " else:\n", " plt.show()\n", "\n", " if depth_recon:\n", " # transform blurry recon latents to images and plot it\n", " fig, axes = plt.subplots(1, 8, figsize=(10, 4))\n", " # axes[0].imshow(utils.torch_to_Image((autoenc.decode(depth_image_enc[[0]]/0.18215).sample / 2 + 0.5).clamp(0,1)))\n", " # axes[1].imshow(utils.torch_to_Image((autoenc.decode(depth_image_enc_[[0]]/0.18215).sample / 2 + 0.5).clamp(0,1)))\n", " jj=-1\n", " for j in [0,1,2,3]:\n", " jj+=1\n", " axes[jj].imshow(utils.torch_to_Image(utils.resize(depth_image_enc[[j]].view(1,1,32,32).clamp(0,1), 224)))\n", " axes[jj].axis('off')\n", " jj+=1\n", " axes[jj].imshow(utils.torch_to_Image(utils.resize(depth_image_enc_[[j]].view(1,1,32,32).clamp(0,1), 224)))\n", " axes[jj].axis('off')\n", " if wandb_log:\n", " logs[f\"test/depth_recons\"] = wandb.Image(fig, caption=f\"epoch{epoch:03d}\")\n", " plt.close()\n", " else:\n", " plt.show()\n", " \n", " progress_bar.set_postfix(**logs)\n", " \n", " # Save model checkpoint and reconstruct\n", " if epoch % ckpt_interval == 0:\n", " if not utils.is_interactive():\n", " save_ckpt(f'last')\n", " \n", " if wandb_log: wandb.log(logs)\n", "\n", " # wait for other GPUs to catch up if needed\n", " accelerator.wait_for_everyone()\n", " torch.cuda.empty_cache()\n", " gc.collect()\n", "\n", "print(\"\\n===Finished!===\\n\")\n", "if ckpt_saving:\n", " save_ckpt(f'last')\n", "if not utils.is_interactive():\n", " sys.exit(0)" ] }, { "cell_type": "code", "execution_count": 26, "id": "35cc1be7-bf76-4ad1-8c6a-de52bd013bf4", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sat Oct 28 21:13:17 2023 \n", "+-----------------------------------------------------------------------------+\n", "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n", "|-------------------------------+----------------------+----------------------+\n", "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", "| | | MIG M. |\n", "|===============================+======================+======================|\n", "| 0 NVIDIA A100-SXM... On | 00000000:10:1C.0 Off | 0 |\n", "| N/A 33C P0 50W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 1 NVIDIA A100-SXM... On | 00000000:10:1D.0 Off | 0 |\n", "| N/A 30C P0 53W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 2 NVIDIA A100-SXM... On | 00000000:20:1C.0 Off | 0 |\n", "| N/A 34C P0 54W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 3 NVIDIA A100-SXM... On | 00000000:20:1D.0 Off | 0 |\n", "| N/A 30C P0 53W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 4 NVIDIA A100-SXM... On | 00000000:90:1C.0 Off | 0 |\n", "| N/A 36C P0 53W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 5 NVIDIA A100-SXM... On | 00000000:90:1D.0 Off | 0 |\n", "| N/A 35C P0 72W / 400W | 38467MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 6 NVIDIA A100-SXM... On | 00000000:A0:1C.0 Off | 0 |\n", "| N/A 33C P0 50W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 7 NVIDIA A100-SXM... On | 00000000:A0:1D.0 Off | 0 |\n", "| N/A 31C P0 51W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", " \n", "+-----------------------------------------------------------------------------+\n", "| Processes: |\n", "| GPU GI CI PID Type Process name GPU Memory |\n", "| ID ID Usage |\n", "|=============================================================================|\n", "| 5 N/A N/A 1896724 C ...3/envs/mindeye/bin/python 38464MiB |\n", "+-----------------------------------------------------------------------------+\n" ] } ], "source": [ "!nvidia-smi" ] }, { "cell_type": "code", "execution_count": null, "id": "93e87fde-815d-4452-9915-f5f5dacf7c2a", "metadata": {}, "outputs": [], "source": [ "plt.plot(losses)\n", "plt.show()\n", "plt.plot(test_losses)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "f2690877-a431-44e8-a2ca-61f4b7397070", "metadata": {}, "source": [ "# Retrieve nearest neighbor in the training set using test set data" ] }, { "cell_type": "code", "execution_count": null, "id": "5b6b8feb-391d-437e-a5d9-a2088f1b1149", "metadata": {}, "outputs": [], "source": [ "annots = np.load(\"/fsx/proj-fmri/shared/mindeyev2_dataset/COCO_73k_annots_curated.npy\")" ] }, { "cell_type": "code", "execution_count": null, "id": "612ac5aa-6f0f-45ad-809e-03df905d184c", "metadata": {}, "outputs": [], "source": [ "ii=2\n", "all_indices = np.unique(train_73k_images) #np.hstack((test_vox_indices[ii],train_vox_indices))\n", "with torch.no_grad(), torch.cuda.amp.autocast():\n", " for batch in tqdm(range(0,len(all_indices),512)):\n", " if batch==0:\n", " clip_target = clip_model.embed_image(images[all_indices[batch:batch+512]]).cpu()\n", " else:\n", " target = clip_model.embed_image(images[all_indices[batch:batch+512]]).cpu()\n", " clip_target = torch.vstack((clip_target,target))\n", " clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n", "\n", " voxel = test_voxel[[ii]].to(device)\n", " image = test_image[[ii]].to(device)\n", "\n", " print(\"Original Image (test set)\")\n", " display(utils.torch_to_Image(image))\n", " \n", " clip_target = clip_model.embed_image(image).cpu()\n", " # clip_target_norm = torch.vstack((clip_target_norm, nn.functional.normalize(clip_target.flatten(1), dim=-1)))\n", " \n", " voxel_ridge = model.ridge(voxel).unsqueeze(1)\n", " clip_voxels, _, _ = model.backbone(voxel_ridge) \n", " clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)\n", " clip_voxels_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n", "\n", " print(\"clip_voxels_norm\", clip_voxels_norm.shape)\n", " print(\"clip_target_norm\", clip_target_norm.shape)\n", " \n", " sortt = torch.argsort(utils.batchwise_cosine_similarity(clip_voxels_norm.cpu(), \n", " clip_target_norm).flatten()).flip(0)\n", " picks = all_indices[sortt[:5]]\n", "\n", " print(\"\\nNearest neighbors in training set\")\n", " for ip,p in enumerate(picks):\n", " display(utils.torch_to_Image(images[[p]]))\n", " # print(utils.select_annotations([annots[int(p)]]))\n", " if ip==0: predicted_caption = utils.select_annotations([annots[int(p)]])[0]\n", "\n", "print(\"\\n=====\\npredicted_caption:\\n\", predicted_caption)" ] }, { "cell_type": "markdown", "id": "1473ddaa-5f2b-4448-9194-c7b0801d05db", "metadata": {}, "source": [ "# Feed into Stable Diffusion XL for reconstructions" ] }, { "cell_type": "code", "execution_count": null, "id": "70e50e0d-c44f-4d56-939a-2943535e1747", "metadata": {}, "outputs": [], "source": [ "from diffusers import StableDiffusionXLPipeline\n", "pipe = StableDiffusionXLPipeline.from_pretrained(\n", " \"/fsx/proj-fmri/shared/cache/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/f898a3e026e802f68796b95e9702464bac78d76f\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n", ")\n", "pipe.to(\"cuda\")\n", "pass" ] }, { "cell_type": "code", "execution_count": null, "id": "479e6994-3eaa-47d2-89a3-422c464fab36", "metadata": {}, "outputs": [], "source": [ "prompt = predicted_caption\n", "recon = pipe(prompt=prompt).images[0]" ] }, { "cell_type": "code", "execution_count": null, "id": "9dc48e1b-5842-4a29-963a-6469d943a72c", "metadata": { "tags": [] }, "outputs": [], "source": [ "print(\"Seen image\")\n", "display(utils.torch_to_Image(image))\n", "\n", "print(\"Reconstruction\")\n", "utils.torch_to_Image(utils.resize(transforms.ToTensor()(recon),224))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "165px" }, "toc_section_display": true, "toc_window_display": true }, "toc-autonumbering": true, "vscode": { "interpreter": { "hash": "62aae01ef0cf7b6af841ab1c8ce59175c4332e693ab3d00bc32ceffb78a35376" } } }, "nbformat": 4, "nbformat_minor": 5 }