diff --git "a/src/MLPencoder.ipynb" "b/src/MLPencoder.ipynb" new file mode 100644--- /dev/null +++ "b/src/MLPencoder.ipynb" @@ -0,0 +1,2138 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "d004931d-1ff4-4d85-9d23-1bb1eab2111e", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fri Sep 8 05:45:08 2023 \n", + "+-----------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n", + "|-------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|===============================+======================+======================|\n", + "| 0 NVIDIA A100-SXM... On | 00000000:10:1C.0 Off | 0 |\n", + "| N/A 30C P0 51W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 1 NVIDIA A100-SXM... On | 00000000:10:1D.0 Off | 0 |\n", + "| N/A 38C P0 157W / 400W | 40431MiB / 40960MiB | 100% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 2 NVIDIA A100-SXM... On | 00000000:20:1C.0 Off | 0 |\n", + "| N/A 43C P0 176W / 400W | 28503MiB / 40960MiB | 100% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 3 NVIDIA A100-SXM... On | 00000000:20:1D.0 Off | 0 |\n", + "| N/A 38C P0 166W / 400W | 38049MiB / 40960MiB | 100% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 4 NVIDIA A100-SXM... On | 00000000:90:1C.0 Off | 0 |\n", + "| N/A 40C P0 178W / 400W | 37159MiB / 40960MiB | 100% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 5 NVIDIA A100-SXM... On | 00000000:90:1D.0 Off | 0 |\n", + "| N/A 39C P0 158W / 400W | 37641MiB / 40960MiB | 100% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 6 NVIDIA A100-SXM... On | 00000000:A0:1C.0 Off | 0 |\n", + "| N/A 39C P0 167W / 400W | 38395MiB / 40960MiB | 98% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 7 NVIDIA A100-SXM... On | 00000000:A0:1D.0 Off | 0 |\n", + "| N/A 48C P0 381W / 400W | 35381MiB / 40960MiB | 100% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + " \n", + "+-----------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=============================================================================|\n", + "| 1 N/A N/A 919504 C ...ari/llama_env/bin/python3 40428MiB |\n", + "| 2 N/A N/A 919505 C ...ari/llama_env/bin/python3 28500MiB |\n", + "| 3 N/A N/A 919506 C ...ari/llama_env/bin/python3 38046MiB |\n", + "| 4 N/A N/A 919507 C ...ari/llama_env/bin/python3 37156MiB |\n", + "| 5 N/A N/A 919508 C ...ari/llama_env/bin/python3 37638MiB |\n", + "| 6 N/A N/A 919509 C ...ari/llama_env/bin/python3 38392MiB |\n", + "| 7 N/A N/A 919510 C ...ari/llama_env/bin/python3 35378MiB |\n", + "+-----------------------------------------------------------------------------+\n" + ] + } + ], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "ec021849-d426-4450-a140-f8647a764d2e", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2023-09-08 05:54:30,058] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" + ] + } + ], + "source": [ + "from functools import partial\n", + "from einops import rearrange\n", + "from transformers import MusicgenForConditionalGeneration\n", + "import pytorch_lightning as pl\n", + "import os, torch, torch.nn as nn, torch.utils.data as data, torchvision as tv\n", + "import lightning as L\n", + "import numpy as np, pandas as pd, matplotlib.pyplot as plt\n", + "from pytorch_lightning.loggers import WandbLogger\n", + "import wandb" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "5f1916fe-163f-4cac-ac0c-c16a74bcae89", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# create the datasets and dataloaders\n", + "train_voxels_path = '/fsx/proj-fmri/ckadirt/b2m/data/sub-001_Resp_Training.npy' # path to training voxels 65000 * 4800 \n", + "test_voxels_path = '/fsx/proj-fmri/ckadirt/b2m/data/sub-001_Resp_Test_Mean.npy' # path to test voxels 65000 * 600\n", + "\n", + "train_embeddings_path = '/fsx/proj-fmri/ckadirt/b2m/data/encodec32khz_training_embeds_sorted.npy' # path to training embeddings 480 * 2 * 1125\n", + "test_embeddings_path = '/fsx/proj-fmri/ckadirt/b2m/data/encodec32khz_testing_embeds_sorted.npy' # path to test embeddings 600 * 2 * 1125\n", + "\n", + "class VoxelsDataset(data.Dataset):\n", + " def __init__(self, voxels_path, embeddings_path):\n", + " # transpose the two dimensions of the voxels data to match the embeddings data\n", + " self.voxels = torch.from_numpy(np.load(voxels_path)).float().transpose(0, 1)\n", + " self.embeddings = torch.from_numpy(np.load(embeddings_path))\n", + " # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus\n", + " self.len = len(self.voxels) // 10\n", + " print(\"The len is \", self.len )\n", + "\n", + " def __getitem__(self, index):\n", + " # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus\n", + " voxels = self.voxels[index*10:(index+1)*10]\n", + " embeddings = self.embeddings[index]\n", + " return voxels, embeddings\n", + "\n", + " def __len__(self):\n", + " return self.len\n", + " \n", + "class VoxelsEmbeddinsEncodecDataModule(pl.LightningDataModule):\n", + " def __init__(self, train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=8):\n", + " super().__init__()\n", + " self.train_voxels_path = train_voxels_path\n", + " self.train_embeddings_path = train_embeddings_path\n", + " self.test_voxels_path = test_voxels_path\n", + " self.test_embeddings_path = test_embeddings_path\n", + " self.batch_size = batch_size\n", + "\n", + " def setup(self, stage=None):\n", + " self.train_dataset = VoxelsDataset(self.train_voxels_path, self.train_embeddings_path)\n", + " self.test_dataset = VoxelsDataset(self.test_voxels_path, self.test_embeddings_path)\n", + "\n", + " def train_dataloader(self):\n", + " return data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)\n", + "\n", + " def val_dataloader(self):\n", + " return data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "39fa231d-7e28-4813-8f5d-edf8fbce1774", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'bn = BrainNetwork(in_dim=4096)\\nrr = RidgeRegression(60784, 4096)\\n\\ntest_input = torch.randn(3, 60784)\\nout1 = rr(test_input)\\nout2 = bn(out1)\\nprint(out2.shape, out1.shape)\\n '" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class RidgeRegression(torch.nn.Module):\n", + " # make sure to add weight_decay when initializing optimizer\n", + " def __init__(self, input_size, out_features): \n", + " super(RidgeRegression, self).__init__()\n", + " self.linear = torch.nn.Linear(input_size, out_features)\n", + " def forward(self, x):\n", + " return self.linear(x)\n", + " \n", + "class BrainNetwork(nn.Module):\n", + " def __init__(self, out_dim=768*128, in_dim=60784, clip_size=128, h=4096, n_blocks=4, norm_type='ln', act_first=False, use_projector=True, drop1=.5, drop2=.15):\n", + " super().__init__()\n", + " norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)\n", + " act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU\n", + " act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)\n", + " self.mlp = nn.ModuleList([\n", + " nn.Sequential(\n", + " nn.Linear(h, h),\n", + " *[item() for item in act_and_norm],\n", + " nn.Dropout(drop2)\n", + " ) for _ in range(n_blocks)\n", + " ])\n", + " self.lin1 = nn.Linear(h, out_dim, bias=True)\n", + " self.n_blocks = n_blocks\n", + " self.clip_size = clip_size\n", + " self.use_projector = use_projector\n", + " if use_projector:\n", + " self.projector = nn.Sequential(\n", + " nn.LayerNorm(clip_size),\n", + " nn.GELU(),\n", + " nn.Linear(clip_size, 2048),\n", + " nn.LayerNorm(2048),\n", + " nn.GELU(),\n", + " nn.Linear(2048, 2048),\n", + " nn.LayerNorm(2048),\n", + " nn.GELU(),\n", + " nn.Linear(2048, clip_size)\n", + " )\n", + " \n", + " def forward(self, x):\n", + " residual = x\n", + " for res_block in range(self.n_blocks):\n", + " x = self.mlp[res_block](x)\n", + " x += residual\n", + " residual = x\n", + " x = x.reshape(len(x), -1)\n", + " x = self.lin1(x)\n", + " if self.use_projector:\n", + " x = self.projector(x.reshape(len(x), -1, self.clip_size))\n", + " x = rearrange(x, 'b e t -> b t e')\n", + " return x\n", + " return x\n", + " \n", + "\"\"\"bn = BrainNetwork(in_dim=4096)\n", + "rr = RidgeRegression(60784, 4096)\n", + "\n", + "test_input = torch.randn(3, 60784)\n", + "out1 = rr(test_input)\n", + "out2 = bn(out1)\n", + "print(out2.shape, out1.shape)\n", + " \"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "495e8e78-c44f-4410-b3c2-61faa6beec77", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\"b2m_test = B2M().to('cuda')\\ntest_b2m_input = torch.randn(4, 60784).to('cuda')\\naudio_codes = np.load('/fsx/proj-fmri/ckadirt/b2m/data/encodec32khz_training_embeds_sorted.npy')\\naudio_codes = torch.from_numpy(audio_codes)\\naudio_codes = audio_codes[:4, :, :]\\naudio_codes = rearrange(audio_codes, 'b c t -> (b c) t').to('cuda').long()\\ntest_b2m_output = b2m_test(test_b2m_input, audio_codes)\\nprint(test_b2m_output.shape)\"" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class B2M(pl.LightningModule):\n", + " def __init__(self, input_size = 60784, mapping_size = 4096, num_codebooks = 4):\n", + " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n", + " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n", + " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n", + " super().__init__()\n", + " self.brain_network = BrainNetwork(h = mapping_size)\n", + " self.ridge_regression = RidgeRegression(input_size=input_size, out_features=mapping_size)\n", + " self.loss = nn.CrossEntropyLoss()\n", + " self.pseudo_text_encoder = nn.Sequential(\n", + " self.ridge_regression,\n", + " self.brain_network\n", + " )\n", + " self.musicgen_decoder = MusicgenForConditionalGeneration.from_pretrained(\"facebook/musicgen-small\")\n", + " self.pad_token_id = self.musicgen_decoder.generation_config.pad_token_id\n", + " self.num_codebooks = num_codebooks\n", + " self.test_outptus = []\n", + " self.train_outptus = []\n", + "\n", + " for param in self.musicgen_decoder.parameters():\n", + " param.requires_grad = False\n", + "\n", + " def forward(self, x, decoder_input_ids=None):\n", + " # x is [batch_size, 60784]\n", + " # decoder input ids is [batch_size * num_codebooks, 750] 750 is the length of the audiocodes for 15 seconds of audio\n", + " # first we pass the voxels through the pseudo text encoder\n", + " pseudo_encoded_fmri = self.pseudo_text_encoder(x)\n", + " # x is [batch_size, 128, 768]\n", + " # now we pass the output through the musicgen projector to get [batch_size, 128, 1024]\n", + " projected_pseudo_encoded_fmri = self.musicgen_decoder.enc_to_dec_proj(pseudo_encoded_fmri)\n", + "\n", + " if decoder_input_ids is None:\n", + " # if no decoder input ids are given, we create a tensor of the size [batch_size * num_codebooks, 1] filled with the pad token id\n", + " decoder_input_ids = (\n", + " torch.ones((x.shape[0] * self.num_codebooks, 1), dtype=torch.long)\n", + " * self.pad_token_id\n", + " )\n", + " \n", + " # now we pass the projected pseudo encoded fmri through the musicgen decoder\n", + " logits = self.musicgen_decoder.decoder(\n", + " input_ids = decoder_input_ids,\n", + " encoder_hidden_states = projected_pseudo_encoded_fmri,\n", + " ).logits\n", + "\n", + " return logits\n", + "\n", + " \n", + " def training_step(self, batch, batch_idx):\n", + " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 4, 750]\n", + " # take the last scan from the voxels\n", + " voxels = voxels[:, -1, :]\n", + " # convert the embeddings to long and combine the batch and codebook dimensions\n", + " embeddings = rearrange(embeddings, 'b c t -> (b c) t').long()\n", + "\n", + "\n", + " #take just the first 200 embeddings\n", + " #embeddings = embeddings[:, :200]\n", + " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n", + " #voxels = voxels[:, 0:2, :]\n", + " #voxels = voxels.mean(dim=1)\n", + " #voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n", + "\n", + "\n", + " # use the decoder input ids to get the logits\n", + " decoder_input_ids = embeddings[:, :-1]\n", + " logits = self(voxels, decoder_input_ids)\n", + "\n", + " # get the loss\n", + " loss = self.loss(rearrange(logits, '(b c) t d -> (b c t) d', c=self.num_codebooks), rearrange(embeddings[:, 1:], '(b c) t -> (b c t)', c=self.num_codebooks))\n", + "\n", + "\n", + " acuracy = self.tokens_accuracy(logits, embeddings[:,1:])\n", + " self.log('train_loss', loss, sync_dist=True)\n", + " self.log('train_accuracy', acuracy, sync_dist=True)\n", + " discrete_outputs = logits.argmax(dim=2)\n", + " self.train_outptus.append(discrete_outputs)\n", + " return loss\n", + " \n", + " def tokens_accuracy(self, outputs, embeddings):\n", + " # outputs is [batch_size, 750, 2048]\n", + " # embeddings is [batch_size, 750]\n", + " # we need to get the index of the maximum value of each token\n", + " outputs = outputs.argmax(dim=2)\n", + " # now we need to compare the outputs with the embeddings\n", + " return (outputs == embeddings).float().mean()\n", + " \n", + " def on_train_epoch_end(self):\n", + " self.train_outptus = torch.cat(self.train_outptus)\n", + " # save the outputs with the current epoch name\n", + " torch.save(self.train_outptus, 'outputs_train'+str(self.current_epoch)+'.pt')\n", + " self.train_outptus = []\n", + " \n", + " def on_validation_epoch_end(self):\n", + " self.test_outptus = torch.cat(self.test_outptus)\n", + " # save the outputs with the current epoch name\n", + " torch.save(self.test_outptus, 'outputs_validation'+str(self.current_epoch)+'.pt')\n", + " self.test_outptus = []\n", + "\n", + " \n", + " def validation_step(self, batch, batch_idx):\n", + " voxels, embeddings = batch\n", + " # take the last scan from the voxels\n", + " voxels = voxels[:, -1, :]\n", + " # convert the embeddings to long and combine the batch and codebook dimensions\n", + " embeddings = rearrange(embeddings, 'b c t -> (b c) t').long()\n", + "\n", + " # take just the first 200 embeddings\n", + " #embeddings = embeddings[:, :200]\n", + " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n", + " #voxels = voxels[:, 0:2, :]\n", + " #voxels = voxels.mean(dim=1)\n", + " #voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n", + "\n", + " # use the decoder input ids to get the logits\n", + " decoder_input_ids = embeddings[:, :-1]\n", + " logits = self(voxels, decoder_input_ids)\n", + "\n", + " # get the loss\n", + " loss = self.loss(rearrange(logits, '(b c) t d -> (b c t) d', c=self.num_codebooks), rearrange(embeddings[:, 1:], '(b c) t -> (b c t)', c=self.num_codebooks))\n", + "\n", + " acuracy = self.tokens_accuracy(logits, embeddings[:,1:])\n", + " self.log('val_loss', loss, sync_dist=True)\n", + " self.log('val_accuracy', acuracy, sync_dist=True)\n", + " discrete_outputs = logits.argmax(dim=2)\n", + " self.test_outptus.append(discrete_outputs)\n", + " return loss\n", + " \n", + " \n", + "\n", + " def configure_optimizers(self):\n", + " # we just want to train the pseudo text encoder, but we need to zero the gradients of the musicgen decoder\n", + " optimizer = torch.optim.AdamW(\n", + " [\n", + " {'params': self.pseudo_text_encoder.parameters(), 'lr': 3e-6, 'weight_decay': 1e-4},\n", + " {'params': self.musicgen_decoder.parameters(), 'lr': 0},\n", + " ],\n", + " )\n", + " return optimizer\n", + "\n", + "\n", + "# create the model\n", + "\"\"\"b2m_test = B2M().to('cuda')\n", + "test_b2m_input = torch.randn(4, 60784).to('cuda')\n", + "audio_codes = np.load('/fsx/proj-fmri/ckadirt/b2m/data/encodec32khz_training_embeds_sorted.npy')\n", + "audio_codes = torch.from_numpy(audio_codes)\n", + "audio_codes = audio_codes[:4, :, :]\n", + "audio_codes = rearrange(audio_codes, 'b c t -> (b c) t').to('cuda').long()\n", + "test_b2m_output = b2m_test(test_b2m_input, audio_codes)\n", + "print(test_b2m_output.shape)\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0646c340-c3df-42c0-ad22-2b677fc85da6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "b2m = B2M()\n", + "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4)\n", + "\n", + "wandb.finish()\n", + "\n", + "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n", + "\n", + "# define the trainer\n", + "trainer = pl.Trainer(devices=1, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n", + "\n", + "# train the model\n", + "trainer.fit(b2m, datamodule=data_module)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "4f458adf-1e89-4b8b-9514-08bcc9e8ef56", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fri Sep 8 05:08:00 2023 \n", + "+-----------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n", + "|-------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|===============================+======================+======================|\n", + "| 0 NVIDIA A100-SXM... On | 00000000:10:1C.0 Off | 0 |\n", + "| N/A 31C P0 71W / 400W | 29323MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 1 NVIDIA A100-SXM... On | 00000000:10:1D.0 Off | 0 |\n", + "| N/A 40C P0 204W / 400W | 40217MiB / 40960MiB | 100% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 2 NVIDIA A100-SXM... On | 00000000:20:1C.0 Off | 0 |\n", + "| N/A 47C P0 243W / 400W | 40421MiB / 40960MiB | 99% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 3 NVIDIA A100-SXM... On | 00000000:20:1D.0 Off | 0 |\n", + "| N/A 42C P0 194W / 400W | 37547MiB / 40960MiB | 100% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 4 NVIDIA A100-SXM... On | 00000000:90:1C.0 Off | 0 |\n", + "| N/A 45C P0 282W / 400W | 29223MiB / 40960MiB | 100% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 5 NVIDIA A100-SXM... On | 00000000:90:1D.0 Off | 0 |\n", + "| N/A 39C P0 179W / 400W | 24387MiB / 40960MiB | 100% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 6 NVIDIA A100-SXM... On | 00000000:A0:1C.0 Off | 0 |\n", + "| N/A 43C P0 232W / 400W | 29819MiB / 40960MiB | 88% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 7 NVIDIA A100-SXM... On | 00000000:A0:1D.0 Off | 0 |\n", + "| N/A 38C P0 166W / 400W | 28583MiB / 40960MiB | 100% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + " \n", + "+-----------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=============================================================================|\n", + "| 0 N/A N/A 895711 C ...3/envs/mindeye/bin/python 29320MiB |\n", + "| 1 N/A N/A 899054 C ...ari/llama_env/bin/python3 40214MiB |\n", + "| 2 N/A N/A 899055 C ...ari/llama_env/bin/python3 40418MiB |\n", + "| 3 N/A N/A 899056 C ...ari/llama_env/bin/python3 37544MiB |\n", + "| 4 N/A N/A 899057 C ...ari/llama_env/bin/python3 29220MiB |\n", + "| 5 N/A N/A 899058 C ...ari/llama_env/bin/python3 24384MiB |\n", + "| 6 N/A N/A 899059 C ...ari/llama_env/bin/python3 29816MiB |\n", + "| 7 N/A N/A 899060 C ...ari/llama_env/bin/python3 28580MiB |\n", + "+-----------------------------------------------------------------------------+\n" + ] + } + ], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "e4c9f9f5-a984-4fcf-8938-861f5bfb38d6", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([240, 749])\n" + ] + } + ], + "source": [ + "# read this file reb2m/src/outputs_train39.pt\n", + "train_outputs = torch.load('/fsx/proj-fmri/ckadirt/b2m/src/outputs_validation39.pt')\n", + "print(train_outputs.shape)\n", + "example1 = train_outputs[0:4].unsqueeze(0).unsqueeze(0)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "80e45d62-27d4-4ebe-a86c-56a9cc4ac0de", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "data = np.load('/fsx/proj-fmri/ckadirt/b2m/data/encodec32khz_training_embeds_sorted.npy')" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "ce72b901-e670-48ca-8c92-2221b46f6145", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 753., 609., 248., 248., 248., 1271., 248., 1350., 1350.,\n", + " 1350., 1350., 367., 297., 1251., 702., 1103., 1106., 1600.,\n", + " 1106., 457., 1638., 903., 1103., 1673., 1103., 902., 149.,\n", + " 1544., 1458., 1544., 773., 1470., 1470., 890., 1457., 1038.,\n", + " 2008., 1506., 457., 1126., 1047., 1103., 1933., 3., 560.,\n", + " 714., 271., 1442., 1710., 949., 1508., 957., 685., 399.,\n", + " 1103., 1667., 1555., 1529., 494., 1436., 1883., 29., 225.,\n", + " 846., 773., 569., 677., 71., 888., 1693., 1401., 888.,\n", + " 1016., 792., 569., 1590., 71., 1600., 314., 272., 1756.,\n", + " 1917., 1264., 917., 1021., 178., 1205., 974., 457., 457.,\n", + " 1106., 569., 1562., 271., 1977., 367., 345., 893., 1842.,\n", + " 1401., 1152., 1152., 1152., 1152., 50., 1418., 1748., 1188.,\n", + " 2043., 1666., 1796., 1512., 457., 812., 1600., 1764., 879.,\n", + " 1194., 1457., 1842., 1883., 104., 666., 352., 612., 1710.,\n", + " 1458., 315., 1990., 741., 1047., 675., 514., 1051., 1132.,\n", + " 1115., 315., 1347., 1670., 1875., 1194., 71., 1786., 196.,\n", + " 2043., 1818., 1477., 996., 1083., 967., 128., 1629., 1562.,\n", + " 1875., 237., 712., 1279., 29., 675., 1207., 1303., 1622.,\n", + " 1622., 1622., 1622., 1557., 675., 1303., 261., 675., 1245.,\n", + " 1245., 675., 675., 714., 378., 1673., 1145., 1673., 1673.,\n", + " 2013., 1990., 974., 457., 1124., 1562., 3., 1721., 846.,\n", + " 378., 271., 271., 675., 1782., 876., 1918., 483., 1419.,\n", + " 1693., 1562., 50., 1198., 1786., 1492., 1670., 1292., 104.,\n", + " 1286., 620., 776., 828., 1629., 762., 71., 1095., 367.,\n", + " 2043., 1501., 1152., 320., 1271., 801., 1671., 1418., 1213.,\n", + " 71., 40., 1419., 1179., 178., 1106., 1016., 1652., 902.,\n", + " 1074., 366., 1484., 42., 1457., 1453., 1241., 849., 1888.,\n", + " 1775., 1888., 82., 1671., 836., 82., 82., 472., 247.,\n", + " 1883., 29., 398., 642., 1904., 1436., 2002., 71., 71.,\n", + " 71., 71., 938., 1122., 1106., 736., 367., 1531., 778.,\n", + " 1388., 1949., 207., 1418., 721., 1130., 989., 1303., 1402.,\n", + " 1047., 569., 569., 569., 272., 320., 548., 976., 314.,\n", + " 1095., 1559., 1378., 828., 1742., 1378., 620., 1271., 776.,\n", + " 801., 1213., 1358., 1883., 1157., 104., 1744., 648., 1271.,\n", + " 1373., 1956., 685., 290., 938., 965., 208., 823., 1287.,\n", + " 893., 1470., 775., 1158., 775., 1990., 986., 1152., 1358.,\n", + " 312., 1111., 1564., 50., 237., 1646., 937., 1052., 1917.,\n", + " 1742., 1702., 297., 259., 1734., 54., 1933., 71., 1875.,\n", + " 1875., 1702., 1875., 237., 237., 237., 164., 1602., 220.,\n", + " 457., 1798., 457., 1798., 1798., 1514., 548., 523., 949.,\n", + " 1599., 714., 1673., 893., 714., 1016., 1016., 1638., 1562.,\n", + " 714., 1016., 1103., 871., 569., 1047., 1047., 612., 1646.,\n", + " 1875., 747., 714., 718., 1562., 1673., 1555., 1763., 1127.,\n", + " 793., 1817., 657., 1106., 457., 1106., 458., 1599., 1106.,\n", + " 29., 1863., 1103., 1599., 714., 812., 1194., 1508., 1194.,\n", + " 1562., 986., 714., 1308., 1097., 1207., 1095., 1763., 1127.,\n", + " 642., 1419., 290., 42., 1246., 400., 1462., 1194., 773.,\n", + " 741., 832., 1368., 2042., 937., 71., 1541., 42., 2008.,\n", + " 920., 872., 1473., 890., 234., 234., 1781., 1742., 1742.,\n", + " 71., 508., 237., 1790., 1428., 347., 1907., 1642., 457.,\n", + " 1106., 987., 1106., 1907., 1562., 1378., 1638., 1106., 1106.,\n", + " 399., 1106., 1106., 569., 1194., 695., 1286., 237., 2043.,\n", + " 1891., 1933., 1194., 1445., 1670., 1933., 1426., 1198., 937.,\n", + " 1977., 677., 1907., 1907., 1492., 1514., 1798., 1562., 1823.,\n", + " 644., 1629., 1842., 4., 712., 949., 686., 607., 1343.,\n", + " 1907., 1907., 1106., 1973., 1051., 1974., 974., 1869., 913.,\n", + " 1499., 272., 1322., 789., 457., 1869., 354., 45., 1869.,\n", + " 846., 1103., 1393., 1798., 1424., 1766., 457., 1126., 1642.,\n", + " 1775., 1775., 1144., 84., 84., 1142., 1549., 1992., 1989.,\n", + " 1047., 500., 637., 1499., 1451., 1732., 297., 261., 1124.,\n", + " 29., 866., 457., 149., 1608., 1047., 1343., 1433., 1047.,\n", + " 1044., 1837., 1984., 1807., 869., 1598., 1188., 50., 569.,\n", + " 272., 1401., 1670., 367., 744., 1817., 178., 1904., 1562.,\n", + " 352., 71., 1670., 980., 1629., 736., 1130., 328., 302.,\n", + " 893., 1378., 1652., 1469., 1786., 1742., 328., 677., 1652.,\n", + " 1652., 1188., 2008., 1733., 1002., 1106., 1103., 464., 1052.,\n", + " 1782., 1508., 1106., 677., 1481., 1666., 1921., 548., 1106.,\n", + " 902., 1106., 893., 1106., 893., 1933., 893., 237., 1051.,\n", + " 626., 919., 1106., 1484., 1288., 1103., 902., 986., 162.,\n", + " 421., 1733., 281., 2031., 366., 494., 1594., 825., 730.,\n", + " 1702., 642., 1478., 1366., 869., 1106., 1106., 1047., 320.,\n", + " 1380., 569., 569., 569., 1380., 104., 1324., 1600., 1132.,\n", + " 1337., 1047., 612., 1555., 2001., 2027., 2027., 71., 800.,\n", + " 30., 2003., 1817., 1047., 1764., 714., 1457., 840., 1246.,\n", + " 1978., 297., 1798., 1127., 1106., 948., 1047., 457., 1158.,\n", + " 2001., 1308., 1316., 634., 1481., 1047., 762., 964., 1913.,\n", + " 1638., 1562., 1582., 1095., 936., 658., 1978., 1512., 1753.,\n", + " 1974., 1564., 2007., 890., 1562., 890., 1801., 1801., 1378.,\n", + " 71., 1481., 1863., 921., 1600., 121., 648., 1132., 1132.,\n", + " 1512., 1529., 263., 1529., 1564., 1484., 642., 642., 1629.,\n", + " 40., 1444., 1078., 1078., 1436., 118., 118., 1522., 1904.,\n", + " 237., 658., 658., 1711., 1095., 778., 121., 248., 999.,\n", + " 648., 648., 247., 648., 648., 8., 8., 1453., 609.,\n", + " 609., 83., 166.], dtype=float32)" + ] + }, + "execution_count": 58, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data[0][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "cf8382ae-ed92-43bb-a65c-7537161956d5", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([609, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", + " 237, 237, 237, 237, 237, 237, 237], device='cuda:0')" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "example1[0][0][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "f5842cd4-17b7-47c6-8939-31634f3db3cc", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n",
+       " in <module>:1                                                                                    \n",
+       "                                                                                                  \n",
+       " 1 sampled.unsqueeze(0).unsqueeze(0).detach().clone()                                           \n",
+       "   2                                                                                              \n",
+       "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+       "RuntimeError: CUDA error: device-side assert triggered\n",
+       "CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be \n",
+       "incorrect.\n",
+       "For debugging consider passing CUDA_LAUNCH_BLOCKING=1.\n",
+       "Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n",
+       "\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", + "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m1\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1 sampled.unsqueeze(\u001b[94m0\u001b[0m).unsqueeze(\u001b[94m0\u001b[0m).detach().clone() \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", + "\u001b[1;91mRuntimeError: \u001b[0mCUDA error: device-side assert triggered\n", + "CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be \n", + "incorrect.\n", + "For debugging consider passing \u001b[33mCUDA_LAUNCH_BLOCKING\u001b[0m=\u001b[1;36m1\u001b[0m.\n", + "Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n", + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "sampled.unsqueeze(0).unsqueeze(0).detach().clone()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "f802e748-ff63-4b9e-8147-77d62bb87c4b", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [0,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [1,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [2,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [3,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [4,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [5,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [6,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [7,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [8,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [9,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [10,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [11,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [12,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [13,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [14,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [15,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [16,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [17,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [18,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [19,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [20,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [21,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [22,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [23,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [24,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [25,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [26,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [27,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [28,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [29,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [30,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [31,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [32,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [33,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [34,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [35,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [36,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [37,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [38,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [39,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [40,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [41,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [42,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [43,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [44,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [45,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [46,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [47,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [48,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [49,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [50,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [51,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [52,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [53,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [54,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [55,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [56,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [57,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [58,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [59,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [60,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [61,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [62,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [63,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [64,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [65,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [66,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [67,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [68,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [69,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [70,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [71,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [72,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [73,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [74,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [75,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [76,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [77,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [78,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [79,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [80,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [81,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [82,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [83,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [84,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [85,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [86,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [87,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [88,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [89,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [90,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [91,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [92,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [93,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [94,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [95,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [96,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [97,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [98,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [99,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [100,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [101,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [102,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [103,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [104,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [105,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [106,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [107,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [108,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [109,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [110,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [111,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [112,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [113,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [114,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [115,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [116,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [117,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [118,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [119,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [120,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [121,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [122,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [123,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [124,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [125,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [126,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n", + "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [127,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n" + ] + }, + { + "data": { + "text/html": [ + "
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n",
+       " in <module>:1                                                                                    \n",
+       "                                                                                                  \n",
+       " 1 decoded = b2m.musicgen_decoder.audio_encoder.decode(audio_codes = sampled.unsqueeze(0).u     \n",
+       "   2                                                                                              \n",
+       "                                                                                                  \n",
+       " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc \n",
+       " odec/modeling_encodec.py:742 in decode                                                           \n",
+       "                                                                                                  \n",
+       "   739 │   │   if chunk_length is None:                                                           \n",
+       "   740 │   │   │   if len(audio_codes) != 1:                                                      \n",
+       "   741 │   │   │   │   raise ValueError(f\"Expected one frame, got {len(audio_codes)}\")            \n",
+       " 742 │   │   │   audio_values = self._decode_frame(audio_codes[0], audio_scales[0])             \n",
+       "   743 │   │   else:                                                                              \n",
+       "   744 │   │   │   decoded_frames = []                                                            \n",
+       "   745                                                                                            \n",
+       "                                                                                                  \n",
+       " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc \n",
+       " odec/modeling_encodec.py:707 in _decode_frame                                                    \n",
+       "                                                                                                  \n",
+       "   704 def _decode_frame(self, codes: torch.Tensor, scale: Optional[torch.Tensor] = None) -   \n",
+       "   705 │   │   codes = codes.transpose(0, 1)                                                      \n",
+       "   706 │   │   embeddings = self.quantizer.decode(codes)                                          \n",
+       " 707 │   │   outputs = self.decoder(embeddings)                                                 \n",
+       "   708 │   │   if scale is not None:                                                              \n",
+       "   709 │   │   │   outputs = outputs * scale.view(-1, 1, 1)                                       \n",
+       "   710 │   │   return outputs                                                                     \n",
+       "                                                                                                  \n",
+       " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/module \n",
+       " .py:1501 in _call_impl                                                                           \n",
+       "                                                                                                  \n",
+       "   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   \n",
+       "   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   \n",
+       "   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   \n",
+       " 1501 │   │   │   return forward_call(*args, **kwargs)                                          \n",
+       "   1502 │   │   # Do not call functions when jit is used                                          \n",
+       "   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             \n",
+       "   1504 │   │   backward_pre_hooks = []                                                           \n",
+       "                                                                                                  \n",
+       " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc \n",
+       " odec/modeling_encodec.py:336 in forward                                                          \n",
+       "                                                                                                  \n",
+       "   333                                                                                        \n",
+       "   334 def forward(self, hidden_states):                                                      \n",
+       "   335 │   │   for layer in self.layers:                                                          \n",
+       " 336 │   │   │   hidden_states = layer(hidden_states)                                           \n",
+       "   337 │   │   return hidden_states                                                               \n",
+       "   338                                                                                            \n",
+       "   339                                                                                            \n",
+       "                                                                                                  \n",
+       " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/module \n",
+       " .py:1501 in _call_impl                                                                           \n",
+       "                                                                                                  \n",
+       "   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   \n",
+       "   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   \n",
+       "   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   \n",
+       " 1501 │   │   │   return forward_call(*args, **kwargs)                                          \n",
+       "   1502 │   │   # Do not call functions when jit is used                                          \n",
+       "   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             \n",
+       "   1504 │   │   backward_pre_hooks = []                                                           \n",
+       "                                                                                                  \n",
+       " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc \n",
+       " odec/modeling_encodec.py:162 in forward                                                          \n",
+       "                                                                                                  \n",
+       "   159 │   │   │   # Asymmetric padding required for odd strides                                  \n",
+       "   160 │   │   │   padding_right = padding_total // 2                                             \n",
+       "   161 │   │   │   padding_left = padding_total - padding_right                                   \n",
+       " 162 │   │   │   hidden_states = self._pad1d(                                                   \n",
+       "   163 │   │   │   │   hidden_states, (padding_left, padding_right + extra_padding), mode=self.   \n",
+       "   164 │   │   │   )                                                                              \n",
+       "   165                                                                                            \n",
+       "                                                                                                  \n",
+       " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc \n",
+       " odec/modeling_encodec.py:143 in _pad1d                                                           \n",
+       "                                                                                                  \n",
+       "   140 │   │   if length <= max_pad:                                                              \n",
+       "   141 │   │   │   extra_pad = max_pad - length + 1                                               \n",
+       "   142 │   │   │   hidden_states = nn.functional.pad(hidden_states, (0, extra_pad))               \n",
+       " 143 │   │   padded = nn.functional.pad(hidden_states, paddings, mode, value)                   \n",
+       "   144 │   │   end = padded.shape[-1] - extra_pad                                                 \n",
+       "   145 │   │   return padded[..., :end]                                                           \n",
+       "   146                                                                                            \n",
+       "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+       "RuntimeError: CUDA error: device-side assert triggered\n",
+       "CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be \n",
+       "incorrect.\n",
+       "For debugging consider passing CUDA_LAUNCH_BLOCKING=1.\n",
+       "Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n",
+       "\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", + "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m1\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1 decoded = b2m.musicgen_decoder.audio_encoder.decode(audio_codes = sampled.unsqueeze(\u001b[94m0\u001b[0m).u \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33modec/\u001b[0m\u001b[1;33mmodeling_encodec.py\u001b[0m:\u001b[94m742\u001b[0m in \u001b[92mdecode\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m739 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m chunk_length \u001b[95mis\u001b[0m \u001b[94mNone\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m740 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mlen\u001b[0m(audio_codes) != \u001b[94m1\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m741 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[94mraise\u001b[0m \u001b[96mValueError\u001b[0m(\u001b[33mf\u001b[0m\u001b[33m\"\u001b[0m\u001b[33mExpected one frame, got \u001b[0m\u001b[33m{\u001b[0m\u001b[96mlen\u001b[0m(audio_codes)\u001b[33m}\u001b[0m\u001b[33m\"\u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m742 \u001b[2m│ │ │ \u001b[0maudio_values = \u001b[96mself\u001b[0m._decode_frame(audio_codes[\u001b[94m0\u001b[0m], audio_scales[\u001b[94m0\u001b[0m]) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m743 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94melse\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m744 \u001b[0m\u001b[2m│ │ │ \u001b[0mdecoded_frames = [] \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m745 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33modec/\u001b[0m\u001b[1;33mmodeling_encodec.py\u001b[0m:\u001b[94m707\u001b[0m in \u001b[92m_decode_frame\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m704 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92m_decode_frame\u001b[0m(\u001b[96mself\u001b[0m, codes: torch.Tensor, scale: Optional[torch.Tensor] = \u001b[94mNone\u001b[0m) - \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m705 \u001b[0m\u001b[2m│ │ \u001b[0mcodes = codes.transpose(\u001b[94m0\u001b[0m, \u001b[94m1\u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m706 \u001b[0m\u001b[2m│ │ \u001b[0membeddings = \u001b[96mself\u001b[0m.quantizer.decode(codes) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m707 \u001b[2m│ │ \u001b[0moutputs = \u001b[96mself\u001b[0m.decoder(embeddings) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m708 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m scale \u001b[95mis\u001b[0m \u001b[95mnot\u001b[0m \u001b[94mNone\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m709 \u001b[0m\u001b[2m│ │ │ \u001b[0moutputs = outputs * scale.view(-\u001b[94m1\u001b[0m, \u001b[94m1\u001b[0m, \u001b[94m1\u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m710 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m outputs \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/\u001b[0m\u001b[1;33mmodule\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[1;33m.py\u001b[0m:\u001b[94m1501\u001b[0m in \u001b[92m_call_impl\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1498 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[95mnot\u001b[0m (\u001b[96mself\u001b[0m._backward_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._backward_pre_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._forward_hooks \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1499 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_backward_pre_hooks \u001b[95mor\u001b[0m _global_backward_hooks \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1500 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_forward_hooks \u001b[95mor\u001b[0m _global_forward_pre_hooks): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1501 \u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m forward_call(*args, **kwargs) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1502 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# Do not call functions when jit is used\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1503 \u001b[0m\u001b[2m│ │ \u001b[0mfull_backward_hooks, non_full_backward_hooks = [], [] \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1504 \u001b[0m\u001b[2m│ │ \u001b[0mbackward_pre_hooks = [] \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33modec/\u001b[0m\u001b[1;33mmodeling_encodec.py\u001b[0m:\u001b[94m336\u001b[0m in \u001b[92mforward\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m333 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m334 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mforward\u001b[0m(\u001b[96mself\u001b[0m, hidden_states): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m335 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mfor\u001b[0m layer \u001b[95min\u001b[0m \u001b[96mself\u001b[0m.layers: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m336 \u001b[2m│ │ │ \u001b[0mhidden_states = layer(hidden_states) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m337 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m hidden_states \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m338 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m339 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/\u001b[0m\u001b[1;33mmodule\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[1;33m.py\u001b[0m:\u001b[94m1501\u001b[0m in \u001b[92m_call_impl\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1498 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[95mnot\u001b[0m (\u001b[96mself\u001b[0m._backward_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._backward_pre_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._forward_hooks \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1499 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_backward_pre_hooks \u001b[95mor\u001b[0m _global_backward_hooks \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1500 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_forward_hooks \u001b[95mor\u001b[0m _global_forward_pre_hooks): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1501 \u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m forward_call(*args, **kwargs) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1502 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# Do not call functions when jit is used\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1503 \u001b[0m\u001b[2m│ │ \u001b[0mfull_backward_hooks, non_full_backward_hooks = [], [] \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1504 \u001b[0m\u001b[2m│ │ \u001b[0mbackward_pre_hooks = [] \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33modec/\u001b[0m\u001b[1;33mmodeling_encodec.py\u001b[0m:\u001b[94m162\u001b[0m in \u001b[92mforward\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m159 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[2m# Asymmetric padding required for odd strides\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m160 \u001b[0m\u001b[2m│ │ │ \u001b[0mpadding_right = padding_total // \u001b[94m2\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m161 \u001b[0m\u001b[2m│ │ │ \u001b[0mpadding_left = padding_total - padding_right \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m162 \u001b[2m│ │ │ \u001b[0mhidden_states = \u001b[96mself\u001b[0m._pad1d( \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m163 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mhidden_states, (padding_left, padding_right + extra_padding), mode=\u001b[96mself\u001b[0m. \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m164 \u001b[0m\u001b[2m│ │ │ \u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m165 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33modec/\u001b[0m\u001b[1;33mmodeling_encodec.py\u001b[0m:\u001b[94m143\u001b[0m in \u001b[92m_pad1d\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m140 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m length <= max_pad: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m141 \u001b[0m\u001b[2m│ │ │ \u001b[0mextra_pad = max_pad - length + \u001b[94m1\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m142 \u001b[0m\u001b[2m│ │ │ \u001b[0mhidden_states = nn.functional.pad(hidden_states, (\u001b[94m0\u001b[0m, extra_pad)) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m143 \u001b[2m│ │ \u001b[0mpadded = nn.functional.pad(hidden_states, paddings, mode, value) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m144 \u001b[0m\u001b[2m│ │ \u001b[0mend = padded.shape[-\u001b[94m1\u001b[0m] - extra_pad \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m145 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m padded[..., :end] \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m146 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", + "\u001b[1;91mRuntimeError: \u001b[0mCUDA error: device-side assert triggered\n", + "CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be \n", + "incorrect.\n", + "For debugging consider passing \u001b[33mCUDA_LAUNCH_BLOCKING\u001b[0m=\u001b[1;36m1\u001b[0m.\n", + "Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n", + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "decoded = b2m.musicgen_decoder.audio_encoder.decode(audio_codes = sampled.unsqueeze(0).unsqueeze(0).detach().clone(), audio_scales = [None])" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "id": "61a216c2-7357-4a06-85aa-13db8acbba79", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fri Sep 8 05:39:28 2023 \n", + "+-----------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n", + "|-------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|===============================+======================+======================|\n", + "| 0 NVIDIA A100-SXM... On | 00000000:10:1C.0 Off | 0 |\n", + "| N/A 30C P0 70W / 400W | 29365MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 1 NVIDIA A100-SXM... On | 00000000:10:1D.0 Off | 0 |\n", + "| N/A 41C P0 240W / 400W | 33897MiB / 40960MiB | 92% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 2 NVIDIA A100-SXM... On | 00000000:20:1C.0 Off | 0 |\n", + "| N/A 44C P0 208W / 400W | 37457MiB / 40960MiB | 82% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 3 NVIDIA A100-SXM... On | 00000000:20:1D.0 Off | 0 |\n", + "| N/A 38C P0 189W / 400W | 40233MiB / 40960MiB | 76% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 4 NVIDIA A100-SXM... On | 00000000:90:1C.0 Off | 0 |\n", + "| N/A 41C P0 242W / 400W | 34031MiB / 40960MiB | 95% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 5 NVIDIA A100-SXM... On | 00000000:90:1D.0 Off | 0 |\n", + "| N/A 38C P0 201W / 400W | 30723MiB / 40960MiB | 100% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 6 NVIDIA A100-SXM... On | 00000000:A0:1C.0 Off | 0 |\n", + "| N/A 39C P0 183W / 400W | 38225MiB / 40960MiB | 97% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 7 NVIDIA A100-SXM... On | 00000000:A0:1D.0 Off | 0 |\n", + "| N/A 38C P0 197W / 400W | 40121MiB / 40960MiB | 79% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + " \n", + "+-----------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=============================================================================|\n", + "| 0 N/A N/A 895711 C ...3/envs/mindeye/bin/python 29362MiB |\n", + "| 1 N/A N/A 919504 C ...ari/llama_env/bin/python3 33894MiB |\n", + "| 2 N/A N/A 919505 C ...ari/llama_env/bin/python3 37454MiB |\n", + "| 3 N/A N/A 919506 C ...ari/llama_env/bin/python3 40230MiB |\n", + "| 4 N/A N/A 919507 C ...ari/llama_env/bin/python3 34028MiB |\n", + "| 5 N/A N/A 919508 C ...ari/llama_env/bin/python3 30720MiB |\n", + "| 6 N/A N/A 919509 C ...ari/llama_env/bin/python3 38222MiB |\n", + "| 7 N/A N/A 919510 C ...ari/llama_env/bin/python3 40118MiB |\n", + "+-----------------------------------------------------------------------------+\n" + ] + } + ], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "f44e9028-39c7-4a55-a6e0-ffc84de2b093", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 1, 479360])" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "decoded.audio_values.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "a2fb3ad8-7a2b-4baf-adb5-485b1deb4828", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from IPython.display import Audio\n", + "Audio(decoded[0][0].cpu().numpy(), rate=32000)" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "48e8680b-bdc3-4c41-af8f-511ca0c0acf8", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n",
+       " in <module>:2                                                                                    \n",
+       "                                                                                                  \n",
+       "   1 projected_pseudo_encoded_fmri = torch.rand((1,15,1024))                                      \n",
+       " 2 gepe = b2m.musicgen_decoder.generate(                                                        \n",
+       "   3 │   │   │   encoder_hidden_states = projected_pseudo_encoded_fmri,                           \n",
+       "   4 │   │   │   max_length = 752                                                                 \n",
+       "   5 │   │   )                                                                                    \n",
+       "                                                                                                  \n",
+       " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/utils/_contextlib \n",
+       " .py:115 in decorate_context                                                                      \n",
+       "                                                                                                  \n",
+       "   112 @functools.wraps(func)                                                                 \n",
+       "   113 def decorate_context(*args, **kwargs):                                                 \n",
+       "   114 │   │   with ctx_factory():                                                                \n",
+       " 115 │   │   │   return func(*args, **kwargs)                                                   \n",
+       "   116                                                                                        \n",
+       "   117 return decorate_context                                                                \n",
+       "   118                                                                                            \n",
+       "                                                                                                  \n",
+       " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/mus \n",
+       " icgen/modeling_musicgen.py:2261 in generate                                                      \n",
+       "                                                                                                  \n",
+       "   2258 │   │   generation_config = copy.deepcopy(generation_config)                              \n",
+       "   2259 │   │   model_kwargs = generation_config.update(**kwargs)  # All unused kwargs must be m  \n",
+       "   2260 │   │   generation_config.validate()                                                      \n",
+       " 2261 │   │   self._validate_model_kwargs(model_kwargs.copy())                                  \n",
+       "   2262 │   │                                                                                     \n",
+       "   2263 │   │   if model_kwargs.get(\"encoder_outputs\") is not None and type(model_kwargs[\"encode  \n",
+       "   2264 │   │   │   # wrap the unconditional outputs as a BaseModelOutput for compatibility with  \n",
+       "                                                                                                  \n",
+       " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/generation \n",
+       " /utils.py:1249 in _validate_model_kwargs                                                         \n",
+       "                                                                                                  \n",
+       "   1246 │   │   │   │   unused_model_args.append(key)                                             \n",
+       "   1247 │   │                                                                                     \n",
+       "   1248 │   │   if unused_model_args:                                                             \n",
+       " 1249 │   │   │   raise ValueError(                                                             \n",
+       "   1250 │   │   │   │   f\"The following `model_kwargs` are not used by the model: {unused_model_  \n",
+       "   1251 │   │   │   │   \" generate arguments will also show up in this list)\"                     \n",
+       "   1252 │   │   │   )                                                                             \n",
+       "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+       "ValueError: The following `model_kwargs` are not used by the model: ['encoder_hidden_states'] (note: typos in the \n",
+       "generate arguments will also show up in this list)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", + "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m2\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1 \u001b[0mprojected_pseudo_encoded_fmri = torch.rand((\u001b[94m1\u001b[0m,\u001b[94m15\u001b[0m,\u001b[94m1024\u001b[0m)) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m2 gepe = b2m.musicgen_decoder.generate( \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m3 \u001b[0m\u001b[2m│ │ │ \u001b[0mencoder_hidden_states = projected_pseudo_encoded_fmri, \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m4 \u001b[0m\u001b[2m│ │ │ \u001b[0mmax_length = \u001b[94m752\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m5 \u001b[0m\u001b[2m│ │ \u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/utils/\u001b[0m\u001b[1;33m_contextlib\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[1;33m.py\u001b[0m:\u001b[94m115\u001b[0m in \u001b[92mdecorate_context\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m112 \u001b[0m\u001b[2m│ \u001b[0m\u001b[1;95m@functools\u001b[0m.wraps(func) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m113 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mdecorate_context\u001b[0m(*args, **kwargs): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m114 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mwith\u001b[0m ctx_factory(): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m115 \u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m func(*args, **kwargs) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m116 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m117 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mreturn\u001b[0m decorate_context \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m118 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/mus\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33micgen/\u001b[0m\u001b[1;33mmodeling_musicgen.py\u001b[0m:\u001b[94m2261\u001b[0m in \u001b[92mgenerate\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2258 \u001b[0m\u001b[2m│ │ \u001b[0mgeneration_config = copy.deepcopy(generation_config) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2259 \u001b[0m\u001b[2m│ │ \u001b[0mmodel_kwargs = generation_config.update(**kwargs) \u001b[2m# All unused kwargs must be m\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2260 \u001b[0m\u001b[2m│ │ \u001b[0mgeneration_config.validate() \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m2261 \u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m._validate_model_kwargs(model_kwargs.copy()) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2262 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2263 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m model_kwargs.get(\u001b[33m\"\u001b[0m\u001b[33mencoder_outputs\u001b[0m\u001b[33m\"\u001b[0m) \u001b[95mis\u001b[0m \u001b[95mnot\u001b[0m \u001b[94mNone\u001b[0m \u001b[95mand\u001b[0m \u001b[96mtype\u001b[0m(model_kwargs[\u001b[33m\"\u001b[0m\u001b[33mencode\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2264 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[2m# wrap the unconditional outputs as a BaseModelOutput for compatibility with\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/generation\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/\u001b[0m\u001b[1;33mutils.py\u001b[0m:\u001b[94m1249\u001b[0m in \u001b[92m_validate_model_kwargs\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1246 \u001b[0m\u001b[2m│ │ │ │ \u001b[0munused_model_args.append(key) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1247 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1248 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m unused_model_args: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1249 \u001b[2m│ │ │ \u001b[0m\u001b[94mraise\u001b[0m \u001b[96mValueError\u001b[0m( \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1250 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[33mf\u001b[0m\u001b[33m\"\u001b[0m\u001b[33mThe following `model_kwargs` are not used by the model: \u001b[0m\u001b[33m{\u001b[0munused_model_ \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1251 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[33m\"\u001b[0m\u001b[33m generate arguments will also show up in this list)\u001b[0m\u001b[33m\"\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1252 \u001b[0m\u001b[2m│ │ │ \u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", + "\u001b[1;91mValueError: \u001b[0mThe following `model_kwargs` are not used by the model: \u001b[1m[\u001b[0m\u001b[32m'encoder_hidden_states'\u001b[0m\u001b[1m]\u001b[0m \u001b[1m(\u001b[0mnote: typos in the \n", + "generate arguments will also show up in this list\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "projected_pseudo_encoded_fmri = torch.rand((1,15,1024))\n", + "gepe = b2m.musicgen_decoder.generate(\n", + " encoder_hidden_states = projected_pseudo_encoded_fmri,\n", + " max_length = 752\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "059f21d7-58ed-4b35-93d6-1ea983ac3161", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "b2m = B2M()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "ae335238-50f3-456a-9dca-3a0958aaa7b7", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "ckpt = torch.load('/fsx/proj-fmri/ckadirt/b2m/src/brain2music/jggbeix7/checkpoints/epoch=39-step=4800.ckpt')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "ff5199bf-a9af-491e-9870-7b203bea49fd", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "b2m.load_state_dict(ckpt['state_dict'])" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a6e4985e-514d-4ed2-bcc7-11edd2ba4387", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "B2M(\n", + " (brain_network): BrainNetwork(\n", + " (mlp): ModuleList(\n", + " (0-3): 4 x Sequential(\n", + " (0): Linear(in_features=4096, out_features=4096, bias=True)\n", + " (1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n", + " (2): GELU(approximate='none')\n", + " (3): Dropout(p=0.15, inplace=False)\n", + " )\n", + " )\n", + " (lin1): Linear(in_features=4096, out_features=98304, bias=True)\n", + " (projector): Sequential(\n", + " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (1): GELU(approximate='none')\n", + " (2): Linear(in_features=128, out_features=2048, bias=True)\n", + " (3): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n", + " (4): GELU(approximate='none')\n", + " (5): Linear(in_features=2048, out_features=2048, bias=True)\n", + " (6): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n", + " (7): GELU(approximate='none')\n", + " (8): Linear(in_features=2048, out_features=128, bias=True)\n", + " )\n", + " )\n", + " (ridge_regression): RidgeRegression(\n", + " (linear): Linear(in_features=60784, out_features=4096, bias=True)\n", + " )\n", + " (loss): CrossEntropyLoss()\n", + " (pseudo_text_encoder): Sequential(\n", + " (0): RidgeRegression(\n", + " (linear): Linear(in_features=60784, out_features=4096, bias=True)\n", + " )\n", + " (1): BrainNetwork(\n", + " (mlp): ModuleList(\n", + " (0-3): 4 x Sequential(\n", + " (0): Linear(in_features=4096, out_features=4096, bias=True)\n", + " (1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n", + " (2): GELU(approximate='none')\n", + " (3): Dropout(p=0.15, inplace=False)\n", + " )\n", + " )\n", + " (lin1): Linear(in_features=4096, out_features=98304, bias=True)\n", + " (projector): Sequential(\n", + " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", + " (1): GELU(approximate='none')\n", + " (2): Linear(in_features=128, out_features=2048, bias=True)\n", + " (3): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n", + " (4): GELU(approximate='none')\n", + " (5): Linear(in_features=2048, out_features=2048, bias=True)\n", + " (6): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n", + " (7): GELU(approximate='none')\n", + " (8): Linear(in_features=2048, out_features=128, bias=True)\n", + " )\n", + " )\n", + " )\n", + " (musicgen_decoder): MusicgenForConditionalGeneration(\n", + " (text_encoder): T5EncoderModel(\n", + " (shared): Embedding(32128, 768)\n", + " (encoder): T5Stack(\n", + " (embed_tokens): Embedding(32128, 768)\n", + " (block): ModuleList(\n", + " (0): T5Block(\n", + " (layer): ModuleList(\n", + " (0): T5LayerSelfAttention(\n", + " (SelfAttention): T5Attention(\n", + " (q): Linear(in_features=768, out_features=768, bias=False)\n", + " (k): Linear(in_features=768, out_features=768, bias=False)\n", + " (v): Linear(in_features=768, out_features=768, bias=False)\n", + " (o): Linear(in_features=768, out_features=768, bias=False)\n", + " (relative_attention_bias): Embedding(32, 12)\n", + " )\n", + " (layer_norm): T5LayerNorm()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (1): T5LayerFF(\n", + " (DenseReluDense): T5DenseActDense(\n", + " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", + " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (act): ReLU()\n", + " )\n", + " (layer_norm): T5LayerNorm()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " (1-11): 11 x T5Block(\n", + " (layer): ModuleList(\n", + " (0): T5LayerSelfAttention(\n", + " (SelfAttention): T5Attention(\n", + " (q): Linear(in_features=768, out_features=768, bias=False)\n", + " (k): Linear(in_features=768, out_features=768, bias=False)\n", + " (v): Linear(in_features=768, out_features=768, bias=False)\n", + " (o): Linear(in_features=768, out_features=768, bias=False)\n", + " )\n", + " (layer_norm): T5LayerNorm()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (1): T5LayerFF(\n", + " (DenseReluDense): T5DenseActDense(\n", + " (wi): Linear(in_features=768, out_features=3072, bias=False)\n", + " (wo): Linear(in_features=3072, out_features=768, bias=False)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (act): ReLU()\n", + " )\n", + " (layer_norm): T5LayerNorm()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (final_layer_norm): T5LayerNorm()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (audio_encoder): EncodecModel(\n", + " (encoder): EncodecEncoder(\n", + " (layers): ModuleList(\n", + " (0): EncodecConv1d(\n", + " (conv): Conv1d(1, 64, kernel_size=(7,), stride=(1,))\n", + " )\n", + " (1): EncodecResnetBlock(\n", + " (block): ModuleList(\n", + " (0): ELU(alpha=1.0)\n", + " (1): EncodecConv1d(\n", + " (conv): Conv1d(64, 32, kernel_size=(3,), stride=(1,))\n", + " )\n", + " (2): ELU(alpha=1.0)\n", + " (3): EncodecConv1d(\n", + " (conv): Conv1d(32, 64, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " (shortcut): Identity()\n", + " )\n", + " (2): ELU(alpha=1.0)\n", + " (3): EncodecConv1d(\n", + " (conv): Conv1d(64, 128, kernel_size=(8,), stride=(4,))\n", + " )\n", + " (4): EncodecResnetBlock(\n", + " (block): ModuleList(\n", + " (0): ELU(alpha=1.0)\n", + " (1): EncodecConv1d(\n", + " (conv): Conv1d(128, 64, kernel_size=(3,), stride=(1,))\n", + " )\n", + " (2): ELU(alpha=1.0)\n", + " (3): EncodecConv1d(\n", + " (conv): Conv1d(64, 128, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " (shortcut): Identity()\n", + " )\n", + " (5): ELU(alpha=1.0)\n", + " (6): EncodecConv1d(\n", + " (conv): Conv1d(128, 256, kernel_size=(8,), stride=(4,))\n", + " )\n", + " (7): EncodecResnetBlock(\n", + " (block): ModuleList(\n", + " (0): ELU(alpha=1.0)\n", + " (1): EncodecConv1d(\n", + " (conv): Conv1d(256, 128, kernel_size=(3,), stride=(1,))\n", + " )\n", + " (2): ELU(alpha=1.0)\n", + " (3): EncodecConv1d(\n", + " (conv): Conv1d(128, 256, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " (shortcut): Identity()\n", + " )\n", + " (8): ELU(alpha=1.0)\n", + " (9): EncodecConv1d(\n", + " (conv): Conv1d(256, 512, kernel_size=(10,), stride=(5,))\n", + " )\n", + " (10): EncodecResnetBlock(\n", + " (block): ModuleList(\n", + " (0): ELU(alpha=1.0)\n", + " (1): EncodecConv1d(\n", + " (conv): Conv1d(512, 256, kernel_size=(3,), stride=(1,))\n", + " )\n", + " (2): ELU(alpha=1.0)\n", + " (3): EncodecConv1d(\n", + " (conv): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " (shortcut): Identity()\n", + " )\n", + " (11): ELU(alpha=1.0)\n", + " (12): EncodecConv1d(\n", + " (conv): Conv1d(512, 1024, kernel_size=(16,), stride=(8,))\n", + " )\n", + " (13): EncodecLSTM(\n", + " (lstm): LSTM(1024, 1024, num_layers=2)\n", + " )\n", + " (14): ELU(alpha=1.0)\n", + " (15): EncodecConv1d(\n", + " (conv): Conv1d(1024, 128, kernel_size=(7,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (decoder): EncodecDecoder(\n", + " (layers): ModuleList(\n", + " (0): EncodecConv1d(\n", + " (conv): Conv1d(128, 1024, kernel_size=(7,), stride=(1,))\n", + " )\n", + " (1): EncodecLSTM(\n", + " (lstm): LSTM(1024, 1024, num_layers=2)\n", + " )\n", + " (2): ELU(alpha=1.0)\n", + " (3): EncodecConvTranspose1d(\n", + " (conv): ConvTranspose1d(1024, 512, kernel_size=(16,), stride=(8,))\n", + " )\n", + " (4): EncodecResnetBlock(\n", + " (block): ModuleList(\n", + " (0): ELU(alpha=1.0)\n", + " (1): EncodecConv1d(\n", + " (conv): Conv1d(512, 256, kernel_size=(3,), stride=(1,))\n", + " )\n", + " (2): ELU(alpha=1.0)\n", + " (3): EncodecConv1d(\n", + " (conv): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " (shortcut): Identity()\n", + " )\n", + " (5): ELU(alpha=1.0)\n", + " (6): EncodecConvTranspose1d(\n", + " (conv): ConvTranspose1d(512, 256, kernel_size=(10,), stride=(5,))\n", + " )\n", + " (7): EncodecResnetBlock(\n", + " (block): ModuleList(\n", + " (0): ELU(alpha=1.0)\n", + " (1): EncodecConv1d(\n", + " (conv): Conv1d(256, 128, kernel_size=(3,), stride=(1,))\n", + " )\n", + " (2): ELU(alpha=1.0)\n", + " (3): EncodecConv1d(\n", + " (conv): Conv1d(128, 256, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " (shortcut): Identity()\n", + " )\n", + " (8): ELU(alpha=1.0)\n", + " (9): EncodecConvTranspose1d(\n", + " (conv): ConvTranspose1d(256, 128, kernel_size=(8,), stride=(4,))\n", + " )\n", + " (10): EncodecResnetBlock(\n", + " (block): ModuleList(\n", + " (0): ELU(alpha=1.0)\n", + " (1): EncodecConv1d(\n", + " (conv): Conv1d(128, 64, kernel_size=(3,), stride=(1,))\n", + " )\n", + " (2): ELU(alpha=1.0)\n", + " (3): EncodecConv1d(\n", + " (conv): Conv1d(64, 128, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " (shortcut): Identity()\n", + " )\n", + " (11): ELU(alpha=1.0)\n", + " (12): EncodecConvTranspose1d(\n", + " (conv): ConvTranspose1d(128, 64, kernel_size=(8,), stride=(4,))\n", + " )\n", + " (13): EncodecResnetBlock(\n", + " (block): ModuleList(\n", + " (0): ELU(alpha=1.0)\n", + " (1): EncodecConv1d(\n", + " (conv): Conv1d(64, 32, kernel_size=(3,), stride=(1,))\n", + " )\n", + " (2): ELU(alpha=1.0)\n", + " (3): EncodecConv1d(\n", + " (conv): Conv1d(32, 64, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " (shortcut): Identity()\n", + " )\n", + " (14): ELU(alpha=1.0)\n", + " (15): EncodecConv1d(\n", + " (conv): Conv1d(64, 1, kernel_size=(7,), stride=(1,))\n", + " )\n", + " )\n", + " )\n", + " (quantizer): EncodecResidualVectorQuantizer(\n", + " (layers): ModuleList(\n", + " (0-3): 4 x EncodecVectorQuantization(\n", + " (codebook): EncodecEuclideanCodebook()\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (decoder): MusicgenForCausalLM(\n", + " (model): MusicgenModel(\n", + " (decoder): MusicgenDecoder(\n", + " (embed_tokens): ModuleList(\n", + " (0-3): 4 x Embedding(2049, 1024)\n", + " )\n", + " (embed_positions): MusicgenSinusoidalPositionalEmbedding()\n", + " (layers): ModuleList(\n", + " (0-23): 24 x MusicgenDecoderLayer(\n", + " (self_attn): MusicgenAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=False)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=False)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=False)\n", + " )\n", + " (activation_fn): GELUActivation()\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): MusicgenAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=False)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=False)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=False)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=False)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (lm_heads): ModuleList(\n", + " (0-3): 4 x Linear(in_features=1024, out_features=2048, bias=False)\n", + " )\n", + " )\n", + " (enc_to_dec_proj): Linear(in_features=768, out_features=1024, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "b2m.to('cuda')" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "7aad33a3-351d-472b-86f7-55ca5ac90a26", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "fmri = np.load('/fsx/proj-fmri/ckadirt/b2m/data/sub-001_Resp_Training.npy')\n", + "exafmri = fmri[:, 9]" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "d4d9c347-bed2-45d0-ba84-e033f18ba4a7", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "texa = torch.from_numpy(exafmri).unsqueeze(0).to('cuda')" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "1a5d2300-f587-43d2-aeee-4575777710ca", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 60784])\n" + ] + } + ], + "source": [ + "print(texa.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "97d2a4a0-9dac-4e8d-8cbf-9c5da2bd589e", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 128, 1024])\n" + ] + } + ], + "source": [ + " def generate(self, x, max_length = 256):\n", + " with torch.no_grad():\n", + " # x is [batch_size, 60784]\n", + " # first we pass the voxels through the pseudo text encoder\n", + " pseudo_encoded_fmri = self.pseudo_text_encoder(x)\n", + " # x is [batch_size, 128, 768]\n", + " # now we pass the output through the musicgen projector to get [batch_size, 128, 1024]\n", + " projected_pseudo_encoded_fmri = self.musicgen_decoder.enc_to_dec_proj(pseudo_encoded_fmri)\n", + " print(projected_pseudo_encoded_fmri.shape)\n", + " decoder_input_ids = (\n", + " torch.ones((x.shape[0] * self.num_codebooks, 1), dtype=torch.long)\n", + " * self.pad_token_id\n", + " ).to(x.device)\n", + " for i in range(max_length):\n", + " # now we pass the projected pseudo encoded fmri through the musicgen decoder\n", + " #print(i)\n", + " logits = self.musicgen_decoder.decoder(\n", + " input_ids = decoder_input_ids,\n", + " encoder_hidden_states = projected_pseudo_encoded_fmri,\n", + " ).logits\n", + " # get the next token\n", + " #print(logits.shape)\n", + " next_token = logits[:,-1,:].argmax(dim=-1)\n", + " #print(next_token.shape)\n", + " # add the next token to the decoder input ids\n", + " decoder_input_ids = torch.cat([decoder_input_ids, next_token.unsqueeze(-1)], dim=-1)\n", + " #print(decoder_input_ids.shape)\n", + " return decoder_input_ids\n", + " \n", + " sampled = generate(b2m, texa)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "fc41e095-263f-4724-91f1-d54e33f3a9b5", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "torch.save(sampled, './samplet.pt')" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "5b8eee6f-06e5-4608-974d-bc20ae297076", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n",
+       " in <module>:1                                                                                    \n",
+       "                                                                                                  \n",
+       " 1 sampled.to('cpu')                                                                            \n",
+       "   2                                                                                              \n",
+       "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+       "RuntimeError: CUDA error: device-side assert triggered\n",
+       "CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be \n",
+       "incorrect.\n",
+       "For debugging consider passing CUDA_LAUNCH_BLOCKING=1.\n",
+       "Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n",
+       "\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", + "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m1\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1 sampled.to(\u001b[33m'\u001b[0m\u001b[33mcpu\u001b[0m\u001b[33m'\u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", + "\u001b[1;91mRuntimeError: \u001b[0mCUDA error: device-side assert triggered\n", + "CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be \n", + "incorrect.\n", + "For debugging consider passing \u001b[33mCUDA_LAUNCH_BLOCKING\u001b[0m=\u001b[1;36m1\u001b[0m.\n", + "Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n", + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "sampled.to('cpu')" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "45aff4da-82bf-48d3-bc74-2456cd63d021", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n",
+       " in <module>:4                                                                                    \n",
+       "                                                                                                  \n",
+       "   1 b2m = b2m.to('cpu')                                                                          \n",
+       "   2 ss = torch.load('./samplet.pt').to('cpu')                                                    \n",
+       "   3 with torch.no_grad():                                                                        \n",
+       " 4 decoded = b2m.musicgen_decoder.audio_encoder.decode(audio_codes = ss.unsqueeze(0).un     \n",
+       "   5                                                                                              \n",
+       "                                                                                                  \n",
+       " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc \n",
+       " odec/modeling_encodec.py:742 in decode                                                           \n",
+       "                                                                                                  \n",
+       "   739 │   │   if chunk_length is None:                                                           \n",
+       "   740 │   │   │   if len(audio_codes) != 1:                                                      \n",
+       "   741 │   │   │   │   raise ValueError(f\"Expected one frame, got {len(audio_codes)}\")            \n",
+       " 742 │   │   │   audio_values = self._decode_frame(audio_codes[0], audio_scales[0])             \n",
+       "   743 │   │   else:                                                                              \n",
+       "   744 │   │   │   decoded_frames = []                                                            \n",
+       "   745                                                                                            \n",
+       "                                                                                                  \n",
+       " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc \n",
+       " odec/modeling_encodec.py:706 in _decode_frame                                                    \n",
+       "                                                                                                  \n",
+       "   703                                                                                        \n",
+       "   704 def _decode_frame(self, codes: torch.Tensor, scale: Optional[torch.Tensor] = None) -   \n",
+       "   705 │   │   codes = codes.transpose(0, 1)                                                      \n",
+       " 706 │   │   embeddings = self.quantizer.decode(codes)                                          \n",
+       "   707 │   │   outputs = self.decoder(embeddings)                                                 \n",
+       "   708 │   │   if scale is not None:                                                              \n",
+       "   709 │   │   │   outputs = outputs * scale.view(-1, 1, 1)                                       \n",
+       "                                                                                                  \n",
+       " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc \n",
+       " odec/modeling_encodec.py:435 in decode                                                           \n",
+       "                                                                                                  \n",
+       "   432 │   │   quantized_out = torch.tensor(0.0, device=codes.device)                             \n",
+       "   433 │   │   for i, indices in enumerate(codes):                                                \n",
+       "   434 │   │   │   layer = self.layers[i]                                                         \n",
+       " 435 │   │   │   quantized = layer.decode(indices)                                              \n",
+       "   436 │   │   │   quantized_out = quantized_out + quantized                                      \n",
+       "   437 │   │   return quantized_out                                                               \n",
+       "   438                                                                                            \n",
+       "                                                                                                  \n",
+       " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc \n",
+       " odec/modeling_encodec.py:391 in decode                                                           \n",
+       "                                                                                                  \n",
+       "   388 │   │   return embed_in                                                                    \n",
+       "   389                                                                                        \n",
+       "   390 def decode(self, embed_ind):                                                           \n",
+       " 391 │   │   quantize = self.codebook.decode(embed_ind)                                         \n",
+       "   392 │   │   quantize = quantize.permute(0, 2, 1)                                               \n",
+       "   393 │   │   return quantize                                                                    \n",
+       "   394                                                                                            \n",
+       "                                                                                                  \n",
+       " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc \n",
+       " odec/modeling_encodec.py:372 in decode                                                           \n",
+       "                                                                                                  \n",
+       "   369 │   │   return embed_ind                                                                   \n",
+       "   370                                                                                        \n",
+       "   371 def decode(self, embed_ind):                                                           \n",
+       " 372 │   │   quantize = nn.functional.embedding(embed_ind, self.embed)                          \n",
+       "   373 │   │   return quantize                                                                    \n",
+       "   374                                                                                            \n",
+       "   375                                                                                            \n",
+       "                                                                                                  \n",
+       " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/functional.py: \n",
+       " 2210 in embedding                                                                                \n",
+       "                                                                                                  \n",
+       "   2207 │   │   #   torch.embedding_renorm_                                                       \n",
+       "   2208 │   │   # remove once script supports set_grad_enabled                                    \n",
+       "   2209 │   │   _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)                    \n",
+       " 2210 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)        \n",
+       "   2211                                                                                           \n",
+       "   2212                                                                                           \n",
+       "   2213 def embedding_bag(                                                                        \n",
+       "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+       "IndexError: index out of range in self\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", + "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m4\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1 \u001b[0mb2m = b2m.to(\u001b[33m'\u001b[0m\u001b[33mcpu\u001b[0m\u001b[33m'\u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2 \u001b[0mss = torch.load(\u001b[33m'\u001b[0m\u001b[33m./samplet.pt\u001b[0m\u001b[33m'\u001b[0m).to(\u001b[33m'\u001b[0m\u001b[33mcpu\u001b[0m\u001b[33m'\u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m3 \u001b[0m\u001b[94mwith\u001b[0m torch.no_grad(): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m4 \u001b[2m│ \u001b[0mdecoded = b2m.musicgen_decoder.audio_encoder.decode(audio_codes = ss.unsqueeze(\u001b[94m0\u001b[0m).un \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m5 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33modec/\u001b[0m\u001b[1;33mmodeling_encodec.py\u001b[0m:\u001b[94m742\u001b[0m in \u001b[92mdecode\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m739 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m chunk_length \u001b[95mis\u001b[0m \u001b[94mNone\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m740 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mlen\u001b[0m(audio_codes) != \u001b[94m1\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m741 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[94mraise\u001b[0m \u001b[96mValueError\u001b[0m(\u001b[33mf\u001b[0m\u001b[33m\"\u001b[0m\u001b[33mExpected one frame, got \u001b[0m\u001b[33m{\u001b[0m\u001b[96mlen\u001b[0m(audio_codes)\u001b[33m}\u001b[0m\u001b[33m\"\u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m742 \u001b[2m│ │ │ \u001b[0maudio_values = \u001b[96mself\u001b[0m._decode_frame(audio_codes[\u001b[94m0\u001b[0m], audio_scales[\u001b[94m0\u001b[0m]) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m743 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94melse\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m744 \u001b[0m\u001b[2m│ │ │ \u001b[0mdecoded_frames = [] \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m745 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33modec/\u001b[0m\u001b[1;33mmodeling_encodec.py\u001b[0m:\u001b[94m706\u001b[0m in \u001b[92m_decode_frame\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m703 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m704 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92m_decode_frame\u001b[0m(\u001b[96mself\u001b[0m, codes: torch.Tensor, scale: Optional[torch.Tensor] = \u001b[94mNone\u001b[0m) - \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m705 \u001b[0m\u001b[2m│ │ \u001b[0mcodes = codes.transpose(\u001b[94m0\u001b[0m, \u001b[94m1\u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m706 \u001b[2m│ │ \u001b[0membeddings = \u001b[96mself\u001b[0m.quantizer.decode(codes) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m707 \u001b[0m\u001b[2m│ │ \u001b[0moutputs = \u001b[96mself\u001b[0m.decoder(embeddings) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m708 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m scale \u001b[95mis\u001b[0m \u001b[95mnot\u001b[0m \u001b[94mNone\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m709 \u001b[0m\u001b[2m│ │ │ \u001b[0moutputs = outputs * scale.view(-\u001b[94m1\u001b[0m, \u001b[94m1\u001b[0m, \u001b[94m1\u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33modec/\u001b[0m\u001b[1;33mmodeling_encodec.py\u001b[0m:\u001b[94m435\u001b[0m in \u001b[92mdecode\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m432 \u001b[0m\u001b[2m│ │ \u001b[0mquantized_out = torch.tensor(\u001b[94m0.0\u001b[0m, device=codes.device) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m433 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mfor\u001b[0m i, indices \u001b[95min\u001b[0m \u001b[96menumerate\u001b[0m(codes): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m434 \u001b[0m\u001b[2m│ │ │ \u001b[0mlayer = \u001b[96mself\u001b[0m.layers[i] \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m435 \u001b[2m│ │ │ \u001b[0mquantized = layer.decode(indices) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m436 \u001b[0m\u001b[2m│ │ │ \u001b[0mquantized_out = quantized_out + quantized \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m437 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m quantized_out \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m438 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33modec/\u001b[0m\u001b[1;33mmodeling_encodec.py\u001b[0m:\u001b[94m391\u001b[0m in \u001b[92mdecode\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m388 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m embed_in \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m389 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m390 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mdecode\u001b[0m(\u001b[96mself\u001b[0m, embed_ind): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m391 \u001b[2m│ │ \u001b[0mquantize = \u001b[96mself\u001b[0m.codebook.decode(embed_ind) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m392 \u001b[0m\u001b[2m│ │ \u001b[0mquantize = quantize.permute(\u001b[94m0\u001b[0m, \u001b[94m2\u001b[0m, \u001b[94m1\u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m393 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m quantize \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m394 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33modec/\u001b[0m\u001b[1;33mmodeling_encodec.py\u001b[0m:\u001b[94m372\u001b[0m in \u001b[92mdecode\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m369 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m embed_ind \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m370 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m371 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mdecode\u001b[0m(\u001b[96mself\u001b[0m, embed_ind): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m372 \u001b[2m│ │ \u001b[0mquantize = nn.functional.embedding(embed_ind, \u001b[96mself\u001b[0m.embed) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m373 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m quantize \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m374 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m375 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/\u001b[0m\u001b[1;33mfunctional.py\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[94m2210\u001b[0m in \u001b[92membedding\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2207 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# torch.embedding_renorm_\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2208 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# remove once script supports set_grad_enabled\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2209 \u001b[0m\u001b[2m│ │ \u001b[0m_no_grad_embedding_renorm_(weight, \u001b[96minput\u001b[0m, max_norm, norm_type) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m2210 \u001b[2m│ \u001b[0m\u001b[94mreturn\u001b[0m torch.embedding(weight, \u001b[96minput\u001b[0m, padding_idx, scale_grad_by_freq, sparse) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2211 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2212 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2213 \u001b[0m\u001b[94mdef\u001b[0m \u001b[92membedding_bag\u001b[0m( \u001b[31m│\u001b[0m\n", + "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", + "\u001b[1;91mIndexError: \u001b[0mindex out of range in self\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "b2m = b2m.to('cpu')\n", + "ss = torch.load('./samplet.pt').to('cpu')\n", + "with torch.no_grad():\n", + " decoded = b2m.musicgen_decoder.audio_encoder.decode(audio_codes = ss.unsqueeze(0).unsqueeze(0), audio_scales = [None])" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "7cca542f-a48e-43e0-bddb-aadf7a8c1068", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([4, 257])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ss.shape" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}