diff --git "a/src/.ipynb_checkpoints/MLPencoder-checkpoint.ipynb" "b/src/.ipynb_checkpoints/MLPencoder-checkpoint.ipynb" new file mode 100644--- /dev/null +++ "b/src/.ipynb_checkpoints/MLPencoder-checkpoint.ipynb" @@ -0,0 +1,1012 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "d004931d-1ff4-4d85-9d23-1bb1eab2111e", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fri Sep 8 03:46:02 2023 \n", + "+-----------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n", + "|-------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|===============================+======================+======================|\n", + "| 0 NVIDIA A100-SXM... On | 00000000:10:1C.0 Off | 0 |\n", + "| N/A 43C P0 77W / 400W | 5553MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 1 NVIDIA A100-SXM... On | 00000000:10:1D.0 Off | 0 |\n", + "| N/A 40C P0 75W / 400W | 15405MiB / 40960MiB | 100% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 2 NVIDIA A100-SXM... On | 00000000:20:1C.0 Off | 0 |\n", + "| N/A 45C P0 79W / 400W | 15405MiB / 40960MiB | 100% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 3 NVIDIA A100-SXM... On | 00000000:20:1D.0 Off | 0 |\n", + "| N/A 42C P0 86W / 400W | 15657MiB / 40960MiB | 100% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 4 NVIDIA A100-SXM... On | 00000000:90:1C.0 Off | 0 |\n", + "| N/A 44C P0 84W / 400W | 15581MiB / 40960MiB | 100% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 5 NVIDIA A100-SXM... On | 00000000:90:1D.0 Off | 0 |\n", + "| N/A 43C P0 80W / 400W | 15543MiB / 40960MiB | 100% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 6 NVIDIA A100-SXM... On | 00000000:A0:1C.0 Off | 0 |\n", + "| N/A 41C P0 52W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 7 NVIDIA A100-SXM... On | 00000000:A0:1D.0 Off | 0 |\n", + "| N/A 39C P0 56W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + " \n", + "+-----------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=============================================================================|\n", + "| 0 N/A N/A 873489 C ...ari/llama_env/bin/python3 5550MiB |\n", + "| 1 N/A N/A 873490 C ...ari/llama_env/bin/python3 15402MiB |\n", + "| 2 N/A N/A 873491 C ...ari/llama_env/bin/python3 15402MiB |\n", + "| 3 N/A N/A 873492 C ...ari/llama_env/bin/python3 15654MiB |\n", + "| 4 N/A N/A 873493 C ...ari/llama_env/bin/python3 15578MiB |\n", + "| 5 N/A N/A 873494 C ...ari/llama_env/bin/python3 15540MiB |\n", + "+-----------------------------------------------------------------------------+\n" + ] + } + ], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "ec021849-d426-4450-a140-f8647a764d2e", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2023-09-08 03:46:06,264] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" + ] + } + ], + "source": [ + "from functools import partial\n", + "from einops import rearrange\n", + "from transformers import MusicgenForConditionalGeneration\n", + "import pytorch_lightning as pl\n", + "import os, torch, torch.nn as nn, torch.utils.data as data, torchvision as tv\n", + "import lightning as L\n", + "import numpy as np, pandas as pd, matplotlib.pyplot as plt\n", + "from pytorch_lightning.loggers import WandbLogger\n", + "import wandb" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5f1916fe-163f-4cac-ac0c-c16a74bcae89", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# create the datasets and dataloaders\n", + "train_voxels_path = '/fsx/proj-fmri/ckadirt/b2m/data/sub-001_Resp_Training.npy' # path to training voxels 65000 * 4800 \n", + "test_voxels_path = '/fsx/proj-fmri/ckadirt/b2m/data/sub-001_Resp_Test_Mean.npy' # path to test voxels 65000 * 600\n", + "\n", + "train_embeddings_path = '/fsx/proj-fmri/ckadirt/b2m/data/encodec32khz_training_embeds_sorted.npy' # path to training embeddings 480 * 2 * 1125\n", + "test_embeddings_path = '/fsx/proj-fmri/ckadirt/b2m/data/encodec32khz_testing_embeds_sorted.npy' # path to test embeddings 600 * 2 * 1125\n", + "\n", + "class VoxelsDataset(data.Dataset):\n", + " def __init__(self, voxels_path, embeddings_path):\n", + " # transpose the two dimensions of the voxels data to match the embeddings data\n", + " self.voxels = torch.from_numpy(np.load(voxels_path)).float().transpose(0, 1)\n", + " self.embeddings = torch.from_numpy(np.load(embeddings_path))\n", + " # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus\n", + " self.len = len(self.voxels) // 10\n", + " print(\"The len is \", self.len )\n", + "\n", + " def __getitem__(self, index):\n", + " # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus\n", + " voxels = self.voxels[index*10:(index+1)*10]\n", + " embeddings = self.embeddings[index]\n", + " return voxels, embeddings\n", + "\n", + " def __len__(self):\n", + " return self.len\n", + " \n", + "class VoxelsEmbeddinsEncodecDataModule(pl.LightningDataModule):\n", + " def __init__(self, train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=8):\n", + " super().__init__()\n", + " self.train_voxels_path = train_voxels_path\n", + " self.train_embeddings_path = train_embeddings_path\n", + " self.test_voxels_path = test_voxels_path\n", + " self.test_embeddings_path = test_embeddings_path\n", + " self.batch_size = batch_size\n", + "\n", + " def setup(self, stage=None):\n", + " self.train_dataset = VoxelsDataset(self.train_voxels_path, self.train_embeddings_path)\n", + " self.test_dataset = VoxelsDataset(self.test_voxels_path, self.test_embeddings_path)\n", + "\n", + " def train_dataloader(self):\n", + " return data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)\n", + "\n", + " def val_dataloader(self):\n", + " return data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "39fa231d-7e28-4813-8f5d-edf8fbce1774", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'bn = BrainNetwork(in_dim=4096)\\nrr = RidgeRegression(60784, 4096)\\n\\ntest_input = torch.randn(3, 60784)\\nout1 = rr(test_input)\\nout2 = bn(out1)\\nprint(out2.shape, out1.shape)\\n '" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class RidgeRegression(torch.nn.Module):\n", + " # make sure to add weight_decay when initializing optimizer\n", + " def __init__(self, input_size, out_features): \n", + " super(RidgeRegression, self).__init__()\n", + " self.linear = torch.nn.Linear(input_size, out_features)\n", + " def forward(self, x):\n", + " return self.linear(x)\n", + " \n", + "class BrainNetwork(nn.Module):\n", + " def __init__(self, out_dim=768*128, in_dim=60784, clip_size=128, h=4096, n_blocks=4, norm_type='ln', act_first=False, use_projector=True, drop1=.5, drop2=.15):\n", + " super().__init__()\n", + " norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)\n", + " act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU\n", + " act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)\n", + " self.mlp = nn.ModuleList([\n", + " nn.Sequential(\n", + " nn.Linear(h, h),\n", + " *[item() for item in act_and_norm],\n", + " nn.Dropout(drop2)\n", + " ) for _ in range(n_blocks)\n", + " ])\n", + " self.lin1 = nn.Linear(h, out_dim, bias=True)\n", + " self.n_blocks = n_blocks\n", + " self.clip_size = clip_size\n", + " self.use_projector = use_projector\n", + " if use_projector:\n", + " self.projector = nn.Sequential(\n", + " nn.LayerNorm(clip_size),\n", + " nn.GELU(),\n", + " nn.Linear(clip_size, 2048),\n", + " nn.LayerNorm(2048),\n", + " nn.GELU(),\n", + " nn.Linear(2048, 2048),\n", + " nn.LayerNorm(2048),\n", + " nn.GELU(),\n", + " nn.Linear(2048, clip_size)\n", + " )\n", + " \n", + " def forward(self, x):\n", + " residual = x\n", + " for res_block in range(self.n_blocks):\n", + " x = self.mlp[res_block](x)\n", + " x += residual\n", + " residual = x\n", + " x = x.reshape(len(x), -1)\n", + " x = self.lin1(x)\n", + " if self.use_projector:\n", + " x = self.projector(x.reshape(len(x), -1, self.clip_size))\n", + " x = rearrange(x, 'b e t -> b t e')\n", + " return x\n", + " return x\n", + " \n", + "\"\"\"bn = BrainNetwork(in_dim=4096)\n", + "rr = RidgeRegression(60784, 4096)\n", + "\n", + "test_input = torch.randn(3, 60784)\n", + "out1 = rr(test_input)\n", + "out2 = bn(out1)\n", + "print(out2.shape, out1.shape)\n", + " \"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "495e8e78-c44f-4410-b3c2-61faa6beec77", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\"b2m_test = B2M().to('cuda')\\ntest_b2m_input = torch.randn(4, 60784).to('cuda')\\naudio_codes = np.load('/fsx/proj-fmri/ckadirt/b2m/data/encodec32khz_training_embeds_sorted.npy')\\naudio_codes = torch.from_numpy(audio_codes)\\naudio_codes = audio_codes[:4, :, :]\\naudio_codes = rearrange(audio_codes, 'b c t -> (b c) t').to('cuda').long()\\ntest_b2m_output = b2m_test(test_b2m_input, audio_codes)\\nprint(test_b2m_output.shape)\"" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class B2M(pl.LightningModule):\n", + " def __init__(self, input_size = 60784, mapping_size = 4096, num_codebooks = 4):\n", + " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n", + " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n", + " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n", + " super().__init__()\n", + " self.brain_network = BrainNetwork(h = mapping_size)\n", + " self.ridge_regression = RidgeRegression(input_size=input_size, out_features=mapping_size)\n", + " self.loss = nn.CrossEntropyLoss()\n", + " self.pseudo_text_encoder = nn.Sequential(\n", + " self.ridge_regression,\n", + " self.brain_network\n", + " )\n", + " self.musicgen_decoder = MusicgenForConditionalGeneration.from_pretrained(\"facebook/musicgen-small\")\n", + " self.pad_token_id = self.musicgen_decoder.generation_config.pad_token_id\n", + " self.num_codebooks = num_codebooks\n", + "\n", + " def forward(self, x, decoder_input_ids=None):\n", + " # x is [batch_size, 60784]\n", + " # decoder input ids is [batch_size * num_codebooks, 750] 750 is the length of the audiocodes for 15 seconds of audio\n", + " # first we pass the voxels through the pseudo text encoder\n", + " pseudo_encoded_fmri = self.pseudo_text_encoder(x)\n", + " # x is [batch_size, 128, 768]\n", + " # now we pass the output through the musicgen projector to get [batch_size, 128, 1024]\n", + " projected_pseudo_encoded_fmri = self.musicgen_decoder.enc_to_dec_proj(pseudo_encoded_fmri)\n", + "\n", + " if decoder_input_ids is None:\n", + " # if no decoder input ids are given, we create a tensor of the size [batch_size * num_codebooks, 1] filled with the pad token id\n", + " decoder_input_ids = (\n", + " torch.ones((x.shape[0] * self.num_codebooks, 1), dtype=torch.long)\n", + " * self.pad_token_id\n", + " )\n", + " \n", + " # now we pass the projected pseudo encoded fmri through the musicgen decoder\n", + " logits = self.musicgen_decoder.decoder(\n", + " input_ids = decoder_input_ids,\n", + " encoder_hidden_states = projected_pseudo_encoded_fmri,\n", + " ).logits\n", + "\n", + " return logits\n", + "\n", + " \n", + " def training_step(self, batch, batch_idx):\n", + " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 4, 750]\n", + " # take the last scan from the voxels\n", + " voxels = voxels[:, -1, :]\n", + " # convert the embeddings to long and combine the batch and codebook dimensions\n", + " embeddings = rearrange(embeddings, 'b c t -> (b c) t').long()\n", + "\n", + "\n", + " #take just the first 200 embeddings\n", + " #embeddings = embeddings[:, :200]\n", + " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n", + " #voxels = voxels[:, 0:2, :]\n", + " #voxels = voxels.mean(dim=1)\n", + " #voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n", + "\n", + "\n", + " # use the decoder input ids to get the logits\n", + " decoder_input_ids = embeddings[:, :-1]\n", + " logits = self(voxels, decoder_input_ids)\n", + "\n", + " # get the loss\n", + " loss = self.loss(rearrange(logits, '(b c) t d -> (b c t) d', c=self.num_codebooks), rearrange(embeddings[:, 1:], '(b c) t -> (b c t)', c=self.num_codebooks))\n", + "\n", + "\n", + " acuracy = self.tokens_accuracy(logits, embeddings[:,1:])\n", + " self.log('train_loss', loss, sync_dist=True)\n", + " self.log('train_accuracy', acuracy, sync_dist=True)\n", + " discrete_outputs = logits.argmax(dim=2)\n", + " self.train_outptus.append(discrete_outputs)\n", + " return loss\n", + " \n", + " def tokens_accuracy(self, outputs, embeddings):\n", + " # outputs is [batch_size, 750, 2048]\n", + " # embeddings is [batch_size, 750]\n", + " # we need to get the index of the maximum value of each token\n", + " outputs = outputs.argmax(dim=2)\n", + " # now we need to compare the outputs with the embeddings\n", + " return (outputs == embeddings).float().mean()\n", + " \n", + " def on_train_epoch_end(self):\n", + " self.train_outptus = torch.cat(self.train_outptus)\n", + " # save the outputs with the current epoch name\n", + " torch.save(self.train_outptus, 'outputs_train'+str(self.current_epoch)+'.pt')\n", + " self.train_outptus = []\n", + " \n", + " def on_validation_epoch_end(self):\n", + " self.test_outptus = torch.cat(self.test_outptus)\n", + " # save the outputs with the current epoch name\n", + " torch.save(self.test_outptus, 'outputs_validation'+str(self.current_epoch)+'.pt')\n", + " self.test_outptus = []\n", + "\n", + " \n", + " def validation_step(self, batch, batch_idx):\n", + " voxels, embeddings = batch\n", + " # take the last scan from the voxels\n", + " voxels = voxels[:, -1, :]\n", + " # convert the embeddings to long and combine the batch and codebook dimensions\n", + " embeddings = rearrange(embeddings, 'b c t -> (b c) t').long()\n", + "\n", + " # take just the first 200 embeddings\n", + " #embeddings = embeddings[:, :200]\n", + " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n", + " #voxels = voxels[:, 0:2, :]\n", + " #voxels = voxels.mean(dim=1)\n", + " #voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n", + "\n", + " # use the decoder input ids to get the logits\n", + " decoder_input_ids = embeddings[:, :-1]\n", + " logits = self(voxels, decoder_input_ids)\n", + "\n", + " # get the loss\n", + " loss = self.loss(rearrange(logits, '(b c) t d -> (b c t) d', c=self.num_codebooks), rearrange(embeddings[:, 1:], '(b c) t -> (b c t)', c=self.num_codebooks))\n", + "\n", + " acuracy = self.tokens_accuracy(logits, embeddings[:,1:])\n", + " self.log('val_loss', loss, sync_dist=True)\n", + " self.log('val_accuracy', acuracy, sync_dist=True)\n", + " discrete_outputs = logits.argmax(dim=2)\n", + " self.test_outptus.append(discrete_outputs)\n", + " return loss\n", + " \n", + " \n", + "\n", + " def configure_optimizers(self):\n", + " # we just want to train the pseudo text encoder\n", + " optimizer = torch.optim.AdamW(\n", + " [\n", + " {'params': self.pseudo_text_encoder.parameters()},\n", + " ],\n", + " lr=1e-4,\n", + " weight_decay=1e-4\n", + " )\n", + " return optimizer\n", + "\n", + "# create the model\n", + "\"\"\"b2m_test = B2M().to('cuda')\n", + "test_b2m_input = torch.randn(4, 60784).to('cuda')\n", + "audio_codes = np.load('/fsx/proj-fmri/ckadirt/b2m/data/encodec32khz_training_embeds_sorted.npy')\n", + "audio_codes = torch.from_numpy(audio_codes)\n", + "audio_codes = audio_codes[:4, :, :]\n", + "audio_codes = rearrange(audio_codes, 'b c t -> (b c) t').to('cuda').long()\n", + "test_b2m_output = b2m_test(test_b2m_input, audio_codes)\n", + "print(test_b2m_output.shape)\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "0646c340-c3df-42c0-ad22-2b677fc85da6", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "Waiting for W&B process to finish... (success)." + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run cool-leaf-48 at: https://stability.wandb.io/ckadirt/brain2music/runs/7nj59774
Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 1 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20230908_034627-7nj59774/logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "wandb version 0.15.10 is available! To upgrade, please run:\n", + " $ pip install wandb --upgrade" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.15.5" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in ./wandb/run-20230908_035326-na1yng1f" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run driven-glade-49 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://stability.wandb.io/ckadirt/brain2music" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://stability.wandb.io/ckadirt/brain2music/runs/na1yng1f" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:165: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /admin/home-ckadirt/miniconda3/envs/mindeye/lib/pyth ...\n", + " rank_zero_warn(\n", + "Using 16bit Automatic Mixed Precision (AMP)\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The len is 480\n", + "The len is 60\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "\n", + " | Name | Type | Params\n", + "-------------------------------------------------------------------------\n", + "0 | brain_network | BrainNetwork | 474 M \n", + "1 | ridge_regression | RidgeRegression | 248 M \n", + "2 | loss | CrossEntropyLoss | 0 \n", + "3 | pseudo_text_encoder | Sequential | 723 M \n", + "4 | musicgen_decoder | MusicgenForConditionalGeneration | 588 M \n", + "-------------------------------------------------------------------------\n", + "1.3 B Trainable params\n", + "2.1 M Non-trainable params\n", + "1.3 B Total params\n", + "5,250.392 Total estimated model params size (MB)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sanity Checking DataLoader 0: 0%| | 0/2 [00:00╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n", + " in <module>:12 \n", + " \n", + " 9 trainer = pl.Trainer(devices=1, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, \n", + " 10 \n", + " 11 # train the model \n", + " 12 trainer.fit(b2m, datamodule=data_module) \n", + " 13 \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train \n", + " er/trainer.py:529 in fit \n", + " \n", + " 526 │ │ \"\"\" \n", + " 527 │ │ model = _maybe_unwrap_optimized(model) \n", + " 528 │ │ self.strategy._lightning_module = model \n", + " 529 │ │ call._call_and_handle_interrupt( \n", + " 530 │ │ │ self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, \n", + " 531 │ │ ) \n", + " 532 \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train \n", + " er/call.py:42 in _call_and_handle_interrupt \n", + " \n", + " 39 try: \n", + " 40 │ │ if trainer.strategy.launcher is not None: \n", + " 41 │ │ │ return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, \n", + " 42 │ │ return trainer_fn(*args, **kwargs) \n", + " 43 \n", + " 44 except _TunerExitException: \n", + " 45 │ │ _call_teardown_hook(trainer) \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train \n", + " er/trainer.py:568 in _fit_impl \n", + " \n", + " 565 │ │ │ model_provided=True, \n", + " 566 │ │ │ model_connected=self.lightning_module is not None, \n", + " 567 │ │ ) \n", + " 568 │ │ self._run(model, ckpt_path=ckpt_path) \n", + " 569 │ │ \n", + " 570 │ │ assert self.state.stopped \n", + " 571 │ │ self.training = False \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train \n", + " er/trainer.py:973 in _run \n", + " \n", + " 970 │ │ # ---------------------------- \n", + " 971 │ │ # RUN THE TRAINER \n", + " 972 │ │ # ---------------------------- \n", + " 973 │ │ results = self._run_stage() \n", + " 974 │ │ \n", + " 975 │ │ # ---------------------------- \n", + " 976 │ │ # POST-Training CLEAN UP \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train \n", + " er/trainer.py:1014 in _run_stage \n", + " \n", + " 1011 │ │ │ return self.predict_loop.run() \n", + " 1012 │ │ if self.training: \n", + " 1013 │ │ │ with isolate_rng(): \n", + " 1014 │ │ │ │ self._run_sanity_check() \n", + " 1015 │ │ │ with torch.autograd.set_detect_anomaly(self._detect_anomaly): \n", + " 1016 │ │ │ │ self.fit_loop.run() \n", + " 1017 │ │ │ return None \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train \n", + " er/trainer.py:1043 in _run_sanity_check \n", + " \n", + " 1040 │ │ │ call._call_callback_hooks(self, \"on_sanity_check_start\") \n", + " 1041 │ │ │ \n", + " 1042 │ │ │ # run eval step \n", + " 1043 │ │ │ val_loop.run() \n", + " 1044 │ │ │ \n", + " 1045 │ │ │ call._call_callback_hooks(self, \"on_sanity_check_end\") \n", + " 1046 \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/loops \n", + " /utilities.py:177 in _decorator \n", + " \n", + " 174 │ │ else: \n", + " 175 │ │ │ context_manager = torch.no_grad \n", + " 176 │ │ with context_manager(): \n", + " 177 │ │ │ return loop_run(self, *args, **kwargs) \n", + " 178 \n", + " 179 return _decorator \n", + " 180 \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/loops \n", + " /evaluation_loop.py:115 in run \n", + " \n", + " 112 │ │ │ │ │ self._store_dataloader_outputs() \n", + " 113 │ │ │ │ previous_dataloader_idx = dataloader_idx \n", + " 114 │ │ │ │ # run step hooks \n", + " 115 │ │ │ │ self._evaluation_step(batch, batch_idx, dataloader_idx) \n", + " 116 │ │ │ except StopIteration: \n", + " 117 │ │ │ │ # this needs to wrap the `*_step` call too (not just `next`) for `datalo \n", + " 118 │ │ │ │ break \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/loops \n", + " /evaluation_loop.py:375 in _evaluation_step \n", + " \n", + " 372 │ │ self.batch_progress.increment_started() \n", + " 373 │ │ \n", + " 374 │ │ hook_name = \"test_step\" if trainer.testing else \"validation_step\" \n", + " 375 │ │ output = call._call_strategy_hook(trainer, hook_name, *step_kwargs.values()) \n", + " 376 │ │ \n", + " 377 │ │ self.batch_progress.increment_processed() \n", + " 378 \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train \n", + " er/call.py:291 in _call_strategy_hook \n", + " \n", + " 288 │ │ return None \n", + " 289 \n", + " 290 with trainer.profiler.profile(f\"[Strategy]{trainer.strategy.__class__.__name__}.{hoo \n", + " 291 │ │ output = fn(*args, **kwargs) \n", + " 292 \n", + " 293 # restore current_fx when nested context \n", + " 294 pl_module._current_fx_name = prev_fx_name \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/strat \n", + " egies/strategy.py:379 in validation_step \n", + " \n", + " 376 │ │ \"\"\" \n", + " 377 │ │ with self.precision_plugin.val_step_context(): \n", + " 378 │ │ │ assert isinstance(self.model, ValidationStep) \n", + " 379 │ │ │ return self.model.validation_step(*args, **kwargs) \n", + " 380 \n", + " 381 def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: \n", + " 382 │ │ \"\"\"The actual test step. \n", + " \n", + " in validation_step:116 \n", + " \n", + " 113 │ │ # get the loss \n", + " 114 │ │ loss = self.loss(rearrange(logits, '(b c) t d -> (b c t) d', c=self.num_codebook \n", + " 115 │ │ \n", + " 116 │ │ acuracy = self.tokens_accuracy(logits, embeddings) \n", + " 117 │ │ self.log('val_loss', loss, sync_dist=True) \n", + " 118 │ │ self.log('val_accuracy', acuracy, sync_dist=True) \n", + " 119 │ │ discrete_outputs = logits.argmax(dim=2) \n", + " \n", + " in tokens_accuracy:80 \n", + " \n", + " 77 │ │ # we need to get the index of the maximum value of each token \n", + " 78 │ │ outputs = outputs.argmax(dim=2) \n", + " 79 │ │ # now we need to compare the outputs with the embeddings \n", + " 80 │ │ return (outputs == embeddings).float().mean() \n", + " 81 \n", + " 82 def on_train_epoch_end(self): \n", + " 83 │ │ self.train_outptus = torch.cat(self.train_outptus) \n", + "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "RuntimeError: The size of tensor a (749) must match the size of tensor b (750) at non-singleton dimension 1\n", + "\n" + ], + "text/plain": [ + "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", + "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m12\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 9 \u001b[0mtrainer = pl.Trainer(devices=\u001b[94m1\u001b[0m, accelerator=\u001b[33m\"\u001b[0m\u001b[33mgpu\u001b[0m\u001b[33m\"\u001b[0m, max_epochs=\u001b[94m400\u001b[0m, logger=wandb_logger, \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m10 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m11 \u001b[0m\u001b[2m# train the model\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m12 trainer.fit(b2m, datamodule=data_module) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m13 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33mer/\u001b[0m\u001b[1;33mtrainer.py\u001b[0m:\u001b[94m529\u001b[0m in \u001b[92mfit\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 526 \u001b[0m\u001b[2;33m│ │ \u001b[0m\u001b[33m\"\"\"\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 527 \u001b[0m\u001b[2m│ │ \u001b[0mmodel = _maybe_unwrap_optimized(model) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 528 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.strategy._lightning_module = model \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 529 \u001b[2m│ │ \u001b[0mcall._call_and_handle_interrupt( \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 530 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[96mself\u001b[0m, \u001b[96mself\u001b[0m._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 531 \u001b[0m\u001b[2m│ │ \u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 532 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33mer/\u001b[0m\u001b[1;33mcall.py\u001b[0m:\u001b[94m42\u001b[0m in \u001b[92m_call_and_handle_interrupt\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 39 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mtry\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 40 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m trainer.strategy.launcher \u001b[95mis\u001b[0m \u001b[95mnot\u001b[0m \u001b[94mNone\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 41 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 42 \u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m trainer_fn(*args, **kwargs) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 43 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 44 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mexcept\u001b[0m _TunerExitException: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 45 \u001b[0m\u001b[2m│ │ \u001b[0m_call_teardown_hook(trainer) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33mer/\u001b[0m\u001b[1;33mtrainer.py\u001b[0m:\u001b[94m568\u001b[0m in \u001b[92m_fit_impl\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 565 \u001b[0m\u001b[2m│ │ │ \u001b[0mmodel_provided=\u001b[94mTrue\u001b[0m, \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 566 \u001b[0m\u001b[2m│ │ │ \u001b[0mmodel_connected=\u001b[96mself\u001b[0m.lightning_module \u001b[95mis\u001b[0m \u001b[95mnot\u001b[0m \u001b[94mNone\u001b[0m, \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 567 \u001b[0m\u001b[2m│ │ \u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 568 \u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m._run(model, ckpt_path=ckpt_path) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 569 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 570 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94massert\u001b[0m \u001b[96mself\u001b[0m.state.stopped \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 571 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.training = \u001b[94mFalse\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33mer/\u001b[0m\u001b[1;33mtrainer.py\u001b[0m:\u001b[94m973\u001b[0m in \u001b[92m_run\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 970 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# ----------------------------\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 971 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# RUN THE TRAINER\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 972 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# ----------------------------\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 973 \u001b[2m│ │ \u001b[0mresults = \u001b[96mself\u001b[0m._run_stage() \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 974 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 975 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# ----------------------------\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 976 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# POST-Training CLEAN UP\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33mer/\u001b[0m\u001b[1;33mtrainer.py\u001b[0m:\u001b[94m1014\u001b[0m in \u001b[92m_run_stage\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1011 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m \u001b[96mself\u001b[0m.predict_loop.run() \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1012 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mself\u001b[0m.training: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1013 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mwith\u001b[0m isolate_rng(): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1014 \u001b[2m│ │ │ │ \u001b[0m\u001b[96mself\u001b[0m._run_sanity_check() \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1015 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mwith\u001b[0m torch.autograd.set_detect_anomaly(\u001b[96mself\u001b[0m._detect_anomaly): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1016 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[96mself\u001b[0m.fit_loop.run() \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1017 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m \u001b[94mNone\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33mer/\u001b[0m\u001b[1;33mtrainer.py\u001b[0m:\u001b[94m1043\u001b[0m in \u001b[92m_run_sanity_check\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1040 \u001b[0m\u001b[2m│ │ │ \u001b[0mcall._call_callback_hooks(\u001b[96mself\u001b[0m, \u001b[33m\"\u001b[0m\u001b[33mon_sanity_check_start\u001b[0m\u001b[33m\"\u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1041 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1042 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[2m# run eval step\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1043 \u001b[2m│ │ │ \u001b[0mval_loop.run() \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1044 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1045 \u001b[0m\u001b[2m│ │ │ \u001b[0mcall._call_callback_hooks(\u001b[96mself\u001b[0m, \u001b[33m\"\u001b[0m\u001b[33mon_sanity_check_end\u001b[0m\u001b[33m\"\u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1046 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/loops\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/\u001b[0m\u001b[1;33mutilities.py\u001b[0m:\u001b[94m177\u001b[0m in \u001b[92m_decorator\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m174 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94melse\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m175 \u001b[0m\u001b[2m│ │ │ \u001b[0mcontext_manager = torch.no_grad \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m176 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mwith\u001b[0m context_manager(): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m177 \u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m loop_run(\u001b[96mself\u001b[0m, *args, **kwargs) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m178 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m179 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mreturn\u001b[0m _decorator \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m180 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/loops\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/\u001b[0m\u001b[1;33mevaluation_loop.py\u001b[0m:\u001b[94m115\u001b[0m in \u001b[92mrun\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m112 \u001b[0m\u001b[2m│ │ │ │ │ \u001b[0m\u001b[96mself\u001b[0m._store_dataloader_outputs() \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m113 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mprevious_dataloader_idx = dataloader_idx \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m114 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[2m# run step hooks\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m115 \u001b[2m│ │ │ │ \u001b[0m\u001b[96mself\u001b[0m._evaluation_step(batch, batch_idx, dataloader_idx) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m116 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mexcept\u001b[0m \u001b[96mStopIteration\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m117 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[2m# this needs to wrap the `*_step` call too (not just `next`) for `datalo\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m118 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[94mbreak\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/loops\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/\u001b[0m\u001b[1;33mevaluation_loop.py\u001b[0m:\u001b[94m375\u001b[0m in \u001b[92m_evaluation_step\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m372 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.batch_progress.increment_started() \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m373 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m374 \u001b[0m\u001b[2m│ │ \u001b[0mhook_name = \u001b[33m\"\u001b[0m\u001b[33mtest_step\u001b[0m\u001b[33m\"\u001b[0m \u001b[94mif\u001b[0m trainer.testing \u001b[94melse\u001b[0m \u001b[33m\"\u001b[0m\u001b[33mvalidation_step\u001b[0m\u001b[33m\"\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m375 \u001b[2m│ │ \u001b[0moutput = call._call_strategy_hook(trainer, hook_name, *step_kwargs.values()) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m376 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m377 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.batch_progress.increment_processed() \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m378 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33mer/\u001b[0m\u001b[1;33mcall.py\u001b[0m:\u001b[94m291\u001b[0m in \u001b[92m_call_strategy_hook\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m288 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m \u001b[94mNone\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m289 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m290 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mwith\u001b[0m trainer.profiler.profile(\u001b[33mf\u001b[0m\u001b[33m\"\u001b[0m\u001b[33m[Strategy]\u001b[0m\u001b[33m{\u001b[0mtrainer.strategy.\u001b[91m__class__\u001b[0m.\u001b[91m__name__\u001b[0m\u001b[33m}\u001b[0m\u001b[33m.\u001b[0m\u001b[33m{\u001b[0mhoo \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m291 \u001b[2m│ │ \u001b[0moutput = fn(*args, **kwargs) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m292 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m293 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# restore current_fx when nested context\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m294 \u001b[0m\u001b[2m│ \u001b[0mpl_module._current_fx_name = prev_fx_name \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/strat\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33megies/\u001b[0m\u001b[1;33mstrategy.py\u001b[0m:\u001b[94m379\u001b[0m in \u001b[92mvalidation_step\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m376 \u001b[0m\u001b[2;33m│ │ \u001b[0m\u001b[33m\"\"\"\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m377 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mwith\u001b[0m \u001b[96mself\u001b[0m.precision_plugin.val_step_context(): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m378 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94massert\u001b[0m \u001b[96misinstance\u001b[0m(\u001b[96mself\u001b[0m.model, ValidationStep) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m379 \u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m \u001b[96mself\u001b[0m.model.validation_step(*args, **kwargs) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m380 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m381 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mtest_step\u001b[0m(\u001b[96mself\u001b[0m, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m382 \u001b[0m\u001b[2;90m│ │ \u001b[0m\u001b[33m\"\"\"The actual test step.\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m in \u001b[92mvalidation_step\u001b[0m:\u001b[94m116\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m113 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# get the loss\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m114 \u001b[0m\u001b[2m│ │ \u001b[0mloss = \u001b[96mself\u001b[0m.loss(rearrange(logits, \u001b[33m'\u001b[0m\u001b[33m(b c) t d -> (b c t) d\u001b[0m\u001b[33m'\u001b[0m, c=\u001b[96mself\u001b[0m.num_codebook \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m115 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m116 \u001b[2m│ │ \u001b[0macuracy = \u001b[96mself\u001b[0m.tokens_accuracy(logits, embeddings) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m117 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.log(\u001b[33m'\u001b[0m\u001b[33mval_loss\u001b[0m\u001b[33m'\u001b[0m, loss, sync_dist=\u001b[94mTrue\u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m118 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.log(\u001b[33m'\u001b[0m\u001b[33mval_accuracy\u001b[0m\u001b[33m'\u001b[0m, acuracy, sync_dist=\u001b[94mTrue\u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m119 \u001b[0m\u001b[2m│ │ \u001b[0mdiscrete_outputs = logits.argmax(dim=\u001b[94m2\u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m in \u001b[92mtokens_accuracy\u001b[0m:\u001b[94m80\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 77 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# we need to get the index of the maximum value of each token\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 78 \u001b[0m\u001b[2m│ │ \u001b[0moutputs = outputs.argmax(dim=\u001b[94m2\u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 79 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# now we need to compare the outputs with the embeddings\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 80 \u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m (outputs == embeddings).float().mean() \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 81 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 82 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mon_train_epoch_end\u001b[0m(\u001b[96mself\u001b[0m): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 83 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.train_outptus = torch.cat(\u001b[96mself\u001b[0m.train_outptus) \u001b[31m│\u001b[0m\n", + "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", + "\u001b[1;91mRuntimeError: \u001b[0mThe size of tensor a \u001b[1m(\u001b[0m\u001b[1;36m749\u001b[0m\u001b[1m)\u001b[0m must match the size of tensor b \u001b[1m(\u001b[0m\u001b[1;36m750\u001b[0m\u001b[1m)\u001b[0m at non-singleton dimension \u001b[1;36m1\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "b2m = B2M()\n", + "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4)\n", + "\n", + "wandb.finish()\n", + "\n", + "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n", + "\n", + "# define the trainer\n", + "trainer = pl.Trainer(devices=1, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n", + "\n", + "# train the model\n", + "trainer.fit(b2m, datamodule=data_module)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "4f458adf-1e89-4b8b-9514-08bcc9e8ef56", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fri Sep 8 03:55:15 2023 \n", + "+-----------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n", + "|-------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|===============================+======================+======================|\n", + "| 0 NVIDIA A100-SXM... On | 00000000:10:1C.0 Off | 0 |\n", + "| N/A 52C P0 205W / 400W | 32829MiB / 40960MiB | 97% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 1 NVIDIA A100-SXM... On | 00000000:10:1D.0 Off | 0 |\n", + "| N/A 54C P0 297W / 400W | 38991MiB / 40960MiB | 98% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 2 NVIDIA A100-SXM... On | 00000000:20:1C.0 Off | 0 |\n", + "| N/A 59C P0 354W / 400W | 39627MiB / 40960MiB | 89% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 3 NVIDIA A100-SXM... On | 00000000:20:1D.0 Off | 0 |\n", + "| N/A 50C P0 180W / 400W | 39719MiB / 40960MiB | 95% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 4 NVIDIA A100-SXM... On | 00000000:90:1C.0 Off | 0 |\n", + "| N/A 53C P0 190W / 400W | 35069MiB / 40960MiB | 98% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 5 NVIDIA A100-SXM... On | 00000000:90:1D.0 Off | 0 |\n", + "| N/A 51C P0 182W / 400W | 34235MiB / 40960MiB | 98% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 6 NVIDIA A100-SXM... On | 00000000:A0:1C.0 Off | 0 |\n", + "| N/A 40C P0 53W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 7 NVIDIA A100-SXM... On | 00000000:A0:1D.0 Off | 0 |\n", + "| N/A 38C P0 56W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + " \n", + "+-----------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=============================================================================|\n", + "| 0 N/A N/A 874054 C ...3/envs/mindeye/bin/python 848MiB |\n", + "| 0 N/A N/A 877194 C ...ari/llama_env/bin/python3 31978MiB |\n", + "| 1 N/A N/A 877195 C ...ari/llama_env/bin/python3 38988MiB |\n", + "| 2 N/A N/A 877196 C ...ari/llama_env/bin/python3 39876MiB |\n", + "| 3 N/A N/A 877197 C ...ari/llama_env/bin/python3 39716MiB |\n", + "| 4 N/A N/A 877198 C ...ari/llama_env/bin/python3 35066MiB |\n", + "| 5 N/A N/A 877199 C ...ari/llama_env/bin/python3 34232MiB |\n", + "+-----------------------------------------------------------------------------+\n" + ] + } + ], + "source": [ + "!nvidia-smi" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}