diff --git "a/src/MLPencoder.ipynb" "b/src/MLPencoder.ipynb"
new file mode 100644--- /dev/null
+++ "b/src/MLPencoder.ipynb"
@@ -0,0 +1,2138 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "d004931d-1ff4-4d85-9d23-1bb1eab2111e",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Fri Sep 8 05:45:08 2023 \n",
+ "+-----------------------------------------------------------------------------+\n",
+ "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n",
+ "|-------------------------------+----------------------+----------------------+\n",
+ "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
+ "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
+ "| | | MIG M. |\n",
+ "|===============================+======================+======================|\n",
+ "| 0 NVIDIA A100-SXM... On | 00000000:10:1C.0 Off | 0 |\n",
+ "| N/A 30C P0 51W / 400W | 3MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 1 NVIDIA A100-SXM... On | 00000000:10:1D.0 Off | 0 |\n",
+ "| N/A 38C P0 157W / 400W | 40431MiB / 40960MiB | 100% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 2 NVIDIA A100-SXM... On | 00000000:20:1C.0 Off | 0 |\n",
+ "| N/A 43C P0 176W / 400W | 28503MiB / 40960MiB | 100% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 3 NVIDIA A100-SXM... On | 00000000:20:1D.0 Off | 0 |\n",
+ "| N/A 38C P0 166W / 400W | 38049MiB / 40960MiB | 100% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 4 NVIDIA A100-SXM... On | 00000000:90:1C.0 Off | 0 |\n",
+ "| N/A 40C P0 178W / 400W | 37159MiB / 40960MiB | 100% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 5 NVIDIA A100-SXM... On | 00000000:90:1D.0 Off | 0 |\n",
+ "| N/A 39C P0 158W / 400W | 37641MiB / 40960MiB | 100% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 6 NVIDIA A100-SXM... On | 00000000:A0:1C.0 Off | 0 |\n",
+ "| N/A 39C P0 167W / 400W | 38395MiB / 40960MiB | 98% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 7 NVIDIA A100-SXM... On | 00000000:A0:1D.0 Off | 0 |\n",
+ "| N/A 48C P0 381W / 400W | 35381MiB / 40960MiB | 100% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ " \n",
+ "+-----------------------------------------------------------------------------+\n",
+ "| Processes: |\n",
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
+ "| ID ID Usage |\n",
+ "|=============================================================================|\n",
+ "| 1 N/A N/A 919504 C ...ari/llama_env/bin/python3 40428MiB |\n",
+ "| 2 N/A N/A 919505 C ...ari/llama_env/bin/python3 28500MiB |\n",
+ "| 3 N/A N/A 919506 C ...ari/llama_env/bin/python3 38046MiB |\n",
+ "| 4 N/A N/A 919507 C ...ari/llama_env/bin/python3 37156MiB |\n",
+ "| 5 N/A N/A 919508 C ...ari/llama_env/bin/python3 37638MiB |\n",
+ "| 6 N/A N/A 919509 C ...ari/llama_env/bin/python3 38392MiB |\n",
+ "| 7 N/A N/A 919510 C ...ari/llama_env/bin/python3 35378MiB |\n",
+ "+-----------------------------------------------------------------------------+\n"
+ ]
+ }
+ ],
+ "source": [
+ "!nvidia-smi"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "ec021849-d426-4450-a140-f8647a764d2e",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[2023-09-08 05:54:30,058] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
+ ]
+ }
+ ],
+ "source": [
+ "from functools import partial\n",
+ "from einops import rearrange\n",
+ "from transformers import MusicgenForConditionalGeneration\n",
+ "import pytorch_lightning as pl\n",
+ "import os, torch, torch.nn as nn, torch.utils.data as data, torchvision as tv\n",
+ "import lightning as L\n",
+ "import numpy as np, pandas as pd, matplotlib.pyplot as plt\n",
+ "from pytorch_lightning.loggers import WandbLogger\n",
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "5f1916fe-163f-4cac-ac0c-c16a74bcae89",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# create the datasets and dataloaders\n",
+ "train_voxels_path = '/fsx/proj-fmri/ckadirt/b2m/data/sub-001_Resp_Training.npy' # path to training voxels 65000 * 4800 \n",
+ "test_voxels_path = '/fsx/proj-fmri/ckadirt/b2m/data/sub-001_Resp_Test_Mean.npy' # path to test voxels 65000 * 600\n",
+ "\n",
+ "train_embeddings_path = '/fsx/proj-fmri/ckadirt/b2m/data/encodec32khz_training_embeds_sorted.npy' # path to training embeddings 480 * 2 * 1125\n",
+ "test_embeddings_path = '/fsx/proj-fmri/ckadirt/b2m/data/encodec32khz_testing_embeds_sorted.npy' # path to test embeddings 600 * 2 * 1125\n",
+ "\n",
+ "class VoxelsDataset(data.Dataset):\n",
+ " def __init__(self, voxels_path, embeddings_path):\n",
+ " # transpose the two dimensions of the voxels data to match the embeddings data\n",
+ " self.voxels = torch.from_numpy(np.load(voxels_path)).float().transpose(0, 1)\n",
+ " self.embeddings = torch.from_numpy(np.load(embeddings_path))\n",
+ " # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus\n",
+ " self.len = len(self.voxels) // 10\n",
+ " print(\"The len is \", self.len )\n",
+ "\n",
+ " def __getitem__(self, index):\n",
+ " # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus\n",
+ " voxels = self.voxels[index*10:(index+1)*10]\n",
+ " embeddings = self.embeddings[index]\n",
+ " return voxels, embeddings\n",
+ "\n",
+ " def __len__(self):\n",
+ " return self.len\n",
+ " \n",
+ "class VoxelsEmbeddinsEncodecDataModule(pl.LightningDataModule):\n",
+ " def __init__(self, train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=8):\n",
+ " super().__init__()\n",
+ " self.train_voxels_path = train_voxels_path\n",
+ " self.train_embeddings_path = train_embeddings_path\n",
+ " self.test_voxels_path = test_voxels_path\n",
+ " self.test_embeddings_path = test_embeddings_path\n",
+ " self.batch_size = batch_size\n",
+ "\n",
+ " def setup(self, stage=None):\n",
+ " self.train_dataset = VoxelsDataset(self.train_voxels_path, self.train_embeddings_path)\n",
+ " self.test_dataset = VoxelsDataset(self.test_voxels_path, self.test_embeddings_path)\n",
+ "\n",
+ " def train_dataloader(self):\n",
+ " return data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)\n",
+ "\n",
+ " def val_dataloader(self):\n",
+ " return data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "39fa231d-7e28-4813-8f5d-edf8fbce1774",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'bn = BrainNetwork(in_dim=4096)\\nrr = RidgeRegression(60784, 4096)\\n\\ntest_input = torch.randn(3, 60784)\\nout1 = rr(test_input)\\nout2 = bn(out1)\\nprint(out2.shape, out1.shape)\\n '"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "class RidgeRegression(torch.nn.Module):\n",
+ " # make sure to add weight_decay when initializing optimizer\n",
+ " def __init__(self, input_size, out_features): \n",
+ " super(RidgeRegression, self).__init__()\n",
+ " self.linear = torch.nn.Linear(input_size, out_features)\n",
+ " def forward(self, x):\n",
+ " return self.linear(x)\n",
+ " \n",
+ "class BrainNetwork(nn.Module):\n",
+ " def __init__(self, out_dim=768*128, in_dim=60784, clip_size=128, h=4096, n_blocks=4, norm_type='ln', act_first=False, use_projector=True, drop1=.5, drop2=.15):\n",
+ " super().__init__()\n",
+ " norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)\n",
+ " act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU\n",
+ " act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)\n",
+ " self.mlp = nn.ModuleList([\n",
+ " nn.Sequential(\n",
+ " nn.Linear(h, h),\n",
+ " *[item() for item in act_and_norm],\n",
+ " nn.Dropout(drop2)\n",
+ " ) for _ in range(n_blocks)\n",
+ " ])\n",
+ " self.lin1 = nn.Linear(h, out_dim, bias=True)\n",
+ " self.n_blocks = n_blocks\n",
+ " self.clip_size = clip_size\n",
+ " self.use_projector = use_projector\n",
+ " if use_projector:\n",
+ " self.projector = nn.Sequential(\n",
+ " nn.LayerNorm(clip_size),\n",
+ " nn.GELU(),\n",
+ " nn.Linear(clip_size, 2048),\n",
+ " nn.LayerNorm(2048),\n",
+ " nn.GELU(),\n",
+ " nn.Linear(2048, 2048),\n",
+ " nn.LayerNorm(2048),\n",
+ " nn.GELU(),\n",
+ " nn.Linear(2048, clip_size)\n",
+ " )\n",
+ " \n",
+ " def forward(self, x):\n",
+ " residual = x\n",
+ " for res_block in range(self.n_blocks):\n",
+ " x = self.mlp[res_block](x)\n",
+ " x += residual\n",
+ " residual = x\n",
+ " x = x.reshape(len(x), -1)\n",
+ " x = self.lin1(x)\n",
+ " if self.use_projector:\n",
+ " x = self.projector(x.reshape(len(x), -1, self.clip_size))\n",
+ " x = rearrange(x, 'b e t -> b t e')\n",
+ " return x\n",
+ " return x\n",
+ " \n",
+ "\"\"\"bn = BrainNetwork(in_dim=4096)\n",
+ "rr = RidgeRegression(60784, 4096)\n",
+ "\n",
+ "test_input = torch.randn(3, 60784)\n",
+ "out1 = rr(test_input)\n",
+ "out2 = bn(out1)\n",
+ "print(out2.shape, out1.shape)\n",
+ " \"\"\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "495e8e78-c44f-4410-b3c2-61faa6beec77",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "\"b2m_test = B2M().to('cuda')\\ntest_b2m_input = torch.randn(4, 60784).to('cuda')\\naudio_codes = np.load('/fsx/proj-fmri/ckadirt/b2m/data/encodec32khz_training_embeds_sorted.npy')\\naudio_codes = torch.from_numpy(audio_codes)\\naudio_codes = audio_codes[:4, :, :]\\naudio_codes = rearrange(audio_codes, 'b c t -> (b c) t').to('cuda').long()\\ntest_b2m_output = b2m_test(test_b2m_input, audio_codes)\\nprint(test_b2m_output.shape)\""
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "class B2M(pl.LightningModule):\n",
+ " def __init__(self, input_size = 60784, mapping_size = 4096, num_codebooks = 4):\n",
+ " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n",
+ " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n",
+ " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n",
+ " super().__init__()\n",
+ " self.brain_network = BrainNetwork(h = mapping_size)\n",
+ " self.ridge_regression = RidgeRegression(input_size=input_size, out_features=mapping_size)\n",
+ " self.loss = nn.CrossEntropyLoss()\n",
+ " self.pseudo_text_encoder = nn.Sequential(\n",
+ " self.ridge_regression,\n",
+ " self.brain_network\n",
+ " )\n",
+ " self.musicgen_decoder = MusicgenForConditionalGeneration.from_pretrained(\"facebook/musicgen-small\")\n",
+ " self.pad_token_id = self.musicgen_decoder.generation_config.pad_token_id\n",
+ " self.num_codebooks = num_codebooks\n",
+ " self.test_outptus = []\n",
+ " self.train_outptus = []\n",
+ "\n",
+ " for param in self.musicgen_decoder.parameters():\n",
+ " param.requires_grad = False\n",
+ "\n",
+ " def forward(self, x, decoder_input_ids=None):\n",
+ " # x is [batch_size, 60784]\n",
+ " # decoder input ids is [batch_size * num_codebooks, 750] 750 is the length of the audiocodes for 15 seconds of audio\n",
+ " # first we pass the voxels through the pseudo text encoder\n",
+ " pseudo_encoded_fmri = self.pseudo_text_encoder(x)\n",
+ " # x is [batch_size, 128, 768]\n",
+ " # now we pass the output through the musicgen projector to get [batch_size, 128, 1024]\n",
+ " projected_pseudo_encoded_fmri = self.musicgen_decoder.enc_to_dec_proj(pseudo_encoded_fmri)\n",
+ "\n",
+ " if decoder_input_ids is None:\n",
+ " # if no decoder input ids are given, we create a tensor of the size [batch_size * num_codebooks, 1] filled with the pad token id\n",
+ " decoder_input_ids = (\n",
+ " torch.ones((x.shape[0] * self.num_codebooks, 1), dtype=torch.long)\n",
+ " * self.pad_token_id\n",
+ " )\n",
+ " \n",
+ " # now we pass the projected pseudo encoded fmri through the musicgen decoder\n",
+ " logits = self.musicgen_decoder.decoder(\n",
+ " input_ids = decoder_input_ids,\n",
+ " encoder_hidden_states = projected_pseudo_encoded_fmri,\n",
+ " ).logits\n",
+ "\n",
+ " return logits\n",
+ "\n",
+ " \n",
+ " def training_step(self, batch, batch_idx):\n",
+ " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 4, 750]\n",
+ " # take the last scan from the voxels\n",
+ " voxels = voxels[:, -1, :]\n",
+ " # convert the embeddings to long and combine the batch and codebook dimensions\n",
+ " embeddings = rearrange(embeddings, 'b c t -> (b c) t').long()\n",
+ "\n",
+ "\n",
+ " #take just the first 200 embeddings\n",
+ " #embeddings = embeddings[:, :200]\n",
+ " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n",
+ " #voxels = voxels[:, 0:2, :]\n",
+ " #voxels = voxels.mean(dim=1)\n",
+ " #voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
+ "\n",
+ "\n",
+ " # use the decoder input ids to get the logits\n",
+ " decoder_input_ids = embeddings[:, :-1]\n",
+ " logits = self(voxels, decoder_input_ids)\n",
+ "\n",
+ " # get the loss\n",
+ " loss = self.loss(rearrange(logits, '(b c) t d -> (b c t) d', c=self.num_codebooks), rearrange(embeddings[:, 1:], '(b c) t -> (b c t)', c=self.num_codebooks))\n",
+ "\n",
+ "\n",
+ " acuracy = self.tokens_accuracy(logits, embeddings[:,1:])\n",
+ " self.log('train_loss', loss, sync_dist=True)\n",
+ " self.log('train_accuracy', acuracy, sync_dist=True)\n",
+ " discrete_outputs = logits.argmax(dim=2)\n",
+ " self.train_outptus.append(discrete_outputs)\n",
+ " return loss\n",
+ " \n",
+ " def tokens_accuracy(self, outputs, embeddings):\n",
+ " # outputs is [batch_size, 750, 2048]\n",
+ " # embeddings is [batch_size, 750]\n",
+ " # we need to get the index of the maximum value of each token\n",
+ " outputs = outputs.argmax(dim=2)\n",
+ " # now we need to compare the outputs with the embeddings\n",
+ " return (outputs == embeddings).float().mean()\n",
+ " \n",
+ " def on_train_epoch_end(self):\n",
+ " self.train_outptus = torch.cat(self.train_outptus)\n",
+ " # save the outputs with the current epoch name\n",
+ " torch.save(self.train_outptus, 'outputs_train'+str(self.current_epoch)+'.pt')\n",
+ " self.train_outptus = []\n",
+ " \n",
+ " def on_validation_epoch_end(self):\n",
+ " self.test_outptus = torch.cat(self.test_outptus)\n",
+ " # save the outputs with the current epoch name\n",
+ " torch.save(self.test_outptus, 'outputs_validation'+str(self.current_epoch)+'.pt')\n",
+ " self.test_outptus = []\n",
+ "\n",
+ " \n",
+ " def validation_step(self, batch, batch_idx):\n",
+ " voxels, embeddings = batch\n",
+ " # take the last scan from the voxels\n",
+ " voxels = voxels[:, -1, :]\n",
+ " # convert the embeddings to long and combine the batch and codebook dimensions\n",
+ " embeddings = rearrange(embeddings, 'b c t -> (b c) t').long()\n",
+ "\n",
+ " # take just the first 200 embeddings\n",
+ " #embeddings = embeddings[:, :200]\n",
+ " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n",
+ " #voxels = voxels[:, 0:2, :]\n",
+ " #voxels = voxels.mean(dim=1)\n",
+ " #voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
+ "\n",
+ " # use the decoder input ids to get the logits\n",
+ " decoder_input_ids = embeddings[:, :-1]\n",
+ " logits = self(voxels, decoder_input_ids)\n",
+ "\n",
+ " # get the loss\n",
+ " loss = self.loss(rearrange(logits, '(b c) t d -> (b c t) d', c=self.num_codebooks), rearrange(embeddings[:, 1:], '(b c) t -> (b c t)', c=self.num_codebooks))\n",
+ "\n",
+ " acuracy = self.tokens_accuracy(logits, embeddings[:,1:])\n",
+ " self.log('val_loss', loss, sync_dist=True)\n",
+ " self.log('val_accuracy', acuracy, sync_dist=True)\n",
+ " discrete_outputs = logits.argmax(dim=2)\n",
+ " self.test_outptus.append(discrete_outputs)\n",
+ " return loss\n",
+ " \n",
+ " \n",
+ "\n",
+ " def configure_optimizers(self):\n",
+ " # we just want to train the pseudo text encoder, but we need to zero the gradients of the musicgen decoder\n",
+ " optimizer = torch.optim.AdamW(\n",
+ " [\n",
+ " {'params': self.pseudo_text_encoder.parameters(), 'lr': 3e-6, 'weight_decay': 1e-4},\n",
+ " {'params': self.musicgen_decoder.parameters(), 'lr': 0},\n",
+ " ],\n",
+ " )\n",
+ " return optimizer\n",
+ "\n",
+ "\n",
+ "# create the model\n",
+ "\"\"\"b2m_test = B2M().to('cuda')\n",
+ "test_b2m_input = torch.randn(4, 60784).to('cuda')\n",
+ "audio_codes = np.load('/fsx/proj-fmri/ckadirt/b2m/data/encodec32khz_training_embeds_sorted.npy')\n",
+ "audio_codes = torch.from_numpy(audio_codes)\n",
+ "audio_codes = audio_codes[:4, :, :]\n",
+ "audio_codes = rearrange(audio_codes, 'b c t -> (b c) t').to('cuda').long()\n",
+ "test_b2m_output = b2m_test(test_b2m_input, audio_codes)\n",
+ "print(test_b2m_output.shape)\"\"\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0646c340-c3df-42c0-ad22-2b677fc85da6",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "b2m = B2M()\n",
+ "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4)\n",
+ "\n",
+ "wandb.finish()\n",
+ "\n",
+ "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n",
+ "\n",
+ "# define the trainer\n",
+ "trainer = pl.Trainer(devices=1, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n",
+ "\n",
+ "# train the model\n",
+ "trainer.fit(b2m, datamodule=data_module)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "4f458adf-1e89-4b8b-9514-08bcc9e8ef56",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Fri Sep 8 05:08:00 2023 \n",
+ "+-----------------------------------------------------------------------------+\n",
+ "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n",
+ "|-------------------------------+----------------------+----------------------+\n",
+ "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
+ "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
+ "| | | MIG M. |\n",
+ "|===============================+======================+======================|\n",
+ "| 0 NVIDIA A100-SXM... On | 00000000:10:1C.0 Off | 0 |\n",
+ "| N/A 31C P0 71W / 400W | 29323MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 1 NVIDIA A100-SXM... On | 00000000:10:1D.0 Off | 0 |\n",
+ "| N/A 40C P0 204W / 400W | 40217MiB / 40960MiB | 100% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 2 NVIDIA A100-SXM... On | 00000000:20:1C.0 Off | 0 |\n",
+ "| N/A 47C P0 243W / 400W | 40421MiB / 40960MiB | 99% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 3 NVIDIA A100-SXM... On | 00000000:20:1D.0 Off | 0 |\n",
+ "| N/A 42C P0 194W / 400W | 37547MiB / 40960MiB | 100% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 4 NVIDIA A100-SXM... On | 00000000:90:1C.0 Off | 0 |\n",
+ "| N/A 45C P0 282W / 400W | 29223MiB / 40960MiB | 100% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 5 NVIDIA A100-SXM... On | 00000000:90:1D.0 Off | 0 |\n",
+ "| N/A 39C P0 179W / 400W | 24387MiB / 40960MiB | 100% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 6 NVIDIA A100-SXM... On | 00000000:A0:1C.0 Off | 0 |\n",
+ "| N/A 43C P0 232W / 400W | 29819MiB / 40960MiB | 88% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 7 NVIDIA A100-SXM... On | 00000000:A0:1D.0 Off | 0 |\n",
+ "| N/A 38C P0 166W / 400W | 28583MiB / 40960MiB | 100% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ " \n",
+ "+-----------------------------------------------------------------------------+\n",
+ "| Processes: |\n",
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
+ "| ID ID Usage |\n",
+ "|=============================================================================|\n",
+ "| 0 N/A N/A 895711 C ...3/envs/mindeye/bin/python 29320MiB |\n",
+ "| 1 N/A N/A 899054 C ...ari/llama_env/bin/python3 40214MiB |\n",
+ "| 2 N/A N/A 899055 C ...ari/llama_env/bin/python3 40418MiB |\n",
+ "| 3 N/A N/A 899056 C ...ari/llama_env/bin/python3 37544MiB |\n",
+ "| 4 N/A N/A 899057 C ...ari/llama_env/bin/python3 29220MiB |\n",
+ "| 5 N/A N/A 899058 C ...ari/llama_env/bin/python3 24384MiB |\n",
+ "| 6 N/A N/A 899059 C ...ari/llama_env/bin/python3 29816MiB |\n",
+ "| 7 N/A N/A 899060 C ...ari/llama_env/bin/python3 28580MiB |\n",
+ "+-----------------------------------------------------------------------------+\n"
+ ]
+ }
+ ],
+ "source": [
+ "!nvidia-smi"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "id": "e4c9f9f5-a984-4fcf-8938-861f5bfb38d6",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([240, 749])\n"
+ ]
+ }
+ ],
+ "source": [
+ "# read this file reb2m/src/outputs_train39.pt\n",
+ "train_outputs = torch.load('/fsx/proj-fmri/ckadirt/b2m/src/outputs_validation39.pt')\n",
+ "print(train_outputs.shape)\n",
+ "example1 = train_outputs[0:4].unsqueeze(0).unsqueeze(0)\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 55,
+ "id": "80e45d62-27d4-4ebe-a86c-56a9cc4ac0de",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "data = np.load('/fsx/proj-fmri/ckadirt/b2m/data/encodec32khz_training_embeds_sorted.npy')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 58,
+ "id": "ce72b901-e670-48ca-8c92-2221b46f6145",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([ 753., 609., 248., 248., 248., 1271., 248., 1350., 1350.,\n",
+ " 1350., 1350., 367., 297., 1251., 702., 1103., 1106., 1600.,\n",
+ " 1106., 457., 1638., 903., 1103., 1673., 1103., 902., 149.,\n",
+ " 1544., 1458., 1544., 773., 1470., 1470., 890., 1457., 1038.,\n",
+ " 2008., 1506., 457., 1126., 1047., 1103., 1933., 3., 560.,\n",
+ " 714., 271., 1442., 1710., 949., 1508., 957., 685., 399.,\n",
+ " 1103., 1667., 1555., 1529., 494., 1436., 1883., 29., 225.,\n",
+ " 846., 773., 569., 677., 71., 888., 1693., 1401., 888.,\n",
+ " 1016., 792., 569., 1590., 71., 1600., 314., 272., 1756.,\n",
+ " 1917., 1264., 917., 1021., 178., 1205., 974., 457., 457.,\n",
+ " 1106., 569., 1562., 271., 1977., 367., 345., 893., 1842.,\n",
+ " 1401., 1152., 1152., 1152., 1152., 50., 1418., 1748., 1188.,\n",
+ " 2043., 1666., 1796., 1512., 457., 812., 1600., 1764., 879.,\n",
+ " 1194., 1457., 1842., 1883., 104., 666., 352., 612., 1710.,\n",
+ " 1458., 315., 1990., 741., 1047., 675., 514., 1051., 1132.,\n",
+ " 1115., 315., 1347., 1670., 1875., 1194., 71., 1786., 196.,\n",
+ " 2043., 1818., 1477., 996., 1083., 967., 128., 1629., 1562.,\n",
+ " 1875., 237., 712., 1279., 29., 675., 1207., 1303., 1622.,\n",
+ " 1622., 1622., 1622., 1557., 675., 1303., 261., 675., 1245.,\n",
+ " 1245., 675., 675., 714., 378., 1673., 1145., 1673., 1673.,\n",
+ " 2013., 1990., 974., 457., 1124., 1562., 3., 1721., 846.,\n",
+ " 378., 271., 271., 675., 1782., 876., 1918., 483., 1419.,\n",
+ " 1693., 1562., 50., 1198., 1786., 1492., 1670., 1292., 104.,\n",
+ " 1286., 620., 776., 828., 1629., 762., 71., 1095., 367.,\n",
+ " 2043., 1501., 1152., 320., 1271., 801., 1671., 1418., 1213.,\n",
+ " 71., 40., 1419., 1179., 178., 1106., 1016., 1652., 902.,\n",
+ " 1074., 366., 1484., 42., 1457., 1453., 1241., 849., 1888.,\n",
+ " 1775., 1888., 82., 1671., 836., 82., 82., 472., 247.,\n",
+ " 1883., 29., 398., 642., 1904., 1436., 2002., 71., 71.,\n",
+ " 71., 71., 938., 1122., 1106., 736., 367., 1531., 778.,\n",
+ " 1388., 1949., 207., 1418., 721., 1130., 989., 1303., 1402.,\n",
+ " 1047., 569., 569., 569., 272., 320., 548., 976., 314.,\n",
+ " 1095., 1559., 1378., 828., 1742., 1378., 620., 1271., 776.,\n",
+ " 801., 1213., 1358., 1883., 1157., 104., 1744., 648., 1271.,\n",
+ " 1373., 1956., 685., 290., 938., 965., 208., 823., 1287.,\n",
+ " 893., 1470., 775., 1158., 775., 1990., 986., 1152., 1358.,\n",
+ " 312., 1111., 1564., 50., 237., 1646., 937., 1052., 1917.,\n",
+ " 1742., 1702., 297., 259., 1734., 54., 1933., 71., 1875.,\n",
+ " 1875., 1702., 1875., 237., 237., 237., 164., 1602., 220.,\n",
+ " 457., 1798., 457., 1798., 1798., 1514., 548., 523., 949.,\n",
+ " 1599., 714., 1673., 893., 714., 1016., 1016., 1638., 1562.,\n",
+ " 714., 1016., 1103., 871., 569., 1047., 1047., 612., 1646.,\n",
+ " 1875., 747., 714., 718., 1562., 1673., 1555., 1763., 1127.,\n",
+ " 793., 1817., 657., 1106., 457., 1106., 458., 1599., 1106.,\n",
+ " 29., 1863., 1103., 1599., 714., 812., 1194., 1508., 1194.,\n",
+ " 1562., 986., 714., 1308., 1097., 1207., 1095., 1763., 1127.,\n",
+ " 642., 1419., 290., 42., 1246., 400., 1462., 1194., 773.,\n",
+ " 741., 832., 1368., 2042., 937., 71., 1541., 42., 2008.,\n",
+ " 920., 872., 1473., 890., 234., 234., 1781., 1742., 1742.,\n",
+ " 71., 508., 237., 1790., 1428., 347., 1907., 1642., 457.,\n",
+ " 1106., 987., 1106., 1907., 1562., 1378., 1638., 1106., 1106.,\n",
+ " 399., 1106., 1106., 569., 1194., 695., 1286., 237., 2043.,\n",
+ " 1891., 1933., 1194., 1445., 1670., 1933., 1426., 1198., 937.,\n",
+ " 1977., 677., 1907., 1907., 1492., 1514., 1798., 1562., 1823.,\n",
+ " 644., 1629., 1842., 4., 712., 949., 686., 607., 1343.,\n",
+ " 1907., 1907., 1106., 1973., 1051., 1974., 974., 1869., 913.,\n",
+ " 1499., 272., 1322., 789., 457., 1869., 354., 45., 1869.,\n",
+ " 846., 1103., 1393., 1798., 1424., 1766., 457., 1126., 1642.,\n",
+ " 1775., 1775., 1144., 84., 84., 1142., 1549., 1992., 1989.,\n",
+ " 1047., 500., 637., 1499., 1451., 1732., 297., 261., 1124.,\n",
+ " 29., 866., 457., 149., 1608., 1047., 1343., 1433., 1047.,\n",
+ " 1044., 1837., 1984., 1807., 869., 1598., 1188., 50., 569.,\n",
+ " 272., 1401., 1670., 367., 744., 1817., 178., 1904., 1562.,\n",
+ " 352., 71., 1670., 980., 1629., 736., 1130., 328., 302.,\n",
+ " 893., 1378., 1652., 1469., 1786., 1742., 328., 677., 1652.,\n",
+ " 1652., 1188., 2008., 1733., 1002., 1106., 1103., 464., 1052.,\n",
+ " 1782., 1508., 1106., 677., 1481., 1666., 1921., 548., 1106.,\n",
+ " 902., 1106., 893., 1106., 893., 1933., 893., 237., 1051.,\n",
+ " 626., 919., 1106., 1484., 1288., 1103., 902., 986., 162.,\n",
+ " 421., 1733., 281., 2031., 366., 494., 1594., 825., 730.,\n",
+ " 1702., 642., 1478., 1366., 869., 1106., 1106., 1047., 320.,\n",
+ " 1380., 569., 569., 569., 1380., 104., 1324., 1600., 1132.,\n",
+ " 1337., 1047., 612., 1555., 2001., 2027., 2027., 71., 800.,\n",
+ " 30., 2003., 1817., 1047., 1764., 714., 1457., 840., 1246.,\n",
+ " 1978., 297., 1798., 1127., 1106., 948., 1047., 457., 1158.,\n",
+ " 2001., 1308., 1316., 634., 1481., 1047., 762., 964., 1913.,\n",
+ " 1638., 1562., 1582., 1095., 936., 658., 1978., 1512., 1753.,\n",
+ " 1974., 1564., 2007., 890., 1562., 890., 1801., 1801., 1378.,\n",
+ " 71., 1481., 1863., 921., 1600., 121., 648., 1132., 1132.,\n",
+ " 1512., 1529., 263., 1529., 1564., 1484., 642., 642., 1629.,\n",
+ " 40., 1444., 1078., 1078., 1436., 118., 118., 1522., 1904.,\n",
+ " 237., 658., 658., 1711., 1095., 778., 121., 248., 999.,\n",
+ " 648., 648., 247., 648., 648., 8., 8., 1453., 609.,\n",
+ " 609., 83., 166.], dtype=float32)"
+ ]
+ },
+ "execution_count": 58,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "data[0][0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 54,
+ "id": "cf8382ae-ed92-43bb-a65c-7537161956d5",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([609, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n",
+ " 237, 237, 237, 237, 237, 237, 237], device='cuda:0')"
+ ]
+ },
+ "execution_count": 54,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "example1[0][0][0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "f5842cd4-17b7-47c6-8939-31634f3db3cc",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n",
+ "│ in <module>:1 │\n",
+ "│ │\n",
+ "│ ❱ 1 sampled.unsqueeze(0).unsqueeze(0).detach().clone() │\n",
+ "│ 2 │\n",
+ "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+ "RuntimeError: CUDA error: device-side assert triggered\n",
+ "CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be \n",
+ "incorrect.\n",
+ "For debugging consider passing CUDA_LAUNCH_BLOCKING=1.\n",
+ "Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n",
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n",
+ "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m1\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1 sampled.unsqueeze(\u001b[94m0\u001b[0m).unsqueeze(\u001b[94m0\u001b[0m).detach().clone() \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m2 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n",
+ "\u001b[1;91mRuntimeError: \u001b[0mCUDA error: device-side assert triggered\n",
+ "CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be \n",
+ "incorrect.\n",
+ "For debugging consider passing \u001b[33mCUDA_LAUNCH_BLOCKING\u001b[0m=\u001b[1;36m1\u001b[0m.\n",
+ "Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n",
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "sampled.unsqueeze(0).unsqueeze(0).detach().clone()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "f802e748-ff63-4b9e-8147-77d62bb87c4b",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [0,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [1,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [2,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [3,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [4,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [5,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [6,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [7,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [8,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [9,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [10,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [11,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [12,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [13,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [14,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [15,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [16,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [17,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [18,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [19,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [20,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [21,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [22,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [23,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [24,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [25,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [26,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [27,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [28,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [29,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [30,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [31,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [32,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [33,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [34,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [35,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [36,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [37,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [38,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [39,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [40,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [41,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [42,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [43,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [44,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [45,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [46,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [47,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [48,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [49,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [50,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [51,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [52,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [53,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [54,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [55,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [56,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [57,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [58,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [59,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [60,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [61,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [62,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [63,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [64,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [65,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [66,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [67,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [68,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [69,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [70,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [71,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [72,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [73,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [74,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [75,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [76,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [77,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [78,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [79,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [80,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [81,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [82,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [83,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [84,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [85,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [86,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [87,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [88,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [89,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [90,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [91,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [92,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [93,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [94,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [95,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [96,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [97,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [98,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [99,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [100,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [101,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [102,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [103,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [104,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [105,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [106,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [107,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [108,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [109,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [110,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [111,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [112,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [113,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [114,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [115,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [116,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [117,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [118,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [119,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [120,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [121,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [122,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [123,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [124,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [125,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [126,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n",
+ "../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [0,0,0], thread: [127,0,0] Assertion `srcIndex < srcSelectDimSize` failed.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n",
+ "│ in <module>:1 │\n",
+ "│ │\n",
+ "│ ❱ 1 decoded = b2m.musicgen_decoder.audio_encoder.decode(audio_codes = sampled.unsqueeze(0).u │\n",
+ "│ 2 │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc │\n",
+ "│ odec/modeling_encodec.py:742 in decode │\n",
+ "│ │\n",
+ "│ 739 │ │ if chunk_length is None: │\n",
+ "│ 740 │ │ │ if len(audio_codes) != 1: │\n",
+ "│ 741 │ │ │ │ raise ValueError(f\"Expected one frame, got {len(audio_codes)}\") │\n",
+ "│ ❱ 742 │ │ │ audio_values = self._decode_frame(audio_codes[0], audio_scales[0]) │\n",
+ "│ 743 │ │ else: │\n",
+ "│ 744 │ │ │ decoded_frames = [] │\n",
+ "│ 745 │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc │\n",
+ "│ odec/modeling_encodec.py:707 in _decode_frame │\n",
+ "│ │\n",
+ "│ 704 │ def _decode_frame(self, codes: torch.Tensor, scale: Optional[torch.Tensor] = None) - │\n",
+ "│ 705 │ │ codes = codes.transpose(0, 1) │\n",
+ "│ 706 │ │ embeddings = self.quantizer.decode(codes) │\n",
+ "│ ❱ 707 │ │ outputs = self.decoder(embeddings) │\n",
+ "│ 708 │ │ if scale is not None: │\n",
+ "│ 709 │ │ │ outputs = outputs * scale.view(-1, 1, 1) │\n",
+ "│ 710 │ │ return outputs │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/module │\n",
+ "│ .py:1501 in _call_impl │\n",
+ "│ │\n",
+ "│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │\n",
+ "│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │\n",
+ "│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │\n",
+ "│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │\n",
+ "│ 1502 │ │ # Do not call functions when jit is used │\n",
+ "│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │\n",
+ "│ 1504 │ │ backward_pre_hooks = [] │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc │\n",
+ "│ odec/modeling_encodec.py:336 in forward │\n",
+ "│ │\n",
+ "│ 333 │ │\n",
+ "│ 334 │ def forward(self, hidden_states): │\n",
+ "│ 335 │ │ for layer in self.layers: │\n",
+ "│ ❱ 336 │ │ │ hidden_states = layer(hidden_states) │\n",
+ "│ 337 │ │ return hidden_states │\n",
+ "│ 338 │\n",
+ "│ 339 │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/module │\n",
+ "│ .py:1501 in _call_impl │\n",
+ "│ │\n",
+ "│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │\n",
+ "│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │\n",
+ "│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │\n",
+ "│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │\n",
+ "│ 1502 │ │ # Do not call functions when jit is used │\n",
+ "│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │\n",
+ "│ 1504 │ │ backward_pre_hooks = [] │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc │\n",
+ "│ odec/modeling_encodec.py:162 in forward │\n",
+ "│ │\n",
+ "│ 159 │ │ │ # Asymmetric padding required for odd strides │\n",
+ "│ 160 │ │ │ padding_right = padding_total // 2 │\n",
+ "│ 161 │ │ │ padding_left = padding_total - padding_right │\n",
+ "│ ❱ 162 │ │ │ hidden_states = self._pad1d( │\n",
+ "│ 163 │ │ │ │ hidden_states, (padding_left, padding_right + extra_padding), mode=self. │\n",
+ "│ 164 │ │ │ ) │\n",
+ "│ 165 │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc │\n",
+ "│ odec/modeling_encodec.py:143 in _pad1d │\n",
+ "│ │\n",
+ "│ 140 │ │ if length <= max_pad: │\n",
+ "│ 141 │ │ │ extra_pad = max_pad - length + 1 │\n",
+ "│ 142 │ │ │ hidden_states = nn.functional.pad(hidden_states, (0, extra_pad)) │\n",
+ "│ ❱ 143 │ │ padded = nn.functional.pad(hidden_states, paddings, mode, value) │\n",
+ "│ 144 │ │ end = padded.shape[-1] - extra_pad │\n",
+ "│ 145 │ │ return padded[..., :end] │\n",
+ "│ 146 │\n",
+ "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+ "RuntimeError: CUDA error: device-side assert triggered\n",
+ "CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be \n",
+ "incorrect.\n",
+ "For debugging consider passing CUDA_LAUNCH_BLOCKING=1.\n",
+ "Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n",
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n",
+ "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m1\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1 decoded = b2m.musicgen_decoder.audio_encoder.decode(audio_codes = sampled.unsqueeze(\u001b[94m0\u001b[0m).u \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m2 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33modec/\u001b[0m\u001b[1;33mmodeling_encodec.py\u001b[0m:\u001b[94m742\u001b[0m in \u001b[92mdecode\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m739 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m chunk_length \u001b[95mis\u001b[0m \u001b[94mNone\u001b[0m: \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m740 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mlen\u001b[0m(audio_codes) != \u001b[94m1\u001b[0m: \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m741 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[94mraise\u001b[0m \u001b[96mValueError\u001b[0m(\u001b[33mf\u001b[0m\u001b[33m\"\u001b[0m\u001b[33mExpected one frame, got \u001b[0m\u001b[33m{\u001b[0m\u001b[96mlen\u001b[0m(audio_codes)\u001b[33m}\u001b[0m\u001b[33m\"\u001b[0m) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m742 \u001b[2m│ │ │ \u001b[0maudio_values = \u001b[96mself\u001b[0m._decode_frame(audio_codes[\u001b[94m0\u001b[0m], audio_scales[\u001b[94m0\u001b[0m]) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m743 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94melse\u001b[0m: \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m744 \u001b[0m\u001b[2m│ │ │ \u001b[0mdecoded_frames = [] \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m745 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33modec/\u001b[0m\u001b[1;33mmodeling_encodec.py\u001b[0m:\u001b[94m707\u001b[0m in \u001b[92m_decode_frame\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m704 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92m_decode_frame\u001b[0m(\u001b[96mself\u001b[0m, codes: torch.Tensor, scale: Optional[torch.Tensor] = \u001b[94mNone\u001b[0m) - \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m705 \u001b[0m\u001b[2m│ │ \u001b[0mcodes = codes.transpose(\u001b[94m0\u001b[0m, \u001b[94m1\u001b[0m) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m706 \u001b[0m\u001b[2m│ │ \u001b[0membeddings = \u001b[96mself\u001b[0m.quantizer.decode(codes) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m707 \u001b[2m│ │ \u001b[0moutputs = \u001b[96mself\u001b[0m.decoder(embeddings) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m708 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m scale \u001b[95mis\u001b[0m \u001b[95mnot\u001b[0m \u001b[94mNone\u001b[0m: \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m709 \u001b[0m\u001b[2m│ │ │ \u001b[0moutputs = outputs * scale.view(-\u001b[94m1\u001b[0m, \u001b[94m1\u001b[0m, \u001b[94m1\u001b[0m) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m710 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m outputs \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/\u001b[0m\u001b[1;33mmodule\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[1;33m.py\u001b[0m:\u001b[94m1501\u001b[0m in \u001b[92m_call_impl\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1498 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[95mnot\u001b[0m (\u001b[96mself\u001b[0m._backward_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._backward_pre_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._forward_hooks \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1499 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_backward_pre_hooks \u001b[95mor\u001b[0m _global_backward_hooks \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1500 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_forward_hooks \u001b[95mor\u001b[0m _global_forward_pre_hooks): \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1501 \u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m forward_call(*args, **kwargs) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1502 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# Do not call functions when jit is used\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1503 \u001b[0m\u001b[2m│ │ \u001b[0mfull_backward_hooks, non_full_backward_hooks = [], [] \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1504 \u001b[0m\u001b[2m│ │ \u001b[0mbackward_pre_hooks = [] \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33modec/\u001b[0m\u001b[1;33mmodeling_encodec.py\u001b[0m:\u001b[94m336\u001b[0m in \u001b[92mforward\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m333 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m334 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mforward\u001b[0m(\u001b[96mself\u001b[0m, hidden_states): \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m335 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mfor\u001b[0m layer \u001b[95min\u001b[0m \u001b[96mself\u001b[0m.layers: \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m336 \u001b[2m│ │ │ \u001b[0mhidden_states = layer(hidden_states) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m337 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m hidden_states \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m338 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m339 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/\u001b[0m\u001b[1;33mmodule\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[1;33m.py\u001b[0m:\u001b[94m1501\u001b[0m in \u001b[92m_call_impl\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1498 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[95mnot\u001b[0m (\u001b[96mself\u001b[0m._backward_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._backward_pre_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._forward_hooks \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1499 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_backward_pre_hooks \u001b[95mor\u001b[0m _global_backward_hooks \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1500 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_forward_hooks \u001b[95mor\u001b[0m _global_forward_pre_hooks): \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1501 \u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m forward_call(*args, **kwargs) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1502 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# Do not call functions when jit is used\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1503 \u001b[0m\u001b[2m│ │ \u001b[0mfull_backward_hooks, non_full_backward_hooks = [], [] \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1504 \u001b[0m\u001b[2m│ │ \u001b[0mbackward_pre_hooks = [] \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33modec/\u001b[0m\u001b[1;33mmodeling_encodec.py\u001b[0m:\u001b[94m162\u001b[0m in \u001b[92mforward\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m159 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[2m# Asymmetric padding required for odd strides\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m160 \u001b[0m\u001b[2m│ │ │ \u001b[0mpadding_right = padding_total // \u001b[94m2\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m161 \u001b[0m\u001b[2m│ │ │ \u001b[0mpadding_left = padding_total - padding_right \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m162 \u001b[2m│ │ │ \u001b[0mhidden_states = \u001b[96mself\u001b[0m._pad1d( \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m163 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mhidden_states, (padding_left, padding_right + extra_padding), mode=\u001b[96mself\u001b[0m. \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m164 \u001b[0m\u001b[2m│ │ │ \u001b[0m) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m165 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33modec/\u001b[0m\u001b[1;33mmodeling_encodec.py\u001b[0m:\u001b[94m143\u001b[0m in \u001b[92m_pad1d\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m140 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m length <= max_pad: \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m141 \u001b[0m\u001b[2m│ │ │ \u001b[0mextra_pad = max_pad - length + \u001b[94m1\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m142 \u001b[0m\u001b[2m│ │ │ \u001b[0mhidden_states = nn.functional.pad(hidden_states, (\u001b[94m0\u001b[0m, extra_pad)) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m143 \u001b[2m│ │ \u001b[0mpadded = nn.functional.pad(hidden_states, paddings, mode, value) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m144 \u001b[0m\u001b[2m│ │ \u001b[0mend = padded.shape[-\u001b[94m1\u001b[0m] - extra_pad \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m145 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m padded[..., :end] \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m146 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n",
+ "\u001b[1;91mRuntimeError: \u001b[0mCUDA error: device-side assert triggered\n",
+ "CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be \n",
+ "incorrect.\n",
+ "For debugging consider passing \u001b[33mCUDA_LAUNCH_BLOCKING\u001b[0m=\u001b[1;36m1\u001b[0m.\n",
+ "Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n",
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "decoded = b2m.musicgen_decoder.audio_encoder.decode(audio_codes = sampled.unsqueeze(0).unsqueeze(0).detach().clone(), audio_scales = [None])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 94,
+ "id": "61a216c2-7357-4a06-85aa-13db8acbba79",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Fri Sep 8 05:39:28 2023 \n",
+ "+-----------------------------------------------------------------------------+\n",
+ "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n",
+ "|-------------------------------+----------------------+----------------------+\n",
+ "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
+ "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
+ "| | | MIG M. |\n",
+ "|===============================+======================+======================|\n",
+ "| 0 NVIDIA A100-SXM... On | 00000000:10:1C.0 Off | 0 |\n",
+ "| N/A 30C P0 70W / 400W | 29365MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 1 NVIDIA A100-SXM... On | 00000000:10:1D.0 Off | 0 |\n",
+ "| N/A 41C P0 240W / 400W | 33897MiB / 40960MiB | 92% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 2 NVIDIA A100-SXM... On | 00000000:20:1C.0 Off | 0 |\n",
+ "| N/A 44C P0 208W / 400W | 37457MiB / 40960MiB | 82% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 3 NVIDIA A100-SXM... On | 00000000:20:1D.0 Off | 0 |\n",
+ "| N/A 38C P0 189W / 400W | 40233MiB / 40960MiB | 76% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 4 NVIDIA A100-SXM... On | 00000000:90:1C.0 Off | 0 |\n",
+ "| N/A 41C P0 242W / 400W | 34031MiB / 40960MiB | 95% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 5 NVIDIA A100-SXM... On | 00000000:90:1D.0 Off | 0 |\n",
+ "| N/A 38C P0 201W / 400W | 30723MiB / 40960MiB | 100% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 6 NVIDIA A100-SXM... On | 00000000:A0:1C.0 Off | 0 |\n",
+ "| N/A 39C P0 183W / 400W | 38225MiB / 40960MiB | 97% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 7 NVIDIA A100-SXM... On | 00000000:A0:1D.0 Off | 0 |\n",
+ "| N/A 38C P0 197W / 400W | 40121MiB / 40960MiB | 79% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ " \n",
+ "+-----------------------------------------------------------------------------+\n",
+ "| Processes: |\n",
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
+ "| ID ID Usage |\n",
+ "|=============================================================================|\n",
+ "| 0 N/A N/A 895711 C ...3/envs/mindeye/bin/python 29362MiB |\n",
+ "| 1 N/A N/A 919504 C ...ari/llama_env/bin/python3 33894MiB |\n",
+ "| 2 N/A N/A 919505 C ...ari/llama_env/bin/python3 37454MiB |\n",
+ "| 3 N/A N/A 919506 C ...ari/llama_env/bin/python3 40230MiB |\n",
+ "| 4 N/A N/A 919507 C ...ari/llama_env/bin/python3 34028MiB |\n",
+ "| 5 N/A N/A 919508 C ...ari/llama_env/bin/python3 30720MiB |\n",
+ "| 6 N/A N/A 919509 C ...ari/llama_env/bin/python3 38222MiB |\n",
+ "| 7 N/A N/A 919510 C ...ari/llama_env/bin/python3 40118MiB |\n",
+ "+-----------------------------------------------------------------------------+\n"
+ ]
+ }
+ ],
+ "source": [
+ "!nvidia-smi"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 52,
+ "id": "f44e9028-39c7-4a55-a6e0-ffc84de2b093",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 1, 479360])"
+ ]
+ },
+ "execution_count": 52,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "decoded.audio_values.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 53,
+ "id": "a2fb3ad8-7a2b-4baf-adb5-485b1deb4828",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 53,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from IPython.display import Audio\n",
+ "Audio(decoded[0][0].cpu().numpy(), rate=32000)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 59,
+ "id": "48e8680b-bdc3-4c41-af8f-511ca0c0acf8",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n",
+ "│ in <module>:2 │\n",
+ "│ │\n",
+ "│ 1 projected_pseudo_encoded_fmri = torch.rand((1,15,1024)) │\n",
+ "│ ❱ 2 gepe = b2m.musicgen_decoder.generate( │\n",
+ "│ 3 │ │ │ encoder_hidden_states = projected_pseudo_encoded_fmri, │\n",
+ "│ 4 │ │ │ max_length = 752 │\n",
+ "│ 5 │ │ ) │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/utils/_contextlib │\n",
+ "│ .py:115 in decorate_context │\n",
+ "│ │\n",
+ "│ 112 │ @functools.wraps(func) │\n",
+ "│ 113 │ def decorate_context(*args, **kwargs): │\n",
+ "│ 114 │ │ with ctx_factory(): │\n",
+ "│ ❱ 115 │ │ │ return func(*args, **kwargs) │\n",
+ "│ 116 │ │\n",
+ "│ 117 │ return decorate_context │\n",
+ "│ 118 │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/mus │\n",
+ "│ icgen/modeling_musicgen.py:2261 in generate │\n",
+ "│ │\n",
+ "│ 2258 │ │ generation_config = copy.deepcopy(generation_config) │\n",
+ "│ 2259 │ │ model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be m │\n",
+ "│ 2260 │ │ generation_config.validate() │\n",
+ "│ ❱ 2261 │ │ self._validate_model_kwargs(model_kwargs.copy()) │\n",
+ "│ 2262 │ │ │\n",
+ "│ 2263 │ │ if model_kwargs.get(\"encoder_outputs\") is not None and type(model_kwargs[\"encode │\n",
+ "│ 2264 │ │ │ # wrap the unconditional outputs as a BaseModelOutput for compatibility with │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/generation │\n",
+ "│ /utils.py:1249 in _validate_model_kwargs │\n",
+ "│ │\n",
+ "│ 1246 │ │ │ │ unused_model_args.append(key) │\n",
+ "│ 1247 │ │ │\n",
+ "│ 1248 │ │ if unused_model_args: │\n",
+ "│ ❱ 1249 │ │ │ raise ValueError( │\n",
+ "│ 1250 │ │ │ │ f\"The following `model_kwargs` are not used by the model: {unused_model_ │\n",
+ "│ 1251 │ │ │ │ \" generate arguments will also show up in this list)\" │\n",
+ "│ 1252 │ │ │ ) │\n",
+ "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+ "ValueError: The following `model_kwargs` are not used by the model: ['encoder_hidden_states'] (note: typos in the \n",
+ "generate arguments will also show up in this list)\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n",
+ "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m2\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1 \u001b[0mprojected_pseudo_encoded_fmri = torch.rand((\u001b[94m1\u001b[0m,\u001b[94m15\u001b[0m,\u001b[94m1024\u001b[0m)) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m2 gepe = b2m.musicgen_decoder.generate( \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m3 \u001b[0m\u001b[2m│ │ │ \u001b[0mencoder_hidden_states = projected_pseudo_encoded_fmri, \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m4 \u001b[0m\u001b[2m│ │ │ \u001b[0mmax_length = \u001b[94m752\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m5 \u001b[0m\u001b[2m│ │ \u001b[0m) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/utils/\u001b[0m\u001b[1;33m_contextlib\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[1;33m.py\u001b[0m:\u001b[94m115\u001b[0m in \u001b[92mdecorate_context\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m112 \u001b[0m\u001b[2m│ \u001b[0m\u001b[1;95m@functools\u001b[0m.wraps(func) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m113 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mdecorate_context\u001b[0m(*args, **kwargs): \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m114 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mwith\u001b[0m ctx_factory(): \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m115 \u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m func(*args, **kwargs) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m116 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m117 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mreturn\u001b[0m decorate_context \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m118 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/mus\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33micgen/\u001b[0m\u001b[1;33mmodeling_musicgen.py\u001b[0m:\u001b[94m2261\u001b[0m in \u001b[92mgenerate\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m2258 \u001b[0m\u001b[2m│ │ \u001b[0mgeneration_config = copy.deepcopy(generation_config) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m2259 \u001b[0m\u001b[2m│ │ \u001b[0mmodel_kwargs = generation_config.update(**kwargs) \u001b[2m# All unused kwargs must be m\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m2260 \u001b[0m\u001b[2m│ │ \u001b[0mgeneration_config.validate() \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m2261 \u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m._validate_model_kwargs(model_kwargs.copy()) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m2262 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m2263 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m model_kwargs.get(\u001b[33m\"\u001b[0m\u001b[33mencoder_outputs\u001b[0m\u001b[33m\"\u001b[0m) \u001b[95mis\u001b[0m \u001b[95mnot\u001b[0m \u001b[94mNone\u001b[0m \u001b[95mand\u001b[0m \u001b[96mtype\u001b[0m(model_kwargs[\u001b[33m\"\u001b[0m\u001b[33mencode\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m2264 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[2m# wrap the unconditional outputs as a BaseModelOutput for compatibility with\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/generation\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/\u001b[0m\u001b[1;33mutils.py\u001b[0m:\u001b[94m1249\u001b[0m in \u001b[92m_validate_model_kwargs\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1246 \u001b[0m\u001b[2m│ │ │ │ \u001b[0munused_model_args.append(key) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1247 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1248 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m unused_model_args: \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1249 \u001b[2m│ │ │ \u001b[0m\u001b[94mraise\u001b[0m \u001b[96mValueError\u001b[0m( \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1250 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[33mf\u001b[0m\u001b[33m\"\u001b[0m\u001b[33mThe following `model_kwargs` are not used by the model: \u001b[0m\u001b[33m{\u001b[0munused_model_ \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1251 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[33m\"\u001b[0m\u001b[33m generate arguments will also show up in this list)\u001b[0m\u001b[33m\"\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1252 \u001b[0m\u001b[2m│ │ │ \u001b[0m) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n",
+ "\u001b[1;91mValueError: \u001b[0mThe following `model_kwargs` are not used by the model: \u001b[1m[\u001b[0m\u001b[32m'encoder_hidden_states'\u001b[0m\u001b[1m]\u001b[0m \u001b[1m(\u001b[0mnote: typos in the \n",
+ "generate arguments will also show up in this list\u001b[1m)\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "projected_pseudo_encoded_fmri = torch.rand((1,15,1024))\n",
+ "gepe = b2m.musicgen_decoder.generate(\n",
+ " encoder_hidden_states = projected_pseudo_encoded_fmri,\n",
+ " max_length = 752\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "059f21d7-58ed-4b35-93d6-1ea983ac3161",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "b2m = B2M()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "ae335238-50f3-456a-9dca-3a0958aaa7b7",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "ckpt = torch.load('/fsx/proj-fmri/ckadirt/b2m/src/brain2music/jggbeix7/checkpoints/epoch=39-step=4800.ckpt')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "ff5199bf-a9af-491e-9870-7b203bea49fd",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "b2m.load_state_dict(ckpt['state_dict'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "a6e4985e-514d-4ed2-bcc7-11edd2ba4387",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "B2M(\n",
+ " (brain_network): BrainNetwork(\n",
+ " (mlp): ModuleList(\n",
+ " (0-3): 4 x Sequential(\n",
+ " (0): Linear(in_features=4096, out_features=4096, bias=True)\n",
+ " (1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
+ " (2): GELU(approximate='none')\n",
+ " (3): Dropout(p=0.15, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (lin1): Linear(in_features=4096, out_features=98304, bias=True)\n",
+ " (projector): Sequential(\n",
+ " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
+ " (1): GELU(approximate='none')\n",
+ " (2): Linear(in_features=128, out_features=2048, bias=True)\n",
+ " (3): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
+ " (4): GELU(approximate='none')\n",
+ " (5): Linear(in_features=2048, out_features=2048, bias=True)\n",
+ " (6): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
+ " (7): GELU(approximate='none')\n",
+ " (8): Linear(in_features=2048, out_features=128, bias=True)\n",
+ " )\n",
+ " )\n",
+ " (ridge_regression): RidgeRegression(\n",
+ " (linear): Linear(in_features=60784, out_features=4096, bias=True)\n",
+ " )\n",
+ " (loss): CrossEntropyLoss()\n",
+ " (pseudo_text_encoder): Sequential(\n",
+ " (0): RidgeRegression(\n",
+ " (linear): Linear(in_features=60784, out_features=4096, bias=True)\n",
+ " )\n",
+ " (1): BrainNetwork(\n",
+ " (mlp): ModuleList(\n",
+ " (0-3): 4 x Sequential(\n",
+ " (0): Linear(in_features=4096, out_features=4096, bias=True)\n",
+ " (1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
+ " (2): GELU(approximate='none')\n",
+ " (3): Dropout(p=0.15, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (lin1): Linear(in_features=4096, out_features=98304, bias=True)\n",
+ " (projector): Sequential(\n",
+ " (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n",
+ " (1): GELU(approximate='none')\n",
+ " (2): Linear(in_features=128, out_features=2048, bias=True)\n",
+ " (3): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
+ " (4): GELU(approximate='none')\n",
+ " (5): Linear(in_features=2048, out_features=2048, bias=True)\n",
+ " (6): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
+ " (7): GELU(approximate='none')\n",
+ " (8): Linear(in_features=2048, out_features=128, bias=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (musicgen_decoder): MusicgenForConditionalGeneration(\n",
+ " (text_encoder): T5EncoderModel(\n",
+ " (shared): Embedding(32128, 768)\n",
+ " (encoder): T5Stack(\n",
+ " (embed_tokens): Embedding(32128, 768)\n",
+ " (block): ModuleList(\n",
+ " (0): T5Block(\n",
+ " (layer): ModuleList(\n",
+ " (0): T5LayerSelfAttention(\n",
+ " (SelfAttention): T5Attention(\n",
+ " (q): Linear(in_features=768, out_features=768, bias=False)\n",
+ " (k): Linear(in_features=768, out_features=768, bias=False)\n",
+ " (v): Linear(in_features=768, out_features=768, bias=False)\n",
+ " (o): Linear(in_features=768, out_features=768, bias=False)\n",
+ " (relative_attention_bias): Embedding(32, 12)\n",
+ " )\n",
+ " (layer_norm): T5LayerNorm()\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (1): T5LayerFF(\n",
+ " (DenseReluDense): T5DenseActDense(\n",
+ " (wi): Linear(in_features=768, out_features=3072, bias=False)\n",
+ " (wo): Linear(in_features=3072, out_features=768, bias=False)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (act): ReLU()\n",
+ " )\n",
+ " (layer_norm): T5LayerNorm()\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (1-11): 11 x T5Block(\n",
+ " (layer): ModuleList(\n",
+ " (0): T5LayerSelfAttention(\n",
+ " (SelfAttention): T5Attention(\n",
+ " (q): Linear(in_features=768, out_features=768, bias=False)\n",
+ " (k): Linear(in_features=768, out_features=768, bias=False)\n",
+ " (v): Linear(in_features=768, out_features=768, bias=False)\n",
+ " (o): Linear(in_features=768, out_features=768, bias=False)\n",
+ " )\n",
+ " (layer_norm): T5LayerNorm()\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (1): T5LayerFF(\n",
+ " (DenseReluDense): T5DenseActDense(\n",
+ " (wi): Linear(in_features=768, out_features=3072, bias=False)\n",
+ " (wo): Linear(in_features=3072, out_features=768, bias=False)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (act): ReLU()\n",
+ " )\n",
+ " (layer_norm): T5LayerNorm()\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (final_layer_norm): T5LayerNorm()\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (audio_encoder): EncodecModel(\n",
+ " (encoder): EncodecEncoder(\n",
+ " (layers): ModuleList(\n",
+ " (0): EncodecConv1d(\n",
+ " (conv): Conv1d(1, 64, kernel_size=(7,), stride=(1,))\n",
+ " )\n",
+ " (1): EncodecResnetBlock(\n",
+ " (block): ModuleList(\n",
+ " (0): ELU(alpha=1.0)\n",
+ " (1): EncodecConv1d(\n",
+ " (conv): Conv1d(64, 32, kernel_size=(3,), stride=(1,))\n",
+ " )\n",
+ " (2): ELU(alpha=1.0)\n",
+ " (3): EncodecConv1d(\n",
+ " (conv): Conv1d(32, 64, kernel_size=(1,), stride=(1,))\n",
+ " )\n",
+ " )\n",
+ " (shortcut): Identity()\n",
+ " )\n",
+ " (2): ELU(alpha=1.0)\n",
+ " (3): EncodecConv1d(\n",
+ " (conv): Conv1d(64, 128, kernel_size=(8,), stride=(4,))\n",
+ " )\n",
+ " (4): EncodecResnetBlock(\n",
+ " (block): ModuleList(\n",
+ " (0): ELU(alpha=1.0)\n",
+ " (1): EncodecConv1d(\n",
+ " (conv): Conv1d(128, 64, kernel_size=(3,), stride=(1,))\n",
+ " )\n",
+ " (2): ELU(alpha=1.0)\n",
+ " (3): EncodecConv1d(\n",
+ " (conv): Conv1d(64, 128, kernel_size=(1,), stride=(1,))\n",
+ " )\n",
+ " )\n",
+ " (shortcut): Identity()\n",
+ " )\n",
+ " (5): ELU(alpha=1.0)\n",
+ " (6): EncodecConv1d(\n",
+ " (conv): Conv1d(128, 256, kernel_size=(8,), stride=(4,))\n",
+ " )\n",
+ " (7): EncodecResnetBlock(\n",
+ " (block): ModuleList(\n",
+ " (0): ELU(alpha=1.0)\n",
+ " (1): EncodecConv1d(\n",
+ " (conv): Conv1d(256, 128, kernel_size=(3,), stride=(1,))\n",
+ " )\n",
+ " (2): ELU(alpha=1.0)\n",
+ " (3): EncodecConv1d(\n",
+ " (conv): Conv1d(128, 256, kernel_size=(1,), stride=(1,))\n",
+ " )\n",
+ " )\n",
+ " (shortcut): Identity()\n",
+ " )\n",
+ " (8): ELU(alpha=1.0)\n",
+ " (9): EncodecConv1d(\n",
+ " (conv): Conv1d(256, 512, kernel_size=(10,), stride=(5,))\n",
+ " )\n",
+ " (10): EncodecResnetBlock(\n",
+ " (block): ModuleList(\n",
+ " (0): ELU(alpha=1.0)\n",
+ " (1): EncodecConv1d(\n",
+ " (conv): Conv1d(512, 256, kernel_size=(3,), stride=(1,))\n",
+ " )\n",
+ " (2): ELU(alpha=1.0)\n",
+ " (3): EncodecConv1d(\n",
+ " (conv): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
+ " )\n",
+ " )\n",
+ " (shortcut): Identity()\n",
+ " )\n",
+ " (11): ELU(alpha=1.0)\n",
+ " (12): EncodecConv1d(\n",
+ " (conv): Conv1d(512, 1024, kernel_size=(16,), stride=(8,))\n",
+ " )\n",
+ " (13): EncodecLSTM(\n",
+ " (lstm): LSTM(1024, 1024, num_layers=2)\n",
+ " )\n",
+ " (14): ELU(alpha=1.0)\n",
+ " (15): EncodecConv1d(\n",
+ " (conv): Conv1d(1024, 128, kernel_size=(7,), stride=(1,))\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (decoder): EncodecDecoder(\n",
+ " (layers): ModuleList(\n",
+ " (0): EncodecConv1d(\n",
+ " (conv): Conv1d(128, 1024, kernel_size=(7,), stride=(1,))\n",
+ " )\n",
+ " (1): EncodecLSTM(\n",
+ " (lstm): LSTM(1024, 1024, num_layers=2)\n",
+ " )\n",
+ " (2): ELU(alpha=1.0)\n",
+ " (3): EncodecConvTranspose1d(\n",
+ " (conv): ConvTranspose1d(1024, 512, kernel_size=(16,), stride=(8,))\n",
+ " )\n",
+ " (4): EncodecResnetBlock(\n",
+ " (block): ModuleList(\n",
+ " (0): ELU(alpha=1.0)\n",
+ " (1): EncodecConv1d(\n",
+ " (conv): Conv1d(512, 256, kernel_size=(3,), stride=(1,))\n",
+ " )\n",
+ " (2): ELU(alpha=1.0)\n",
+ " (3): EncodecConv1d(\n",
+ " (conv): Conv1d(256, 512, kernel_size=(1,), stride=(1,))\n",
+ " )\n",
+ " )\n",
+ " (shortcut): Identity()\n",
+ " )\n",
+ " (5): ELU(alpha=1.0)\n",
+ " (6): EncodecConvTranspose1d(\n",
+ " (conv): ConvTranspose1d(512, 256, kernel_size=(10,), stride=(5,))\n",
+ " )\n",
+ " (7): EncodecResnetBlock(\n",
+ " (block): ModuleList(\n",
+ " (0): ELU(alpha=1.0)\n",
+ " (1): EncodecConv1d(\n",
+ " (conv): Conv1d(256, 128, kernel_size=(3,), stride=(1,))\n",
+ " )\n",
+ " (2): ELU(alpha=1.0)\n",
+ " (3): EncodecConv1d(\n",
+ " (conv): Conv1d(128, 256, kernel_size=(1,), stride=(1,))\n",
+ " )\n",
+ " )\n",
+ " (shortcut): Identity()\n",
+ " )\n",
+ " (8): ELU(alpha=1.0)\n",
+ " (9): EncodecConvTranspose1d(\n",
+ " (conv): ConvTranspose1d(256, 128, kernel_size=(8,), stride=(4,))\n",
+ " )\n",
+ " (10): EncodecResnetBlock(\n",
+ " (block): ModuleList(\n",
+ " (0): ELU(alpha=1.0)\n",
+ " (1): EncodecConv1d(\n",
+ " (conv): Conv1d(128, 64, kernel_size=(3,), stride=(1,))\n",
+ " )\n",
+ " (2): ELU(alpha=1.0)\n",
+ " (3): EncodecConv1d(\n",
+ " (conv): Conv1d(64, 128, kernel_size=(1,), stride=(1,))\n",
+ " )\n",
+ " )\n",
+ " (shortcut): Identity()\n",
+ " )\n",
+ " (11): ELU(alpha=1.0)\n",
+ " (12): EncodecConvTranspose1d(\n",
+ " (conv): ConvTranspose1d(128, 64, kernel_size=(8,), stride=(4,))\n",
+ " )\n",
+ " (13): EncodecResnetBlock(\n",
+ " (block): ModuleList(\n",
+ " (0): ELU(alpha=1.0)\n",
+ " (1): EncodecConv1d(\n",
+ " (conv): Conv1d(64, 32, kernel_size=(3,), stride=(1,))\n",
+ " )\n",
+ " (2): ELU(alpha=1.0)\n",
+ " (3): EncodecConv1d(\n",
+ " (conv): Conv1d(32, 64, kernel_size=(1,), stride=(1,))\n",
+ " )\n",
+ " )\n",
+ " (shortcut): Identity()\n",
+ " )\n",
+ " (14): ELU(alpha=1.0)\n",
+ " (15): EncodecConv1d(\n",
+ " (conv): Conv1d(64, 1, kernel_size=(7,), stride=(1,))\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (quantizer): EncodecResidualVectorQuantizer(\n",
+ " (layers): ModuleList(\n",
+ " (0-3): 4 x EncodecVectorQuantization(\n",
+ " (codebook): EncodecEuclideanCodebook()\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (decoder): MusicgenForCausalLM(\n",
+ " (model): MusicgenModel(\n",
+ " (decoder): MusicgenDecoder(\n",
+ " (embed_tokens): ModuleList(\n",
+ " (0-3): 4 x Embedding(2049, 1024)\n",
+ " )\n",
+ " (embed_positions): MusicgenSinusoidalPositionalEmbedding()\n",
+ " (layers): ModuleList(\n",
+ " (0-23): 24 x MusicgenDecoderLayer(\n",
+ " (self_attn): MusicgenAttention(\n",
+ " (k_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
+ " (v_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
+ " (q_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
+ " (out_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
+ " )\n",
+ " (activation_fn): GELUActivation()\n",
+ " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " (encoder_attn): MusicgenAttention(\n",
+ " (k_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
+ " (v_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
+ " (q_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
+ " (out_proj): Linear(in_features=1024, out_features=1024, bias=False)\n",
+ " )\n",
+ " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " (fc1): Linear(in_features=1024, out_features=4096, bias=False)\n",
+ " (fc2): Linear(in_features=4096, out_features=1024, bias=False)\n",
+ " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " )\n",
+ " (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " )\n",
+ " (lm_heads): ModuleList(\n",
+ " (0-3): 4 x Linear(in_features=1024, out_features=2048, bias=False)\n",
+ " )\n",
+ " )\n",
+ " (enc_to_dec_proj): Linear(in_features=768, out_features=1024, bias=True)\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "b2m.to('cuda')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "7aad33a3-351d-472b-86f7-55ca5ac90a26",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "fmri = np.load('/fsx/proj-fmri/ckadirt/b2m/data/sub-001_Resp_Training.npy')\n",
+ "exafmri = fmri[:, 9]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "d4d9c347-bed2-45d0-ba84-e033f18ba4a7",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "texa = torch.from_numpy(exafmri).unsqueeze(0).to('cuda')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "1a5d2300-f587-43d2-aeee-4575777710ca",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([1, 60784])\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(texa.shape)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "97d2a4a0-9dac-4e8d-8cbf-9c5da2bd589e",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([1, 128, 1024])\n"
+ ]
+ }
+ ],
+ "source": [
+ " def generate(self, x, max_length = 256):\n",
+ " with torch.no_grad():\n",
+ " # x is [batch_size, 60784]\n",
+ " # first we pass the voxels through the pseudo text encoder\n",
+ " pseudo_encoded_fmri = self.pseudo_text_encoder(x)\n",
+ " # x is [batch_size, 128, 768]\n",
+ " # now we pass the output through the musicgen projector to get [batch_size, 128, 1024]\n",
+ " projected_pseudo_encoded_fmri = self.musicgen_decoder.enc_to_dec_proj(pseudo_encoded_fmri)\n",
+ " print(projected_pseudo_encoded_fmri.shape)\n",
+ " decoder_input_ids = (\n",
+ " torch.ones((x.shape[0] * self.num_codebooks, 1), dtype=torch.long)\n",
+ " * self.pad_token_id\n",
+ " ).to(x.device)\n",
+ " for i in range(max_length):\n",
+ " # now we pass the projected pseudo encoded fmri through the musicgen decoder\n",
+ " #print(i)\n",
+ " logits = self.musicgen_decoder.decoder(\n",
+ " input_ids = decoder_input_ids,\n",
+ " encoder_hidden_states = projected_pseudo_encoded_fmri,\n",
+ " ).logits\n",
+ " # get the next token\n",
+ " #print(logits.shape)\n",
+ " next_token = logits[:,-1,:].argmax(dim=-1)\n",
+ " #print(next_token.shape)\n",
+ " # add the next token to the decoder input ids\n",
+ " decoder_input_ids = torch.cat([decoder_input_ids, next_token.unsqueeze(-1)], dim=-1)\n",
+ " #print(decoder_input_ids.shape)\n",
+ " return decoder_input_ids\n",
+ " \n",
+ " sampled = generate(b2m, texa)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "fc41e095-263f-4724-91f1-d54e33f3a9b5",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "torch.save(sampled, './samplet.pt')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "5b8eee6f-06e5-4608-974d-bc20ae297076",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n",
+ "│ in <module>:1 │\n",
+ "│ │\n",
+ "│ ❱ 1 sampled.to('cpu') │\n",
+ "│ 2 │\n",
+ "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+ "RuntimeError: CUDA error: device-side assert triggered\n",
+ "CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be \n",
+ "incorrect.\n",
+ "For debugging consider passing CUDA_LAUNCH_BLOCKING=1.\n",
+ "Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n",
+ "\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n",
+ "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m1\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1 sampled.to(\u001b[33m'\u001b[0m\u001b[33mcpu\u001b[0m\u001b[33m'\u001b[0m) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m2 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n",
+ "\u001b[1;91mRuntimeError: \u001b[0mCUDA error: device-side assert triggered\n",
+ "CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be \n",
+ "incorrect.\n",
+ "For debugging consider passing \u001b[33mCUDA_LAUNCH_BLOCKING\u001b[0m=\u001b[1;36m1\u001b[0m.\n",
+ "Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n",
+ "\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "sampled.to('cpu')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "45aff4da-82bf-48d3-bc74-2456cd63d021",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n",
+ "│ in <module>:4 │\n",
+ "│ │\n",
+ "│ 1 b2m = b2m.to('cpu') │\n",
+ "│ 2 ss = torch.load('./samplet.pt').to('cpu') │\n",
+ "│ 3 with torch.no_grad(): │\n",
+ "│ ❱ 4 │ decoded = b2m.musicgen_decoder.audio_encoder.decode(audio_codes = ss.unsqueeze(0).un │\n",
+ "│ 5 │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc │\n",
+ "│ odec/modeling_encodec.py:742 in decode │\n",
+ "│ │\n",
+ "│ 739 │ │ if chunk_length is None: │\n",
+ "│ 740 │ │ │ if len(audio_codes) != 1: │\n",
+ "│ 741 │ │ │ │ raise ValueError(f\"Expected one frame, got {len(audio_codes)}\") │\n",
+ "│ ❱ 742 │ │ │ audio_values = self._decode_frame(audio_codes[0], audio_scales[0]) │\n",
+ "│ 743 │ │ else: │\n",
+ "│ 744 │ │ │ decoded_frames = [] │\n",
+ "│ 745 │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc │\n",
+ "│ odec/modeling_encodec.py:706 in _decode_frame │\n",
+ "│ │\n",
+ "│ 703 │ │\n",
+ "│ 704 │ def _decode_frame(self, codes: torch.Tensor, scale: Optional[torch.Tensor] = None) - │\n",
+ "│ 705 │ │ codes = codes.transpose(0, 1) │\n",
+ "│ ❱ 706 │ │ embeddings = self.quantizer.decode(codes) │\n",
+ "│ 707 │ │ outputs = self.decoder(embeddings) │\n",
+ "│ 708 │ │ if scale is not None: │\n",
+ "│ 709 │ │ │ outputs = outputs * scale.view(-1, 1, 1) │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc │\n",
+ "│ odec/modeling_encodec.py:435 in decode │\n",
+ "│ │\n",
+ "│ 432 │ │ quantized_out = torch.tensor(0.0, device=codes.device) │\n",
+ "│ 433 │ │ for i, indices in enumerate(codes): │\n",
+ "│ 434 │ │ │ layer = self.layers[i] │\n",
+ "│ ❱ 435 │ │ │ quantized = layer.decode(indices) │\n",
+ "│ 436 │ │ │ quantized_out = quantized_out + quantized │\n",
+ "│ 437 │ │ return quantized_out │\n",
+ "│ 438 │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc │\n",
+ "│ odec/modeling_encodec.py:391 in decode │\n",
+ "│ │\n",
+ "│ 388 │ │ return embed_in │\n",
+ "│ 389 │ │\n",
+ "│ 390 │ def decode(self, embed_ind): │\n",
+ "│ ❱ 391 │ │ quantize = self.codebook.decode(embed_ind) │\n",
+ "│ 392 │ │ quantize = quantize.permute(0, 2, 1) │\n",
+ "│ 393 │ │ return quantize │\n",
+ "│ 394 │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc │\n",
+ "│ odec/modeling_encodec.py:372 in decode │\n",
+ "│ │\n",
+ "│ 369 │ │ return embed_ind │\n",
+ "│ 370 │ │\n",
+ "│ 371 │ def decode(self, embed_ind): │\n",
+ "│ ❱ 372 │ │ quantize = nn.functional.embedding(embed_ind, self.embed) │\n",
+ "│ 373 │ │ return quantize │\n",
+ "│ 374 │\n",
+ "│ 375 │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/functional.py: │\n",
+ "│ 2210 in embedding │\n",
+ "│ │\n",
+ "│ 2207 │ │ # torch.embedding_renorm_ │\n",
+ "│ 2208 │ │ # remove once script supports set_grad_enabled │\n",
+ "│ 2209 │ │ _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) │\n",
+ "│ ❱ 2210 │ return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) │\n",
+ "│ 2211 │\n",
+ "│ 2212 │\n",
+ "│ 2213 def embedding_bag( │\n",
+ "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+ "IndexError: index out of range in self\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n",
+ "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m4\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1 \u001b[0mb2m = b2m.to(\u001b[33m'\u001b[0m\u001b[33mcpu\u001b[0m\u001b[33m'\u001b[0m) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m2 \u001b[0mss = torch.load(\u001b[33m'\u001b[0m\u001b[33m./samplet.pt\u001b[0m\u001b[33m'\u001b[0m).to(\u001b[33m'\u001b[0m\u001b[33mcpu\u001b[0m\u001b[33m'\u001b[0m) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m3 \u001b[0m\u001b[94mwith\u001b[0m torch.no_grad(): \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m4 \u001b[2m│ \u001b[0mdecoded = b2m.musicgen_decoder.audio_encoder.decode(audio_codes = ss.unsqueeze(\u001b[94m0\u001b[0m).un \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m5 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33modec/\u001b[0m\u001b[1;33mmodeling_encodec.py\u001b[0m:\u001b[94m742\u001b[0m in \u001b[92mdecode\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m739 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m chunk_length \u001b[95mis\u001b[0m \u001b[94mNone\u001b[0m: \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m740 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mlen\u001b[0m(audio_codes) != \u001b[94m1\u001b[0m: \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m741 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[94mraise\u001b[0m \u001b[96mValueError\u001b[0m(\u001b[33mf\u001b[0m\u001b[33m\"\u001b[0m\u001b[33mExpected one frame, got \u001b[0m\u001b[33m{\u001b[0m\u001b[96mlen\u001b[0m(audio_codes)\u001b[33m}\u001b[0m\u001b[33m\"\u001b[0m) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m742 \u001b[2m│ │ │ \u001b[0maudio_values = \u001b[96mself\u001b[0m._decode_frame(audio_codes[\u001b[94m0\u001b[0m], audio_scales[\u001b[94m0\u001b[0m]) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m743 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94melse\u001b[0m: \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m744 \u001b[0m\u001b[2m│ │ │ \u001b[0mdecoded_frames = [] \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m745 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33modec/\u001b[0m\u001b[1;33mmodeling_encodec.py\u001b[0m:\u001b[94m706\u001b[0m in \u001b[92m_decode_frame\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m703 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m704 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92m_decode_frame\u001b[0m(\u001b[96mself\u001b[0m, codes: torch.Tensor, scale: Optional[torch.Tensor] = \u001b[94mNone\u001b[0m) - \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m705 \u001b[0m\u001b[2m│ │ \u001b[0mcodes = codes.transpose(\u001b[94m0\u001b[0m, \u001b[94m1\u001b[0m) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m706 \u001b[2m│ │ \u001b[0membeddings = \u001b[96mself\u001b[0m.quantizer.decode(codes) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m707 \u001b[0m\u001b[2m│ │ \u001b[0moutputs = \u001b[96mself\u001b[0m.decoder(embeddings) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m708 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m scale \u001b[95mis\u001b[0m \u001b[95mnot\u001b[0m \u001b[94mNone\u001b[0m: \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m709 \u001b[0m\u001b[2m│ │ │ \u001b[0moutputs = outputs * scale.view(-\u001b[94m1\u001b[0m, \u001b[94m1\u001b[0m, \u001b[94m1\u001b[0m) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33modec/\u001b[0m\u001b[1;33mmodeling_encodec.py\u001b[0m:\u001b[94m435\u001b[0m in \u001b[92mdecode\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m432 \u001b[0m\u001b[2m│ │ \u001b[0mquantized_out = torch.tensor(\u001b[94m0.0\u001b[0m, device=codes.device) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m433 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mfor\u001b[0m i, indices \u001b[95min\u001b[0m \u001b[96menumerate\u001b[0m(codes): \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m434 \u001b[0m\u001b[2m│ │ │ \u001b[0mlayer = \u001b[96mself\u001b[0m.layers[i] \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m435 \u001b[2m│ │ │ \u001b[0mquantized = layer.decode(indices) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m436 \u001b[0m\u001b[2m│ │ │ \u001b[0mquantized_out = quantized_out + quantized \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m437 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m quantized_out \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m438 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33modec/\u001b[0m\u001b[1;33mmodeling_encodec.py\u001b[0m:\u001b[94m391\u001b[0m in \u001b[92mdecode\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m388 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m embed_in \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m389 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m390 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mdecode\u001b[0m(\u001b[96mself\u001b[0m, embed_ind): \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m391 \u001b[2m│ │ \u001b[0mquantize = \u001b[96mself\u001b[0m.codebook.decode(embed_ind) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m392 \u001b[0m\u001b[2m│ │ \u001b[0mquantize = quantize.permute(\u001b[94m0\u001b[0m, \u001b[94m2\u001b[0m, \u001b[94m1\u001b[0m) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m393 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m quantize \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m394 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33modec/\u001b[0m\u001b[1;33mmodeling_encodec.py\u001b[0m:\u001b[94m372\u001b[0m in \u001b[92mdecode\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m369 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m embed_ind \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m370 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m371 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mdecode\u001b[0m(\u001b[96mself\u001b[0m, embed_ind): \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m372 \u001b[2m│ │ \u001b[0mquantize = nn.functional.embedding(embed_ind, \u001b[96mself\u001b[0m.embed) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m373 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m quantize \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m374 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m375 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/\u001b[0m\u001b[1;33mfunctional.py\u001b[0m: \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[94m2210\u001b[0m in \u001b[92membedding\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m2207 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# torch.embedding_renorm_\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m2208 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# remove once script supports set_grad_enabled\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m2209 \u001b[0m\u001b[2m│ │ \u001b[0m_no_grad_embedding_renorm_(weight, \u001b[96minput\u001b[0m, max_norm, norm_type) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m2210 \u001b[2m│ \u001b[0m\u001b[94mreturn\u001b[0m torch.embedding(weight, \u001b[96minput\u001b[0m, padding_idx, scale_grad_by_freq, sparse) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m2211 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m2212 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m2213 \u001b[0m\u001b[94mdef\u001b[0m \u001b[92membedding_bag\u001b[0m( \u001b[31m│\u001b[0m\n",
+ "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n",
+ "\u001b[1;91mIndexError: \u001b[0mindex out of range in self\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "b2m = b2m.to('cpu')\n",
+ "ss = torch.load('./samplet.pt').to('cpu')\n",
+ "with torch.no_grad():\n",
+ " decoded = b2m.musicgen_decoder.audio_encoder.decode(audio_codes = ss.unsqueeze(0).unsqueeze(0), audio_scales = [None])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "7cca542f-a48e-43e0-bddb-aadf7a8c1068",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([4, 257])"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ss.shape"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}