"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:165: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /admin/home-ckadirt/miniconda3/envs/mindeye/lib/pyth ...\n",
+ " rank_zero_warn(\n",
+ "Using 16bit Automatic Mixed Precision (AMP)\n",
+ "GPU available: True (cuda), used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "IPU available: False, using: 0 IPUs\n",
+ "HPU available: False, using: 0 HPUs\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The len is 480\n",
+ "The len is 60\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "\n",
+ " | Name | Type | Params\n",
+ "-------------------------------------------------------------------------\n",
+ "0 | brain_network | BrainNetwork | 474 M \n",
+ "1 | ridge_regression | RidgeRegression | 248 M \n",
+ "2 | loss | CrossEntropyLoss | 0 \n",
+ "3 | pseudo_text_encoder | Sequential | 723 M \n",
+ "4 | musicgen_decoder | MusicgenForConditionalGeneration | 588 M \n",
+ "-------------------------------------------------------------------------\n",
+ "1.3 B Trainable params\n",
+ "2.1 M Non-trainable params\n",
+ "1.3 B Total params\n",
+ "5,250.392 Total estimated model params size (MB)\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Sanity Checking DataLoader 0: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 96 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
+ " rank_zero_warn(\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n",
+ "│ in <module>:12 │\n",
+ "│ │\n",
+ "│ 9 trainer = pl.Trainer(devices=1, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, │\n",
+ "│ 10 │\n",
+ "│ 11 # train the model │\n",
+ "│ ❱ 12 trainer.fit(b2m, datamodule=data_module) │\n",
+ "│ 13 │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train │\n",
+ "│ er/trainer.py:529 in fit │\n",
+ "│ │\n",
+ "│ 526 │ │ \"\"\" │\n",
+ "│ 527 │ │ model = _maybe_unwrap_optimized(model) │\n",
+ "│ 528 │ │ self.strategy._lightning_module = model │\n",
+ "│ ❱ 529 │ │ call._call_and_handle_interrupt( │\n",
+ "│ 530 │ │ │ self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, │\n",
+ "│ 531 │ │ ) │\n",
+ "│ 532 │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train │\n",
+ "│ er/call.py:42 in _call_and_handle_interrupt │\n",
+ "│ │\n",
+ "│ 39 │ try: │\n",
+ "│ 40 │ │ if trainer.strategy.launcher is not None: │\n",
+ "│ 41 │ │ │ return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, │\n",
+ "│ ❱ 42 │ │ return trainer_fn(*args, **kwargs) │\n",
+ "│ 43 │ │\n",
+ "│ 44 │ except _TunerExitException: │\n",
+ "│ 45 │ │ _call_teardown_hook(trainer) │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train │\n",
+ "│ er/trainer.py:568 in _fit_impl │\n",
+ "│ │\n",
+ "│ 565 │ │ │ model_provided=True, │\n",
+ "│ 566 │ │ │ model_connected=self.lightning_module is not None, │\n",
+ "│ 567 │ │ ) │\n",
+ "│ ❱ 568 │ │ self._run(model, ckpt_path=ckpt_path) │\n",
+ "│ 569 │ │ │\n",
+ "│ 570 │ │ assert self.state.stopped │\n",
+ "│ 571 │ │ self.training = False │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train │\n",
+ "│ er/trainer.py:973 in _run │\n",
+ "│ │\n",
+ "│ 970 │ │ # ---------------------------- │\n",
+ "│ 971 │ │ # RUN THE TRAINER │\n",
+ "│ 972 │ │ # ---------------------------- │\n",
+ "│ ❱ 973 │ │ results = self._run_stage() │\n",
+ "│ 974 │ │ │\n",
+ "│ 975 │ │ # ---------------------------- │\n",
+ "│ 976 │ │ # POST-Training CLEAN UP │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train │\n",
+ "│ er/trainer.py:1014 in _run_stage │\n",
+ "│ │\n",
+ "│ 1011 │ │ │ return self.predict_loop.run() │\n",
+ "│ 1012 │ │ if self.training: │\n",
+ "│ 1013 │ │ │ with isolate_rng(): │\n",
+ "│ ❱ 1014 │ │ │ │ self._run_sanity_check() │\n",
+ "│ 1015 │ │ │ with torch.autograd.set_detect_anomaly(self._detect_anomaly): │\n",
+ "│ 1016 │ │ │ │ self.fit_loop.run() │\n",
+ "│ 1017 │ │ │ return None │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train │\n",
+ "│ er/trainer.py:1043 in _run_sanity_check │\n",
+ "│ │\n",
+ "│ 1040 │ │ │ call._call_callback_hooks(self, \"on_sanity_check_start\") │\n",
+ "│ 1041 │ │ │ │\n",
+ "│ 1042 │ │ │ # run eval step │\n",
+ "│ ❱ 1043 │ │ │ val_loop.run() │\n",
+ "│ 1044 │ │ │ │\n",
+ "│ 1045 │ │ │ call._call_callback_hooks(self, \"on_sanity_check_end\") │\n",
+ "│ 1046 │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/loops │\n",
+ "│ /utilities.py:177 in _decorator │\n",
+ "│ │\n",
+ "│ 174 │ │ else: │\n",
+ "│ 175 │ │ │ context_manager = torch.no_grad │\n",
+ "│ 176 │ │ with context_manager(): │\n",
+ "│ ❱ 177 │ │ │ return loop_run(self, *args, **kwargs) │\n",
+ "│ 178 │ │\n",
+ "│ 179 │ return _decorator │\n",
+ "│ 180 │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/loops │\n",
+ "│ /evaluation_loop.py:115 in run │\n",
+ "│ │\n",
+ "│ 112 │ │ │ │ │ self._store_dataloader_outputs() │\n",
+ "│ 113 │ │ │ │ previous_dataloader_idx = dataloader_idx │\n",
+ "│ 114 │ │ │ │ # run step hooks │\n",
+ "│ ❱ 115 │ │ │ │ self._evaluation_step(batch, batch_idx, dataloader_idx) │\n",
+ "│ 116 │ │ │ except StopIteration: │\n",
+ "│ 117 │ │ │ │ # this needs to wrap the `*_step` call too (not just `next`) for `datalo │\n",
+ "│ 118 │ │ │ │ break │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/loops │\n",
+ "│ /evaluation_loop.py:375 in _evaluation_step │\n",
+ "│ │\n",
+ "│ 372 │ │ self.batch_progress.increment_started() │\n",
+ "│ 373 │ │ │\n",
+ "│ 374 │ │ hook_name = \"test_step\" if trainer.testing else \"validation_step\" │\n",
+ "│ ❱ 375 │ │ output = call._call_strategy_hook(trainer, hook_name, *step_kwargs.values()) │\n",
+ "│ 376 │ │ │\n",
+ "│ 377 │ │ self.batch_progress.increment_processed() │\n",
+ "│ 378 │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train │\n",
+ "│ er/call.py:291 in _call_strategy_hook │\n",
+ "│ │\n",
+ "│ 288 │ │ return None │\n",
+ "│ 289 │ │\n",
+ "│ 290 │ with trainer.profiler.profile(f\"[Strategy]{trainer.strategy.__class__.__name__}.{hoo │\n",
+ "│ ❱ 291 │ │ output = fn(*args, **kwargs) │\n",
+ "│ 292 │ │\n",
+ "│ 293 │ # restore current_fx when nested context │\n",
+ "│ 294 │ pl_module._current_fx_name = prev_fx_name │\n",
+ "│ │\n",
+ "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/strat │\n",
+ "│ egies/strategy.py:379 in validation_step │\n",
+ "│ │\n",
+ "│ 376 │ │ \"\"\" │\n",
+ "│ 377 │ │ with self.precision_plugin.val_step_context(): │\n",
+ "│ 378 │ │ │ assert isinstance(self.model, ValidationStep) │\n",
+ "│ ❱ 379 │ │ │ return self.model.validation_step(*args, **kwargs) │\n",
+ "│ 380 │ │\n",
+ "│ 381 │ def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: │\n",
+ "│ 382 │ │ \"\"\"The actual test step. │\n",
+ "│ │\n",
+ "│ in validation_step:116 │\n",
+ "│ │\n",
+ "│ 113 │ │ # get the loss │\n",
+ "│ 114 │ │ loss = self.loss(rearrange(logits, '(b c) t d -> (b c t) d', c=self.num_codebook │\n",
+ "│ 115 │ │ │\n",
+ "│ ❱ 116 │ │ acuracy = self.tokens_accuracy(logits, embeddings) │\n",
+ "│ 117 │ │ self.log('val_loss', loss, sync_dist=True) │\n",
+ "│ 118 │ │ self.log('val_accuracy', acuracy, sync_dist=True) │\n",
+ "│ 119 │ │ discrete_outputs = logits.argmax(dim=2) │\n",
+ "│ │\n",
+ "│ in tokens_accuracy:80 │\n",
+ "│ │\n",
+ "│ 77 │ │ # we need to get the index of the maximum value of each token │\n",
+ "│ 78 │ │ outputs = outputs.argmax(dim=2) │\n",
+ "│ 79 │ │ # now we need to compare the outputs with the embeddings │\n",
+ "│ ❱ 80 │ │ return (outputs == embeddings).float().mean() │\n",
+ "│ 81 │ │\n",
+ "│ 82 │ def on_train_epoch_end(self): │\n",
+ "│ 83 │ │ self.train_outptus = torch.cat(self.train_outptus) │\n",
+ "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+ "RuntimeError: The size of tensor a (749) must match the size of tensor b (750) at non-singleton dimension 1\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n",
+ "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m12\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 9 \u001b[0mtrainer = pl.Trainer(devices=\u001b[94m1\u001b[0m, accelerator=\u001b[33m\"\u001b[0m\u001b[33mgpu\u001b[0m\u001b[33m\"\u001b[0m, max_epochs=\u001b[94m400\u001b[0m, logger=wandb_logger, \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m10 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m11 \u001b[0m\u001b[2m# train the model\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m12 trainer.fit(b2m, datamodule=data_module) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m13 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33mer/\u001b[0m\u001b[1;33mtrainer.py\u001b[0m:\u001b[94m529\u001b[0m in \u001b[92mfit\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 526 \u001b[0m\u001b[2;33m│ │ \u001b[0m\u001b[33m\"\"\"\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 527 \u001b[0m\u001b[2m│ │ \u001b[0mmodel = _maybe_unwrap_optimized(model) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 528 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.strategy._lightning_module = model \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 529 \u001b[2m│ │ \u001b[0mcall._call_and_handle_interrupt( \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 530 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[96mself\u001b[0m, \u001b[96mself\u001b[0m._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 531 \u001b[0m\u001b[2m│ │ \u001b[0m) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 532 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33mer/\u001b[0m\u001b[1;33mcall.py\u001b[0m:\u001b[94m42\u001b[0m in \u001b[92m_call_and_handle_interrupt\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 39 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mtry\u001b[0m: \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 40 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m trainer.strategy.launcher \u001b[95mis\u001b[0m \u001b[95mnot\u001b[0m \u001b[94mNone\u001b[0m: \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 41 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 42 \u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m trainer_fn(*args, **kwargs) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 43 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 44 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mexcept\u001b[0m _TunerExitException: \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 45 \u001b[0m\u001b[2m│ │ \u001b[0m_call_teardown_hook(trainer) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33mer/\u001b[0m\u001b[1;33mtrainer.py\u001b[0m:\u001b[94m568\u001b[0m in \u001b[92m_fit_impl\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 565 \u001b[0m\u001b[2m│ │ │ \u001b[0mmodel_provided=\u001b[94mTrue\u001b[0m, \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 566 \u001b[0m\u001b[2m│ │ │ \u001b[0mmodel_connected=\u001b[96mself\u001b[0m.lightning_module \u001b[95mis\u001b[0m \u001b[95mnot\u001b[0m \u001b[94mNone\u001b[0m, \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 567 \u001b[0m\u001b[2m│ │ \u001b[0m) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 568 \u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m._run(model, ckpt_path=ckpt_path) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 569 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 570 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94massert\u001b[0m \u001b[96mself\u001b[0m.state.stopped \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 571 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.training = \u001b[94mFalse\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33mer/\u001b[0m\u001b[1;33mtrainer.py\u001b[0m:\u001b[94m973\u001b[0m in \u001b[92m_run\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 970 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# ----------------------------\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 971 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# RUN THE TRAINER\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 972 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# ----------------------------\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 973 \u001b[2m│ │ \u001b[0mresults = \u001b[96mself\u001b[0m._run_stage() \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 974 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 975 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# ----------------------------\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 976 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# POST-Training CLEAN UP\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33mer/\u001b[0m\u001b[1;33mtrainer.py\u001b[0m:\u001b[94m1014\u001b[0m in \u001b[92m_run_stage\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1011 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m \u001b[96mself\u001b[0m.predict_loop.run() \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1012 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mself\u001b[0m.training: \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1013 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mwith\u001b[0m isolate_rng(): \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1014 \u001b[2m│ │ │ │ \u001b[0m\u001b[96mself\u001b[0m._run_sanity_check() \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1015 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mwith\u001b[0m torch.autograd.set_detect_anomaly(\u001b[96mself\u001b[0m._detect_anomaly): \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1016 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[96mself\u001b[0m.fit_loop.run() \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1017 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m \u001b[94mNone\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33mer/\u001b[0m\u001b[1;33mtrainer.py\u001b[0m:\u001b[94m1043\u001b[0m in \u001b[92m_run_sanity_check\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1040 \u001b[0m\u001b[2m│ │ │ \u001b[0mcall._call_callback_hooks(\u001b[96mself\u001b[0m, \u001b[33m\"\u001b[0m\u001b[33mon_sanity_check_start\u001b[0m\u001b[33m\"\u001b[0m) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1041 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1042 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[2m# run eval step\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1043 \u001b[2m│ │ │ \u001b[0mval_loop.run() \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1044 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1045 \u001b[0m\u001b[2m│ │ │ \u001b[0mcall._call_callback_hooks(\u001b[96mself\u001b[0m, \u001b[33m\"\u001b[0m\u001b[33mon_sanity_check_end\u001b[0m\u001b[33m\"\u001b[0m) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m1046 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/loops\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/\u001b[0m\u001b[1;33mutilities.py\u001b[0m:\u001b[94m177\u001b[0m in \u001b[92m_decorator\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m174 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94melse\u001b[0m: \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m175 \u001b[0m\u001b[2m│ │ │ \u001b[0mcontext_manager = torch.no_grad \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m176 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mwith\u001b[0m context_manager(): \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m177 \u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m loop_run(\u001b[96mself\u001b[0m, *args, **kwargs) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m178 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m179 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mreturn\u001b[0m _decorator \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m180 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/loops\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/\u001b[0m\u001b[1;33mevaluation_loop.py\u001b[0m:\u001b[94m115\u001b[0m in \u001b[92mrun\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m112 \u001b[0m\u001b[2m│ │ │ │ │ \u001b[0m\u001b[96mself\u001b[0m._store_dataloader_outputs() \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m113 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mprevious_dataloader_idx = dataloader_idx \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m114 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[2m# run step hooks\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m115 \u001b[2m│ │ │ │ \u001b[0m\u001b[96mself\u001b[0m._evaluation_step(batch, batch_idx, dataloader_idx) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m116 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mexcept\u001b[0m \u001b[96mStopIteration\u001b[0m: \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m117 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[2m# this needs to wrap the `*_step` call too (not just `next`) for `datalo\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m118 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[94mbreak\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/loops\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/\u001b[0m\u001b[1;33mevaluation_loop.py\u001b[0m:\u001b[94m375\u001b[0m in \u001b[92m_evaluation_step\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m372 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.batch_progress.increment_started() \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m373 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m374 \u001b[0m\u001b[2m│ │ \u001b[0mhook_name = \u001b[33m\"\u001b[0m\u001b[33mtest_step\u001b[0m\u001b[33m\"\u001b[0m \u001b[94mif\u001b[0m trainer.testing \u001b[94melse\u001b[0m \u001b[33m\"\u001b[0m\u001b[33mvalidation_step\u001b[0m\u001b[33m\"\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m375 \u001b[2m│ │ \u001b[0moutput = call._call_strategy_hook(trainer, hook_name, *step_kwargs.values()) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m376 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m377 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.batch_progress.increment_processed() \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m378 \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/train\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33mer/\u001b[0m\u001b[1;33mcall.py\u001b[0m:\u001b[94m291\u001b[0m in \u001b[92m_call_strategy_hook\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m288 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m \u001b[94mNone\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m289 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m290 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mwith\u001b[0m trainer.profiler.profile(\u001b[33mf\u001b[0m\u001b[33m\"\u001b[0m\u001b[33m[Strategy]\u001b[0m\u001b[33m{\u001b[0mtrainer.strategy.\u001b[91m__class__\u001b[0m.\u001b[91m__name__\u001b[0m\u001b[33m}\u001b[0m\u001b[33m.\u001b[0m\u001b[33m{\u001b[0mhoo \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m291 \u001b[2m│ │ \u001b[0moutput = fn(*args, **kwargs) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m292 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m293 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# restore current_fx when nested context\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m294 \u001b[0m\u001b[2m│ \u001b[0mpl_module._current_fx_name = prev_fx_name \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/pytorch_lightning/strat\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2;33megies/\u001b[0m\u001b[1;33mstrategy.py\u001b[0m:\u001b[94m379\u001b[0m in \u001b[92mvalidation_step\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m376 \u001b[0m\u001b[2;33m│ │ \u001b[0m\u001b[33m\"\"\"\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m377 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mwith\u001b[0m \u001b[96mself\u001b[0m.precision_plugin.val_step_context(): \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m378 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94massert\u001b[0m \u001b[96misinstance\u001b[0m(\u001b[96mself\u001b[0m.model, ValidationStep) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m379 \u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m \u001b[96mself\u001b[0m.model.validation_step(*args, **kwargs) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m380 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m381 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mtest_step\u001b[0m(\u001b[96mself\u001b[0m, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m382 \u001b[0m\u001b[2;90m│ │ \u001b[0m\u001b[33m\"\"\"The actual test step.\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m in \u001b[92mvalidation_step\u001b[0m:\u001b[94m116\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m113 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# get the loss\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m114 \u001b[0m\u001b[2m│ │ \u001b[0mloss = \u001b[96mself\u001b[0m.loss(rearrange(logits, \u001b[33m'\u001b[0m\u001b[33m(b c) t d -> (b c t) d\u001b[0m\u001b[33m'\u001b[0m, c=\u001b[96mself\u001b[0m.num_codebook \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m115 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m116 \u001b[2m│ │ \u001b[0macuracy = \u001b[96mself\u001b[0m.tokens_accuracy(logits, embeddings) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m117 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.log(\u001b[33m'\u001b[0m\u001b[33mval_loss\u001b[0m\u001b[33m'\u001b[0m, loss, sync_dist=\u001b[94mTrue\u001b[0m) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m118 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.log(\u001b[33m'\u001b[0m\u001b[33mval_accuracy\u001b[0m\u001b[33m'\u001b[0m, acuracy, sync_dist=\u001b[94mTrue\u001b[0m) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m119 \u001b[0m\u001b[2m│ │ \u001b[0mdiscrete_outputs = logits.argmax(dim=\u001b[94m2\u001b[0m) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m in \u001b[92mtokens_accuracy\u001b[0m:\u001b[94m80\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 77 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# we need to get the index of the maximum value of each token\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 78 \u001b[0m\u001b[2m│ │ \u001b[0moutputs = outputs.argmax(dim=\u001b[94m2\u001b[0m) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 79 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# now we need to compare the outputs with the embeddings\u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 80 \u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m (outputs == embeddings).float().mean() \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 81 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 82 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mon_train_epoch_end\u001b[0m(\u001b[96mself\u001b[0m): \u001b[31m│\u001b[0m\n",
+ "\u001b[31m│\u001b[0m \u001b[2m 83 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.train_outptus = torch.cat(\u001b[96mself\u001b[0m.train_outptus) \u001b[31m│\u001b[0m\n",
+ "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n",
+ "\u001b[1;91mRuntimeError: \u001b[0mThe size of tensor a \u001b[1m(\u001b[0m\u001b[1;36m749\u001b[0m\u001b[1m)\u001b[0m must match the size of tensor b \u001b[1m(\u001b[0m\u001b[1;36m750\u001b[0m\u001b[1m)\u001b[0m at non-singleton dimension \u001b[1;36m1\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "b2m = B2M()\n",
+ "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4)\n",
+ "\n",
+ "wandb.finish()\n",
+ "\n",
+ "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n",
+ "\n",
+ "# define the trainer\n",
+ "trainer = pl.Trainer(devices=1, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n",
+ "\n",
+ "# train the model\n",
+ "trainer.fit(b2m, datamodule=data_module)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "4f458adf-1e89-4b8b-9514-08bcc9e8ef56",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Fri Sep 8 03:55:15 2023 \n",
+ "+-----------------------------------------------------------------------------+\n",
+ "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n",
+ "|-------------------------------+----------------------+----------------------+\n",
+ "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
+ "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
+ "| | | MIG M. |\n",
+ "|===============================+======================+======================|\n",
+ "| 0 NVIDIA A100-SXM... On | 00000000:10:1C.0 Off | 0 |\n",
+ "| N/A 52C P0 205W / 400W | 32829MiB / 40960MiB | 97% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 1 NVIDIA A100-SXM... On | 00000000:10:1D.0 Off | 0 |\n",
+ "| N/A 54C P0 297W / 400W | 38991MiB / 40960MiB | 98% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 2 NVIDIA A100-SXM... On | 00000000:20:1C.0 Off | 0 |\n",
+ "| N/A 59C P0 354W / 400W | 39627MiB / 40960MiB | 89% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 3 NVIDIA A100-SXM... On | 00000000:20:1D.0 Off | 0 |\n",
+ "| N/A 50C P0 180W / 400W | 39719MiB / 40960MiB | 95% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 4 NVIDIA A100-SXM... On | 00000000:90:1C.0 Off | 0 |\n",
+ "| N/A 53C P0 190W / 400W | 35069MiB / 40960MiB | 98% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 5 NVIDIA A100-SXM... On | 00000000:90:1D.0 Off | 0 |\n",
+ "| N/A 51C P0 182W / 400W | 34235MiB / 40960MiB | 98% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 6 NVIDIA A100-SXM... On | 00000000:A0:1C.0 Off | 0 |\n",
+ "| N/A 40C P0 53W / 400W | 3MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 7 NVIDIA A100-SXM... On | 00000000:A0:1D.0 Off | 0 |\n",
+ "| N/A 38C P0 56W / 400W | 3MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ " \n",
+ "+-----------------------------------------------------------------------------+\n",
+ "| Processes: |\n",
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
+ "| ID ID Usage |\n",
+ "|=============================================================================|\n",
+ "| 0 N/A N/A 874054 C ...3/envs/mindeye/bin/python 848MiB |\n",
+ "| 0 N/A N/A 877194 C ...ari/llama_env/bin/python3 31978MiB |\n",
+ "| 1 N/A N/A 877195 C ...ari/llama_env/bin/python3 38988MiB |\n",
+ "| 2 N/A N/A 877196 C ...ari/llama_env/bin/python3 39876MiB |\n",
+ "| 3 N/A N/A 877197 C ...ari/llama_env/bin/python3 39716MiB |\n",
+ "| 4 N/A N/A 877198 C ...ari/llama_env/bin/python3 35066MiB |\n",
+ "| 5 N/A N/A 877199 C ...ari/llama_env/bin/python3 34232MiB |\n",
+ "+-----------------------------------------------------------------------------+\n"
+ ]
+ }
+ ],
+ "source": [
+ "!nvidia-smi"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}