Christina Theodoris
commited on
Commit
•
9e9cca9
1
Parent(s):
d6c634c
Add classifier module and examples
Browse files- docs/source/api.rst +8 -0
- docs/source/geneformer.classifier.rst +9 -0
- examples/cell_classification.ipynb +0 -0
- examples/extract_and_plot_cell_embeddings.ipynb +1 -1
- examples/gene_classification.ipynb +0 -0
- geneformer/__init__.py +18 -10
- geneformer/classifier.py +1203 -0
- geneformer/classifier_utils.py +406 -0
- geneformer/emb_extractor.py +3 -9
- geneformer/evaluation_utils.py +287 -0
- geneformer/in_silico_perturber_stats.py +6 -0
- geneformer/tokenizer.py +25 -14
docs/source/api.rst
CHANGED
@@ -9,6 +9,14 @@ Tokenizer
|
|
9 |
|
10 |
geneformer.tokenizer
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
Embedding Extractor
|
13 |
-------------------
|
14 |
|
|
|
9 |
|
10 |
geneformer.tokenizer
|
11 |
|
12 |
+
Classifier
|
13 |
+
----------
|
14 |
+
|
15 |
+
.. toctree::
|
16 |
+
:maxdepth: 1
|
17 |
+
|
18 |
+
geneformer.classifier
|
19 |
+
|
20 |
Embedding Extractor
|
21 |
-------------------
|
22 |
|
docs/source/geneformer.classifier.rst
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
geneformer.classifier
|
2 |
+
=====================
|
3 |
+
|
4 |
+
.. automodule:: geneformer.classifier
|
5 |
+
:members:
|
6 |
+
:undoc-members:
|
7 |
+
:show-inheritance:
|
8 |
+
:exclude-members:
|
9 |
+
validate_options
|
examples/cell_classification.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
examples/extract_and_plot_cell_embeddings.ipynb
CHANGED
@@ -129,7 +129,7 @@
|
|
129 |
"name": "python",
|
130 |
"nbconvert_exporter": "python",
|
131 |
"pygments_lexer": "ipython3",
|
132 |
-
"version": "3.
|
133 |
}
|
134 |
},
|
135 |
"nbformat": 4,
|
|
|
129 |
"name": "python",
|
130 |
"nbconvert_exporter": "python",
|
131 |
"pygments_lexer": "ipython3",
|
132 |
+
"version": "3.11.5"
|
133 |
}
|
134 |
},
|
135 |
"nbformat": 4,
|
examples/gene_classification.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
geneformer/__init__.py
CHANGED
@@ -1,12 +1,20 @@
|
|
1 |
-
|
2 |
-
from . import
|
3 |
-
from . import
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
from .emb_extractor import EmbExtractor
|
11 |
from .in_silico_perturber import InSilicoPerturber
|
12 |
-
from .in_silico_perturber_stats import InSilicoPerturberStats
|
|
|
|
|
|
1 |
+
# ruff: noqa: F401
|
2 |
+
from . import classifier # noqa
|
3 |
+
from . import (
|
4 |
+
collator_for_classification,
|
5 |
+
emb_extractor,
|
6 |
+
in_silico_perturber,
|
7 |
+
in_silico_perturber_stats,
|
8 |
+
pretrainer,
|
9 |
+
tokenizer,
|
10 |
+
)
|
11 |
+
from .classifier import Classifier
|
12 |
+
from .collator_for_classification import (
|
13 |
+
DataCollatorForCellClassification,
|
14 |
+
DataCollatorForGeneClassification,
|
15 |
+
)
|
16 |
from .emb_extractor import EmbExtractor
|
17 |
from .in_silico_perturber import InSilicoPerturber
|
18 |
+
from .in_silico_perturber_stats import InSilicoPerturberStats
|
19 |
+
from .pretrainer import GeneformerPretrainer
|
20 |
+
from .tokenizer import TranscriptomeTokenizer
|
geneformer/classifier.py
ADDED
@@ -0,0 +1,1203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Geneformer classifier.
|
3 |
+
|
4 |
+
**Input data:**
|
5 |
+
|
6 |
+
Cell state classifier:
|
7 |
+
| Single-cell transcriptomes as Geneformer rank value encodings with cell state labels
|
8 |
+
| in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py)
|
9 |
+
|
10 |
+
Gene classifier:
|
11 |
+
| Dictionary in format {Gene_label: list(genes)} for gene labels
|
12 |
+
| and single-cell transcriptomes as Geneformer rank value encodings
|
13 |
+
| in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py)
|
14 |
+
|
15 |
+
**Usage:**
|
16 |
+
|
17 |
+
.. code-block :: python
|
18 |
+
|
19 |
+
>>> from geneformer import Classifier
|
20 |
+
>>> cc = Classifier(classifier="cell", # example of cell state classifier
|
21 |
+
... cell_state_dict={"state_key": "disease", "states": "all"},
|
22 |
+
... filter_data={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]},
|
23 |
+
... training_args=training_args,
|
24 |
+
... freeze_layers = 2,
|
25 |
+
... num_crossval_splits = 1,
|
26 |
+
... forward_batch_size=200,
|
27 |
+
... nproc=16)
|
28 |
+
>>> cc.prepare_data(input_data_file="path/to/input_data",
|
29 |
+
... output_directory="path/to/output_directory",
|
30 |
+
... output_prefix="output_prefix")
|
31 |
+
>>> all_metrics = cc.validate(model_directory="path/to/model",
|
32 |
+
... prepared_input_data_file=f"path/to/output_directory/{output_prefix}_labeled.dataset",
|
33 |
+
... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
|
34 |
+
... output_directory="path/to/output_directory",
|
35 |
+
... output_prefix="output_prefix",
|
36 |
+
... predict=True)
|
37 |
+
>>> cc.plot_conf_mat(conf_mat_dict={"Geneformer": all_metrics["conf_matrix"]},
|
38 |
+
... output_directory="path/to/output_directory",
|
39 |
+
... output_prefix="output_prefix",
|
40 |
+
... custom_class_order=["healthy","disease1","disease2"])
|
41 |
+
>>> cc.plot_predictions(predictions_file=f"path/to/output_directory/datestamp_geneformer_cellClassifier_{output_prefix}/ksplit1/predictions.pkl",
|
42 |
+
... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
|
43 |
+
... title="disease",
|
44 |
+
... output_directory="path/to/output_directory",
|
45 |
+
... output_prefix="output_prefix",
|
46 |
+
... custom_class_order=["healthy","disease1","disease2"])
|
47 |
+
"""
|
48 |
+
|
49 |
+
import datetime
|
50 |
+
import logging
|
51 |
+
import os
|
52 |
+
import pickle
|
53 |
+
import subprocess
|
54 |
+
from pathlib import Path
|
55 |
+
|
56 |
+
import numpy as np
|
57 |
+
import pandas as pd
|
58 |
+
import seaborn as sns
|
59 |
+
from sklearn.model_selection import StratifiedKFold
|
60 |
+
from tqdm.auto import tqdm, trange
|
61 |
+
from transformers import Trainer
|
62 |
+
from transformers.training_args import TrainingArguments
|
63 |
+
|
64 |
+
from . import DataCollatorForCellClassification, DataCollatorForGeneClassification
|
65 |
+
from . import classifier_utils as cu
|
66 |
+
from . import evaluation_utils as eu
|
67 |
+
from . import perturber_utils as pu
|
68 |
+
from .tokenizer import TOKEN_DICTIONARY_FILE
|
69 |
+
|
70 |
+
sns.set()
|
71 |
+
|
72 |
+
|
73 |
+
logger = logging.getLogger(__name__)
|
74 |
+
|
75 |
+
|
76 |
+
class Classifier:
|
77 |
+
valid_option_dict = {
|
78 |
+
"classifier": {"cell", "gene"},
|
79 |
+
"cell_state_dict": {None, dict},
|
80 |
+
"gene_class_dict": {None, dict},
|
81 |
+
"filter_data": {None, dict},
|
82 |
+
"rare_threshold": {int, float},
|
83 |
+
"max_ncells": {None, int},
|
84 |
+
"max_ncells_per_class": {None, int},
|
85 |
+
"training_args": {None, dict},
|
86 |
+
"freeze_layers": {int},
|
87 |
+
"num_crossval_splits": {0, 1, 5},
|
88 |
+
"eval_size": {int, float},
|
89 |
+
"no_eval": {bool},
|
90 |
+
"stratify_splits_col": {None, str},
|
91 |
+
"forward_batch_size": {int},
|
92 |
+
"nproc": {int},
|
93 |
+
}
|
94 |
+
|
95 |
+
def __init__(
|
96 |
+
self,
|
97 |
+
classifier=None,
|
98 |
+
cell_state_dict=None,
|
99 |
+
gene_class_dict=None,
|
100 |
+
filter_data=None,
|
101 |
+
rare_threshold=0,
|
102 |
+
max_ncells=None,
|
103 |
+
max_ncells_per_class=None,
|
104 |
+
training_args=None,
|
105 |
+
freeze_layers=0,
|
106 |
+
num_crossval_splits=1,
|
107 |
+
eval_size=0.2,
|
108 |
+
stratify_splits_col=None,
|
109 |
+
no_eval=False,
|
110 |
+
forward_batch_size=100,
|
111 |
+
nproc=4,
|
112 |
+
):
|
113 |
+
"""
|
114 |
+
Initialize Geneformer classifier.
|
115 |
+
|
116 |
+
**Parameters:**
|
117 |
+
|
118 |
+
classifier : {"cell", "gene"}
|
119 |
+
| Whether to fine-tune a cell state or gene classifier.
|
120 |
+
cell_state_dict : None, dict
|
121 |
+
| Cell states to fine-tune model to distinguish.
|
122 |
+
| Two-item dictionary with keys: state_key and states
|
123 |
+
| state_key: key specifying name of column in .dataset that defines the states to model
|
124 |
+
| states: list of values in the state_key column that specifies the states to model
|
125 |
+
| Alternatively, instead of a list of states, can specify "all" to use all states in that state key from input data.
|
126 |
+
| Of note, if using "all", states will be defined after data is filtered.
|
127 |
+
| Must have at least 2 states to model.
|
128 |
+
| For example: {"state_key": "disease",
|
129 |
+
| "states": ["nf", "hcm", "dcm"]}
|
130 |
+
| or
|
131 |
+
| {"state_key": "disease",
|
132 |
+
| "states": "all"}
|
133 |
+
gene_class_dict : None, dict
|
134 |
+
| Gene classes to fine-tune model to distinguish.
|
135 |
+
| Dictionary in format: {Gene_label_A: list(geneA1, geneA2, ...),
|
136 |
+
| Gene_label_B: list(geneB1, geneB2, ...)}
|
137 |
+
| Gene values should be Ensembl IDs.
|
138 |
+
filter_data : None, dict
|
139 |
+
| Default is to fine-tune with all input data.
|
140 |
+
| Otherwise, dictionary specifying .dataset column name and list of values to filter by.
|
141 |
+
rare_threshold : float
|
142 |
+
| Threshold below which rare cell states should be removed.
|
143 |
+
| For example, setting to 0.05 will remove cell states representing
|
144 |
+
| < 5% of the total cells from the cell state classifier's possible classes.
|
145 |
+
max_ncells : None, int
|
146 |
+
| Maximum number of cells to use for fine-tuning.
|
147 |
+
| Default is to fine-tune with all input data.
|
148 |
+
max_ncells_per_class : None, int
|
149 |
+
| Maximum number of cells per cell class to use for fine-tuning.
|
150 |
+
| Of note, will be applied after max_ncells above.
|
151 |
+
| (Only valid for cell classification.)
|
152 |
+
training_args : None, dict
|
153 |
+
| Training arguments for fine-tuning.
|
154 |
+
| If None, defaults will be inferred for 6 layer Geneformer.
|
155 |
+
| Otherwise, will use the Hugging Face defaults:
|
156 |
+
| https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
|
157 |
+
| Note: Hyperparameter tuning is highly recommended, rather than using defaults.
|
158 |
+
freeze_layers : int
|
159 |
+
| Number of layers to freeze from fine-tuning.
|
160 |
+
| 0: no layers will be frozen; 2: first two layers will be frozen; etc.
|
161 |
+
num_crossval_splits : {0, 1, 5}
|
162 |
+
| 0: train on all data without splitting
|
163 |
+
| 1: split data into train and eval sets by designated eval_size
|
164 |
+
| 5: split data into 5 folds of train and eval sets by designated eval_size
|
165 |
+
eval_size : None, float
|
166 |
+
| Proportion of data to hold out for evaluation (e.g. 0.2 if intending 80:20 train/eval split)
|
167 |
+
stratify_splits_col : None, str
|
168 |
+
| Name of column in .dataset to be used for stratified splitting.
|
169 |
+
| Proportion of each class in this column will be the same in the splits as in the original dataset.
|
170 |
+
no_eval : bool
|
171 |
+
| If True, will skip eval step and use all data for training.
|
172 |
+
| Otherwise, will perform eval during training.
|
173 |
+
forward_batch_size : int
|
174 |
+
| Batch size for forward pass (for evaluation, not training).
|
175 |
+
nproc : int
|
176 |
+
| Number of CPU processes to use.
|
177 |
+
|
178 |
+
"""
|
179 |
+
|
180 |
+
self.classifier = classifier
|
181 |
+
self.cell_state_dict = cell_state_dict
|
182 |
+
self.gene_class_dict = gene_class_dict
|
183 |
+
self.filter_data = filter_data
|
184 |
+
self.rare_threshold = rare_threshold
|
185 |
+
self.max_ncells = max_ncells
|
186 |
+
self.max_ncells_per_class = max_ncells_per_class
|
187 |
+
self.training_args = training_args
|
188 |
+
self.freeze_layers = freeze_layers
|
189 |
+
self.num_crossval_splits = num_crossval_splits
|
190 |
+
self.eval_size = eval_size
|
191 |
+
self.stratify_splits_col = stratify_splits_col
|
192 |
+
self.no_eval = no_eval
|
193 |
+
self.forward_batch_size = forward_batch_size
|
194 |
+
self.nproc = nproc
|
195 |
+
|
196 |
+
if self.training_args is None:
|
197 |
+
logger.warning(
|
198 |
+
"Hyperparameter tuning is highly recommended for optimal results. "
|
199 |
+
"No training_args provided; using default hyperparameters."
|
200 |
+
)
|
201 |
+
|
202 |
+
self.validate_options()
|
203 |
+
|
204 |
+
if self.filter_data is None:
|
205 |
+
self.filter_data = dict()
|
206 |
+
|
207 |
+
if self.classifier == "cell":
|
208 |
+
if self.cell_state_dict["states"] != "all":
|
209 |
+
self.filter_data[
|
210 |
+
self.cell_state_dict["state_key"]
|
211 |
+
] = self.cell_state_dict["states"]
|
212 |
+
|
213 |
+
# load token dictionary (Ensembl IDs:token)
|
214 |
+
with open(TOKEN_DICTIONARY_FILE, "rb") as f:
|
215 |
+
self.gene_token_dict = pickle.load(f)
|
216 |
+
|
217 |
+
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
218 |
+
|
219 |
+
# filter genes for gene classification for those in token dictionary
|
220 |
+
if self.classifier == "gene":
|
221 |
+
all_gene_class_values = set(pu.flatten_list(self.gene_class_dict.values()))
|
222 |
+
missing_genes = [
|
223 |
+
gene
|
224 |
+
for gene in all_gene_class_values
|
225 |
+
if gene not in self.gene_token_dict.keys()
|
226 |
+
]
|
227 |
+
if len(missing_genes) == len(all_gene_class_values):
|
228 |
+
logger.error(
|
229 |
+
"None of the provided genes to classify are in token dictionary."
|
230 |
+
)
|
231 |
+
raise
|
232 |
+
elif len(missing_genes) > 0:
|
233 |
+
logger.warning(
|
234 |
+
f"Genes to classify {missing_genes} are not in token dictionary."
|
235 |
+
)
|
236 |
+
self.gene_class_dict = {
|
237 |
+
k: set([self.gene_token_dict.get(gene) for gene in v])
|
238 |
+
for k, v in self.gene_class_dict.items()
|
239 |
+
}
|
240 |
+
empty_classes = []
|
241 |
+
for k, v in self.gene_class_dict.items():
|
242 |
+
if len(v) == 0:
|
243 |
+
empty_classes += [k]
|
244 |
+
if len(empty_classes) > 0:
|
245 |
+
logger.error(
|
246 |
+
f"Class(es) {empty_classes} did not contain any genes in the token dictionary."
|
247 |
+
)
|
248 |
+
raise
|
249 |
+
|
250 |
+
def validate_options(self):
|
251 |
+
# confirm arguments are within valid options and compatible with each other
|
252 |
+
for attr_name, valid_options in self.valid_option_dict.items():
|
253 |
+
attr_value = self.__dict__[attr_name]
|
254 |
+
if not isinstance(attr_value, (list, dict)):
|
255 |
+
if attr_value in valid_options:
|
256 |
+
continue
|
257 |
+
valid_type = False
|
258 |
+
for option in valid_options:
|
259 |
+
if (option in [int, float, list, dict, bool]) and isinstance(
|
260 |
+
attr_value, option
|
261 |
+
):
|
262 |
+
valid_type = True
|
263 |
+
break
|
264 |
+
if valid_type:
|
265 |
+
continue
|
266 |
+
logger.error(
|
267 |
+
f"Invalid option for {attr_name}. "
|
268 |
+
f"Valid options for {attr_name}: {valid_options}"
|
269 |
+
)
|
270 |
+
raise
|
271 |
+
|
272 |
+
if self.filter_data is not None:
|
273 |
+
for key, value in self.filter_data.items():
|
274 |
+
if not isinstance(value, list):
|
275 |
+
self.filter_data[key] = [value]
|
276 |
+
logger.warning(
|
277 |
+
"Values in filter_data dict must be lists. "
|
278 |
+
f"Changing {key} value to list ([{value}])."
|
279 |
+
)
|
280 |
+
|
281 |
+
if self.classifier == "cell":
|
282 |
+
if set(self.cell_state_dict.keys()) != set(["state_key", "states"]):
|
283 |
+
logger.error(
|
284 |
+
"Invalid keys for cell_state_dict. "
|
285 |
+
"The cell_state_dict should have only 2 keys: state_key and states"
|
286 |
+
)
|
287 |
+
raise
|
288 |
+
|
289 |
+
if self.cell_state_dict["states"] != "all":
|
290 |
+
if not isinstance(self.cell_state_dict["states"], list):
|
291 |
+
logger.error(
|
292 |
+
"States in cell_state_dict should be list of states to model."
|
293 |
+
)
|
294 |
+
raise
|
295 |
+
if len(self.cell_state_dict["states"]) < 2:
|
296 |
+
logger.error(
|
297 |
+
"States in cell_state_dict should contain at least 2 states to classify."
|
298 |
+
)
|
299 |
+
raise
|
300 |
+
|
301 |
+
if self.classifier == "gene":
|
302 |
+
if len(self.gene_class_dict.keys()) < 2:
|
303 |
+
logger.error(
|
304 |
+
"Gene_class_dict should contain at least 2 gene classes to classify."
|
305 |
+
)
|
306 |
+
raise
|
307 |
+
|
308 |
+
def prepare_data(
|
309 |
+
self,
|
310 |
+
input_data_file,
|
311 |
+
output_directory,
|
312 |
+
output_prefix,
|
313 |
+
split_id_dict=None,
|
314 |
+
test_size=0,
|
315 |
+
attr_to_split=None,
|
316 |
+
attr_to_balance=None,
|
317 |
+
max_trials=100,
|
318 |
+
pval_threshold=0.1,
|
319 |
+
):
|
320 |
+
"""
|
321 |
+
Prepare data for cell state or gene classification.
|
322 |
+
|
323 |
+
**Parameters**
|
324 |
+
|
325 |
+
input_data_file : Path
|
326 |
+
| Path to directory containing .dataset input
|
327 |
+
output_directory : Path
|
328 |
+
| Path to directory where prepared data will be saved
|
329 |
+
output_prefix : str
|
330 |
+
| Prefix for output file
|
331 |
+
split_id_dict : None, dict
|
332 |
+
| Dictionary of IDs for train and test splits
|
333 |
+
| Three-item dictionary with keys: attr_key, train, test
|
334 |
+
| attr_key: key specifying name of column in .dataset that contains the IDs for the data splits
|
335 |
+
| train: list of IDs in the attr_key column to include in the train split
|
336 |
+
| test: list of IDs in the attr_key column to include in the test split
|
337 |
+
| For example: {"attr_key": "individual",
|
338 |
+
| "train": ["patient1", "patient2", "patient3", "patient4"],
|
339 |
+
| "test": ["patient5", "patient6"]}
|
340 |
+
test_size : None, float
|
341 |
+
| Proportion of data to be saved separately and held out for test set
|
342 |
+
| (e.g. 0.2 if intending hold out 20%)
|
343 |
+
| The training set will be further split to train / validation in self.validate
|
344 |
+
| Note: only available for CellClassifiers
|
345 |
+
attr_to_split : None, str
|
346 |
+
| Key for attribute on which to split data while balancing potential confounders
|
347 |
+
| e.g. "patient_id" for splitting by patient while balancing other characteristics
|
348 |
+
| Note: only available for CellClassifiers
|
349 |
+
attr_to_balance : None, list
|
350 |
+
| List of attribute keys on which to balance data while splitting on attr_to_split
|
351 |
+
| e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
|
352 |
+
| Note: only available for CellClassifiers
|
353 |
+
max_trials : None, int
|
354 |
+
| Maximum number of trials of random splitting to try to achieve balanced other attributes
|
355 |
+
| If no split is found without significant (p<0.05) differences in other attributes, will select best
|
356 |
+
| Note: only available for CellClassifiers
|
357 |
+
pval_threshold : None, float
|
358 |
+
| P-value threshold to use for attribute balancing across splits
|
359 |
+
| E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
|
360 |
+
"""
|
361 |
+
|
362 |
+
# prepare data and labels for classification
|
363 |
+
data = pu.load_and_filter(self.filter_data, self.nproc, input_data_file)
|
364 |
+
|
365 |
+
if self.classifier == "cell":
|
366 |
+
if "label" in data.features:
|
367 |
+
logger.error(
|
368 |
+
"Column name 'label' must be reserved for class IDs. Please rename column."
|
369 |
+
)
|
370 |
+
raise
|
371 |
+
elif self.classifier == "gene":
|
372 |
+
if "labels" in data.features:
|
373 |
+
logger.error(
|
374 |
+
"Column name 'labels' must be reserved for class IDs. Please rename column."
|
375 |
+
)
|
376 |
+
raise
|
377 |
+
|
378 |
+
if self.classifier == "cell":
|
379 |
+
# remove cell states representing < rare_threshold of cells
|
380 |
+
data = cu.remove_rare(
|
381 |
+
data, self.rare_threshold, self.cell_state_dict["state_key"], self.nproc
|
382 |
+
)
|
383 |
+
# downsample max cells and max per class
|
384 |
+
data = cu.downsample_and_shuffle(
|
385 |
+
data, self.max_ncells, self.max_ncells_per_class, self.cell_state_dict
|
386 |
+
)
|
387 |
+
# rename cell state column to "label"
|
388 |
+
data = cu.rename_cols(data, self.cell_state_dict["state_key"])
|
389 |
+
|
390 |
+
# convert classes to numerical labels and save as id_class_dict
|
391 |
+
# of note, will label all genes in gene_class_dict
|
392 |
+
# if (cross-)validating, genes will be relabeled in column "labels" for each split
|
393 |
+
# at the time of training with Classifier.validate
|
394 |
+
data, id_class_dict = cu.label_classes(
|
395 |
+
self.classifier, data, self.gene_class_dict, self.nproc
|
396 |
+
)
|
397 |
+
|
398 |
+
# save id_class_dict for future reference
|
399 |
+
id_class_output_path = (
|
400 |
+
Path(output_directory) / f"{output_prefix}_id_class_dict"
|
401 |
+
).with_suffix(".pkl")
|
402 |
+
with open(id_class_output_path, "wb") as f:
|
403 |
+
pickle.dump(id_class_dict, f)
|
404 |
+
|
405 |
+
if split_id_dict is not None:
|
406 |
+
data_dict = dict()
|
407 |
+
data_dict["train"] = pu.filter_by_dict(
|
408 |
+
data, {split_id_dict["attr_key"]: split_id_dict["train"]}, self.nproc
|
409 |
+
)
|
410 |
+
data_dict["test"] = pu.filter_by_dict(
|
411 |
+
data, {split_id_dict["attr_key"]: split_id_dict["test"]}, self.nproc
|
412 |
+
)
|
413 |
+
train_data_output_path = (
|
414 |
+
Path(output_directory) / f"{output_prefix}_labeled_train"
|
415 |
+
).with_suffix(".dataset")
|
416 |
+
test_data_output_path = (
|
417 |
+
Path(output_directory) / f"{output_prefix}_labeled_test"
|
418 |
+
).with_suffix(".dataset")
|
419 |
+
data_dict["train"].save_to_disk(train_data_output_path)
|
420 |
+
data_dict["test"].save_to_disk(test_data_output_path)
|
421 |
+
elif (test_size is not None) and (self.classifier == "cell"):
|
422 |
+
if 1 > test_size > 0:
|
423 |
+
data_dict, balance_df = cu.balance_attr_splits(
|
424 |
+
data,
|
425 |
+
attr_to_split,
|
426 |
+
attr_to_balance,
|
427 |
+
test_size,
|
428 |
+
max_trials,
|
429 |
+
pval_threshold,
|
430 |
+
self.cell_state_dict["state_key"],
|
431 |
+
self.nproc,
|
432 |
+
)
|
433 |
+
balance_df.to_csv(
|
434 |
+
f"{output_directory}/{output_prefix}_train_test_balance_df.csv"
|
435 |
+
)
|
436 |
+
train_data_output_path = (
|
437 |
+
Path(output_directory) / f"{output_prefix}_labeled_train"
|
438 |
+
).with_suffix(".dataset")
|
439 |
+
test_data_output_path = (
|
440 |
+
Path(output_directory) / f"{output_prefix}_labeled_test"
|
441 |
+
).with_suffix(".dataset")
|
442 |
+
data_dict["train"].save_to_disk(train_data_output_path)
|
443 |
+
data_dict["test"].save_to_disk(test_data_output_path)
|
444 |
+
else:
|
445 |
+
data_output_path = (
|
446 |
+
Path(output_directory) / f"{output_prefix}_labeled"
|
447 |
+
).with_suffix(".dataset")
|
448 |
+
data.save_to_disk(data_output_path)
|
449 |
+
|
450 |
+
def train_all_data(
|
451 |
+
self,
|
452 |
+
model_directory,
|
453 |
+
prepared_input_data_file,
|
454 |
+
id_class_dict_file,
|
455 |
+
output_directory,
|
456 |
+
output_prefix,
|
457 |
+
save_eval_output=True,
|
458 |
+
):
|
459 |
+
"""
|
460 |
+
Train cell state or gene classifier using all data.
|
461 |
+
|
462 |
+
**Parameters**
|
463 |
+
|
464 |
+
model_directory : Path
|
465 |
+
| Path to directory containing model
|
466 |
+
prepared_input_data_file : Path
|
467 |
+
| Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data
|
468 |
+
id_class_dict_file : Path
|
469 |
+
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
470 |
+
| (dictionary of format: numerical IDs: class_labels)
|
471 |
+
output_directory : Path
|
472 |
+
| Path to directory where model and eval data will be saved
|
473 |
+
output_prefix : str
|
474 |
+
| Prefix for output files
|
475 |
+
save_eval_output : bool
|
476 |
+
| Whether to save cross-fold eval output
|
477 |
+
| Saves as pickle file of dictionary of eval metrics
|
478 |
+
|
479 |
+
**Output**
|
480 |
+
|
481 |
+
Returns trainer after fine-tuning with all data.
|
482 |
+
|
483 |
+
"""
|
484 |
+
|
485 |
+
##### Load data and prepare output directory #####
|
486 |
+
# load numerical id to class dictionary (id:class)
|
487 |
+
with open(id_class_dict_file, "rb") as f:
|
488 |
+
id_class_dict = pickle.load(f)
|
489 |
+
class_id_dict = {v: k for k, v in id_class_dict.items()}
|
490 |
+
|
491 |
+
# load previously filtered and prepared data
|
492 |
+
data = pu.load_and_filter(None, self.nproc, prepared_input_data_file)
|
493 |
+
data = data.shuffle(seed=42) # reshuffle in case users provide unshuffled data
|
494 |
+
|
495 |
+
# define output directory path
|
496 |
+
current_date = datetime.datetime.now()
|
497 |
+
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
498 |
+
if output_directory[-1:] != "/": # add slash for dir if not present
|
499 |
+
output_directory = output_directory + "/"
|
500 |
+
output_dir = f"{output_directory}{datestamp}_geneformer_{self.classifier}Classifier_{output_prefix}/"
|
501 |
+
subprocess.call(f"mkdir {output_dir}", shell=True)
|
502 |
+
|
503 |
+
# get number of classes for classifier
|
504 |
+
num_classes = cu.get_num_classes(id_class_dict)
|
505 |
+
|
506 |
+
if self.classifier == "gene":
|
507 |
+
targets = pu.flatten_list(self.gene_class_dict.values())
|
508 |
+
labels = pu.flatten_list(
|
509 |
+
[
|
510 |
+
[class_id_dict[label]] * len(targets)
|
511 |
+
for label, targets in self.gene_class_dict.items()
|
512 |
+
]
|
513 |
+
)
|
514 |
+
assert len(targets) == len(labels)
|
515 |
+
data = cu.prep_gene_classifier_all_data(
|
516 |
+
data, targets, labels, self.max_ncells, self.nproc
|
517 |
+
)
|
518 |
+
|
519 |
+
trainer = self.train_classifier(
|
520 |
+
model_directory, num_classes, data, None, output_dir
|
521 |
+
)
|
522 |
+
|
523 |
+
return trainer
|
524 |
+
|
525 |
+
def validate(
|
526 |
+
self,
|
527 |
+
model_directory,
|
528 |
+
prepared_input_data_file,
|
529 |
+
id_class_dict_file,
|
530 |
+
output_directory,
|
531 |
+
output_prefix,
|
532 |
+
split_id_dict=None,
|
533 |
+
attr_to_split=None,
|
534 |
+
attr_to_balance=None,
|
535 |
+
max_trials=100,
|
536 |
+
pval_threshold=0.1,
|
537 |
+
save_eval_output=True,
|
538 |
+
predict_eval=True,
|
539 |
+
predict_trainer=False,
|
540 |
+
):
|
541 |
+
"""
|
542 |
+
(Cross-)validate cell state or gene classifier.
|
543 |
+
|
544 |
+
**Parameters**
|
545 |
+
|
546 |
+
model_directory : Path
|
547 |
+
| Path to directory containing model
|
548 |
+
prepared_input_data_file : Path
|
549 |
+
| Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data
|
550 |
+
id_class_dict_file : Path
|
551 |
+
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
552 |
+
| (dictionary of format: numerical IDs: class_labels)
|
553 |
+
output_directory : Path
|
554 |
+
| Path to directory where model and eval data will be saved
|
555 |
+
output_prefix : str
|
556 |
+
| Prefix for output files
|
557 |
+
split_id_dict : None, dict
|
558 |
+
| Dictionary of IDs for train and eval splits
|
559 |
+
| Three-item dictionary with keys: attr_key, train, eval
|
560 |
+
| attr_key: key specifying name of column in .dataset that contains the IDs for the data splits
|
561 |
+
| train: list of IDs in the attr_key column to include in the train split
|
562 |
+
| eval: list of IDs in the attr_key column to include in the eval split
|
563 |
+
| For example: {"attr_key": "individual",
|
564 |
+
| "train": ["patient1", "patient2", "patient3", "patient4"],
|
565 |
+
| "eval": ["patient5", "patient6"]}
|
566 |
+
| Note: only available for CellClassifiers with 1-fold split (self.classifier="cell"; self.num_crossval_splits=1)
|
567 |
+
attr_to_split : None, str
|
568 |
+
| Key for attribute on which to split data while balancing potential confounders
|
569 |
+
| e.g. "patient_id" for splitting by patient while balancing other characteristics
|
570 |
+
| Note: only available for CellClassifiers with 1-fold split (self.classifier="cell"; self.num_crossval_splits=1)
|
571 |
+
attr_to_balance : None, list
|
572 |
+
| List of attribute keys on which to balance data while splitting on attr_to_split
|
573 |
+
| e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
|
574 |
+
max_trials : None, int
|
575 |
+
| Maximum number of trials of random splitting to try to achieve balanced other attribute
|
576 |
+
| If no split is found without significant (p < pval_threshold) differences in other attributes, will select best
|
577 |
+
pval_threshold : None, float
|
578 |
+
| P-value threshold to use for attribute balancing across splits
|
579 |
+
| E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
|
580 |
+
save_eval_output : bool
|
581 |
+
| Whether to save cross-fold eval output
|
582 |
+
| Saves as pickle file of dictionary of eval metrics
|
583 |
+
predict_eval : bool
|
584 |
+
| Whether or not to save eval predictions
|
585 |
+
| Saves as a pickle file of self.evaluate predictions
|
586 |
+
predict_trainer : bool
|
587 |
+
| Whether or not to save eval predictions from trainer
|
588 |
+
| Saves as a pickle file of trainer predictions
|
589 |
+
"""
|
590 |
+
|
591 |
+
if self.num_crossval_splits == 0:
|
592 |
+
logger.error("num_crossval_splits must be 1 or 5 to validate.")
|
593 |
+
raise
|
594 |
+
|
595 |
+
# ensure number of genes in each class is > 5 if validating model
|
596 |
+
if self.classifier == "gene":
|
597 |
+
insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5]
|
598 |
+
if (self.num_crossval_splits > 0) and (len(insuff_classes) > 0):
|
599 |
+
logger.error(
|
600 |
+
f"Insufficient # of members in class(es) {insuff_classes} to (cross-)validate."
|
601 |
+
)
|
602 |
+
raise
|
603 |
+
|
604 |
+
##### Load data and prepare output directory #####
|
605 |
+
# load numerical id to class dictionary (id:class)
|
606 |
+
with open(id_class_dict_file, "rb") as f:
|
607 |
+
id_class_dict = pickle.load(f)
|
608 |
+
class_id_dict = {v: k for k, v in id_class_dict.items()}
|
609 |
+
|
610 |
+
# load previously filtered and prepared data
|
611 |
+
data = pu.load_and_filter(None, self.nproc, prepared_input_data_file)
|
612 |
+
data = data.shuffle(seed=42) # reshuffle in case users provide unshuffled data
|
613 |
+
|
614 |
+
# define output directory path
|
615 |
+
current_date = datetime.datetime.now()
|
616 |
+
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
617 |
+
if output_directory[-1:] != "/": # add slash for dir if not present
|
618 |
+
output_directory = output_directory + "/"
|
619 |
+
output_dir = f"{output_directory}{datestamp}_geneformer_{self.classifier}Classifier_{output_prefix}/"
|
620 |
+
subprocess.call(f"mkdir {output_dir}", shell=True)
|
621 |
+
|
622 |
+
# get number of classes for classifier
|
623 |
+
num_classes = cu.get_num_classes(id_class_dict)
|
624 |
+
|
625 |
+
##### (Cross-)validate the model #####
|
626 |
+
results = []
|
627 |
+
all_conf_mat = np.zeros((num_classes, num_classes))
|
628 |
+
iteration_num = 1
|
629 |
+
if self.classifier == "cell":
|
630 |
+
for i in trange(self.num_crossval_splits):
|
631 |
+
print(
|
632 |
+
f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
|
633 |
+
)
|
634 |
+
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
|
635 |
+
if self.num_crossval_splits == 1:
|
636 |
+
# single 1-eval_size:eval_size split
|
637 |
+
if split_id_dict is not None:
|
638 |
+
data_dict = dict()
|
639 |
+
data_dict["train"] = pu.filter_by_dict(
|
640 |
+
data,
|
641 |
+
{split_id_dict["attr_key"]: split_id_dict["train"]},
|
642 |
+
self.nproc,
|
643 |
+
)
|
644 |
+
data_dict["test"] = pu.filter_by_dict(
|
645 |
+
data,
|
646 |
+
{split_id_dict["attr_key"]: split_id_dict["eval"]},
|
647 |
+
self.nproc,
|
648 |
+
)
|
649 |
+
elif attr_to_split is not None:
|
650 |
+
data_dict, balance_df = cu.balance_attr_splits(
|
651 |
+
data,
|
652 |
+
attr_to_split,
|
653 |
+
attr_to_balance,
|
654 |
+
self.eval_size,
|
655 |
+
max_trials,
|
656 |
+
pval_threshold,
|
657 |
+
self.cell_state_dict["state_key"],
|
658 |
+
self.nproc,
|
659 |
+
)
|
660 |
+
|
661 |
+
balance_df.to_csv(
|
662 |
+
f"{output_dir}/{output_prefix}_train_valid_balance_df.csv"
|
663 |
+
)
|
664 |
+
else:
|
665 |
+
data_dict = data.train_test_split(
|
666 |
+
test_size=self.eval_size,
|
667 |
+
stratify_by_column=self.stratify_splits_col,
|
668 |
+
seed=42,
|
669 |
+
)
|
670 |
+
train_data = data_dict["train"]
|
671 |
+
eval_data = data_dict["test"]
|
672 |
+
else:
|
673 |
+
# 5-fold cross-validate
|
674 |
+
num_cells = len(data)
|
675 |
+
fifth_cells = num_cells * 0.2
|
676 |
+
num_eval = min((self.eval_size * num_cells), fifth_cells)
|
677 |
+
start = i * fifth_cells
|
678 |
+
end = start + num_eval
|
679 |
+
eval_indices = [j for j in range(start, end)]
|
680 |
+
train_indices = [
|
681 |
+
j for j in range(num_cells) if j not in eval_indices
|
682 |
+
]
|
683 |
+
eval_data = data.select(eval_indices)
|
684 |
+
train_data = data.select(train_indices)
|
685 |
+
trainer = self.train_classifier(
|
686 |
+
model_directory,
|
687 |
+
num_classes,
|
688 |
+
train_data,
|
689 |
+
eval_data,
|
690 |
+
ksplit_output_dir,
|
691 |
+
predict_trainer,
|
692 |
+
)
|
693 |
+
result = self.evaluate_model(
|
694 |
+
trainer.model,
|
695 |
+
num_classes,
|
696 |
+
id_class_dict,
|
697 |
+
eval_data,
|
698 |
+
predict_eval,
|
699 |
+
ksplit_output_dir,
|
700 |
+
output_prefix,
|
701 |
+
)
|
702 |
+
results += [result]
|
703 |
+
all_conf_mat = all_conf_mat + result["conf_mat"]
|
704 |
+
iteration_num = iteration_num + 1
|
705 |
+
|
706 |
+
elif self.classifier == "gene":
|
707 |
+
# set up (cross-)validation splits
|
708 |
+
targets = pu.flatten_list(self.gene_class_dict.values())
|
709 |
+
labels = pu.flatten_list(
|
710 |
+
[
|
711 |
+
[class_id_dict[label]] * len(targets)
|
712 |
+
for label, targets in self.gene_class_dict.items()
|
713 |
+
]
|
714 |
+
)
|
715 |
+
assert len(targets) == len(labels)
|
716 |
+
n_splits = int(1 / self.eval_size)
|
717 |
+
skf = StratifiedKFold(n_splits=n_splits, random_state=0, shuffle=True)
|
718 |
+
# (Cross-)validate
|
719 |
+
for train_index, eval_index in tqdm(skf.split(targets, labels)):
|
720 |
+
print(
|
721 |
+
f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
|
722 |
+
)
|
723 |
+
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
|
724 |
+
# filter data for examples containing classes for this split
|
725 |
+
# subsample to max_ncells and relabel data in column "labels"
|
726 |
+
train_data, eval_data = cu.prep_gene_classifier_split(
|
727 |
+
data,
|
728 |
+
targets,
|
729 |
+
labels,
|
730 |
+
train_index,
|
731 |
+
eval_index,
|
732 |
+
self.max_ncells,
|
733 |
+
iteration_num,
|
734 |
+
self.nproc,
|
735 |
+
)
|
736 |
+
|
737 |
+
trainer = self.train_classifier(
|
738 |
+
model_directory,
|
739 |
+
num_classes,
|
740 |
+
train_data,
|
741 |
+
eval_data,
|
742 |
+
ksplit_output_dir,
|
743 |
+
predict_trainer,
|
744 |
+
)
|
745 |
+
result = self.evaluate_model(
|
746 |
+
trainer.model,
|
747 |
+
num_classes,
|
748 |
+
id_class_dict,
|
749 |
+
eval_data,
|
750 |
+
predict_eval,
|
751 |
+
ksplit_output_dir,
|
752 |
+
output_prefix,
|
753 |
+
)
|
754 |
+
results += [result]
|
755 |
+
all_conf_mat = all_conf_mat + result["conf_mat"]
|
756 |
+
# break after 1 or 5 splits, each with train/eval proportions dictated by eval_size
|
757 |
+
if iteration_num == self.num_crossval_splits:
|
758 |
+
break
|
759 |
+
iteration_num = iteration_num + 1
|
760 |
+
|
761 |
+
all_conf_mat_df = pd.DataFrame(
|
762 |
+
all_conf_mat, columns=id_class_dict.values(), index=id_class_dict.values()
|
763 |
+
)
|
764 |
+
all_metrics = {
|
765 |
+
"conf_matrix": all_conf_mat_df,
|
766 |
+
"macro_f1": [result["macro_f1"] for result in results],
|
767 |
+
"acc": [result["acc"] for result in results],
|
768 |
+
}
|
769 |
+
all_roc_metrics = None # roc metrics not reported for multiclass
|
770 |
+
if num_classes == 2:
|
771 |
+
mean_fpr = np.linspace(0, 1, 100)
|
772 |
+
all_tpr = [result["roc_metrics"]["interp_tpr"] for result in results]
|
773 |
+
all_roc_auc = [result["roc_metrics"]["auc"] for result in results]
|
774 |
+
all_tpr_wt = [result["roc_metrics"]["tpr_wt"] for result in results]
|
775 |
+
mean_tpr, roc_auc, roc_auc_sd = eu.get_cross_valid_roc_metrics(
|
776 |
+
all_tpr, all_roc_auc, all_tpr_wt
|
777 |
+
)
|
778 |
+
all_roc_metrics = {
|
779 |
+
"mean_tpr": mean_tpr,
|
780 |
+
"mean_fpr": mean_fpr,
|
781 |
+
"all_roc_auc": all_roc_auc,
|
782 |
+
"roc_auc": roc_auc,
|
783 |
+
"roc_auc_sd": roc_auc_sd,
|
784 |
+
}
|
785 |
+
all_metrics["all_roc_metrics"] = all_roc_metrics
|
786 |
+
if save_eval_output is True:
|
787 |
+
eval_metrics_output_path = (
|
788 |
+
Path(output_dir) / f"{output_prefix}_eval_metrics_dict"
|
789 |
+
).with_suffix(".pkl")
|
790 |
+
with open(eval_metrics_output_path, "wb") as f:
|
791 |
+
pickle.dump(all_metrics, f)
|
792 |
+
|
793 |
+
return all_metrics
|
794 |
+
|
795 |
+
def train_classifier(
|
796 |
+
self,
|
797 |
+
model_directory,
|
798 |
+
num_classes,
|
799 |
+
train_data,
|
800 |
+
eval_data,
|
801 |
+
output_directory,
|
802 |
+
predict=False,
|
803 |
+
):
|
804 |
+
"""
|
805 |
+
Fine-tune model for cell state or gene classification.
|
806 |
+
|
807 |
+
**Parameters**
|
808 |
+
|
809 |
+
model_directory : Path
|
810 |
+
| Path to directory containing model
|
811 |
+
num_classes : int
|
812 |
+
| Number of classes for classifier
|
813 |
+
train_data : Dataset
|
814 |
+
| Loaded training .dataset input
|
815 |
+
| For cell classifier, labels in column "label".
|
816 |
+
| For gene classifier, labels in column "labels".
|
817 |
+
eval_data : None, Dataset
|
818 |
+
| (Optional) Loaded evaluation .dataset input
|
819 |
+
| For cell classifier, labels in column "label".
|
820 |
+
| For gene classifier, labels in column "labels".
|
821 |
+
output_directory : Path
|
822 |
+
| Path to directory where fine-tuned model will be saved
|
823 |
+
predict : bool
|
824 |
+
| Whether or not to save eval predictions from trainer
|
825 |
+
"""
|
826 |
+
|
827 |
+
##### Validate and prepare data #####
|
828 |
+
train_data, eval_data = cu.validate_and_clean_cols(
|
829 |
+
train_data, eval_data, self.classifier
|
830 |
+
)
|
831 |
+
|
832 |
+
if (self.no_eval is True) and (eval_data is not None):
|
833 |
+
logger.warning(
|
834 |
+
"no_eval set to True; model will be trained without evaluation."
|
835 |
+
)
|
836 |
+
eval_data = None
|
837 |
+
|
838 |
+
if (self.classifier == "gene") and (predict is True):
|
839 |
+
logger.warning(
|
840 |
+
"Predictions during training not currently available for gene classifiers; setting predict to False."
|
841 |
+
)
|
842 |
+
predict = False
|
843 |
+
|
844 |
+
# ensure not overwriting previously saved model
|
845 |
+
saved_model_test = os.path.join(output_directory, "pytorch_model.bin")
|
846 |
+
if os.path.isfile(saved_model_test) is True:
|
847 |
+
logger.error("Model already saved to this designated output directory.")
|
848 |
+
raise
|
849 |
+
# make output directory
|
850 |
+
subprocess.call(f"mkdir {output_directory}", shell=True)
|
851 |
+
|
852 |
+
##### Load model and training args #####
|
853 |
+
if self.classifier == "cell":
|
854 |
+
model_type = "CellClassifier"
|
855 |
+
elif self.classifier == "gene":
|
856 |
+
model_type = "GeneClassifier"
|
857 |
+
model = pu.load_model(model_type, num_classes, model_directory, "train")
|
858 |
+
|
859 |
+
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
860 |
+
model, self.classifier, train_data, output_directory
|
861 |
+
)
|
862 |
+
|
863 |
+
if self.training_args is not None:
|
864 |
+
def_training_args.update(self.training_args)
|
865 |
+
logging_steps = round(
|
866 |
+
len(train_data) / def_training_args["per_device_train_batch_size"] / 10
|
867 |
+
)
|
868 |
+
def_training_args["logging_steps"] = logging_steps
|
869 |
+
def_training_args["output_dir"] = output_directory
|
870 |
+
if eval_data is None:
|
871 |
+
def_training_args["evaluation_strategy"] = "no"
|
872 |
+
def_training_args["load_best_model_at_end"] = False
|
873 |
+
training_args_init = TrainingArguments(**def_training_args)
|
874 |
+
|
875 |
+
if self.freeze_layers is not None:
|
876 |
+
def_freeze_layers = self.freeze_layers
|
877 |
+
|
878 |
+
if def_freeze_layers > 0:
|
879 |
+
modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]
|
880 |
+
for module in modules_to_freeze:
|
881 |
+
for param in module.parameters():
|
882 |
+
param.requires_grad = False
|
883 |
+
|
884 |
+
##### Fine-tune the model #####
|
885 |
+
# define the data collator
|
886 |
+
if self.classifier == "cell":
|
887 |
+
data_collator = DataCollatorForCellClassification()
|
888 |
+
elif self.classifier == "gene":
|
889 |
+
data_collator = DataCollatorForGeneClassification()
|
890 |
+
|
891 |
+
# create the trainer
|
892 |
+
trainer = Trainer(
|
893 |
+
model=model,
|
894 |
+
args=training_args_init,
|
895 |
+
data_collator=data_collator,
|
896 |
+
train_dataset=train_data,
|
897 |
+
eval_dataset=eval_data,
|
898 |
+
compute_metrics=cu.compute_metrics,
|
899 |
+
)
|
900 |
+
|
901 |
+
# train the classifier
|
902 |
+
trainer.train()
|
903 |
+
trainer.save_model(output_directory)
|
904 |
+
if predict is True:
|
905 |
+
# make eval predictions and save predictions and metrics
|
906 |
+
predictions = trainer.predict(eval_data)
|
907 |
+
prediction_output_path = f"{output_directory}/predictions.pkl"
|
908 |
+
with open(prediction_output_path, "wb") as f:
|
909 |
+
pickle.dump(predictions, f)
|
910 |
+
trainer.save_metrics("eval", predictions.metrics)
|
911 |
+
return trainer
|
912 |
+
|
913 |
+
def evaluate_model(
|
914 |
+
self,
|
915 |
+
model,
|
916 |
+
num_classes,
|
917 |
+
id_class_dict,
|
918 |
+
eval_data,
|
919 |
+
predict=False,
|
920 |
+
output_directory=None,
|
921 |
+
output_prefix=None,
|
922 |
+
):
|
923 |
+
"""
|
924 |
+
Evaluate the fine-tuned model.
|
925 |
+
|
926 |
+
**Parameters**
|
927 |
+
|
928 |
+
model : nn.Module
|
929 |
+
| Loaded fine-tuned model (e.g. trainer.model)
|
930 |
+
num_classes : int
|
931 |
+
| Number of classes for classifier
|
932 |
+
id_class_dict : dict
|
933 |
+
| Loaded _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
934 |
+
| (dictionary of format: numerical IDs: class_labels)
|
935 |
+
eval_data : Dataset
|
936 |
+
| Loaded evaluation .dataset input
|
937 |
+
predict : bool
|
938 |
+
| Whether or not to save eval predictions
|
939 |
+
output_directory : Path
|
940 |
+
| Path to directory where eval data will be saved
|
941 |
+
output_prefix : str
|
942 |
+
| Prefix for output files
|
943 |
+
"""
|
944 |
+
|
945 |
+
##### Evaluate the model #####
|
946 |
+
labels = id_class_dict.keys()
|
947 |
+
y_pred, y_true, logits_list = eu.classifier_predict(
|
948 |
+
model, self.classifier, eval_data, self.forward_batch_size
|
949 |
+
)
|
950 |
+
conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
|
951 |
+
y_pred, y_true, logits_list, num_classes, labels
|
952 |
+
)
|
953 |
+
if predict is True:
|
954 |
+
pred_dict = {
|
955 |
+
"pred_ids": y_pred,
|
956 |
+
"label_ids": y_true,
|
957 |
+
"predictions": logits_list,
|
958 |
+
}
|
959 |
+
pred_dict_output_path = (
|
960 |
+
Path(output_directory) / f"{output_prefix}_pred_dict"
|
961 |
+
).with_suffix(".pkl")
|
962 |
+
with open(pred_dict_output_path, "wb") as f:
|
963 |
+
pickle.dump(pred_dict, f)
|
964 |
+
return {
|
965 |
+
"conf_mat": conf_mat,
|
966 |
+
"macro_f1": macro_f1,
|
967 |
+
"acc": acc,
|
968 |
+
"roc_metrics": roc_metrics,
|
969 |
+
}
|
970 |
+
|
971 |
+
def evaluate_saved_model(
|
972 |
+
self,
|
973 |
+
model_directory,
|
974 |
+
id_class_dict_file,
|
975 |
+
test_data_file,
|
976 |
+
output_directory,
|
977 |
+
output_prefix,
|
978 |
+
predict=True,
|
979 |
+
):
|
980 |
+
"""
|
981 |
+
Evaluate the fine-tuned model.
|
982 |
+
|
983 |
+
**Parameters**
|
984 |
+
|
985 |
+
model_directory : Path
|
986 |
+
| Path to directory containing model
|
987 |
+
id_class_dict_file : Path
|
988 |
+
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
989 |
+
| (dictionary of format: numerical IDs: class_labels)
|
990 |
+
test_data_file : Path
|
991 |
+
| Path to directory containing test .dataset
|
992 |
+
output_directory : Path
|
993 |
+
| Path to directory where eval data will be saved
|
994 |
+
output_prefix : str
|
995 |
+
| Prefix for output files
|
996 |
+
predict : bool
|
997 |
+
| Whether or not to save eval predictions
|
998 |
+
"""
|
999 |
+
|
1000 |
+
# load numerical id to class dictionary (id:class)
|
1001 |
+
with open(id_class_dict_file, "rb") as f:
|
1002 |
+
id_class_dict = pickle.load(f)
|
1003 |
+
|
1004 |
+
# get number of classes for classifier
|
1005 |
+
num_classes = cu.get_num_classes(id_class_dict)
|
1006 |
+
|
1007 |
+
# load previously filtered and prepared data
|
1008 |
+
test_data = pu.load_and_filter(None, self.nproc, test_data_file)
|
1009 |
+
|
1010 |
+
# load previously fine-tuned model
|
1011 |
+
if self.classifier == "cell":
|
1012 |
+
model_type = "CellClassifier"
|
1013 |
+
elif self.classifier == "gene":
|
1014 |
+
model_type = "GeneClassifier"
|
1015 |
+
model = pu.load_model(model_type, num_classes, model_directory, "eval")
|
1016 |
+
|
1017 |
+
# evaluate the model
|
1018 |
+
results = self.evaluate_model(
|
1019 |
+
model,
|
1020 |
+
num_classes,
|
1021 |
+
id_class_dict,
|
1022 |
+
test_data,
|
1023 |
+
predict=predict,
|
1024 |
+
output_directory=output_directory,
|
1025 |
+
output_prefix=output_prefix,
|
1026 |
+
)
|
1027 |
+
|
1028 |
+
all_conf_mat_df = pd.DataFrame(
|
1029 |
+
results["conf_mat"],
|
1030 |
+
columns=id_class_dict.values(),
|
1031 |
+
index=id_class_dict.values(),
|
1032 |
+
)
|
1033 |
+
all_metrics = {
|
1034 |
+
"conf_matrix": all_conf_mat_df,
|
1035 |
+
"macro_f1": results["macro_f1"],
|
1036 |
+
"acc": results["acc"],
|
1037 |
+
}
|
1038 |
+
all_roc_metrics = None # roc metrics not reported for multiclass
|
1039 |
+
if num_classes == 2:
|
1040 |
+
mean_fpr = np.linspace(0, 1, 100)
|
1041 |
+
all_tpr = [result["roc_metrics"]["interp_tpr"] for result in results]
|
1042 |
+
all_roc_auc = [result["roc_metrics"]["auc"] for result in results]
|
1043 |
+
all_tpr_wt = [result["roc_metrics"]["tpr_wt"] for result in results]
|
1044 |
+
mean_tpr, roc_auc, roc_auc_sd = eu.get_cross_valid_roc_metrics(
|
1045 |
+
all_tpr, all_roc_auc, all_tpr_wt
|
1046 |
+
)
|
1047 |
+
all_roc_metrics = {
|
1048 |
+
"mean_tpr": mean_tpr,
|
1049 |
+
"mean_fpr": mean_fpr,
|
1050 |
+
"all_roc_auc": all_roc_auc,
|
1051 |
+
}
|
1052 |
+
all_metrics["all_roc_metrics"] = all_roc_metrics
|
1053 |
+
test_metrics_output_path = (
|
1054 |
+
Path(output_directory) / f"{output_prefix}_test_metrics_dict"
|
1055 |
+
).with_suffix(".pkl")
|
1056 |
+
with open(test_metrics_output_path, "wb") as f:
|
1057 |
+
pickle.dump(all_metrics, f)
|
1058 |
+
|
1059 |
+
return all_metrics
|
1060 |
+
|
1061 |
+
def plot_conf_mat(
|
1062 |
+
self,
|
1063 |
+
conf_mat_dict,
|
1064 |
+
output_directory,
|
1065 |
+
output_prefix,
|
1066 |
+
custom_class_order=None,
|
1067 |
+
):
|
1068 |
+
"""
|
1069 |
+
Plot confusion matrix results of evaluating the fine-tuned model.
|
1070 |
+
|
1071 |
+
**Parameters**
|
1072 |
+
|
1073 |
+
conf_mat_dict : dict
|
1074 |
+
| Dictionary of model_name : confusion_matrix_DataFrame
|
1075 |
+
| (all_metrics["conf_matrix"] from self.validate)
|
1076 |
+
output_directory : Path
|
1077 |
+
| Path to directory where plots will be saved
|
1078 |
+
output_prefix : str
|
1079 |
+
| Prefix for output file
|
1080 |
+
custom_class_order : None, list
|
1081 |
+
| List of classes in custom order for plots.
|
1082 |
+
| Same order will be used for all models.
|
1083 |
+
"""
|
1084 |
+
|
1085 |
+
for model_name in conf_mat_dict.keys():
|
1086 |
+
eu.plot_confusion_matrix(
|
1087 |
+
conf_mat_dict[model_name],
|
1088 |
+
model_name,
|
1089 |
+
output_directory,
|
1090 |
+
output_prefix,
|
1091 |
+
custom_class_order,
|
1092 |
+
)
|
1093 |
+
|
1094 |
+
def plot_roc(
|
1095 |
+
self,
|
1096 |
+
roc_metric_dict,
|
1097 |
+
model_style_dict,
|
1098 |
+
title,
|
1099 |
+
output_directory,
|
1100 |
+
output_prefix,
|
1101 |
+
):
|
1102 |
+
"""
|
1103 |
+
Plot ROC curve results of evaluating the fine-tuned model.
|
1104 |
+
|
1105 |
+
**Parameters**
|
1106 |
+
|
1107 |
+
roc_metric_dict : dict
|
1108 |
+
| Dictionary of model_name : roc_metrics
|
1109 |
+
| (all_metrics["all_roc_metrics"] from self.validate)
|
1110 |
+
model_style_dict : dict[dict]
|
1111 |
+
| Dictionary of model_name : dictionary of style_attribute : style
|
1112 |
+
| where style includes color and linestyle
|
1113 |
+
| e.g. {'Model_A': {'color': 'black', 'linestyle': '-'}, 'Model_B': ...}
|
1114 |
+
title : str
|
1115 |
+
| Title of plot (e.g. 'Dosage-sensitive vs -insensitive factors')
|
1116 |
+
output_directory : Path
|
1117 |
+
| Path to directory where plots will be saved
|
1118 |
+
output_prefix : str
|
1119 |
+
| Prefix for output file
|
1120 |
+
"""
|
1121 |
+
|
1122 |
+
eu.plot_ROC(
|
1123 |
+
roc_metric_dict, model_style_dict, title, output_directory, output_prefix
|
1124 |
+
)
|
1125 |
+
|
1126 |
+
def plot_predictions(
|
1127 |
+
self,
|
1128 |
+
predictions_file,
|
1129 |
+
id_class_dict_file,
|
1130 |
+
title,
|
1131 |
+
output_directory,
|
1132 |
+
output_prefix,
|
1133 |
+
custom_class_order=None,
|
1134 |
+
kwargs_dict=None,
|
1135 |
+
):
|
1136 |
+
"""
|
1137 |
+
Plot prediction results of evaluating the fine-tuned model.
|
1138 |
+
|
1139 |
+
**Parameters**
|
1140 |
+
|
1141 |
+
predictions_file : path
|
1142 |
+
| Path of model predictions output to plot
|
1143 |
+
| (saved output from self.validate if predict=True)
|
1144 |
+
| (or saved output from self.evaluate_saved_model)
|
1145 |
+
id_class_dict_file : Path
|
1146 |
+
| Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
|
1147 |
+
| (dictionary of format: numerical IDs: class_labels)
|
1148 |
+
title : str
|
1149 |
+
| Title for legend containing class labels.
|
1150 |
+
output_directory : Path
|
1151 |
+
| Path to directory where plots will be saved
|
1152 |
+
output_prefix : str
|
1153 |
+
| Prefix for output file
|
1154 |
+
custom_class_order : None, list
|
1155 |
+
| List of classes in custom order for plots.
|
1156 |
+
| Same order will be used for all models.
|
1157 |
+
kwargs_dict : None, dict
|
1158 |
+
| Dictionary of kwargs to pass to plotting function.
|
1159 |
+
"""
|
1160 |
+
# load predictions
|
1161 |
+
with open(predictions_file, "rb") as f:
|
1162 |
+
predictions = pickle.load(f)
|
1163 |
+
|
1164 |
+
# load numerical id to class dictionary (id:class)
|
1165 |
+
with open(id_class_dict_file, "rb") as f:
|
1166 |
+
id_class_dict = pickle.load(f)
|
1167 |
+
|
1168 |
+
if isinstance(predictions, dict):
|
1169 |
+
if all(
|
1170 |
+
[
|
1171 |
+
key in predictions.keys()
|
1172 |
+
for key in ["pred_ids", "label_ids", "predictions"]
|
1173 |
+
]
|
1174 |
+
):
|
1175 |
+
# format is output from self.evaluate_saved_model
|
1176 |
+
predictions_logits = np.array(predictions["predictions"])
|
1177 |
+
true_ids = predictions["label_ids"]
|
1178 |
+
else:
|
1179 |
+
# format is output from self.validate if predict=True
|
1180 |
+
predictions_logits = predictions.predictions
|
1181 |
+
true_ids = predictions.label_ids
|
1182 |
+
|
1183 |
+
num_classes = len(id_class_dict.keys())
|
1184 |
+
num_predict_classes = predictions_logits.shape[1]
|
1185 |
+
assert num_classes == num_predict_classes
|
1186 |
+
classes = id_class_dict.values()
|
1187 |
+
true_labels = [id_class_dict[idx] for idx in true_ids]
|
1188 |
+
predictions_df = pd.DataFrame(predictions_logits, columns=classes)
|
1189 |
+
if custom_class_order is not None:
|
1190 |
+
predictions_df = predictions_df.reindex(columns=custom_class_order)
|
1191 |
+
predictions_df["true"] = true_labels
|
1192 |
+
custom_dict = dict(zip(classes, [i for i in range(len(classes))]))
|
1193 |
+
if custom_class_order is not None:
|
1194 |
+
custom_dict = dict(
|
1195 |
+
zip(custom_class_order, [i for i in range(len(custom_class_order))])
|
1196 |
+
)
|
1197 |
+
predictions_df = predictions_df.sort_values(
|
1198 |
+
by=["true"], key=lambda x: x.map(custom_dict)
|
1199 |
+
)
|
1200 |
+
|
1201 |
+
eu.plot_predictions(
|
1202 |
+
predictions_df, title, output_directory, output_prefix, kwargs_dict
|
1203 |
+
)
|
geneformer/classifier_utils.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import random
|
3 |
+
from collections import Counter, defaultdict
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
from scipy.stats import chisquare, ranksums
|
8 |
+
from sklearn.metrics import accuracy_score, f1_score
|
9 |
+
|
10 |
+
from . import perturber_utils as pu
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
def downsample_and_shuffle(data, max_ncells, max_ncells_per_class, cell_state_dict):
|
16 |
+
data = data.shuffle(seed=42)
|
17 |
+
num_cells = len(data)
|
18 |
+
# if max number of cells is defined, then subsample to this max number
|
19 |
+
if max_ncells is not None:
|
20 |
+
if num_cells > max_ncells:
|
21 |
+
data = data.select([i for i in range(max_ncells)])
|
22 |
+
if max_ncells_per_class is not None:
|
23 |
+
class_labels = data[cell_state_dict["state_key"]]
|
24 |
+
random.seed(42)
|
25 |
+
subsample_indices = subsample_by_class(class_labels, max_ncells_per_class)
|
26 |
+
data = data.select(subsample_indices)
|
27 |
+
return data
|
28 |
+
|
29 |
+
|
30 |
+
# subsample labels to maximum number N per class and return indices
|
31 |
+
def subsample_by_class(labels, N):
|
32 |
+
label_indices = defaultdict(list)
|
33 |
+
# Gather indices for each label
|
34 |
+
for idx, label in enumerate(labels):
|
35 |
+
label_indices[label].append(idx)
|
36 |
+
selected_indices = []
|
37 |
+
# Select up to N indices for each label
|
38 |
+
for label, indices in label_indices.items():
|
39 |
+
if len(indices) > N:
|
40 |
+
selected_indices.extend(random.sample(indices, N))
|
41 |
+
else:
|
42 |
+
selected_indices.extend(indices)
|
43 |
+
return selected_indices
|
44 |
+
|
45 |
+
|
46 |
+
def rename_cols(data, state_key):
|
47 |
+
data = data.rename_column(state_key, "label")
|
48 |
+
return data
|
49 |
+
|
50 |
+
|
51 |
+
def validate_and_clean_cols(train_data, eval_data, classifier):
|
52 |
+
# validate that data has expected label column and remove others
|
53 |
+
if classifier == "cell":
|
54 |
+
label_col = "label"
|
55 |
+
elif classifier == "gene":
|
56 |
+
label_col = "labels"
|
57 |
+
|
58 |
+
cols_to_keep = [label_col] + ["input_ids", "length"]
|
59 |
+
if label_col not in train_data.column_names:
|
60 |
+
logger.error(f"train_data must contain column {label_col} with class labels.")
|
61 |
+
raise
|
62 |
+
else:
|
63 |
+
train_data = remove_cols(train_data, cols_to_keep)
|
64 |
+
|
65 |
+
if eval_data is not None:
|
66 |
+
if label_col not in eval_data.column_names:
|
67 |
+
logger.error(
|
68 |
+
f"eval_data must contain column {label_col} with class labels."
|
69 |
+
)
|
70 |
+
raise
|
71 |
+
else:
|
72 |
+
eval_data = remove_cols(eval_data, cols_to_keep)
|
73 |
+
return train_data, eval_data
|
74 |
+
|
75 |
+
|
76 |
+
def remove_cols(data, cols_to_keep):
|
77 |
+
other_cols = list(data.features.keys())
|
78 |
+
other_cols = [ele for ele in other_cols if ele not in cols_to_keep]
|
79 |
+
data = data.remove_columns(other_cols)
|
80 |
+
return data
|
81 |
+
|
82 |
+
|
83 |
+
def remove_rare(data, rare_threshold, label, nproc):
|
84 |
+
if rare_threshold > 0:
|
85 |
+
total_cells = len(data)
|
86 |
+
label_counter = Counter(data[label])
|
87 |
+
nonrare_label_dict = {
|
88 |
+
label: [k for k, v in label_counter if (v / total_cells) > rare_threshold]
|
89 |
+
}
|
90 |
+
data = pu.filter_by_dict(data, nonrare_label_dict, nproc)
|
91 |
+
return data
|
92 |
+
|
93 |
+
|
94 |
+
def label_classes(classifier, data, gene_class_dict, nproc):
|
95 |
+
if classifier == "cell":
|
96 |
+
label_set = set(data["label"])
|
97 |
+
elif classifier == "gene":
|
98 |
+
# remove cells without any of the target genes
|
99 |
+
def if_contains_label(example):
|
100 |
+
a = pu.flatten_list(gene_class_dict.values())
|
101 |
+
b = example["input_ids"]
|
102 |
+
return not set(a).isdisjoint(b)
|
103 |
+
|
104 |
+
data = data.filter(if_contains_label, num_proc=nproc)
|
105 |
+
label_set = gene_class_dict.keys()
|
106 |
+
|
107 |
+
if len(data) == 0:
|
108 |
+
logger.error(
|
109 |
+
"No cells remain after filtering for target genes. Check target gene list."
|
110 |
+
)
|
111 |
+
raise
|
112 |
+
|
113 |
+
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
|
114 |
+
id_class_dict = {v: k for k, v in class_id_dict.items()}
|
115 |
+
|
116 |
+
def classes_to_ids(example):
|
117 |
+
if classifier == "cell":
|
118 |
+
example["label"] = class_id_dict[example["label"]]
|
119 |
+
elif classifier == "gene":
|
120 |
+
example["labels"] = label_gene_classes(
|
121 |
+
example, class_id_dict, gene_class_dict
|
122 |
+
)
|
123 |
+
return example
|
124 |
+
|
125 |
+
data = data.map(classes_to_ids, num_proc=nproc)
|
126 |
+
return data, id_class_dict
|
127 |
+
|
128 |
+
|
129 |
+
def label_gene_classes(example, class_id_dict, gene_class_dict):
|
130 |
+
return [
|
131 |
+
class_id_dict.get(gene_class_dict.get(token_id, -100), -100)
|
132 |
+
for token_id in example["input_ids"]
|
133 |
+
]
|
134 |
+
|
135 |
+
|
136 |
+
def prep_gene_classifier_split(
|
137 |
+
data, targets, labels, train_index, eval_index, max_ncells, iteration_num, num_proc
|
138 |
+
):
|
139 |
+
# generate cross-validation splits
|
140 |
+
targets = np.array(targets)
|
141 |
+
labels = np.array(labels)
|
142 |
+
targets_train, targets_eval = targets[train_index], targets[eval_index]
|
143 |
+
labels_train, labels_eval = labels[train_index], labels[eval_index]
|
144 |
+
label_dict_train = dict(zip(targets_train, labels_train))
|
145 |
+
label_dict_eval = dict(zip(targets_eval, labels_eval))
|
146 |
+
|
147 |
+
# function to filter by whether contains train or eval labels
|
148 |
+
def if_contains_train_label(example):
|
149 |
+
a = targets_train
|
150 |
+
b = example["input_ids"]
|
151 |
+
return not set(a).isdisjoint(b)
|
152 |
+
|
153 |
+
def if_contains_eval_label(example):
|
154 |
+
a = targets_eval
|
155 |
+
b = example["input_ids"]
|
156 |
+
return not set(a).isdisjoint(b)
|
157 |
+
|
158 |
+
# filter dataset for examples containing classes for this split
|
159 |
+
logger.info(f"Filtering training data for genes in split {iteration_num}")
|
160 |
+
train_data = data.filter(if_contains_train_label, num_proc=num_proc)
|
161 |
+
logger.info(
|
162 |
+
f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n"
|
163 |
+
)
|
164 |
+
logger.info(f"Filtering evalation data for genes in split {iteration_num}")
|
165 |
+
eval_data = data.filter(if_contains_eval_label, num_proc=num_proc)
|
166 |
+
logger.info(
|
167 |
+
f"Filtered {round((1-len(eval_data)/len(data))*100)}%; {len(eval_data)} remain\n"
|
168 |
+
)
|
169 |
+
|
170 |
+
# subsample to max_ncells
|
171 |
+
train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
|
172 |
+
eval_data = downsample_and_shuffle(eval_data, max_ncells, None, None)
|
173 |
+
|
174 |
+
# relabel genes for this split
|
175 |
+
def train_classes_to_ids(example):
|
176 |
+
example["labels"] = [
|
177 |
+
label_dict_train.get(token_id, -100) for token_id in example["input_ids"]
|
178 |
+
]
|
179 |
+
return example
|
180 |
+
|
181 |
+
def eval_classes_to_ids(example):
|
182 |
+
example["labels"] = [
|
183 |
+
label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]
|
184 |
+
]
|
185 |
+
return example
|
186 |
+
|
187 |
+
train_data = train_data.map(train_classes_to_ids, num_proc=num_proc)
|
188 |
+
eval_data = eval_data.map(eval_classes_to_ids, num_proc=num_proc)
|
189 |
+
|
190 |
+
return train_data, eval_data
|
191 |
+
|
192 |
+
|
193 |
+
def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc):
|
194 |
+
targets = np.array(targets)
|
195 |
+
labels = np.array(labels)
|
196 |
+
label_dict_train = dict(zip(targets, labels))
|
197 |
+
|
198 |
+
# function to filter by whether contains train labels
|
199 |
+
def if_contains_train_label(example):
|
200 |
+
a = targets
|
201 |
+
b = example["input_ids"]
|
202 |
+
return not set(a).isdisjoint(b)
|
203 |
+
|
204 |
+
# filter dataset for examples containing classes for this split
|
205 |
+
logger.info("Filtering training data for genes to classify.")
|
206 |
+
train_data = data.filter(if_contains_train_label, num_proc=num_proc)
|
207 |
+
logger.info(
|
208 |
+
f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n"
|
209 |
+
)
|
210 |
+
|
211 |
+
# subsample to max_ncells
|
212 |
+
train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
|
213 |
+
|
214 |
+
# relabel genes for this split
|
215 |
+
def train_classes_to_ids(example):
|
216 |
+
example["labels"] = [
|
217 |
+
label_dict_train.get(token_id, -100) for token_id in example["input_ids"]
|
218 |
+
]
|
219 |
+
return example
|
220 |
+
|
221 |
+
train_data = train_data.map(train_classes_to_ids, num_proc=num_proc)
|
222 |
+
|
223 |
+
return train_data
|
224 |
+
|
225 |
+
|
226 |
+
def balance_attr_splits(
|
227 |
+
data,
|
228 |
+
attr_to_split,
|
229 |
+
attr_to_balance,
|
230 |
+
eval_size,
|
231 |
+
max_trials,
|
232 |
+
pval_threshold,
|
233 |
+
state_key,
|
234 |
+
nproc,
|
235 |
+
):
|
236 |
+
metadata_df = pd.DataFrame({"split_attr_ids": data[attr_to_split]})
|
237 |
+
for attr in attr_to_balance:
|
238 |
+
if attr == state_key:
|
239 |
+
metadata_df[attr] = data["label"]
|
240 |
+
else:
|
241 |
+
metadata_df[attr] = data[attr]
|
242 |
+
metadata_df = metadata_df.drop_duplicates()
|
243 |
+
|
244 |
+
split_attr_ids = list(metadata_df["split_attr_ids"])
|
245 |
+
assert len(split_attr_ids) == len(set(split_attr_ids))
|
246 |
+
eval_num = round(len(split_attr_ids) * eval_size)
|
247 |
+
colnames = (
|
248 |
+
["trial_num", "train_ids", "eval_ids"]
|
249 |
+
+ pu.flatten_list(
|
250 |
+
[
|
251 |
+
[
|
252 |
+
f"{attr}_train_mean_or_counts",
|
253 |
+
f"{attr}_eval_mean_or_counts",
|
254 |
+
f"{attr}_pval",
|
255 |
+
]
|
256 |
+
for attr in attr_to_balance
|
257 |
+
]
|
258 |
+
)
|
259 |
+
+ ["mean_pval"]
|
260 |
+
)
|
261 |
+
balance_df = pd.DataFrame(columns=colnames)
|
262 |
+
data_dict = dict()
|
263 |
+
trial_num = 1
|
264 |
+
for i in range(max_trials):
|
265 |
+
if not all(
|
266 |
+
count > 1 for count in list(Counter(metadata_df[state_key]).values())
|
267 |
+
):
|
268 |
+
logger.error(
|
269 |
+
f"Cannot balance by {attr_to_split} while retaining at least 1 occurrence of each {state_key} class in both data splits. "
|
270 |
+
)
|
271 |
+
raise
|
272 |
+
eval_base = []
|
273 |
+
for state in set(metadata_df[state_key]):
|
274 |
+
eval_base += list(
|
275 |
+
metadata_df.loc[
|
276 |
+
metadata_df[state_key][metadata_df[state_key].eq(state)]
|
277 |
+
.sample(1, random_state=i)
|
278 |
+
.index
|
279 |
+
]["split_attr_ids"]
|
280 |
+
)
|
281 |
+
non_eval_base = [idx for idx in split_attr_ids if idx not in eval_base]
|
282 |
+
random.seed(i)
|
283 |
+
eval_ids = random.sample(non_eval_base, eval_num - len(eval_base)) + eval_base
|
284 |
+
train_ids = [idx for idx in split_attr_ids if idx not in eval_ids]
|
285 |
+
df_vals = [trial_num, train_ids, eval_ids]
|
286 |
+
pvals = []
|
287 |
+
for attr in attr_to_balance:
|
288 |
+
train_attr = list(
|
289 |
+
metadata_df[metadata_df["split_attr_ids"].isin(train_ids)][attr]
|
290 |
+
)
|
291 |
+
eval_attr = list(
|
292 |
+
metadata_df[metadata_df["split_attr_ids"].isin(eval_ids)][attr]
|
293 |
+
)
|
294 |
+
if attr == state_key:
|
295 |
+
# ensure IDs are interpreted as categorical
|
296 |
+
train_attr = [str(item) for item in train_attr]
|
297 |
+
eval_attr = [str(item) for item in eval_attr]
|
298 |
+
if all(isinstance(item, (int, float)) for item in train_attr + eval_attr):
|
299 |
+
train_attr_mean = np.nanmean(train_attr)
|
300 |
+
eval_attr_mean = np.nanmean(eval_attr)
|
301 |
+
pval = ranksums(train_attr, eval_attr, nan_policy="omit").pvalue
|
302 |
+
df_vals += [train_attr_mean, eval_attr_mean, pval]
|
303 |
+
elif all(isinstance(item, (str)) for item in train_attr + eval_attr):
|
304 |
+
obs_counts = Counter(train_attr)
|
305 |
+
exp_counts = Counter(eval_attr)
|
306 |
+
all_categ = set(obs_counts.keys()).union(set(exp_counts.keys()))
|
307 |
+
obs = [obs_counts[cat] for cat in all_categ]
|
308 |
+
exp = [
|
309 |
+
exp_counts[cat] * sum(obs) / sum(exp_counts.values())
|
310 |
+
for cat in all_categ
|
311 |
+
]
|
312 |
+
chisquare(f_obs=obs, f_exp=exp).pvalue
|
313 |
+
train_attr_counts = str(obs_counts).strip("Counter(").strip(")")
|
314 |
+
eval_attr_counts = str(exp_counts).strip("Counter(").strip(")")
|
315 |
+
df_vals += [train_attr_counts, eval_attr_counts, pval]
|
316 |
+
else:
|
317 |
+
logger.error(
|
318 |
+
f"Inconsistent data types in attribute {attr}. "
|
319 |
+
"Cannot infer if continuous or categorical. "
|
320 |
+
"Must be all numeric (continuous) or all strings (categorical) to balance."
|
321 |
+
)
|
322 |
+
raise
|
323 |
+
pvals += [pval]
|
324 |
+
|
325 |
+
df_vals += [np.nanmean(pvals)]
|
326 |
+
balance_df_i = pd.DataFrame(df_vals, index=colnames).T
|
327 |
+
balance_df = pd.concat([balance_df, balance_df_i], ignore_index=True)
|
328 |
+
valid_pvals = [
|
329 |
+
pval_i
|
330 |
+
for pval_i in pvals
|
331 |
+
if isinstance(pval_i, (int, float)) and not np.isnan(pval_i)
|
332 |
+
]
|
333 |
+
if all(i >= pval_threshold for i in valid_pvals):
|
334 |
+
data_dict["train"] = pu.filter_by_dict(
|
335 |
+
data, {attr_to_split: balance_df_i["train_ids"][0]}, nproc
|
336 |
+
)
|
337 |
+
data_dict["test"] = pu.filter_by_dict(
|
338 |
+
data, {attr_to_split: balance_df_i["eval_ids"][0]}, nproc
|
339 |
+
)
|
340 |
+
return data_dict, balance_df
|
341 |
+
trial_num = trial_num + 1
|
342 |
+
balance_max_df = balance_df.iloc[balance_df["mean_pval"].idxmax(), :]
|
343 |
+
data_dict["train"] = pu.filter_by_dict(
|
344 |
+
data, {attr_to_split: balance_df_i["train_ids"][0]}, nproc
|
345 |
+
)
|
346 |
+
data_dict["test"] = pu.filter_by_dict(
|
347 |
+
data, {attr_to_split: balance_df_i["eval_ids"][0]}, nproc
|
348 |
+
)
|
349 |
+
logger.warning(
|
350 |
+
f"No splits found without significant difference in attr_to_balance among {max_trials} trials. "
|
351 |
+
f"Selecting optimal split (trial #{balance_max_df['trial_num']}) from completed trials."
|
352 |
+
)
|
353 |
+
return data_dict, balance_df
|
354 |
+
|
355 |
+
|
356 |
+
def get_num_classes(id_class_dict):
|
357 |
+
return len(set(id_class_dict.values()))
|
358 |
+
|
359 |
+
|
360 |
+
def compute_metrics(pred):
|
361 |
+
labels = pred.label_ids
|
362 |
+
preds = pred.predictions.argmax(-1)
|
363 |
+
# calculate accuracy and macro f1 using sklearn's function
|
364 |
+
acc = accuracy_score(labels, preds)
|
365 |
+
macro_f1 = f1_score(labels, preds, average="macro")
|
366 |
+
return {"accuracy": acc, "macro_f1": macro_f1}
|
367 |
+
|
368 |
+
|
369 |
+
def get_default_train_args(model, classifier, data, output_dir):
|
370 |
+
num_layers = pu.quant_layers(model)
|
371 |
+
freeze_layers = 0
|
372 |
+
batch_size = 12
|
373 |
+
if classifier == "cell":
|
374 |
+
epochs = 10
|
375 |
+
evaluation_strategy = "epoch"
|
376 |
+
load_best_model_at_end = True
|
377 |
+
else:
|
378 |
+
epochs = 1
|
379 |
+
evaluation_strategy = "no"
|
380 |
+
load_best_model_at_end = False
|
381 |
+
|
382 |
+
if num_layers == 6:
|
383 |
+
default_training_args = {
|
384 |
+
"learning_rate": 5e-5,
|
385 |
+
"lr_scheduler_type": "linear",
|
386 |
+
"warmup_steps": 500,
|
387 |
+
"per_device_train_batch_size": batch_size,
|
388 |
+
"per_device_eval_batch_size": batch_size,
|
389 |
+
}
|
390 |
+
|
391 |
+
training_args = {
|
392 |
+
"num_train_epochs": epochs,
|
393 |
+
"do_train": True,
|
394 |
+
"do_eval": True,
|
395 |
+
"evaluation_strategy": evaluation_strategy,
|
396 |
+
"logging_steps": np.floor(len(data) / batch_size / 8), # 8 evals per epoch
|
397 |
+
"save_strategy": "epoch",
|
398 |
+
"group_by_length": False,
|
399 |
+
"length_column_name": "length",
|
400 |
+
"disable_tqdm": False,
|
401 |
+
"weight_decay": 0.001,
|
402 |
+
"load_best_model_at_end": load_best_model_at_end,
|
403 |
+
}
|
404 |
+
training_args.update(default_training_args)
|
405 |
+
|
406 |
+
return training_args, freeze_layers
|
geneformer/emb_extractor.py
CHANGED
@@ -17,7 +17,6 @@ from pathlib import Path
|
|
17 |
|
18 |
import anndata
|
19 |
import matplotlib.pyplot as plt
|
20 |
-
import numpy as np
|
21 |
import pandas as pd
|
22 |
import scanpy as sc
|
23 |
import seaborn as sns
|
@@ -303,13 +302,6 @@ def make_colorbar(embs_df, label):
|
|
303 |
cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
|
304 |
label_colors = pd.DataFrame(cell_type_colors, columns=[label])
|
305 |
|
306 |
-
for i, row in label_colors.iterrows():
|
307 |
-
colors = row[0]
|
308 |
-
if len(colors) != 3 or any(np.isnan(colors)):
|
309 |
-
print(i, colors)
|
310 |
-
|
311 |
-
label_colors.isna().sum()
|
312 |
-
|
313 |
# create dictionary for colors and classes
|
314 |
label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
|
315 |
return label_colors, label_color_dict
|
@@ -565,7 +557,9 @@ class EmbExtractor:
|
|
565 |
filtered_input_data, cell_state, self.nproc
|
566 |
)
|
567 |
downsampled_data = pu.downsample_and_sort(filtered_input_data, self.max_ncells)
|
568 |
-
model = pu.load_model(
|
|
|
|
|
569 |
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
570 |
embs = get_embs(
|
571 |
model,
|
|
|
17 |
|
18 |
import anndata
|
19 |
import matplotlib.pyplot as plt
|
|
|
20 |
import pandas as pd
|
21 |
import scanpy as sc
|
22 |
import seaborn as sns
|
|
|
302 |
cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
|
303 |
label_colors = pd.DataFrame(cell_type_colors, columns=[label])
|
304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
# create dictionary for colors and classes
|
306 |
label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
|
307 |
return label_colors, label_color_dict
|
|
|
557 |
filtered_input_data, cell_state, self.nproc
|
558 |
)
|
559 |
downsampled_data = pu.downsample_and_sort(filtered_input_data, self.max_ncells)
|
560 |
+
model = pu.load_model(
|
561 |
+
self.model_type, self.num_classes, model_directory, mode="eval"
|
562 |
+
)
|
563 |
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
564 |
embs = get_embs(
|
565 |
model,
|
geneformer/evaluation_utils.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
import pickle
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import seaborn as sns
|
10 |
+
import torch
|
11 |
+
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
12 |
+
from sklearn import preprocessing
|
13 |
+
from sklearn.metrics import (
|
14 |
+
ConfusionMatrixDisplay,
|
15 |
+
accuracy_score,
|
16 |
+
auc,
|
17 |
+
confusion_matrix,
|
18 |
+
f1_score,
|
19 |
+
roc_curve,
|
20 |
+
)
|
21 |
+
from tqdm.auto import trange
|
22 |
+
|
23 |
+
from .emb_extractor import make_colorbar
|
24 |
+
from .tokenizer import TOKEN_DICTIONARY_FILE
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
# load token dictionary (Ensembl IDs:token)
|
29 |
+
with open(TOKEN_DICTIONARY_FILE, "rb") as f:
|
30 |
+
gene_token_dict = pickle.load(f)
|
31 |
+
|
32 |
+
|
33 |
+
def preprocess_classifier_batch(cell_batch, max_len, label_name):
|
34 |
+
if max_len is None:
|
35 |
+
max_len = max([len(i) for i in cell_batch["input_ids"]])
|
36 |
+
|
37 |
+
def pad_label_example(example):
|
38 |
+
example[label_name] = np.pad(
|
39 |
+
example[label_name],
|
40 |
+
(0, max_len - len(example["input_ids"])),
|
41 |
+
mode="constant",
|
42 |
+
constant_values=-100,
|
43 |
+
)
|
44 |
+
example["input_ids"] = np.pad(
|
45 |
+
example["input_ids"],
|
46 |
+
(0, max_len - len(example["input_ids"])),
|
47 |
+
mode="constant",
|
48 |
+
constant_values=gene_token_dict.get("<pad>"),
|
49 |
+
)
|
50 |
+
example["attention_mask"] = (
|
51 |
+
example["input_ids"] != gene_token_dict.get("<pad>")
|
52 |
+
).astype(int)
|
53 |
+
return example
|
54 |
+
|
55 |
+
padded_batch = cell_batch.map(pad_label_example)
|
56 |
+
return padded_batch
|
57 |
+
|
58 |
+
|
59 |
+
# Function to find the largest number smaller
|
60 |
+
# than or equal to N that is divisible by k
|
61 |
+
def find_largest_div(N, K):
|
62 |
+
rem = N % K
|
63 |
+
if rem == 0:
|
64 |
+
return N
|
65 |
+
else:
|
66 |
+
return N - rem
|
67 |
+
|
68 |
+
|
69 |
+
def vote(logit_list):
|
70 |
+
m = max(logit_list)
|
71 |
+
logit_list.index(m)
|
72 |
+
indices = [i for i, x in enumerate(logit_list) if x == m]
|
73 |
+
if len(indices) > 1:
|
74 |
+
return "tie"
|
75 |
+
else:
|
76 |
+
return indices[0]
|
77 |
+
|
78 |
+
|
79 |
+
def py_softmax(vector):
|
80 |
+
e = np.exp(vector)
|
81 |
+
return e / e.sum()
|
82 |
+
|
83 |
+
|
84 |
+
def classifier_predict(model, classifier_type, evalset, forward_batch_size):
|
85 |
+
if classifier_type == "gene":
|
86 |
+
label_name = "labels"
|
87 |
+
elif classifier_type == "cell":
|
88 |
+
label_name = "label"
|
89 |
+
|
90 |
+
predict_logits = []
|
91 |
+
predict_labels = []
|
92 |
+
model.eval()
|
93 |
+
|
94 |
+
# ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
|
95 |
+
evalset_len = len(evalset)
|
96 |
+
max_divisible = find_largest_div(evalset_len, forward_batch_size)
|
97 |
+
if len(evalset) - max_divisible == 1:
|
98 |
+
evalset_len = max_divisible
|
99 |
+
|
100 |
+
max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
|
101 |
+
|
102 |
+
disable_progress_bar() # disable progress bar for preprocess_classifier_batch mapping
|
103 |
+
for i in trange(0, evalset_len, forward_batch_size):
|
104 |
+
max_range = min(i + forward_batch_size, evalset_len)
|
105 |
+
batch_evalset = evalset.select([i for i in range(i, max_range)])
|
106 |
+
padded_batch = preprocess_classifier_batch(
|
107 |
+
batch_evalset, max_evalset_len, label_name
|
108 |
+
)
|
109 |
+
padded_batch.set_format(type="torch")
|
110 |
+
|
111 |
+
input_data_batch = padded_batch["input_ids"]
|
112 |
+
attn_msk_batch = padded_batch["attention_mask"]
|
113 |
+
label_batch = padded_batch[label_name]
|
114 |
+
with torch.no_grad():
|
115 |
+
outputs = model(
|
116 |
+
input_ids=input_data_batch.to("cuda"),
|
117 |
+
attention_mask=attn_msk_batch.to("cuda"),
|
118 |
+
labels=label_batch.to("cuda"),
|
119 |
+
)
|
120 |
+
predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
|
121 |
+
predict_labels += [torch.squeeze(label_batch.to("cpu"))]
|
122 |
+
|
123 |
+
enable_progress_bar()
|
124 |
+
logits_by_cell = torch.cat(predict_logits)
|
125 |
+
last_dim = len(logits_by_cell.shape) - 1
|
126 |
+
all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[last_dim])
|
127 |
+
labels_by_cell = torch.cat(predict_labels)
|
128 |
+
all_labels = torch.flatten(labels_by_cell)
|
129 |
+
logit_label_paired = [
|
130 |
+
item
|
131 |
+
for item in list(zip(all_logits.tolist(), all_labels.tolist()))
|
132 |
+
if item[1] != -100
|
133 |
+
]
|
134 |
+
y_pred = [vote(item[0]) for item in logit_label_paired]
|
135 |
+
y_true = [item[1] for item in logit_label_paired]
|
136 |
+
logits_list = [item[0] for item in logit_label_paired]
|
137 |
+
return y_pred, y_true, logits_list
|
138 |
+
|
139 |
+
|
140 |
+
def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
|
141 |
+
conf_mat = confusion_matrix(y_true, y_pred, labels=list(labels))
|
142 |
+
macro_f1 = f1_score(y_true, y_pred, average="macro")
|
143 |
+
acc = accuracy_score(y_true, y_pred)
|
144 |
+
roc_metrics = None # roc metrics not reported for multiclass
|
145 |
+
if num_classes == 2:
|
146 |
+
y_score = [py_softmax(item)[1] for item in logits_list]
|
147 |
+
fpr, tpr, _ = roc_curve(y_true, y_score)
|
148 |
+
mean_fpr = np.linspace(0, 1, 100)
|
149 |
+
interp_tpr = np.interp(mean_fpr, fpr, tpr)
|
150 |
+
interp_tpr[0] = 0.0
|
151 |
+
tpr_wt = len(tpr)
|
152 |
+
roc_auc = auc(fpr, tpr)
|
153 |
+
roc_metrics = {
|
154 |
+
"fpr": fpr,
|
155 |
+
"tpr": tpr,
|
156 |
+
"interp_tpr": interp_tpr,
|
157 |
+
"auc": roc_auc,
|
158 |
+
"tpr_wt": tpr_wt,
|
159 |
+
}
|
160 |
+
return conf_mat, macro_f1, acc, roc_metrics
|
161 |
+
|
162 |
+
|
163 |
+
# get cross-validated mean and sd metrics
|
164 |
+
def get_cross_valid_roc_metrics(all_tpr, all_roc_auc, all_tpr_wt):
|
165 |
+
wts = [count / sum(all_tpr_wt) for count in all_tpr_wt]
|
166 |
+
all_weighted_tpr = [a * b for a, b in zip(all_tpr, wts)]
|
167 |
+
mean_tpr = np.sum(all_weighted_tpr, axis=0)
|
168 |
+
mean_tpr[-1] = 1.0
|
169 |
+
all_weighted_roc_auc = [a * b for a, b in zip(all_roc_auc, wts)]
|
170 |
+
roc_auc = np.sum(all_weighted_roc_auc)
|
171 |
+
roc_auc_sd = math.sqrt(np.average((all_roc_auc - roc_auc) ** 2, weights=wts))
|
172 |
+
return mean_tpr, roc_auc, roc_auc_sd
|
173 |
+
|
174 |
+
|
175 |
+
# plot ROC curve
|
176 |
+
def plot_ROC(roc_metric_dict, model_style_dict, title, output_dir, output_prefix):
|
177 |
+
fig = plt.figure()
|
178 |
+
fig.set_size_inches(10, 8)
|
179 |
+
sns.set(font_scale=2)
|
180 |
+
sns.set_style("white")
|
181 |
+
lw = 3
|
182 |
+
for model_name in roc_metric_dict.keys():
|
183 |
+
mean_fpr = roc_metric_dict[model_name]["mean_fpr"]
|
184 |
+
mean_tpr = roc_metric_dict[model_name]["mean_tpr"]
|
185 |
+
roc_auc = roc_metric_dict[model_name]["roc_auc"]
|
186 |
+
roc_auc_sd = roc_metric_dict[model_name]["roc_auc_sd"]
|
187 |
+
color = model_style_dict[model_name]["color"]
|
188 |
+
linestyle = model_style_dict[model_name]["linestyle"]
|
189 |
+
if len(roc_metric_dict[model_name]["all_roc_auc"]) > 1:
|
190 |
+
label = f"{model_name} (AUC {roc_auc:0.2f} $\pm$ {roc_auc_sd:0.2f})"
|
191 |
+
else:
|
192 |
+
label = f"{model_name} (AUC {roc_auc:0.2f})"
|
193 |
+
plt.plot(
|
194 |
+
mean_fpr, mean_tpr, color=color, linestyle=linestyle, lw=lw, label=label
|
195 |
+
)
|
196 |
+
|
197 |
+
plt.plot([0, 1], [0, 1], color="black", lw=lw, linestyle="--")
|
198 |
+
plt.xlim([0.0, 1.0])
|
199 |
+
plt.ylim([0.0, 1.05])
|
200 |
+
plt.xlabel("False Positive Rate")
|
201 |
+
plt.ylabel("True Positive Rate")
|
202 |
+
plt.title(title)
|
203 |
+
plt.legend(loc="lower right")
|
204 |
+
plt.show()
|
205 |
+
|
206 |
+
output_file = (Path(output_dir) / f"{output_prefix}_roc").with_suffix(".pdf")
|
207 |
+
plt.savefig(output_file, bbox_inches="tight")
|
208 |
+
|
209 |
+
|
210 |
+
# plot confusion matrix
|
211 |
+
def plot_confusion_matrix(
|
212 |
+
conf_mat_df, title, output_dir, output_prefix, custom_class_order
|
213 |
+
):
|
214 |
+
fig = plt.figure()
|
215 |
+
fig.set_size_inches(10, 10)
|
216 |
+
sns.set(font_scale=1)
|
217 |
+
sns.set_style("whitegrid", {"axes.grid": False})
|
218 |
+
if custom_class_order is not None:
|
219 |
+
conf_mat_df = conf_mat_df.reindex(
|
220 |
+
index=custom_class_order, columns=custom_class_order
|
221 |
+
)
|
222 |
+
display_labels = generate_display_labels(conf_mat_df)
|
223 |
+
conf_mat = preprocessing.normalize(conf_mat_df.to_numpy(), norm="l1")
|
224 |
+
display = ConfusionMatrixDisplay(
|
225 |
+
confusion_matrix=conf_mat, display_labels=display_labels
|
226 |
+
)
|
227 |
+
display.plot(cmap="Blues", values_format=".2g")
|
228 |
+
plt.title(title)
|
229 |
+
plt.show()
|
230 |
+
|
231 |
+
output_file = (Path(output_dir) / f"{output_prefix}_conf_mat").with_suffix(".pdf")
|
232 |
+
display.figure_.savefig(output_file, bbox_inches="tight")
|
233 |
+
|
234 |
+
|
235 |
+
def generate_display_labels(conf_mat_df):
|
236 |
+
display_labels = []
|
237 |
+
i = 0
|
238 |
+
for label in conf_mat_df.index:
|
239 |
+
display_labels += [f"{label}\nn={conf_mat_df.iloc[i,:].sum():.0f}"]
|
240 |
+
i = i + 1
|
241 |
+
return display_labels
|
242 |
+
|
243 |
+
|
244 |
+
def plot_predictions(predictions_df, title, output_dir, output_prefix, kwargs_dict):
|
245 |
+
sns.set(font_scale=2)
|
246 |
+
plt.figure(figsize=(10, 10), dpi=150)
|
247 |
+
label_colors, label_color_dict = make_colorbar(predictions_df, "true")
|
248 |
+
predictions_df = predictions_df.drop(columns=["true"])
|
249 |
+
predict_colors_list = [label_color_dict[label] for label in predictions_df.columns]
|
250 |
+
predict_label_list = [label for label in predictions_df.columns]
|
251 |
+
predict_colors = pd.DataFrame(
|
252 |
+
pd.Series(predict_colors_list, index=predict_label_list), columns=["predicted"]
|
253 |
+
)
|
254 |
+
|
255 |
+
default_kwargs_dict = {
|
256 |
+
"row_cluster": False,
|
257 |
+
"col_cluster": False,
|
258 |
+
"row_colors": label_colors,
|
259 |
+
"col_colors": predict_colors,
|
260 |
+
"linewidths": 0,
|
261 |
+
"xticklabels": False,
|
262 |
+
"yticklabels": False,
|
263 |
+
"center": 0,
|
264 |
+
"cmap": "vlag",
|
265 |
+
}
|
266 |
+
|
267 |
+
if kwargs_dict is not None:
|
268 |
+
default_kwargs_dict.update(kwargs_dict)
|
269 |
+
g = sns.clustermap(predictions_df, **default_kwargs_dict)
|
270 |
+
|
271 |
+
plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
|
272 |
+
|
273 |
+
for label_color in list(label_color_dict.keys()):
|
274 |
+
g.ax_col_dendrogram.bar(
|
275 |
+
0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0
|
276 |
+
)
|
277 |
+
|
278 |
+
g.ax_col_dendrogram.legend(
|
279 |
+
title=f"{title}",
|
280 |
+
loc="lower center",
|
281 |
+
ncol=4,
|
282 |
+
bbox_to_anchor=(0.5, 1),
|
283 |
+
facecolor="white",
|
284 |
+
)
|
285 |
+
|
286 |
+
output_file = (Path(output_dir) / f"{output_prefix}_pred").with_suffix(".pdf")
|
287 |
+
plt.savefig(output_file, bbox_inches="tight")
|
geneformer/in_silico_perturber_stats.py
CHANGED
@@ -801,6 +801,12 @@ class InSilicoPerturberStats:
|
|
801 |
logger.error("All states must be unique.")
|
802 |
raise
|
803 |
|
|
|
|
|
|
|
|
|
|
|
|
|
804 |
else:
|
805 |
logger.error(
|
806 |
"cell_states_to_model must only have the following four keys: "
|
|
|
801 |
logger.error("All states must be unique.")
|
802 |
raise
|
803 |
|
804 |
+
elif set(self.cell_states_to_model.keys()) == {
|
805 |
+
"state_key",
|
806 |
+
"start_state",
|
807 |
+
"goal_state",
|
808 |
+
}:
|
809 |
+
self.cell_states_to_model["alt_states"] = []
|
810 |
else:
|
811 |
logger.error(
|
812 |
"cell_states_to_model must only have the following four keys: "
|
geneformer/tokenizer.py
CHANGED
@@ -43,12 +43,13 @@ from pathlib import Path
|
|
43 |
from typing import Literal
|
44 |
|
45 |
import anndata as ad
|
46 |
-
import loompy as lp
|
47 |
import numpy as np
|
48 |
import scipy.sparse as sp
|
49 |
from datasets import Dataset
|
50 |
|
51 |
-
warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
|
|
|
|
|
52 |
logger = logging.getLogger(__name__)
|
53 |
|
54 |
GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
|
@@ -81,7 +82,7 @@ class TranscriptomeTokenizer:
|
|
81 |
custom_attr_name_dict=None,
|
82 |
nproc=1,
|
83 |
chunk_size=512,
|
84 |
-
|
85 |
special_token=False,
|
86 |
gene_median_file=GENE_MEDIAN_FILE,
|
87 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
@@ -95,10 +96,10 @@ class TranscriptomeTokenizer:
|
|
95 |
| Values are the names of the attributes in the dataset.
|
96 |
nproc : int
|
97 |
| Number of processes to use for dataset mapping.
|
98 |
-
chunk_size: int = 512
|
99 |
| Chunk size for anndata tokenizer.
|
100 |
-
|
101 |
-
|
|
102 |
special_token: bool = False
|
103 |
| Option to add CLS and SEP tokens
|
104 |
gene_median_file : Path
|
@@ -117,7 +118,7 @@ class TranscriptomeTokenizer:
|
|
117 |
self.chunk_size = chunk_size
|
118 |
|
119 |
# input size for tokenization
|
120 |
-
self.
|
121 |
|
122 |
# add CLS and SEP tokens
|
123 |
self.special_token = special_token
|
@@ -163,7 +164,9 @@ class TranscriptomeTokenizer:
|
|
163 |
Path(data_directory), file_format
|
164 |
)
|
165 |
tokenized_dataset = self.create_dataset(
|
166 |
-
tokenized_cells,
|
|
|
|
|
167 |
)
|
168 |
|
169 |
output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
|
@@ -332,7 +335,7 @@ class TranscriptomeTokenizer:
|
|
332 |
file_cell_metadata[k] += subview.ca[k].tolist()
|
333 |
else:
|
334 |
file_cell_metadata = None
|
335 |
-
|
336 |
return tokenized_cells, file_cell_metadata
|
337 |
|
338 |
def create_dataset(
|
@@ -367,12 +370,20 @@ class TranscriptomeTokenizer:
|
|
367 |
|
368 |
# Truncate/Crop input_ids to input size
|
369 |
if self.special_token:
|
370 |
-
example["input_ids"] = example["input_ids"][
|
371 |
-
|
372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
else:
|
374 |
# Truncate/Crop input_ids to input size
|
375 |
-
example["input_ids"] = example["input_ids"][0:self.
|
376 |
example["length"] = len(example["input_ids"])
|
377 |
|
378 |
return example
|
@@ -380,4 +391,4 @@ class TranscriptomeTokenizer:
|
|
380 |
output_dataset_truncated = output_dataset.map(
|
381 |
format_cell_features, num_proc=self.nproc
|
382 |
)
|
383 |
-
return output_dataset_truncated
|
|
|
43 |
from typing import Literal
|
44 |
|
45 |
import anndata as ad
|
|
|
46 |
import numpy as np
|
47 |
import scipy.sparse as sp
|
48 |
from datasets import Dataset
|
49 |
|
50 |
+
warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa
|
51 |
+
import loompy as lp # noqa
|
52 |
+
|
53 |
logger = logging.getLogger(__name__)
|
54 |
|
55 |
GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
|
|
|
82 |
custom_attr_name_dict=None,
|
83 |
nproc=1,
|
84 |
chunk_size=512,
|
85 |
+
model_input_size=2048,
|
86 |
special_token=False,
|
87 |
gene_median_file=GENE_MEDIAN_FILE,
|
88 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
|
|
96 |
| Values are the names of the attributes in the dataset.
|
97 |
nproc : int
|
98 |
| Number of processes to use for dataset mapping.
|
99 |
+
chunk_size : int = 512
|
100 |
| Chunk size for anndata tokenizer.
|
101 |
+
model_input_size: int = 2048
|
102 |
+
| Max input size of model to truncate input to.
|
103 |
special_token: bool = False
|
104 |
| Option to add CLS and SEP tokens
|
105 |
gene_median_file : Path
|
|
|
118 |
self.chunk_size = chunk_size
|
119 |
|
120 |
# input size for tokenization
|
121 |
+
self.model_input_size = model_input_size
|
122 |
|
123 |
# add CLS and SEP tokens
|
124 |
self.special_token = special_token
|
|
|
164 |
Path(data_directory), file_format
|
165 |
)
|
166 |
tokenized_dataset = self.create_dataset(
|
167 |
+
tokenized_cells,
|
168 |
+
cell_metadata,
|
169 |
+
use_generator=use_generator,
|
170 |
)
|
171 |
|
172 |
output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
|
|
|
335 |
file_cell_metadata[k] += subview.ca[k].tolist()
|
336 |
else:
|
337 |
file_cell_metadata = None
|
338 |
+
|
339 |
return tokenized_cells, file_cell_metadata
|
340 |
|
341 |
def create_dataset(
|
|
|
370 |
|
371 |
# Truncate/Crop input_ids to input size
|
372 |
if self.special_token:
|
373 |
+
example["input_ids"] = example["input_ids"][
|
374 |
+
0 : self.model_input_size - 2
|
375 |
+
] # truncate to leave space for CLS and SEP token
|
376 |
+
example["input_ids"] = np.insert(
|
377 |
+
example["input_ids"], 0, self.gene_token_dict.get("<cls>")
|
378 |
+
)
|
379 |
+
example["input_ids"] = np.insert(
|
380 |
+
example["input_ids"],
|
381 |
+
len(example["input_ids"]),
|
382 |
+
self.gene_token_dict.get("<sep>"),
|
383 |
+
)
|
384 |
else:
|
385 |
# Truncate/Crop input_ids to input size
|
386 |
+
example["input_ids"] = example["input_ids"][0 : self._model_input_size]
|
387 |
example["length"] = len(example["input_ids"])
|
388 |
|
389 |
return example
|
|
|
391 |
output_dataset_truncated = output_dataset.map(
|
392 |
format_cell_features, num_proc=self.nproc
|
393 |
)
|
394 |
+
return output_dataset_truncated
|