diff --git "a/examples/gene_classification.ipynb" "b/examples/gene_classification.ipynb" --- "a/examples/gene_classification.ipynb" +++ "b/examples/gene_classification.ipynb" @@ -2,593 +2,204 @@ "cells": [ { "cell_type": "markdown", + "id": "08f41458-5304-48c5-9e92-f9b56ab052c4", "metadata": {}, "source": [ "## Geneformer Fine-Tuning for Classification of Dosage-Sensitive vs. -Insensitive Transcription Factors (TFs)" ] }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "GPU_NUMBER = [0]\n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \",\".join([str(s) for s in GPU_NUMBER])\n", - "os.environ[\"NCCL_DEBUG\"] = \"INFO\"" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "# imports\n", - "import datetime\n", - "import subprocess\n", - "import math\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "from datasets import load_from_disk\n", - "from sklearn import preprocessing\n", - "from sklearn.metrics import accuracy_score, auc, confusion_matrix, ConfusionMatrixDisplay, roc_curve\n", - "from sklearn.model_selection import StratifiedKFold\n", - "import torch\n", - "from transformers import BertForTokenClassification\n", - "from transformers import Trainer\n", - "from transformers.training_args import TrainingArguments\n", - "from tqdm.notebook import tqdm\n", - "\n", - "from geneformer import DataCollatorForGeneClassification\n", - "from geneformer.pretrainer import token_dictionary" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load Gene Attribute Information" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "# table of corresponding Ensembl IDs, gene names, and gene types (e.g. coding, miRNA, etc.)\n", - "gene_info = pd.read_csv(\"/path/to/gene_info_table.csv\", index_col=0)\n", - "\n", - "# create dictionaries for corresponding attributes\n", - "gene_id_type_dict = dict(zip(gene_info[\"ensembl_id\"],gene_info[\"gene_type\"]))\n", - "gene_name_id_dict = dict(zip(gene_info[\"gene_name\"],gene_info[\"ensembl_id\"]))\n", - "gene_id_name_dict = {v: k for k,v in gene_name_id_dict.items()}" - ] - }, { "cell_type": "markdown", + "id": "79539e95-2c9c-4162-835c-f0d158abb15d", "metadata": {}, "source": [ - "## Load Training Data and Class Labels" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "# function for preparing targets and labels\n", - "def prep_inputs(genegroup1, genegroup2, id_type):\n", - " if id_type == \"gene_name\":\n", - " targets1 = [gene_name_id_dict[gene] for gene in genegroup1 if gene_name_id_dict.get(gene) in token_dictionary]\n", - " targets2 = [gene_name_id_dict[gene] for gene in genegroup2 if gene_name_id_dict.get(gene) in token_dictionary]\n", - " elif id_type == \"ensembl_id\":\n", - " targets1 = [gene for gene in genegroup1 if gene in token_dictionary]\n", - " targets2 = [gene for gene in genegroup2 if gene in token_dictionary]\n", - " \n", - " targets1_id = [token_dictionary[gene] for gene in targets1]\n", - " targets2_id = [token_dictionary[gene] for gene in targets2]\n", - " \n", - " targets = np.array(targets1_id + targets2_id)\n", - " labels = np.array([0]*len(targets1_id) + [1]*len(targets2_id))\n", - " nsplits = min(5, min(len(targets1_id), len(targets2_id))-1)\n", - " assert nsplits > 2\n", - " print(f\"# targets1: {len(targets1_id)}\\n# targets2: {len(targets2_id)}\\n# splits: {nsplits}\")\n", - " return targets, labels, nsplits" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# preparing targets and labels for dosage sensitive vs insensitive TFs\n", - "dosage_tfs = pd.read_csv(\"/path/to/dosage_sens_tf_labels.csv\", header=0)\n", - "sensitive = dosage_tfs[\"dosage_sensitive\"].dropna()\n", - "insensitive = dosage_tfs[\"dosage_insensitive\"].dropna()\n", - "targets, labels, nsplits = prep_inputs(sensitive, insensitive, \"ensembl_id\")" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "# load training dataset\n", - "train_dataset=load_from_disk(\"/path/to/gene_train_data.dataset\")\n", - "shuffled_train_dataset = train_dataset.shuffle(seed=42)\n", - "subsampled_train_dataset = shuffled_train_dataset.select([i for i in range(50_000)])" + "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses default hyperparameters, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications." ] }, { "cell_type": "markdown", + "id": "51b4852a-9f03-4bc3-ba33-79eaa4582d50", "metadata": {}, "source": [ - "## Define Functions for Training and Cross-Validating Classifier" + "### Train gene classifier with 5-fold cross-validation:" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 1, + "id": "58d59e09-5e6c-4fba-ba2b-3aee103869fd", "metadata": {}, "outputs": [], "source": [ - "def preprocess_classifier_batch(cell_batch, max_len):\n", - " if max_len == None:\n", - " max_len = max([len(i) for i in cell_batch[\"input_ids\"]])\n", - " def pad_label_example(example):\n", - " example[\"labels\"] = np.pad(example[\"labels\"], \n", - " (0, max_len-len(example[\"input_ids\"])), \n", - " mode='constant', constant_values=-100)\n", - " example[\"input_ids\"] = np.pad(example[\"input_ids\"], \n", - " (0, max_len-len(example[\"input_ids\"])), \n", - " mode='constant', constant_values=token_dictionary.get(\"\"))\n", - " example[\"attention_mask\"] = (example[\"input_ids\"] != token_dictionary.get(\"\")).astype(int)\n", - " return example\n", - " padded_batch = cell_batch.map(pad_label_example)\n", - " return padded_batch\n", - "\n", - "# forward batch size is batch size for model inference (e.g. 200)\n", - "def classifier_predict(model, evalset, forward_batch_size, mean_fpr):\n", - " predict_logits = []\n", - " predict_labels = []\n", - " model.eval()\n", - " \n", - " # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims\n", - " evalset_len = len(evalset)\n", - " max_divisible = find_largest_div(evalset_len, forward_batch_size)\n", - " if len(evalset) - max_divisible == 1:\n", - " evalset_len = max_divisible\n", - " \n", - " max_evalset_len = max(evalset.select([i for i in range(evalset_len)])[\"length\"])\n", - " \n", - " for i in range(0, evalset_len, forward_batch_size):\n", - " max_range = min(i+forward_batch_size, evalset_len)\n", - " batch_evalset = evalset.select([i for i in range(i, max_range)])\n", - " padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len)\n", - " padded_batch.set_format(type=\"torch\")\n", - " \n", - " input_data_batch = padded_batch[\"input_ids\"]\n", - " attn_msk_batch = padded_batch[\"attention_mask\"]\n", - " label_batch = padded_batch[\"labels\"]\n", - " with torch.no_grad():\n", - " outputs = model(\n", - " input_ids = input_data_batch.to(\"cuda\"), \n", - " attention_mask = attn_msk_batch.to(\"cuda\"), \n", - " labels = label_batch.to(\"cuda\"), \n", - " )\n", - " predict_logits += [torch.squeeze(outputs.logits.to(\"cpu\"))]\n", - " predict_labels += [torch.squeeze(label_batch.to(\"cpu\"))]\n", - " \n", - " logits_by_cell = torch.cat(predict_logits)\n", - " all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[2])\n", - " labels_by_cell = torch.cat(predict_labels)\n", - " all_labels = torch.flatten(labels_by_cell)\n", - " logit_label_paired = [item for item in list(zip(all_logits.tolist(), all_labels.tolist())) if item[1]!=-100]\n", - " y_pred = [vote(item[0]) for item in logit_label_paired]\n", - " y_true = [item[1] for item in logit_label_paired]\n", - " logits_list = [item[0] for item in logit_label_paired]\n", - " # probability of class 1\n", - " y_score = [py_softmax(item)[1] for item in logits_list]\n", - " conf_mat = confusion_matrix(y_true, y_pred)\n", - " fpr, tpr, _ = roc_curve(y_true, y_score)\n", - " # plot roc_curve for this split\n", - " plt.plot(fpr, tpr)\n", - " plt.xlim([0.0, 1.0])\n", - " plt.ylim([0.0, 1.05])\n", - " plt.xlabel('False Positive Rate')\n", - " plt.ylabel('True Positive Rate')\n", - " plt.title('ROC')\n", - " plt.show()\n", - " # interpolate to graph\n", - " interp_tpr = np.interp(mean_fpr, fpr, tpr)\n", - " interp_tpr[0] = 0.0\n", - " return fpr, tpr, interp_tpr, conf_mat \n", + "import datetime\n", + "import pickle\n", + "from geneformer import Classifier\n", "\n", - "def vote(logit_pair):\n", - " a, b = logit_pair\n", - " if a > b:\n", - " return 0\n", - " elif b > a:\n", - " return 1\n", - " elif a == b:\n", - " return \"tie\"\n", - " \n", - "def py_softmax(vector):\n", - "\te = np.exp(vector)\n", - "\treturn e / e.sum()\n", - " \n", - "# get cross-validated mean and sd metrics\n", - "def get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt):\n", - " wts = [count/sum(all_tpr_wt) for count in all_tpr_wt]\n", - " print(wts)\n", - " all_weighted_tpr = [a*b for a,b in zip(all_tpr, wts)]\n", - " mean_tpr = np.sum(all_weighted_tpr, axis=0)\n", - " mean_tpr[-1] = 1.0\n", - " all_weighted_roc_auc = [a*b for a,b in zip(all_roc_auc, wts)]\n", - " roc_auc = np.sum(all_weighted_roc_auc)\n", - " roc_auc_sd = math.sqrt(np.average((all_roc_auc-roc_auc)**2, weights=wts))\n", - " return mean_tpr, roc_auc, roc_auc_sd\n", + "current_date = datetime.datetime.now()\n", + "datestamp = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}{current_date.hour:02d}{current_date.minute:02d}{current_date.second:02d}\"\n", + "datestamp_min = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}\"\n", "\n", - "# Function to find the largest number smaller\n", - "# than or equal to N that is divisible by k\n", - "def find_largest_div(N, K):\n", - " rem = N % K\n", - " if(rem == 0):\n", - " return N\n", - " else:\n", - " return N - rem" + "output_prefix = \"tf_dosage_sens_test\"\n", + "output_dir = f\"/path/to/output_dir/{datestamp}\"\n", + "!mkdir $output_dir" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 2, + "id": "9e33942f-39e4-4db4-a3de-5949bed9fa5d", "metadata": {}, "outputs": [], "source": [ - "# cross-validate gene classifier\n", - "def cross_validate(data, targets, labels, nsplits, subsample_size, training_args, freeze_layers, output_dir, num_proc):\n", - " # check if output directory already written to\n", - " # ensure not overwriting previously saved model\n", - " model_dir_test = os.path.join(output_dir, \"ksplit0/models/pytorch_model.bin\")\n", - " if os.path.isfile(model_dir_test) == True:\n", - " raise Exception(\"Model already saved to this directory.\")\n", - " \n", - " # initiate eval metrics to return\n", - " num_classes = len(set(labels))\n", - " mean_fpr = np.linspace(0, 1, 100)\n", - " all_tpr = []\n", - " all_roc_auc = []\n", - " all_tpr_wt = []\n", - " label_dicts = []\n", - " confusion = np.zeros((num_classes,num_classes))\n", - " \n", - " # set up cross-validation splits\n", - " skf = StratifiedKFold(n_splits=nsplits, random_state=0, shuffle=True)\n", - " # train and evaluate\n", - " iteration_num = 0\n", - " for train_index, eval_index in tqdm(skf.split(targets, labels)):\n", - " if len(labels) > 500:\n", - " print(\"early stopping activated due to large # of training examples\")\n", - " nsplits = 3\n", - " if iteration_num == 3:\n", - " break\n", - " print(f\"****** Crossval split: {iteration_num}/{nsplits-1} ******\\n\")\n", - " # generate cross-validation splits\n", - " targets_train, targets_eval = targets[train_index], targets[eval_index]\n", - " labels_train, labels_eval = labels[train_index], labels[eval_index]\n", - " label_dict_train = dict(zip(targets_train, labels_train))\n", - " label_dict_eval = dict(zip(targets_eval, labels_eval))\n", - " label_dicts += (iteration_num, targets_train, targets_eval, labels_train, labels_eval)\n", - " \n", - " # function to filter by whether contains train or eval labels\n", - " def if_contains_train_label(example):\n", - " a = label_dict_train.keys()\n", - " b = example['input_ids']\n", - " return not set(a).isdisjoint(b)\n", - "\n", - " def if_contains_eval_label(example):\n", - " a = label_dict_eval.keys()\n", - " b = example['input_ids']\n", - " return not set(a).isdisjoint(b)\n", - " \n", - " # filter dataset for examples containing classes for this split\n", - " print(f\"Filtering training data\")\n", - " trainset = data.filter(if_contains_train_label, num_proc=num_proc)\n", - " print(f\"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\\n\")\n", - " print(f\"Filtering evalation data\")\n", - " evalset = data.filter(if_contains_eval_label, num_proc=num_proc)\n", - " print(f\"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\\n\")\n", - "\n", - " # minimize to smaller training sample\n", - " training_size = min(subsample_size, len(trainset))\n", - " trainset_min = trainset.select([i for i in range(training_size)])\n", - " eval_size = min(training_size, len(evalset))\n", - " half_training_size = round(eval_size/2)\n", - " evalset_train_min = evalset.select([i for i in range(half_training_size)])\n", - " evalset_oos_min = evalset.select([i for i in range(half_training_size, eval_size)])\n", - " \n", - " # label conversion functions\n", - " def generate_train_labels(example):\n", - " example[\"labels\"] = [label_dict_train.get(token_id, -100) for token_id in example[\"input_ids\"]]\n", - " return example\n", - "\n", - " def generate_eval_labels(example):\n", - " example[\"labels\"] = [label_dict_eval.get(token_id, -100) for token_id in example[\"input_ids\"]]\n", - " return example\n", - " \n", - " # label datasets \n", - " print(f\"Labeling training data\")\n", - " trainset_labeled = trainset_min.map(generate_train_labels)\n", - " print(f\"Labeling evaluation data\")\n", - " evalset_train_labeled = evalset_train_min.map(generate_eval_labels)\n", - " print(f\"Labeling evaluation OOS data\")\n", - " evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)\n", - " \n", - " # create output directories\n", - " ksplit_output_dir = os.path.join(output_dir, f\"ksplit{iteration_num}\")\n", - " ksplit_model_dir = os.path.join(ksplit_output_dir, \"models/\") \n", - " \n", - " # ensure not overwriting previously saved model\n", - " model_output_file = os.path.join(ksplit_model_dir, \"pytorch_model.bin\")\n", - " if os.path.isfile(model_output_file) == True:\n", - " raise Exception(\"Model already saved to this directory.\")\n", - "\n", - " # make training and model output directories\n", - " subprocess.call(f'mkdir {ksplit_output_dir}', shell=True)\n", - " subprocess.call(f'mkdir {ksplit_model_dir}', shell=True)\n", - " \n", - " # load model\n", - " model = BertForTokenClassification.from_pretrained(\n", - " \"/gladstone/theodoris/lab/ctheodoris/archive/geneformer_files/geneformer/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/\",\n", - " num_labels=2,\n", - " output_attentions = False,\n", - " output_hidden_states = False\n", - " )\n", - " if freeze_layers is not None:\n", - " modules_to_freeze = model.bert.encoder.layer[:freeze_layers]\n", - " for module in modules_to_freeze:\n", - " for param in module.parameters():\n", - " param.requires_grad = False\n", - " \n", - " model = model.to(\"cuda:0\")\n", - " \n", - " # add output directory to training args and initiate\n", - " training_args[\"output_dir\"] = ksplit_output_dir\n", - " training_args_init = TrainingArguments(**training_args)\n", - " \n", - " # create the trainer\n", - " trainer = Trainer(\n", - " model=model,\n", - " args=training_args_init,\n", - " data_collator=DataCollatorForGeneClassification(),\n", - " train_dataset=trainset_labeled,\n", - " eval_dataset=evalset_train_labeled\n", - " )\n", - "\n", - " # train the gene classifier\n", - " trainer.train()\n", - " \n", - " # save model\n", - " trainer.save_model(ksplit_model_dir)\n", - " \n", - " # evaluate model\n", - " fpr, tpr, interp_tpr, conf_mat = classifier_predict(trainer.model, evalset_oos_labeled, 200, mean_fpr)\n", - " \n", - " # append to tpr and roc lists\n", - " confusion = confusion + conf_mat\n", - " all_tpr.append(interp_tpr)\n", - " all_roc_auc.append(auc(fpr, tpr))\n", - " # append number of eval examples by which to weight tpr in averaged graphs\n", - " all_tpr_wt.append(len(tpr))\n", - " \n", - " iteration_num = iteration_num + 1\n", - " \n", - " # get overall metrics for cross-validation\n", - " mean_tpr, roc_auc, roc_auc_sd = get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt)\n", - " return all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Define Functions for Plotting Results" + "# Example input_data_file: https://huggingface.co./datasets/ctheodoris/Genecorpus-30M/blob/main/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sensitivity_TFs.pickle\n", + "with open(\"/path/to/dosage_sensitivity_TFs.pickle\", \"rb\") as fp:\n", + " gene_class_dict = pickle.load(fp)" ] }, { "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "# plot ROC curve\n", - "def plot_ROC(bundled_data, title):\n", - " plt.figure()\n", - " lw = 2\n", - " for roc_auc, roc_auc_sd, mean_fpr, mean_tpr, sample, color in bundled_data:\n", - " plt.plot(mean_fpr, mean_tpr, color=color,\n", - " lw=lw, label=\"{0} (AUC {1:0.2f} $\\pm$ {2:0.2f})\".format(sample, roc_auc, roc_auc_sd))\n", - " plt.plot([0, 1], [0, 1], color='black', lw=lw, linestyle='--')\n", - " plt.xlim([0.0, 1.0])\n", - " plt.ylim([0.0, 1.05])\n", - " plt.xlabel('False Positive Rate')\n", - " plt.ylabel('True Positive Rate')\n", - " plt.title(title)\n", - " plt.legend(loc=\"lower right\")\n", - " plt.show()\n", - " \n", - "# plot confusion matrix\n", - "def plot_confusion_matrix(classes_list, conf_mat, title):\n", - " display_labels = []\n", - " i = 0\n", - " for label in classes_list:\n", - " display_labels += [\"{0}\\nn={1:.0f}\".format(label, sum(conf_mat[:,i]))]\n", - " i = i + 1\n", - " display = ConfusionMatrixDisplay(confusion_matrix=preprocessing.normalize(conf_mat, norm=\"l1\"), \n", - " display_labels=display_labels)\n", - " display.plot(cmap=\"Blues\",values_format=\".2g\")\n", - " plt.title(title)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Fine-Tune With Gene Classification Learning Objective and Quantify Predictive Performance" - ] - }, - { - "cell_type": "markdown", + "execution_count": 3, + "id": "f4053ee9-3506-4c97-b544-8d667f0adfab", "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Hyperparameter tuning is highly recommended for optimal results. No training_args provided; using default hyperparameters.\n" + ] + } + ], "source": [ - "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example hyperparameters are defined below, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications." + "cc = Classifier(classifier=\"gene\",\n", + " gene_class_dict = gene_class_dict,\n", + " max_ncells = 10_000,\n", + " freeze_layers = 4,\n", + " num_crossval_splits = 5,\n", + " forward_batch_size=200,\n", + " nproc=16)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, + "id": "e4855e53-1cd7-4af0-b786-02b6c0e55f8c", "metadata": {}, - "outputs": [], - "source": [ - "# set model parameters\n", - "# max input size\n", - "max_input_size = 2 ** 11 # 2048\n", - "\n", - "# set training hyperparameters\n", - "# max learning rate\n", - "max_lr = 5e-5\n", - "# how many pretrained layers to freeze\n", - "freeze_layers = 4\n", - "# number gpus\n", - "num_gpus = 1\n", - "# number cpu cores\n", - "num_proc = 24\n", - "# batch size for training and eval\n", - "geneformer_batch_size = 12\n", - "# learning schedule\n", - "lr_schedule_fn = \"linear\"\n", - "# warmup steps\n", - "warmup_steps = 500\n", - "# number of epochs\n", - "epochs = 1\n", - "# optimizer\n", - "optimizer = \"adamw\"" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "tags": [] - }, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6a3f7bcf2a314368b00f49c74a775571", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Saving the dataset (0/1 shards): 0%| | 0/33558 [00:00:45: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n", + "/gladstone/theodoris/home/ctheodoris/Geneformer/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, @@ -599,47 +210,55 @@ "
\n", " \n", " \n", - " [834/834 01:33, Epoch 1/1]\n", + " [834/834 02:37, Epoch 1/1]\n", "
\n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "
StepTraining Loss
1000.684000830.729100
1660.667600
2490.553100
2000.6176003320.409100
3000.4774004150.294300
4000.3343004980.197000
5000.2295005810.138300
6000.1527006640.099900
7000.1256007470.083700
8000.1049008300.072300

" @@ -652,108 +271,77 @@ "output_type": "display_data" }, { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-4d8947ed4c65f4a4.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-8a83f628e23d5548.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-c6c437341faa1cfe.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-2010c177e27e09d1.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-15543d980ad3cbb0.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-a81a942ab15e4aa3.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-5d2c963673bb1115.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-6c7cc476a9d722c3.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-e274abd189113bba.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-1aedba9e0b982e5c.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-6668161997480231.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-d802b8093fb9c6f7.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-3ea48baa5fe880e2.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-86024b6184e99afe.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-7a47db2c9f9758a4.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-af1f6b8f743677db.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-67cffffa35fa22f7.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-81ed63bd02a44ee5.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-6e5a21d4d57e333d.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-eecde81c07e6d036.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-fcc19fab82bb7115.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-ea856d7fa4e78b24.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-698344adb3749f61.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-ee3f9e89abdbee4c.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-d98fd9d7fda61d3b.arrow\n" + "****** Validation split: 2/5 ******\n", + "\n" ] }, { "data": { - "image/png": "", + "application/vnd.jupyter.widget-view+json": { + "model_id": "d186836393d84c19b9c0dffafb31a09c", + "version_major": 2, + "version_minor": 0 + }, "text/plain": [ - "

" + "Filter (num_proc=16): 0%| | 0/33558 [00:00:45: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n", + "/gladstone/theodoris/home/ctheodoris/Geneformer/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, @@ -764,47 +352,55 @@ "
\n", " \n", " \n", - " [834/834 01:33, Epoch 1/1]\n", + " [834/834 02:34, Epoch 1/1]\n", "
\n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
StepTraining Loss
1000.658900830.695400
1660.634600
2000.5854002490.540200
3000.4746003320.414800
4000.3466004150.298500
5000.2574004980.199100
6000.1858005810.133200
7000.1342006640.096300
8000.1145007470.078100
8300.068100

" @@ -817,96 +413,77 @@ "output_type": "display_data" }, { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-cbfcb02a16dd9d81.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-b151d664d8c68613.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-52266cf801a76344.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-5c7ceff44bad692c.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-81bcbb23e61bfc0c.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-e99a8c7eedd34769.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-6d7d5150907035d9.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-735b525b0abf0f74.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-9a47cf8290cd2f6b.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-56deb15eec02ca33.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-2aea162267b33f73.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-3bc7a169c841323d.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-1f67206928846c7a.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-88375062775280fb.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-bb45ebd2db699b53.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-fd6e4344cc2f8033.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-b8a9338cde5e5801.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-c013876f43a71ad7.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-148c328cb89da5c3.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-488b3d116a6d3b19.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-835e3e1538e24397.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-d176e8ab14f1ce28.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-3451fb13f869a5b0.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-56f270f895acc3ff.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-db497551e7a1e808.arrow\n" + "****** Validation split: 3/5 ******\n", + "\n" ] }, { "data": { - "image/png": "", + "application/vnd.jupyter.widget-view+json": { + "model_id": "93e9c12bc6e243b39224994add37ce21", + "version_major": 2, + "version_minor": 0 + }, "text/plain": [ - "

" + "Filter (num_proc=16): 0%| | 0/33558 [00:00:45: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n", + "/gladstone/theodoris/home/ctheodoris/Geneformer/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, @@ -917,47 +494,55 @@ "
\n", " \n", " \n", - " [834/834 01:33, Epoch 1/1]\n", + " [834/834 02:35, Epoch 1/1]\n", "
\n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "
StepTraining Loss
1000.645900830.708600
1660.656300
2490.553600
2000.5828003320.430600
3000.4617004150.300000
4000.3502004980.202900
5000.2628005810.144700
6000.1804006640.109900
7000.1409007470.096000
8000.1096008300.086700

" @@ -970,84 +555,77 @@ "output_type": "display_data" }, { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-8e85e7414566994a.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-e2704cdfc217c3e3.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-e213b038886d7cd4.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-d6c9eba9fe9ffafc.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-442181417de57bb6.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-0d8563be811b9c30.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-85690e0bf5863858.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-3bdda0a32e054f19.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-3abe0ffb170c29f0.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-b132478871346000.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-09db8f6a69301008.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-34ae599619e2ced6.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-c74b97625f913f63.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-228b6002a6690208.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-d644cc9c55478a2a.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-d3d097800ebd687c.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-2e536900ba2b88cc.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-0434f2adbb78af27.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-926036de71570e84.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-d7f012de8332824e.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-57a002ae2aa9ba42.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-0476d5fed302e1c5.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-69341790285e8ce2.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-ee190fa69ba78df3.arrow\n", - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-4b3dc879e23e8e63.arrow\n" + "****** Validation split: 4/5 ******\n", + "\n" ] }, { "data": { - "image/png": "", + "application/vnd.jupyter.widget-view+json": { + "model_id": "1a9cebe980534274907ae3858a706c37", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Filter (num_proc=16): 0%| | 0/33558 [00:00" + "Filter (num_proc=16): 0%| | 0/33558 [00:00:45: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n", + "/gladstone/theodoris/home/ctheodoris/Geneformer/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, @@ -1058,47 +636,55 @@ "

\n", " \n", " \n", - " [834/834 01:32, Epoch 1/1]\n", + " [834/834 02:35, Epoch 1/1]\n", "
\n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
StepTraining Loss
1000.660300830.697500
2000.5880001660.632000
3000.4654002490.524600
4000.3314003320.394300
5000.2411004150.264700
6000.1688004980.180100
7000.1366005810.128300
8000.1139006640.094200
7470.082200
8300.078500

" @@ -1111,1300 +697,530 @@ "output_type": "display_data" }, { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-c438e6f7f8463bbc.arrow\n" + "****** Validation split: 5/5 ******\n", + "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "6f8a9dd0a5754dec845c0022470a8c96", + "model_id": "455067153dc145cba4e3cfdc63f129cc", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" + "Filter (num_proc=16): 0%| | 0/33558 [00:00\n", + " \n", + " \n", + " [834/834 02:35, Epoch 1/1]\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
830.711400
1660.644000
2490.535900
3320.395400
4150.275400
4980.193600
5810.129300
6640.093300
7470.070000
8300.067100

" + ], "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" + "" ] }, "metadata": {}, "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, + } + ], + "source": [ + "# 6 layer Geneformer: https://huggingface.co./ctheodoris/Geneformer/blob/main/model.safetensors\n", + "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n", + " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n", + " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n", + " output_directory=output_dir,\n", + " output_prefix=output_prefix)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "11a1329b-4968-45f3-ac7a-2438b574404e", + "metadata": {}, + "outputs": [ { "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e103daf395794272989c209b32c12afc", - "version_major": 2, - "version_minor": 0 - }, "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" + "

" ] }, "metadata": {}, "output_type": "display_data" }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, { "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "81053043727a4c1dbe23304e5ad6282a", - "version_major": 2, - "version_minor": 0 - }, + "image/png": "", "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" + "
" ] }, "metadata": {}, "output_type": "display_data" }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, { "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "5d1d3f2835b74004b267d67d04c24663", - "version_major": 2, - "version_minor": 0 - }, "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" + "
" ] }, "metadata": {}, "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, + } + ], + "source": [ + "cc.plot_conf_mat(\n", + " conf_mat_dict={\"Geneformer\": all_metrics[\"conf_matrix\"]},\n", + " output_directory=output_dir,\n", + " output_prefix=output_prefix,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "edf6ffd9-8b84-4d31-8b39-11959140382f", + "metadata": {}, + "outputs": [ { "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "14f38354b0354bc187be9db34990fcce", - "version_major": 2, - "version_minor": 0 - }, + "image/png": "", "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" + "
" ] }, "metadata": {}, "output_type": "display_data" }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, { "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "4e3d47f0ecdc489ca34de778ebfb3021", - "version_major": 2, - "version_minor": 0 - }, "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" + "
" ] }, "metadata": {}, "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "5997f34a471f4a918fd32043fc519bb3", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "affe20b63e08414cb0863e1f6c1aad18", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "fca7f8cafa504738b7eaddd3f7b708fc", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "11f299f23b124674ab9e334bdbe09288", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "01a88ef05cb64f24adecfb5674265a02", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "2f88e6525cbd486c9f03491a04681283", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "8bb884df7370471d986c51c10431ba10", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "4b82e5fe600b4270bb6268e68f76d093", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "cd15c803ecc34a8d878df577ffd80252", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "246cac7b5a0b4fd799e7e2081badbdbf", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "fbc93f4256724314a5141ac29062bae9", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b38551b3ac134fef8aa0c6ea3b7fa2a0", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "16ddc360a6b64906bd3f1d1adcc94efe", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "44b3af87a1794fc09d00dd3743c4705d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "****** Crossval split: 4/4 ******\n", - "\n", - "Filtering training data\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "be5426abaf5b41ebb51e2567dd73b0a4", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Filtered 35%; 32428 remain\n", - "\n", - "Filtering evalation data\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ff5aad423e4f4bbab54518bc5f0fd028", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Filtered 53%; 23660 remain\n", - "\n", - "Labeling training data\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "78c25d0976854653be92baf65ca71158", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Labeling evaluation data\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c445de0805e145249f4647e5552292a2", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Labeling evaluation OOS data\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c553f188f56e47acafa77fab9cb2b21f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']\n", - "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", - "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", - "Some weights of BertForTokenClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['classifier.weight', 'classifier.bias']\n", - "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", - ":45: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [834/834 01:35, Epoch 1/1]\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
StepTraining Loss
1000.663500
2000.601800
3000.486200
4000.340400
5000.242700
6000.202300
7000.153600
8000.124400

" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0e1c475ab2ff4bfa8c65a24d587c8ad0", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "2ee8ff99342d4741a3f4ec4176b5d746", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "78a1a6af9439481ebe87731bb2d37c95", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "411ed284d33740eca1f0cef18df500a4", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "aafdf3014691426c9c6acca3834c45f2", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "5aa3add5de134f589eaab69087b66549", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "7d255e53e1c2408697da1fa08860c9c0", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "29b8945f64354ae1b840a1dc316dedbf", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "de251d1fba3d4a67893047ee8275d606", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "8928cf69ea8746b2bef14028c0c0274a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0c0c4e21626f4ab99ce0696ee9322e0c", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9e3499a2376d43bab0086cba34d1b522", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f33d4f879c294c6a8a6455b3692488d5", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "38dd78e3ebf44c2bad58f9576a525ab3", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b052e8b179584043945b49de9af31676", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e3e11781b4394db1a01454ef37a490f2", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "915efb0adfb44c5caa01cf213c3cd56b", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ceb10f0f87d044ebab534aefef5ec69c", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "31f4bd65079e4983b8a1937901cfbace", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ccb5be44b5494de8862488f82bf01741", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9da6bd7370db44889cab2fb81dcebe11", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, + } + ], + "source": [ + "cc.plot_roc(\n", + " roc_metric_dict={\"Geneformer\": all_metrics[\"all_roc_metrics\"]},\n", + " model_style_dict={\"Geneformer\": {\"color\": \"red\", \"linestyle\": \"-\"}},\n", + " title=\"Dosage-sensitive vs -insensitive factors\",\n", + " output_directory=output_dir,\n", + " output_prefix=output_prefix,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "d10ac27f-8d70-400e-8a00-d0b84c1d02b4", + "metadata": {}, + "outputs": [ { "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "12bddf69336d481fb0076dced187523c", - "version_major": 2, - "version_minor": 0 - }, "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, + "{'conf_matrix': Dosage-sensitive TFs Dosage-insensitive TFs\n", + " Dosage-sensitive TFs 61229.0 14801.0\n", + " Dosage-insensitive TFs 9094.0 73907.0,\n", + " 'macro_f1': [0.8489695337205987,\n", + " 0.8637730998133415,\n", + " 0.9122635701525341,\n", + " 0.8180200155972593,\n", + " 0.7913574275548942],\n", + " 'acc': [0.8544562281799618,\n", + " 0.8647275498539312,\n", + " 0.9122812348079727,\n", + " 0.8182044035899506,\n", + " 0.798060129740519],\n", + " 'all_roc_metrics': {'mean_tpr': array([0. , 0.29330305, 0.39824459, 0.48477052, 0.53910681,\n", + " 0.58654819, 0.62233428, 0.65499297, 0.68383714, 0.7105218 ,\n", + " 0.7331015 , 0.75404762, 0.77191402, 0.79007262, 0.80530801,\n", + " 0.81812243, 0.83182971, 0.84348565, 0.85308334, 0.86179954,\n", + " 0.87018186, 0.87841599, 0.88666193, 0.89398957, 0.90104605,\n", + " 0.90768847, 0.91468381, 0.92081589, 0.92687436, 0.93170239,\n", + " 0.93600138, 0.93963402, 0.9430781 , 0.94641134, 0.94881205,\n", + " 0.95143243, 0.95361201, 0.95556462, 0.95766077, 0.95966244,\n", + " 0.96118109, 0.96277551, 0.96448544, 0.96590662, 0.96726595,\n", + " 0.96852001, 0.96991619, 0.97113487, 0.9723888 , 0.97361378,\n", + " 0.97487929, 0.97591807, 0.97725326, 0.97856005, 0.97952476,\n", + " 0.98071045, 0.98164245, 0.98264028, 0.98393822, 0.9850845 ,\n", + " 0.98620898, 0.9872157 , 0.98857151, 0.98954745, 0.99058733,\n", + " 0.99138259, 0.99226871, 0.99306583, 0.99380789, 0.99461065,\n", + " 0.99527049, 0.99592002, 0.99655526, 0.99691174, 0.99757778,\n", + " 0.9978895 , 0.99816814, 0.99852539, 0.99874352, 0.99896924,\n", + " 0.99925024, 0.9993954 , 0.99949426, 0.99964604, 0.99974177,\n", + " 0.99977018, 0.9998233 , 0.99984802, 0.99990114, 0.99994688,\n", + " 0.99996108, 0.99997159, 1. , 1. , 1. ,\n", + " 1. , 1. , 1. , 1. , 1. ]),\n", + " 'mean_fpr': array([0. , 0.01010101, 0.02020202, 0.03030303, 0.04040404,\n", + " 0.05050505, 0.06060606, 0.07070707, 0.08080808, 0.09090909,\n", + " 0.1010101 , 0.11111111, 0.12121212, 0.13131313, 0.14141414,\n", + " 0.15151515, 0.16161616, 0.17171717, 0.18181818, 0.19191919,\n", + " 0.2020202 , 0.21212121, 0.22222222, 0.23232323, 0.24242424,\n", + " 0.25252525, 0.26262626, 0.27272727, 0.28282828, 0.29292929,\n", + " 0.3030303 , 0.31313131, 0.32323232, 0.33333333, 0.34343434,\n", + " 0.35353535, 0.36363636, 0.37373737, 0.38383838, 0.39393939,\n", + " 0.4040404 , 0.41414141, 0.42424242, 0.43434343, 0.44444444,\n", + " 0.45454545, 0.46464646, 0.47474747, 0.48484848, 0.49494949,\n", + " 0.50505051, 0.51515152, 0.52525253, 0.53535354, 0.54545455,\n", + " 0.55555556, 0.56565657, 0.57575758, 0.58585859, 0.5959596 ,\n", + " 0.60606061, 0.61616162, 0.62626263, 0.63636364, 0.64646465,\n", + " 0.65656566, 0.66666667, 0.67676768, 0.68686869, 0.6969697 ,\n", + " 0.70707071, 0.71717172, 0.72727273, 0.73737374, 0.74747475,\n", + " 0.75757576, 0.76767677, 0.77777778, 0.78787879, 0.7979798 ,\n", + " 0.80808081, 0.81818182, 0.82828283, 0.83838384, 0.84848485,\n", + " 0.85858586, 0.86868687, 0.87878788, 0.88888889, 0.8989899 ,\n", + " 0.90909091, 0.91919192, 0.92929293, 0.93939394, 0.94949495,\n", + " 0.95959596, 0.96969697, 0.97979798, 0.98989899, 1. ]),\n", + " 'all_roc_auc': [0.9373324264902606,\n", + " 0.9410936383111078,\n", + " 0.9635257667493496,\n", + " 0.8903987740960708,\n", + " 0.8781592994811886],\n", + " 'roc_auc': 0.9141830130444975,\n", + " 'roc_auc_sd': 0.03204329033266111}}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "all_metrics" + ] + }, + { + "cell_type": "markdown", + "id": "7007e45e-16c2-47a3-962c-92b9fe867bde", + "metadata": {}, + "source": [ + "### Train gene classifier with all data:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "6df82c21-937c-4563-ba6b-a52ce287f542", + "metadata": {}, + "outputs": [], + "source": [ + "import datetime\n", + "import pickle\n", + "from geneformer import Classifier\n", + "\n", + "current_date = datetime.datetime.now()\n", + "datestamp = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}{current_date.hour:02d}{current_date.minute:02d}{current_date.second:02d}\"\n", + "datestamp_min = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}\"\n", + "\n", + "\n", + "output_prefix = \"tf_dosage_sens_alldata\"\n", + "output_dir = f\"/path/to/output_dir/{datestamp}\"\n", + "!mkdir $output_dir" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "f031131c-54fd-4ad1-a925-bf0846cc3235", + "metadata": {}, + "outputs": [], + "source": [ + "# Example input_data_file: https://huggingface.co./datasets/ctheodoris/Genecorpus-30M/blob/main/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sensitivity_TFs.pickle\n", + "with open(\"/path/to/dosage_sensitivity_TFs.pickle\", \"rb\") as fp:\n", + " gene_class_dict = pickle.load(fp)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "cd27b15c-52d4-46a6-af8c-812c8731f82c", + "metadata": {}, + "outputs": [ { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "\n" + "Hyperparameter tuning is highly recommended for optimal results. No training_args provided; using default hyperparameters.\n" ] - }, + } + ], + "source": [ + "cc = Classifier(classifier=\"gene\",\n", + " gene_class_dict = gene_class_dict,\n", + " max_ncells = 10_000,\n", + " freeze_layers = 4,\n", + " num_crossval_splits = 0,\n", + " forward_batch_size=200,\n", + " nproc=16)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "3d542bda-fbab-4d63-ab58-00d4caa996b9", + "metadata": {}, + "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b89b616cd8064d248b37cc642a09b9bf", + "model_id": "7f77eaec105642b199a9e797fccdbf4b", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" + "Saving the dataset (0/1 shards): 0%| | 0/33558 [00:00\n", + " \n", + " \n", + " [834/834 02:35, Epoch 1/1]\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
830.700600
1660.643100
2490.544700
3320.412900
4150.298600
4980.205700
5810.138900
6640.103200
7470.090000
8300.083100

" + ], "text/plain": [ - "

" + "" ] }, "metadata": {}, "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "[0.24272061700106187, 0.1890124629743475, 0.1665455764824233, 0.212820656122506, 0.18890068741966132]\n" - ] } ], "source": [ - "# cross-validate gene classifier\n", - "all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts \\\n", - " = cross_validate(subsampled_train_dataset, targets, labels, nsplits, subsample_size, training_args, freeze_layers, training_output_dir, 1)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [], - "source": [ - "# bundle data for plotting\n", - "bundled_data = []\n", - "bundled_data += [(roc_auc, roc_auc_sd, mean_fpr, mean_tpr, \"Geneformer\", \"red\")]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# plot ROC curve\n", - "plot_ROC(bundled_data, 'Dosage Sensitive vs Insensitive TFs')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# plot confusion matrix\n", - "classes_list = [\"Dosage Sensitive\", \"Dosage Insensitive\"]\n", - "plot_confusion_matrix(classes_list, confusion, \"Geneformer\")" + "# 6 layer Geneformer: https://huggingface.co./ctheodoris/Geneformer/blob/main/model.safetensors\n", + "trainer_test = cc.train_all_data(model_directory=\"/path/to/Geneformer\",\n", + " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n", + " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n", + " output_directory=output_dir,\n", + " output_prefix=output_prefix)" ] } ], @@ -2424,14 +1240,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" - }, - "vscode": { - "interpreter": { - "hash": "eba1599a1f7e611c14c87ccff6793920aa63510b01fc0e229d6dd014149b8829" - } + "version": "3.11.5" } }, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 5 }