Christina Theodoris commited on
Commit
9e9cca9
1 Parent(s): d6c634c

Add classifier module and examples

Browse files
docs/source/api.rst CHANGED
@@ -9,6 +9,14 @@ Tokenizer
9
 
10
  geneformer.tokenizer
11
 
 
 
 
 
 
 
 
 
12
  Embedding Extractor
13
  -------------------
14
 
 
9
 
10
  geneformer.tokenizer
11
 
12
+ Classifier
13
+ ----------
14
+
15
+ .. toctree::
16
+ :maxdepth: 1
17
+
18
+ geneformer.classifier
19
+
20
  Embedding Extractor
21
  -------------------
22
 
docs/source/geneformer.classifier.rst ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ geneformer.classifier
2
+ =====================
3
+
4
+ .. automodule:: geneformer.classifier
5
+ :members:
6
+ :undoc-members:
7
+ :show-inheritance:
8
+ :exclude-members:
9
+ validate_options
examples/cell_classification.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
examples/extract_and_plot_cell_embeddings.ipynb CHANGED
@@ -129,7 +129,7 @@
129
  "name": "python",
130
  "nbconvert_exporter": "python",
131
  "pygments_lexer": "ipython3",
132
- "version": "3.10.11"
133
  }
134
  },
135
  "nbformat": 4,
 
129
  "name": "python",
130
  "nbconvert_exporter": "python",
131
  "pygments_lexer": "ipython3",
132
+ "version": "3.11.5"
133
  }
134
  },
135
  "nbformat": 4,
examples/gene_classification.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
geneformer/__init__.py CHANGED
@@ -1,12 +1,20 @@
1
- from . import tokenizer
2
- from . import pretrainer
3
- from . import collator_for_classification
4
- from . import in_silico_perturber
5
- from . import in_silico_perturber_stats
6
- from .tokenizer import TranscriptomeTokenizer
7
- from .pretrainer import GeneformerPretrainer
8
- from .collator_for_classification import DataCollatorForGeneClassification
9
- from .collator_for_classification import DataCollatorForCellClassification
 
 
 
 
 
 
10
  from .emb_extractor import EmbExtractor
11
  from .in_silico_perturber import InSilicoPerturber
12
- from .in_silico_perturber_stats import InSilicoPerturberStats
 
 
 
1
+ # ruff: noqa: F401
2
+ from . import classifier # noqa
3
+ from . import (
4
+ collator_for_classification,
5
+ emb_extractor,
6
+ in_silico_perturber,
7
+ in_silico_perturber_stats,
8
+ pretrainer,
9
+ tokenizer,
10
+ )
11
+ from .classifier import Classifier
12
+ from .collator_for_classification import (
13
+ DataCollatorForCellClassification,
14
+ DataCollatorForGeneClassification,
15
+ )
16
  from .emb_extractor import EmbExtractor
17
  from .in_silico_perturber import InSilicoPerturber
18
+ from .in_silico_perturber_stats import InSilicoPerturberStats
19
+ from .pretrainer import GeneformerPretrainer
20
+ from .tokenizer import TranscriptomeTokenizer
geneformer/classifier.py ADDED
@@ -0,0 +1,1203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer classifier.
3
+
4
+ **Input data:**
5
+
6
+ Cell state classifier:
7
+ | Single-cell transcriptomes as Geneformer rank value encodings with cell state labels
8
+ | in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py)
9
+
10
+ Gene classifier:
11
+ | Dictionary in format {Gene_label: list(genes)} for gene labels
12
+ | and single-cell transcriptomes as Geneformer rank value encodings
13
+ | in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py)
14
+
15
+ **Usage:**
16
+
17
+ .. code-block :: python
18
+
19
+ >>> from geneformer import Classifier
20
+ >>> cc = Classifier(classifier="cell", # example of cell state classifier
21
+ ... cell_state_dict={"state_key": "disease", "states": "all"},
22
+ ... filter_data={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]},
23
+ ... training_args=training_args,
24
+ ... freeze_layers = 2,
25
+ ... num_crossval_splits = 1,
26
+ ... forward_batch_size=200,
27
+ ... nproc=16)
28
+ >>> cc.prepare_data(input_data_file="path/to/input_data",
29
+ ... output_directory="path/to/output_directory",
30
+ ... output_prefix="output_prefix")
31
+ >>> all_metrics = cc.validate(model_directory="path/to/model",
32
+ ... prepared_input_data_file=f"path/to/output_directory/{output_prefix}_labeled.dataset",
33
+ ... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
34
+ ... output_directory="path/to/output_directory",
35
+ ... output_prefix="output_prefix",
36
+ ... predict=True)
37
+ >>> cc.plot_conf_mat(conf_mat_dict={"Geneformer": all_metrics["conf_matrix"]},
38
+ ... output_directory="path/to/output_directory",
39
+ ... output_prefix="output_prefix",
40
+ ... custom_class_order=["healthy","disease1","disease2"])
41
+ >>> cc.plot_predictions(predictions_file=f"path/to/output_directory/datestamp_geneformer_cellClassifier_{output_prefix}/ksplit1/predictions.pkl",
42
+ ... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
43
+ ... title="disease",
44
+ ... output_directory="path/to/output_directory",
45
+ ... output_prefix="output_prefix",
46
+ ... custom_class_order=["healthy","disease1","disease2"])
47
+ """
48
+
49
+ import datetime
50
+ import logging
51
+ import os
52
+ import pickle
53
+ import subprocess
54
+ from pathlib import Path
55
+
56
+ import numpy as np
57
+ import pandas as pd
58
+ import seaborn as sns
59
+ from sklearn.model_selection import StratifiedKFold
60
+ from tqdm.auto import tqdm, trange
61
+ from transformers import Trainer
62
+ from transformers.training_args import TrainingArguments
63
+
64
+ from . import DataCollatorForCellClassification, DataCollatorForGeneClassification
65
+ from . import classifier_utils as cu
66
+ from . import evaluation_utils as eu
67
+ from . import perturber_utils as pu
68
+ from .tokenizer import TOKEN_DICTIONARY_FILE
69
+
70
+ sns.set()
71
+
72
+
73
+ logger = logging.getLogger(__name__)
74
+
75
+
76
+ class Classifier:
77
+ valid_option_dict = {
78
+ "classifier": {"cell", "gene"},
79
+ "cell_state_dict": {None, dict},
80
+ "gene_class_dict": {None, dict},
81
+ "filter_data": {None, dict},
82
+ "rare_threshold": {int, float},
83
+ "max_ncells": {None, int},
84
+ "max_ncells_per_class": {None, int},
85
+ "training_args": {None, dict},
86
+ "freeze_layers": {int},
87
+ "num_crossval_splits": {0, 1, 5},
88
+ "eval_size": {int, float},
89
+ "no_eval": {bool},
90
+ "stratify_splits_col": {None, str},
91
+ "forward_batch_size": {int},
92
+ "nproc": {int},
93
+ }
94
+
95
+ def __init__(
96
+ self,
97
+ classifier=None,
98
+ cell_state_dict=None,
99
+ gene_class_dict=None,
100
+ filter_data=None,
101
+ rare_threshold=0,
102
+ max_ncells=None,
103
+ max_ncells_per_class=None,
104
+ training_args=None,
105
+ freeze_layers=0,
106
+ num_crossval_splits=1,
107
+ eval_size=0.2,
108
+ stratify_splits_col=None,
109
+ no_eval=False,
110
+ forward_batch_size=100,
111
+ nproc=4,
112
+ ):
113
+ """
114
+ Initialize Geneformer classifier.
115
+
116
+ **Parameters:**
117
+
118
+ classifier : {"cell", "gene"}
119
+ | Whether to fine-tune a cell state or gene classifier.
120
+ cell_state_dict : None, dict
121
+ | Cell states to fine-tune model to distinguish.
122
+ | Two-item dictionary with keys: state_key and states
123
+ | state_key: key specifying name of column in .dataset that defines the states to model
124
+ | states: list of values in the state_key column that specifies the states to model
125
+ | Alternatively, instead of a list of states, can specify "all" to use all states in that state key from input data.
126
+ | Of note, if using "all", states will be defined after data is filtered.
127
+ | Must have at least 2 states to model.
128
+ | For example: {"state_key": "disease",
129
+ | "states": ["nf", "hcm", "dcm"]}
130
+ | or
131
+ | {"state_key": "disease",
132
+ | "states": "all"}
133
+ gene_class_dict : None, dict
134
+ | Gene classes to fine-tune model to distinguish.
135
+ | Dictionary in format: {Gene_label_A: list(geneA1, geneA2, ...),
136
+ | Gene_label_B: list(geneB1, geneB2, ...)}
137
+ | Gene values should be Ensembl IDs.
138
+ filter_data : None, dict
139
+ | Default is to fine-tune with all input data.
140
+ | Otherwise, dictionary specifying .dataset column name and list of values to filter by.
141
+ rare_threshold : float
142
+ | Threshold below which rare cell states should be removed.
143
+ | For example, setting to 0.05 will remove cell states representing
144
+ | < 5% of the total cells from the cell state classifier's possible classes.
145
+ max_ncells : None, int
146
+ | Maximum number of cells to use for fine-tuning.
147
+ | Default is to fine-tune with all input data.
148
+ max_ncells_per_class : None, int
149
+ | Maximum number of cells per cell class to use for fine-tuning.
150
+ | Of note, will be applied after max_ncells above.
151
+ | (Only valid for cell classification.)
152
+ training_args : None, dict
153
+ | Training arguments for fine-tuning.
154
+ | If None, defaults will be inferred for 6 layer Geneformer.
155
+ | Otherwise, will use the Hugging Face defaults:
156
+ | https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
157
+ | Note: Hyperparameter tuning is highly recommended, rather than using defaults.
158
+ freeze_layers : int
159
+ | Number of layers to freeze from fine-tuning.
160
+ | 0: no layers will be frozen; 2: first two layers will be frozen; etc.
161
+ num_crossval_splits : {0, 1, 5}
162
+ | 0: train on all data without splitting
163
+ | 1: split data into train and eval sets by designated eval_size
164
+ | 5: split data into 5 folds of train and eval sets by designated eval_size
165
+ eval_size : None, float
166
+ | Proportion of data to hold out for evaluation (e.g. 0.2 if intending 80:20 train/eval split)
167
+ stratify_splits_col : None, str
168
+ | Name of column in .dataset to be used for stratified splitting.
169
+ | Proportion of each class in this column will be the same in the splits as in the original dataset.
170
+ no_eval : bool
171
+ | If True, will skip eval step and use all data for training.
172
+ | Otherwise, will perform eval during training.
173
+ forward_batch_size : int
174
+ | Batch size for forward pass (for evaluation, not training).
175
+ nproc : int
176
+ | Number of CPU processes to use.
177
+
178
+ """
179
+
180
+ self.classifier = classifier
181
+ self.cell_state_dict = cell_state_dict
182
+ self.gene_class_dict = gene_class_dict
183
+ self.filter_data = filter_data
184
+ self.rare_threshold = rare_threshold
185
+ self.max_ncells = max_ncells
186
+ self.max_ncells_per_class = max_ncells_per_class
187
+ self.training_args = training_args
188
+ self.freeze_layers = freeze_layers
189
+ self.num_crossval_splits = num_crossval_splits
190
+ self.eval_size = eval_size
191
+ self.stratify_splits_col = stratify_splits_col
192
+ self.no_eval = no_eval
193
+ self.forward_batch_size = forward_batch_size
194
+ self.nproc = nproc
195
+
196
+ if self.training_args is None:
197
+ logger.warning(
198
+ "Hyperparameter tuning is highly recommended for optimal results. "
199
+ "No training_args provided; using default hyperparameters."
200
+ )
201
+
202
+ self.validate_options()
203
+
204
+ if self.filter_data is None:
205
+ self.filter_data = dict()
206
+
207
+ if self.classifier == "cell":
208
+ if self.cell_state_dict["states"] != "all":
209
+ self.filter_data[
210
+ self.cell_state_dict["state_key"]
211
+ ] = self.cell_state_dict["states"]
212
+
213
+ # load token dictionary (Ensembl IDs:token)
214
+ with open(TOKEN_DICTIONARY_FILE, "rb") as f:
215
+ self.gene_token_dict = pickle.load(f)
216
+
217
+ self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
218
+
219
+ # filter genes for gene classification for those in token dictionary
220
+ if self.classifier == "gene":
221
+ all_gene_class_values = set(pu.flatten_list(self.gene_class_dict.values()))
222
+ missing_genes = [
223
+ gene
224
+ for gene in all_gene_class_values
225
+ if gene not in self.gene_token_dict.keys()
226
+ ]
227
+ if len(missing_genes) == len(all_gene_class_values):
228
+ logger.error(
229
+ "None of the provided genes to classify are in token dictionary."
230
+ )
231
+ raise
232
+ elif len(missing_genes) > 0:
233
+ logger.warning(
234
+ f"Genes to classify {missing_genes} are not in token dictionary."
235
+ )
236
+ self.gene_class_dict = {
237
+ k: set([self.gene_token_dict.get(gene) for gene in v])
238
+ for k, v in self.gene_class_dict.items()
239
+ }
240
+ empty_classes = []
241
+ for k, v in self.gene_class_dict.items():
242
+ if len(v) == 0:
243
+ empty_classes += [k]
244
+ if len(empty_classes) > 0:
245
+ logger.error(
246
+ f"Class(es) {empty_classes} did not contain any genes in the token dictionary."
247
+ )
248
+ raise
249
+
250
+ def validate_options(self):
251
+ # confirm arguments are within valid options and compatible with each other
252
+ for attr_name, valid_options in self.valid_option_dict.items():
253
+ attr_value = self.__dict__[attr_name]
254
+ if not isinstance(attr_value, (list, dict)):
255
+ if attr_value in valid_options:
256
+ continue
257
+ valid_type = False
258
+ for option in valid_options:
259
+ if (option in [int, float, list, dict, bool]) and isinstance(
260
+ attr_value, option
261
+ ):
262
+ valid_type = True
263
+ break
264
+ if valid_type:
265
+ continue
266
+ logger.error(
267
+ f"Invalid option for {attr_name}. "
268
+ f"Valid options for {attr_name}: {valid_options}"
269
+ )
270
+ raise
271
+
272
+ if self.filter_data is not None:
273
+ for key, value in self.filter_data.items():
274
+ if not isinstance(value, list):
275
+ self.filter_data[key] = [value]
276
+ logger.warning(
277
+ "Values in filter_data dict must be lists. "
278
+ f"Changing {key} value to list ([{value}])."
279
+ )
280
+
281
+ if self.classifier == "cell":
282
+ if set(self.cell_state_dict.keys()) != set(["state_key", "states"]):
283
+ logger.error(
284
+ "Invalid keys for cell_state_dict. "
285
+ "The cell_state_dict should have only 2 keys: state_key and states"
286
+ )
287
+ raise
288
+
289
+ if self.cell_state_dict["states"] != "all":
290
+ if not isinstance(self.cell_state_dict["states"], list):
291
+ logger.error(
292
+ "States in cell_state_dict should be list of states to model."
293
+ )
294
+ raise
295
+ if len(self.cell_state_dict["states"]) < 2:
296
+ logger.error(
297
+ "States in cell_state_dict should contain at least 2 states to classify."
298
+ )
299
+ raise
300
+
301
+ if self.classifier == "gene":
302
+ if len(self.gene_class_dict.keys()) < 2:
303
+ logger.error(
304
+ "Gene_class_dict should contain at least 2 gene classes to classify."
305
+ )
306
+ raise
307
+
308
+ def prepare_data(
309
+ self,
310
+ input_data_file,
311
+ output_directory,
312
+ output_prefix,
313
+ split_id_dict=None,
314
+ test_size=0,
315
+ attr_to_split=None,
316
+ attr_to_balance=None,
317
+ max_trials=100,
318
+ pval_threshold=0.1,
319
+ ):
320
+ """
321
+ Prepare data for cell state or gene classification.
322
+
323
+ **Parameters**
324
+
325
+ input_data_file : Path
326
+ | Path to directory containing .dataset input
327
+ output_directory : Path
328
+ | Path to directory where prepared data will be saved
329
+ output_prefix : str
330
+ | Prefix for output file
331
+ split_id_dict : None, dict
332
+ | Dictionary of IDs for train and test splits
333
+ | Three-item dictionary with keys: attr_key, train, test
334
+ | attr_key: key specifying name of column in .dataset that contains the IDs for the data splits
335
+ | train: list of IDs in the attr_key column to include in the train split
336
+ | test: list of IDs in the attr_key column to include in the test split
337
+ | For example: {"attr_key": "individual",
338
+ | "train": ["patient1", "patient2", "patient3", "patient4"],
339
+ | "test": ["patient5", "patient6"]}
340
+ test_size : None, float
341
+ | Proportion of data to be saved separately and held out for test set
342
+ | (e.g. 0.2 if intending hold out 20%)
343
+ | The training set will be further split to train / validation in self.validate
344
+ | Note: only available for CellClassifiers
345
+ attr_to_split : None, str
346
+ | Key for attribute on which to split data while balancing potential confounders
347
+ | e.g. "patient_id" for splitting by patient while balancing other characteristics
348
+ | Note: only available for CellClassifiers
349
+ attr_to_balance : None, list
350
+ | List of attribute keys on which to balance data while splitting on attr_to_split
351
+ | e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
352
+ | Note: only available for CellClassifiers
353
+ max_trials : None, int
354
+ | Maximum number of trials of random splitting to try to achieve balanced other attributes
355
+ | If no split is found without significant (p<0.05) differences in other attributes, will select best
356
+ | Note: only available for CellClassifiers
357
+ pval_threshold : None, float
358
+ | P-value threshold to use for attribute balancing across splits
359
+ | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
360
+ """
361
+
362
+ # prepare data and labels for classification
363
+ data = pu.load_and_filter(self.filter_data, self.nproc, input_data_file)
364
+
365
+ if self.classifier == "cell":
366
+ if "label" in data.features:
367
+ logger.error(
368
+ "Column name 'label' must be reserved for class IDs. Please rename column."
369
+ )
370
+ raise
371
+ elif self.classifier == "gene":
372
+ if "labels" in data.features:
373
+ logger.error(
374
+ "Column name 'labels' must be reserved for class IDs. Please rename column."
375
+ )
376
+ raise
377
+
378
+ if self.classifier == "cell":
379
+ # remove cell states representing < rare_threshold of cells
380
+ data = cu.remove_rare(
381
+ data, self.rare_threshold, self.cell_state_dict["state_key"], self.nproc
382
+ )
383
+ # downsample max cells and max per class
384
+ data = cu.downsample_and_shuffle(
385
+ data, self.max_ncells, self.max_ncells_per_class, self.cell_state_dict
386
+ )
387
+ # rename cell state column to "label"
388
+ data = cu.rename_cols(data, self.cell_state_dict["state_key"])
389
+
390
+ # convert classes to numerical labels and save as id_class_dict
391
+ # of note, will label all genes in gene_class_dict
392
+ # if (cross-)validating, genes will be relabeled in column "labels" for each split
393
+ # at the time of training with Classifier.validate
394
+ data, id_class_dict = cu.label_classes(
395
+ self.classifier, data, self.gene_class_dict, self.nproc
396
+ )
397
+
398
+ # save id_class_dict for future reference
399
+ id_class_output_path = (
400
+ Path(output_directory) / f"{output_prefix}_id_class_dict"
401
+ ).with_suffix(".pkl")
402
+ with open(id_class_output_path, "wb") as f:
403
+ pickle.dump(id_class_dict, f)
404
+
405
+ if split_id_dict is not None:
406
+ data_dict = dict()
407
+ data_dict["train"] = pu.filter_by_dict(
408
+ data, {split_id_dict["attr_key"]: split_id_dict["train"]}, self.nproc
409
+ )
410
+ data_dict["test"] = pu.filter_by_dict(
411
+ data, {split_id_dict["attr_key"]: split_id_dict["test"]}, self.nproc
412
+ )
413
+ train_data_output_path = (
414
+ Path(output_directory) / f"{output_prefix}_labeled_train"
415
+ ).with_suffix(".dataset")
416
+ test_data_output_path = (
417
+ Path(output_directory) / f"{output_prefix}_labeled_test"
418
+ ).with_suffix(".dataset")
419
+ data_dict["train"].save_to_disk(train_data_output_path)
420
+ data_dict["test"].save_to_disk(test_data_output_path)
421
+ elif (test_size is not None) and (self.classifier == "cell"):
422
+ if 1 > test_size > 0:
423
+ data_dict, balance_df = cu.balance_attr_splits(
424
+ data,
425
+ attr_to_split,
426
+ attr_to_balance,
427
+ test_size,
428
+ max_trials,
429
+ pval_threshold,
430
+ self.cell_state_dict["state_key"],
431
+ self.nproc,
432
+ )
433
+ balance_df.to_csv(
434
+ f"{output_directory}/{output_prefix}_train_test_balance_df.csv"
435
+ )
436
+ train_data_output_path = (
437
+ Path(output_directory) / f"{output_prefix}_labeled_train"
438
+ ).with_suffix(".dataset")
439
+ test_data_output_path = (
440
+ Path(output_directory) / f"{output_prefix}_labeled_test"
441
+ ).with_suffix(".dataset")
442
+ data_dict["train"].save_to_disk(train_data_output_path)
443
+ data_dict["test"].save_to_disk(test_data_output_path)
444
+ else:
445
+ data_output_path = (
446
+ Path(output_directory) / f"{output_prefix}_labeled"
447
+ ).with_suffix(".dataset")
448
+ data.save_to_disk(data_output_path)
449
+
450
+ def train_all_data(
451
+ self,
452
+ model_directory,
453
+ prepared_input_data_file,
454
+ id_class_dict_file,
455
+ output_directory,
456
+ output_prefix,
457
+ save_eval_output=True,
458
+ ):
459
+ """
460
+ Train cell state or gene classifier using all data.
461
+
462
+ **Parameters**
463
+
464
+ model_directory : Path
465
+ | Path to directory containing model
466
+ prepared_input_data_file : Path
467
+ | Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data
468
+ id_class_dict_file : Path
469
+ | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
470
+ | (dictionary of format: numerical IDs: class_labels)
471
+ output_directory : Path
472
+ | Path to directory where model and eval data will be saved
473
+ output_prefix : str
474
+ | Prefix for output files
475
+ save_eval_output : bool
476
+ | Whether to save cross-fold eval output
477
+ | Saves as pickle file of dictionary of eval metrics
478
+
479
+ **Output**
480
+
481
+ Returns trainer after fine-tuning with all data.
482
+
483
+ """
484
+
485
+ ##### Load data and prepare output directory #####
486
+ # load numerical id to class dictionary (id:class)
487
+ with open(id_class_dict_file, "rb") as f:
488
+ id_class_dict = pickle.load(f)
489
+ class_id_dict = {v: k for k, v in id_class_dict.items()}
490
+
491
+ # load previously filtered and prepared data
492
+ data = pu.load_and_filter(None, self.nproc, prepared_input_data_file)
493
+ data = data.shuffle(seed=42) # reshuffle in case users provide unshuffled data
494
+
495
+ # define output directory path
496
+ current_date = datetime.datetime.now()
497
+ datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
498
+ if output_directory[-1:] != "/": # add slash for dir if not present
499
+ output_directory = output_directory + "/"
500
+ output_dir = f"{output_directory}{datestamp}_geneformer_{self.classifier}Classifier_{output_prefix}/"
501
+ subprocess.call(f"mkdir {output_dir}", shell=True)
502
+
503
+ # get number of classes for classifier
504
+ num_classes = cu.get_num_classes(id_class_dict)
505
+
506
+ if self.classifier == "gene":
507
+ targets = pu.flatten_list(self.gene_class_dict.values())
508
+ labels = pu.flatten_list(
509
+ [
510
+ [class_id_dict[label]] * len(targets)
511
+ for label, targets in self.gene_class_dict.items()
512
+ ]
513
+ )
514
+ assert len(targets) == len(labels)
515
+ data = cu.prep_gene_classifier_all_data(
516
+ data, targets, labels, self.max_ncells, self.nproc
517
+ )
518
+
519
+ trainer = self.train_classifier(
520
+ model_directory, num_classes, data, None, output_dir
521
+ )
522
+
523
+ return trainer
524
+
525
+ def validate(
526
+ self,
527
+ model_directory,
528
+ prepared_input_data_file,
529
+ id_class_dict_file,
530
+ output_directory,
531
+ output_prefix,
532
+ split_id_dict=None,
533
+ attr_to_split=None,
534
+ attr_to_balance=None,
535
+ max_trials=100,
536
+ pval_threshold=0.1,
537
+ save_eval_output=True,
538
+ predict_eval=True,
539
+ predict_trainer=False,
540
+ ):
541
+ """
542
+ (Cross-)validate cell state or gene classifier.
543
+
544
+ **Parameters**
545
+
546
+ model_directory : Path
547
+ | Path to directory containing model
548
+ prepared_input_data_file : Path
549
+ | Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data
550
+ id_class_dict_file : Path
551
+ | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
552
+ | (dictionary of format: numerical IDs: class_labels)
553
+ output_directory : Path
554
+ | Path to directory where model and eval data will be saved
555
+ output_prefix : str
556
+ | Prefix for output files
557
+ split_id_dict : None, dict
558
+ | Dictionary of IDs for train and eval splits
559
+ | Three-item dictionary with keys: attr_key, train, eval
560
+ | attr_key: key specifying name of column in .dataset that contains the IDs for the data splits
561
+ | train: list of IDs in the attr_key column to include in the train split
562
+ | eval: list of IDs in the attr_key column to include in the eval split
563
+ | For example: {"attr_key": "individual",
564
+ | "train": ["patient1", "patient2", "patient3", "patient4"],
565
+ | "eval": ["patient5", "patient6"]}
566
+ | Note: only available for CellClassifiers with 1-fold split (self.classifier="cell"; self.num_crossval_splits=1)
567
+ attr_to_split : None, str
568
+ | Key for attribute on which to split data while balancing potential confounders
569
+ | e.g. "patient_id" for splitting by patient while balancing other characteristics
570
+ | Note: only available for CellClassifiers with 1-fold split (self.classifier="cell"; self.num_crossval_splits=1)
571
+ attr_to_balance : None, list
572
+ | List of attribute keys on which to balance data while splitting on attr_to_split
573
+ | e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
574
+ max_trials : None, int
575
+ | Maximum number of trials of random splitting to try to achieve balanced other attribute
576
+ | If no split is found without significant (p < pval_threshold) differences in other attributes, will select best
577
+ pval_threshold : None, float
578
+ | P-value threshold to use for attribute balancing across splits
579
+ | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
580
+ save_eval_output : bool
581
+ | Whether to save cross-fold eval output
582
+ | Saves as pickle file of dictionary of eval metrics
583
+ predict_eval : bool
584
+ | Whether or not to save eval predictions
585
+ | Saves as a pickle file of self.evaluate predictions
586
+ predict_trainer : bool
587
+ | Whether or not to save eval predictions from trainer
588
+ | Saves as a pickle file of trainer predictions
589
+ """
590
+
591
+ if self.num_crossval_splits == 0:
592
+ logger.error("num_crossval_splits must be 1 or 5 to validate.")
593
+ raise
594
+
595
+ # ensure number of genes in each class is > 5 if validating model
596
+ if self.classifier == "gene":
597
+ insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5]
598
+ if (self.num_crossval_splits > 0) and (len(insuff_classes) > 0):
599
+ logger.error(
600
+ f"Insufficient # of members in class(es) {insuff_classes} to (cross-)validate."
601
+ )
602
+ raise
603
+
604
+ ##### Load data and prepare output directory #####
605
+ # load numerical id to class dictionary (id:class)
606
+ with open(id_class_dict_file, "rb") as f:
607
+ id_class_dict = pickle.load(f)
608
+ class_id_dict = {v: k for k, v in id_class_dict.items()}
609
+
610
+ # load previously filtered and prepared data
611
+ data = pu.load_and_filter(None, self.nproc, prepared_input_data_file)
612
+ data = data.shuffle(seed=42) # reshuffle in case users provide unshuffled data
613
+
614
+ # define output directory path
615
+ current_date = datetime.datetime.now()
616
+ datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
617
+ if output_directory[-1:] != "/": # add slash for dir if not present
618
+ output_directory = output_directory + "/"
619
+ output_dir = f"{output_directory}{datestamp}_geneformer_{self.classifier}Classifier_{output_prefix}/"
620
+ subprocess.call(f"mkdir {output_dir}", shell=True)
621
+
622
+ # get number of classes for classifier
623
+ num_classes = cu.get_num_classes(id_class_dict)
624
+
625
+ ##### (Cross-)validate the model #####
626
+ results = []
627
+ all_conf_mat = np.zeros((num_classes, num_classes))
628
+ iteration_num = 1
629
+ if self.classifier == "cell":
630
+ for i in trange(self.num_crossval_splits):
631
+ print(
632
+ f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
633
+ )
634
+ ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
635
+ if self.num_crossval_splits == 1:
636
+ # single 1-eval_size:eval_size split
637
+ if split_id_dict is not None:
638
+ data_dict = dict()
639
+ data_dict["train"] = pu.filter_by_dict(
640
+ data,
641
+ {split_id_dict["attr_key"]: split_id_dict["train"]},
642
+ self.nproc,
643
+ )
644
+ data_dict["test"] = pu.filter_by_dict(
645
+ data,
646
+ {split_id_dict["attr_key"]: split_id_dict["eval"]},
647
+ self.nproc,
648
+ )
649
+ elif attr_to_split is not None:
650
+ data_dict, balance_df = cu.balance_attr_splits(
651
+ data,
652
+ attr_to_split,
653
+ attr_to_balance,
654
+ self.eval_size,
655
+ max_trials,
656
+ pval_threshold,
657
+ self.cell_state_dict["state_key"],
658
+ self.nproc,
659
+ )
660
+
661
+ balance_df.to_csv(
662
+ f"{output_dir}/{output_prefix}_train_valid_balance_df.csv"
663
+ )
664
+ else:
665
+ data_dict = data.train_test_split(
666
+ test_size=self.eval_size,
667
+ stratify_by_column=self.stratify_splits_col,
668
+ seed=42,
669
+ )
670
+ train_data = data_dict["train"]
671
+ eval_data = data_dict["test"]
672
+ else:
673
+ # 5-fold cross-validate
674
+ num_cells = len(data)
675
+ fifth_cells = num_cells * 0.2
676
+ num_eval = min((self.eval_size * num_cells), fifth_cells)
677
+ start = i * fifth_cells
678
+ end = start + num_eval
679
+ eval_indices = [j for j in range(start, end)]
680
+ train_indices = [
681
+ j for j in range(num_cells) if j not in eval_indices
682
+ ]
683
+ eval_data = data.select(eval_indices)
684
+ train_data = data.select(train_indices)
685
+ trainer = self.train_classifier(
686
+ model_directory,
687
+ num_classes,
688
+ train_data,
689
+ eval_data,
690
+ ksplit_output_dir,
691
+ predict_trainer,
692
+ )
693
+ result = self.evaluate_model(
694
+ trainer.model,
695
+ num_classes,
696
+ id_class_dict,
697
+ eval_data,
698
+ predict_eval,
699
+ ksplit_output_dir,
700
+ output_prefix,
701
+ )
702
+ results += [result]
703
+ all_conf_mat = all_conf_mat + result["conf_mat"]
704
+ iteration_num = iteration_num + 1
705
+
706
+ elif self.classifier == "gene":
707
+ # set up (cross-)validation splits
708
+ targets = pu.flatten_list(self.gene_class_dict.values())
709
+ labels = pu.flatten_list(
710
+ [
711
+ [class_id_dict[label]] * len(targets)
712
+ for label, targets in self.gene_class_dict.items()
713
+ ]
714
+ )
715
+ assert len(targets) == len(labels)
716
+ n_splits = int(1 / self.eval_size)
717
+ skf = StratifiedKFold(n_splits=n_splits, random_state=0, shuffle=True)
718
+ # (Cross-)validate
719
+ for train_index, eval_index in tqdm(skf.split(targets, labels)):
720
+ print(
721
+ f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
722
+ )
723
+ ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
724
+ # filter data for examples containing classes for this split
725
+ # subsample to max_ncells and relabel data in column "labels"
726
+ train_data, eval_data = cu.prep_gene_classifier_split(
727
+ data,
728
+ targets,
729
+ labels,
730
+ train_index,
731
+ eval_index,
732
+ self.max_ncells,
733
+ iteration_num,
734
+ self.nproc,
735
+ )
736
+
737
+ trainer = self.train_classifier(
738
+ model_directory,
739
+ num_classes,
740
+ train_data,
741
+ eval_data,
742
+ ksplit_output_dir,
743
+ predict_trainer,
744
+ )
745
+ result = self.evaluate_model(
746
+ trainer.model,
747
+ num_classes,
748
+ id_class_dict,
749
+ eval_data,
750
+ predict_eval,
751
+ ksplit_output_dir,
752
+ output_prefix,
753
+ )
754
+ results += [result]
755
+ all_conf_mat = all_conf_mat + result["conf_mat"]
756
+ # break after 1 or 5 splits, each with train/eval proportions dictated by eval_size
757
+ if iteration_num == self.num_crossval_splits:
758
+ break
759
+ iteration_num = iteration_num + 1
760
+
761
+ all_conf_mat_df = pd.DataFrame(
762
+ all_conf_mat, columns=id_class_dict.values(), index=id_class_dict.values()
763
+ )
764
+ all_metrics = {
765
+ "conf_matrix": all_conf_mat_df,
766
+ "macro_f1": [result["macro_f1"] for result in results],
767
+ "acc": [result["acc"] for result in results],
768
+ }
769
+ all_roc_metrics = None # roc metrics not reported for multiclass
770
+ if num_classes == 2:
771
+ mean_fpr = np.linspace(0, 1, 100)
772
+ all_tpr = [result["roc_metrics"]["interp_tpr"] for result in results]
773
+ all_roc_auc = [result["roc_metrics"]["auc"] for result in results]
774
+ all_tpr_wt = [result["roc_metrics"]["tpr_wt"] for result in results]
775
+ mean_tpr, roc_auc, roc_auc_sd = eu.get_cross_valid_roc_metrics(
776
+ all_tpr, all_roc_auc, all_tpr_wt
777
+ )
778
+ all_roc_metrics = {
779
+ "mean_tpr": mean_tpr,
780
+ "mean_fpr": mean_fpr,
781
+ "all_roc_auc": all_roc_auc,
782
+ "roc_auc": roc_auc,
783
+ "roc_auc_sd": roc_auc_sd,
784
+ }
785
+ all_metrics["all_roc_metrics"] = all_roc_metrics
786
+ if save_eval_output is True:
787
+ eval_metrics_output_path = (
788
+ Path(output_dir) / f"{output_prefix}_eval_metrics_dict"
789
+ ).with_suffix(".pkl")
790
+ with open(eval_metrics_output_path, "wb") as f:
791
+ pickle.dump(all_metrics, f)
792
+
793
+ return all_metrics
794
+
795
+ def train_classifier(
796
+ self,
797
+ model_directory,
798
+ num_classes,
799
+ train_data,
800
+ eval_data,
801
+ output_directory,
802
+ predict=False,
803
+ ):
804
+ """
805
+ Fine-tune model for cell state or gene classification.
806
+
807
+ **Parameters**
808
+
809
+ model_directory : Path
810
+ | Path to directory containing model
811
+ num_classes : int
812
+ | Number of classes for classifier
813
+ train_data : Dataset
814
+ | Loaded training .dataset input
815
+ | For cell classifier, labels in column "label".
816
+ | For gene classifier, labels in column "labels".
817
+ eval_data : None, Dataset
818
+ | (Optional) Loaded evaluation .dataset input
819
+ | For cell classifier, labels in column "label".
820
+ | For gene classifier, labels in column "labels".
821
+ output_directory : Path
822
+ | Path to directory where fine-tuned model will be saved
823
+ predict : bool
824
+ | Whether or not to save eval predictions from trainer
825
+ """
826
+
827
+ ##### Validate and prepare data #####
828
+ train_data, eval_data = cu.validate_and_clean_cols(
829
+ train_data, eval_data, self.classifier
830
+ )
831
+
832
+ if (self.no_eval is True) and (eval_data is not None):
833
+ logger.warning(
834
+ "no_eval set to True; model will be trained without evaluation."
835
+ )
836
+ eval_data = None
837
+
838
+ if (self.classifier == "gene") and (predict is True):
839
+ logger.warning(
840
+ "Predictions during training not currently available for gene classifiers; setting predict to False."
841
+ )
842
+ predict = False
843
+
844
+ # ensure not overwriting previously saved model
845
+ saved_model_test = os.path.join(output_directory, "pytorch_model.bin")
846
+ if os.path.isfile(saved_model_test) is True:
847
+ logger.error("Model already saved to this designated output directory.")
848
+ raise
849
+ # make output directory
850
+ subprocess.call(f"mkdir {output_directory}", shell=True)
851
+
852
+ ##### Load model and training args #####
853
+ if self.classifier == "cell":
854
+ model_type = "CellClassifier"
855
+ elif self.classifier == "gene":
856
+ model_type = "GeneClassifier"
857
+ model = pu.load_model(model_type, num_classes, model_directory, "train")
858
+
859
+ def_training_args, def_freeze_layers = cu.get_default_train_args(
860
+ model, self.classifier, train_data, output_directory
861
+ )
862
+
863
+ if self.training_args is not None:
864
+ def_training_args.update(self.training_args)
865
+ logging_steps = round(
866
+ len(train_data) / def_training_args["per_device_train_batch_size"] / 10
867
+ )
868
+ def_training_args["logging_steps"] = logging_steps
869
+ def_training_args["output_dir"] = output_directory
870
+ if eval_data is None:
871
+ def_training_args["evaluation_strategy"] = "no"
872
+ def_training_args["load_best_model_at_end"] = False
873
+ training_args_init = TrainingArguments(**def_training_args)
874
+
875
+ if self.freeze_layers is not None:
876
+ def_freeze_layers = self.freeze_layers
877
+
878
+ if def_freeze_layers > 0:
879
+ modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]
880
+ for module in modules_to_freeze:
881
+ for param in module.parameters():
882
+ param.requires_grad = False
883
+
884
+ ##### Fine-tune the model #####
885
+ # define the data collator
886
+ if self.classifier == "cell":
887
+ data_collator = DataCollatorForCellClassification()
888
+ elif self.classifier == "gene":
889
+ data_collator = DataCollatorForGeneClassification()
890
+
891
+ # create the trainer
892
+ trainer = Trainer(
893
+ model=model,
894
+ args=training_args_init,
895
+ data_collator=data_collator,
896
+ train_dataset=train_data,
897
+ eval_dataset=eval_data,
898
+ compute_metrics=cu.compute_metrics,
899
+ )
900
+
901
+ # train the classifier
902
+ trainer.train()
903
+ trainer.save_model(output_directory)
904
+ if predict is True:
905
+ # make eval predictions and save predictions and metrics
906
+ predictions = trainer.predict(eval_data)
907
+ prediction_output_path = f"{output_directory}/predictions.pkl"
908
+ with open(prediction_output_path, "wb") as f:
909
+ pickle.dump(predictions, f)
910
+ trainer.save_metrics("eval", predictions.metrics)
911
+ return trainer
912
+
913
+ def evaluate_model(
914
+ self,
915
+ model,
916
+ num_classes,
917
+ id_class_dict,
918
+ eval_data,
919
+ predict=False,
920
+ output_directory=None,
921
+ output_prefix=None,
922
+ ):
923
+ """
924
+ Evaluate the fine-tuned model.
925
+
926
+ **Parameters**
927
+
928
+ model : nn.Module
929
+ | Loaded fine-tuned model (e.g. trainer.model)
930
+ num_classes : int
931
+ | Number of classes for classifier
932
+ id_class_dict : dict
933
+ | Loaded _id_class_dict.pkl previously prepared by Classifier.prepare_data
934
+ | (dictionary of format: numerical IDs: class_labels)
935
+ eval_data : Dataset
936
+ | Loaded evaluation .dataset input
937
+ predict : bool
938
+ | Whether or not to save eval predictions
939
+ output_directory : Path
940
+ | Path to directory where eval data will be saved
941
+ output_prefix : str
942
+ | Prefix for output files
943
+ """
944
+
945
+ ##### Evaluate the model #####
946
+ labels = id_class_dict.keys()
947
+ y_pred, y_true, logits_list = eu.classifier_predict(
948
+ model, self.classifier, eval_data, self.forward_batch_size
949
+ )
950
+ conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
951
+ y_pred, y_true, logits_list, num_classes, labels
952
+ )
953
+ if predict is True:
954
+ pred_dict = {
955
+ "pred_ids": y_pred,
956
+ "label_ids": y_true,
957
+ "predictions": logits_list,
958
+ }
959
+ pred_dict_output_path = (
960
+ Path(output_directory) / f"{output_prefix}_pred_dict"
961
+ ).with_suffix(".pkl")
962
+ with open(pred_dict_output_path, "wb") as f:
963
+ pickle.dump(pred_dict, f)
964
+ return {
965
+ "conf_mat": conf_mat,
966
+ "macro_f1": macro_f1,
967
+ "acc": acc,
968
+ "roc_metrics": roc_metrics,
969
+ }
970
+
971
+ def evaluate_saved_model(
972
+ self,
973
+ model_directory,
974
+ id_class_dict_file,
975
+ test_data_file,
976
+ output_directory,
977
+ output_prefix,
978
+ predict=True,
979
+ ):
980
+ """
981
+ Evaluate the fine-tuned model.
982
+
983
+ **Parameters**
984
+
985
+ model_directory : Path
986
+ | Path to directory containing model
987
+ id_class_dict_file : Path
988
+ | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
989
+ | (dictionary of format: numerical IDs: class_labels)
990
+ test_data_file : Path
991
+ | Path to directory containing test .dataset
992
+ output_directory : Path
993
+ | Path to directory where eval data will be saved
994
+ output_prefix : str
995
+ | Prefix for output files
996
+ predict : bool
997
+ | Whether or not to save eval predictions
998
+ """
999
+
1000
+ # load numerical id to class dictionary (id:class)
1001
+ with open(id_class_dict_file, "rb") as f:
1002
+ id_class_dict = pickle.load(f)
1003
+
1004
+ # get number of classes for classifier
1005
+ num_classes = cu.get_num_classes(id_class_dict)
1006
+
1007
+ # load previously filtered and prepared data
1008
+ test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1009
+
1010
+ # load previously fine-tuned model
1011
+ if self.classifier == "cell":
1012
+ model_type = "CellClassifier"
1013
+ elif self.classifier == "gene":
1014
+ model_type = "GeneClassifier"
1015
+ model = pu.load_model(model_type, num_classes, model_directory, "eval")
1016
+
1017
+ # evaluate the model
1018
+ results = self.evaluate_model(
1019
+ model,
1020
+ num_classes,
1021
+ id_class_dict,
1022
+ test_data,
1023
+ predict=predict,
1024
+ output_directory=output_directory,
1025
+ output_prefix=output_prefix,
1026
+ )
1027
+
1028
+ all_conf_mat_df = pd.DataFrame(
1029
+ results["conf_mat"],
1030
+ columns=id_class_dict.values(),
1031
+ index=id_class_dict.values(),
1032
+ )
1033
+ all_metrics = {
1034
+ "conf_matrix": all_conf_mat_df,
1035
+ "macro_f1": results["macro_f1"],
1036
+ "acc": results["acc"],
1037
+ }
1038
+ all_roc_metrics = None # roc metrics not reported for multiclass
1039
+ if num_classes == 2:
1040
+ mean_fpr = np.linspace(0, 1, 100)
1041
+ all_tpr = [result["roc_metrics"]["interp_tpr"] for result in results]
1042
+ all_roc_auc = [result["roc_metrics"]["auc"] for result in results]
1043
+ all_tpr_wt = [result["roc_metrics"]["tpr_wt"] for result in results]
1044
+ mean_tpr, roc_auc, roc_auc_sd = eu.get_cross_valid_roc_metrics(
1045
+ all_tpr, all_roc_auc, all_tpr_wt
1046
+ )
1047
+ all_roc_metrics = {
1048
+ "mean_tpr": mean_tpr,
1049
+ "mean_fpr": mean_fpr,
1050
+ "all_roc_auc": all_roc_auc,
1051
+ }
1052
+ all_metrics["all_roc_metrics"] = all_roc_metrics
1053
+ test_metrics_output_path = (
1054
+ Path(output_directory) / f"{output_prefix}_test_metrics_dict"
1055
+ ).with_suffix(".pkl")
1056
+ with open(test_metrics_output_path, "wb") as f:
1057
+ pickle.dump(all_metrics, f)
1058
+
1059
+ return all_metrics
1060
+
1061
+ def plot_conf_mat(
1062
+ self,
1063
+ conf_mat_dict,
1064
+ output_directory,
1065
+ output_prefix,
1066
+ custom_class_order=None,
1067
+ ):
1068
+ """
1069
+ Plot confusion matrix results of evaluating the fine-tuned model.
1070
+
1071
+ **Parameters**
1072
+
1073
+ conf_mat_dict : dict
1074
+ | Dictionary of model_name : confusion_matrix_DataFrame
1075
+ | (all_metrics["conf_matrix"] from self.validate)
1076
+ output_directory : Path
1077
+ | Path to directory where plots will be saved
1078
+ output_prefix : str
1079
+ | Prefix for output file
1080
+ custom_class_order : None, list
1081
+ | List of classes in custom order for plots.
1082
+ | Same order will be used for all models.
1083
+ """
1084
+
1085
+ for model_name in conf_mat_dict.keys():
1086
+ eu.plot_confusion_matrix(
1087
+ conf_mat_dict[model_name],
1088
+ model_name,
1089
+ output_directory,
1090
+ output_prefix,
1091
+ custom_class_order,
1092
+ )
1093
+
1094
+ def plot_roc(
1095
+ self,
1096
+ roc_metric_dict,
1097
+ model_style_dict,
1098
+ title,
1099
+ output_directory,
1100
+ output_prefix,
1101
+ ):
1102
+ """
1103
+ Plot ROC curve results of evaluating the fine-tuned model.
1104
+
1105
+ **Parameters**
1106
+
1107
+ roc_metric_dict : dict
1108
+ | Dictionary of model_name : roc_metrics
1109
+ | (all_metrics["all_roc_metrics"] from self.validate)
1110
+ model_style_dict : dict[dict]
1111
+ | Dictionary of model_name : dictionary of style_attribute : style
1112
+ | where style includes color and linestyle
1113
+ | e.g. {'Model_A': {'color': 'black', 'linestyle': '-'}, 'Model_B': ...}
1114
+ title : str
1115
+ | Title of plot (e.g. 'Dosage-sensitive vs -insensitive factors')
1116
+ output_directory : Path
1117
+ | Path to directory where plots will be saved
1118
+ output_prefix : str
1119
+ | Prefix for output file
1120
+ """
1121
+
1122
+ eu.plot_ROC(
1123
+ roc_metric_dict, model_style_dict, title, output_directory, output_prefix
1124
+ )
1125
+
1126
+ def plot_predictions(
1127
+ self,
1128
+ predictions_file,
1129
+ id_class_dict_file,
1130
+ title,
1131
+ output_directory,
1132
+ output_prefix,
1133
+ custom_class_order=None,
1134
+ kwargs_dict=None,
1135
+ ):
1136
+ """
1137
+ Plot prediction results of evaluating the fine-tuned model.
1138
+
1139
+ **Parameters**
1140
+
1141
+ predictions_file : path
1142
+ | Path of model predictions output to plot
1143
+ | (saved output from self.validate if predict=True)
1144
+ | (or saved output from self.evaluate_saved_model)
1145
+ id_class_dict_file : Path
1146
+ | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
1147
+ | (dictionary of format: numerical IDs: class_labels)
1148
+ title : str
1149
+ | Title for legend containing class labels.
1150
+ output_directory : Path
1151
+ | Path to directory where plots will be saved
1152
+ output_prefix : str
1153
+ | Prefix for output file
1154
+ custom_class_order : None, list
1155
+ | List of classes in custom order for plots.
1156
+ | Same order will be used for all models.
1157
+ kwargs_dict : None, dict
1158
+ | Dictionary of kwargs to pass to plotting function.
1159
+ """
1160
+ # load predictions
1161
+ with open(predictions_file, "rb") as f:
1162
+ predictions = pickle.load(f)
1163
+
1164
+ # load numerical id to class dictionary (id:class)
1165
+ with open(id_class_dict_file, "rb") as f:
1166
+ id_class_dict = pickle.load(f)
1167
+
1168
+ if isinstance(predictions, dict):
1169
+ if all(
1170
+ [
1171
+ key in predictions.keys()
1172
+ for key in ["pred_ids", "label_ids", "predictions"]
1173
+ ]
1174
+ ):
1175
+ # format is output from self.evaluate_saved_model
1176
+ predictions_logits = np.array(predictions["predictions"])
1177
+ true_ids = predictions["label_ids"]
1178
+ else:
1179
+ # format is output from self.validate if predict=True
1180
+ predictions_logits = predictions.predictions
1181
+ true_ids = predictions.label_ids
1182
+
1183
+ num_classes = len(id_class_dict.keys())
1184
+ num_predict_classes = predictions_logits.shape[1]
1185
+ assert num_classes == num_predict_classes
1186
+ classes = id_class_dict.values()
1187
+ true_labels = [id_class_dict[idx] for idx in true_ids]
1188
+ predictions_df = pd.DataFrame(predictions_logits, columns=classes)
1189
+ if custom_class_order is not None:
1190
+ predictions_df = predictions_df.reindex(columns=custom_class_order)
1191
+ predictions_df["true"] = true_labels
1192
+ custom_dict = dict(zip(classes, [i for i in range(len(classes))]))
1193
+ if custom_class_order is not None:
1194
+ custom_dict = dict(
1195
+ zip(custom_class_order, [i for i in range(len(custom_class_order))])
1196
+ )
1197
+ predictions_df = predictions_df.sort_values(
1198
+ by=["true"], key=lambda x: x.map(custom_dict)
1199
+ )
1200
+
1201
+ eu.plot_predictions(
1202
+ predictions_df, title, output_directory, output_prefix, kwargs_dict
1203
+ )
geneformer/classifier_utils.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ from collections import Counter, defaultdict
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from scipy.stats import chisquare, ranksums
8
+ from sklearn.metrics import accuracy_score, f1_score
9
+
10
+ from . import perturber_utils as pu
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def downsample_and_shuffle(data, max_ncells, max_ncells_per_class, cell_state_dict):
16
+ data = data.shuffle(seed=42)
17
+ num_cells = len(data)
18
+ # if max number of cells is defined, then subsample to this max number
19
+ if max_ncells is not None:
20
+ if num_cells > max_ncells:
21
+ data = data.select([i for i in range(max_ncells)])
22
+ if max_ncells_per_class is not None:
23
+ class_labels = data[cell_state_dict["state_key"]]
24
+ random.seed(42)
25
+ subsample_indices = subsample_by_class(class_labels, max_ncells_per_class)
26
+ data = data.select(subsample_indices)
27
+ return data
28
+
29
+
30
+ # subsample labels to maximum number N per class and return indices
31
+ def subsample_by_class(labels, N):
32
+ label_indices = defaultdict(list)
33
+ # Gather indices for each label
34
+ for idx, label in enumerate(labels):
35
+ label_indices[label].append(idx)
36
+ selected_indices = []
37
+ # Select up to N indices for each label
38
+ for label, indices in label_indices.items():
39
+ if len(indices) > N:
40
+ selected_indices.extend(random.sample(indices, N))
41
+ else:
42
+ selected_indices.extend(indices)
43
+ return selected_indices
44
+
45
+
46
+ def rename_cols(data, state_key):
47
+ data = data.rename_column(state_key, "label")
48
+ return data
49
+
50
+
51
+ def validate_and_clean_cols(train_data, eval_data, classifier):
52
+ # validate that data has expected label column and remove others
53
+ if classifier == "cell":
54
+ label_col = "label"
55
+ elif classifier == "gene":
56
+ label_col = "labels"
57
+
58
+ cols_to_keep = [label_col] + ["input_ids", "length"]
59
+ if label_col not in train_data.column_names:
60
+ logger.error(f"train_data must contain column {label_col} with class labels.")
61
+ raise
62
+ else:
63
+ train_data = remove_cols(train_data, cols_to_keep)
64
+
65
+ if eval_data is not None:
66
+ if label_col not in eval_data.column_names:
67
+ logger.error(
68
+ f"eval_data must contain column {label_col} with class labels."
69
+ )
70
+ raise
71
+ else:
72
+ eval_data = remove_cols(eval_data, cols_to_keep)
73
+ return train_data, eval_data
74
+
75
+
76
+ def remove_cols(data, cols_to_keep):
77
+ other_cols = list(data.features.keys())
78
+ other_cols = [ele for ele in other_cols if ele not in cols_to_keep]
79
+ data = data.remove_columns(other_cols)
80
+ return data
81
+
82
+
83
+ def remove_rare(data, rare_threshold, label, nproc):
84
+ if rare_threshold > 0:
85
+ total_cells = len(data)
86
+ label_counter = Counter(data[label])
87
+ nonrare_label_dict = {
88
+ label: [k for k, v in label_counter if (v / total_cells) > rare_threshold]
89
+ }
90
+ data = pu.filter_by_dict(data, nonrare_label_dict, nproc)
91
+ return data
92
+
93
+
94
+ def label_classes(classifier, data, gene_class_dict, nproc):
95
+ if classifier == "cell":
96
+ label_set = set(data["label"])
97
+ elif classifier == "gene":
98
+ # remove cells without any of the target genes
99
+ def if_contains_label(example):
100
+ a = pu.flatten_list(gene_class_dict.values())
101
+ b = example["input_ids"]
102
+ return not set(a).isdisjoint(b)
103
+
104
+ data = data.filter(if_contains_label, num_proc=nproc)
105
+ label_set = gene_class_dict.keys()
106
+
107
+ if len(data) == 0:
108
+ logger.error(
109
+ "No cells remain after filtering for target genes. Check target gene list."
110
+ )
111
+ raise
112
+
113
+ class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
114
+ id_class_dict = {v: k for k, v in class_id_dict.items()}
115
+
116
+ def classes_to_ids(example):
117
+ if classifier == "cell":
118
+ example["label"] = class_id_dict[example["label"]]
119
+ elif classifier == "gene":
120
+ example["labels"] = label_gene_classes(
121
+ example, class_id_dict, gene_class_dict
122
+ )
123
+ return example
124
+
125
+ data = data.map(classes_to_ids, num_proc=nproc)
126
+ return data, id_class_dict
127
+
128
+
129
+ def label_gene_classes(example, class_id_dict, gene_class_dict):
130
+ return [
131
+ class_id_dict.get(gene_class_dict.get(token_id, -100), -100)
132
+ for token_id in example["input_ids"]
133
+ ]
134
+
135
+
136
+ def prep_gene_classifier_split(
137
+ data, targets, labels, train_index, eval_index, max_ncells, iteration_num, num_proc
138
+ ):
139
+ # generate cross-validation splits
140
+ targets = np.array(targets)
141
+ labels = np.array(labels)
142
+ targets_train, targets_eval = targets[train_index], targets[eval_index]
143
+ labels_train, labels_eval = labels[train_index], labels[eval_index]
144
+ label_dict_train = dict(zip(targets_train, labels_train))
145
+ label_dict_eval = dict(zip(targets_eval, labels_eval))
146
+
147
+ # function to filter by whether contains train or eval labels
148
+ def if_contains_train_label(example):
149
+ a = targets_train
150
+ b = example["input_ids"]
151
+ return not set(a).isdisjoint(b)
152
+
153
+ def if_contains_eval_label(example):
154
+ a = targets_eval
155
+ b = example["input_ids"]
156
+ return not set(a).isdisjoint(b)
157
+
158
+ # filter dataset for examples containing classes for this split
159
+ logger.info(f"Filtering training data for genes in split {iteration_num}")
160
+ train_data = data.filter(if_contains_train_label, num_proc=num_proc)
161
+ logger.info(
162
+ f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n"
163
+ )
164
+ logger.info(f"Filtering evalation data for genes in split {iteration_num}")
165
+ eval_data = data.filter(if_contains_eval_label, num_proc=num_proc)
166
+ logger.info(
167
+ f"Filtered {round((1-len(eval_data)/len(data))*100)}%; {len(eval_data)} remain\n"
168
+ )
169
+
170
+ # subsample to max_ncells
171
+ train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
172
+ eval_data = downsample_and_shuffle(eval_data, max_ncells, None, None)
173
+
174
+ # relabel genes for this split
175
+ def train_classes_to_ids(example):
176
+ example["labels"] = [
177
+ label_dict_train.get(token_id, -100) for token_id in example["input_ids"]
178
+ ]
179
+ return example
180
+
181
+ def eval_classes_to_ids(example):
182
+ example["labels"] = [
183
+ label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]
184
+ ]
185
+ return example
186
+
187
+ train_data = train_data.map(train_classes_to_ids, num_proc=num_proc)
188
+ eval_data = eval_data.map(eval_classes_to_ids, num_proc=num_proc)
189
+
190
+ return train_data, eval_data
191
+
192
+
193
+ def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc):
194
+ targets = np.array(targets)
195
+ labels = np.array(labels)
196
+ label_dict_train = dict(zip(targets, labels))
197
+
198
+ # function to filter by whether contains train labels
199
+ def if_contains_train_label(example):
200
+ a = targets
201
+ b = example["input_ids"]
202
+ return not set(a).isdisjoint(b)
203
+
204
+ # filter dataset for examples containing classes for this split
205
+ logger.info("Filtering training data for genes to classify.")
206
+ train_data = data.filter(if_contains_train_label, num_proc=num_proc)
207
+ logger.info(
208
+ f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n"
209
+ )
210
+
211
+ # subsample to max_ncells
212
+ train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
213
+
214
+ # relabel genes for this split
215
+ def train_classes_to_ids(example):
216
+ example["labels"] = [
217
+ label_dict_train.get(token_id, -100) for token_id in example["input_ids"]
218
+ ]
219
+ return example
220
+
221
+ train_data = train_data.map(train_classes_to_ids, num_proc=num_proc)
222
+
223
+ return train_data
224
+
225
+
226
+ def balance_attr_splits(
227
+ data,
228
+ attr_to_split,
229
+ attr_to_balance,
230
+ eval_size,
231
+ max_trials,
232
+ pval_threshold,
233
+ state_key,
234
+ nproc,
235
+ ):
236
+ metadata_df = pd.DataFrame({"split_attr_ids": data[attr_to_split]})
237
+ for attr in attr_to_balance:
238
+ if attr == state_key:
239
+ metadata_df[attr] = data["label"]
240
+ else:
241
+ metadata_df[attr] = data[attr]
242
+ metadata_df = metadata_df.drop_duplicates()
243
+
244
+ split_attr_ids = list(metadata_df["split_attr_ids"])
245
+ assert len(split_attr_ids) == len(set(split_attr_ids))
246
+ eval_num = round(len(split_attr_ids) * eval_size)
247
+ colnames = (
248
+ ["trial_num", "train_ids", "eval_ids"]
249
+ + pu.flatten_list(
250
+ [
251
+ [
252
+ f"{attr}_train_mean_or_counts",
253
+ f"{attr}_eval_mean_or_counts",
254
+ f"{attr}_pval",
255
+ ]
256
+ for attr in attr_to_balance
257
+ ]
258
+ )
259
+ + ["mean_pval"]
260
+ )
261
+ balance_df = pd.DataFrame(columns=colnames)
262
+ data_dict = dict()
263
+ trial_num = 1
264
+ for i in range(max_trials):
265
+ if not all(
266
+ count > 1 for count in list(Counter(metadata_df[state_key]).values())
267
+ ):
268
+ logger.error(
269
+ f"Cannot balance by {attr_to_split} while retaining at least 1 occurrence of each {state_key} class in both data splits. "
270
+ )
271
+ raise
272
+ eval_base = []
273
+ for state in set(metadata_df[state_key]):
274
+ eval_base += list(
275
+ metadata_df.loc[
276
+ metadata_df[state_key][metadata_df[state_key].eq(state)]
277
+ .sample(1, random_state=i)
278
+ .index
279
+ ]["split_attr_ids"]
280
+ )
281
+ non_eval_base = [idx for idx in split_attr_ids if idx not in eval_base]
282
+ random.seed(i)
283
+ eval_ids = random.sample(non_eval_base, eval_num - len(eval_base)) + eval_base
284
+ train_ids = [idx for idx in split_attr_ids if idx not in eval_ids]
285
+ df_vals = [trial_num, train_ids, eval_ids]
286
+ pvals = []
287
+ for attr in attr_to_balance:
288
+ train_attr = list(
289
+ metadata_df[metadata_df["split_attr_ids"].isin(train_ids)][attr]
290
+ )
291
+ eval_attr = list(
292
+ metadata_df[metadata_df["split_attr_ids"].isin(eval_ids)][attr]
293
+ )
294
+ if attr == state_key:
295
+ # ensure IDs are interpreted as categorical
296
+ train_attr = [str(item) for item in train_attr]
297
+ eval_attr = [str(item) for item in eval_attr]
298
+ if all(isinstance(item, (int, float)) for item in train_attr + eval_attr):
299
+ train_attr_mean = np.nanmean(train_attr)
300
+ eval_attr_mean = np.nanmean(eval_attr)
301
+ pval = ranksums(train_attr, eval_attr, nan_policy="omit").pvalue
302
+ df_vals += [train_attr_mean, eval_attr_mean, pval]
303
+ elif all(isinstance(item, (str)) for item in train_attr + eval_attr):
304
+ obs_counts = Counter(train_attr)
305
+ exp_counts = Counter(eval_attr)
306
+ all_categ = set(obs_counts.keys()).union(set(exp_counts.keys()))
307
+ obs = [obs_counts[cat] for cat in all_categ]
308
+ exp = [
309
+ exp_counts[cat] * sum(obs) / sum(exp_counts.values())
310
+ for cat in all_categ
311
+ ]
312
+ chisquare(f_obs=obs, f_exp=exp).pvalue
313
+ train_attr_counts = str(obs_counts).strip("Counter(").strip(")")
314
+ eval_attr_counts = str(exp_counts).strip("Counter(").strip(")")
315
+ df_vals += [train_attr_counts, eval_attr_counts, pval]
316
+ else:
317
+ logger.error(
318
+ f"Inconsistent data types in attribute {attr}. "
319
+ "Cannot infer if continuous or categorical. "
320
+ "Must be all numeric (continuous) or all strings (categorical) to balance."
321
+ )
322
+ raise
323
+ pvals += [pval]
324
+
325
+ df_vals += [np.nanmean(pvals)]
326
+ balance_df_i = pd.DataFrame(df_vals, index=colnames).T
327
+ balance_df = pd.concat([balance_df, balance_df_i], ignore_index=True)
328
+ valid_pvals = [
329
+ pval_i
330
+ for pval_i in pvals
331
+ if isinstance(pval_i, (int, float)) and not np.isnan(pval_i)
332
+ ]
333
+ if all(i >= pval_threshold for i in valid_pvals):
334
+ data_dict["train"] = pu.filter_by_dict(
335
+ data, {attr_to_split: balance_df_i["train_ids"][0]}, nproc
336
+ )
337
+ data_dict["test"] = pu.filter_by_dict(
338
+ data, {attr_to_split: balance_df_i["eval_ids"][0]}, nproc
339
+ )
340
+ return data_dict, balance_df
341
+ trial_num = trial_num + 1
342
+ balance_max_df = balance_df.iloc[balance_df["mean_pval"].idxmax(), :]
343
+ data_dict["train"] = pu.filter_by_dict(
344
+ data, {attr_to_split: balance_df_i["train_ids"][0]}, nproc
345
+ )
346
+ data_dict["test"] = pu.filter_by_dict(
347
+ data, {attr_to_split: balance_df_i["eval_ids"][0]}, nproc
348
+ )
349
+ logger.warning(
350
+ f"No splits found without significant difference in attr_to_balance among {max_trials} trials. "
351
+ f"Selecting optimal split (trial #{balance_max_df['trial_num']}) from completed trials."
352
+ )
353
+ return data_dict, balance_df
354
+
355
+
356
+ def get_num_classes(id_class_dict):
357
+ return len(set(id_class_dict.values()))
358
+
359
+
360
+ def compute_metrics(pred):
361
+ labels = pred.label_ids
362
+ preds = pred.predictions.argmax(-1)
363
+ # calculate accuracy and macro f1 using sklearn's function
364
+ acc = accuracy_score(labels, preds)
365
+ macro_f1 = f1_score(labels, preds, average="macro")
366
+ return {"accuracy": acc, "macro_f1": macro_f1}
367
+
368
+
369
+ def get_default_train_args(model, classifier, data, output_dir):
370
+ num_layers = pu.quant_layers(model)
371
+ freeze_layers = 0
372
+ batch_size = 12
373
+ if classifier == "cell":
374
+ epochs = 10
375
+ evaluation_strategy = "epoch"
376
+ load_best_model_at_end = True
377
+ else:
378
+ epochs = 1
379
+ evaluation_strategy = "no"
380
+ load_best_model_at_end = False
381
+
382
+ if num_layers == 6:
383
+ default_training_args = {
384
+ "learning_rate": 5e-5,
385
+ "lr_scheduler_type": "linear",
386
+ "warmup_steps": 500,
387
+ "per_device_train_batch_size": batch_size,
388
+ "per_device_eval_batch_size": batch_size,
389
+ }
390
+
391
+ training_args = {
392
+ "num_train_epochs": epochs,
393
+ "do_train": True,
394
+ "do_eval": True,
395
+ "evaluation_strategy": evaluation_strategy,
396
+ "logging_steps": np.floor(len(data) / batch_size / 8), # 8 evals per epoch
397
+ "save_strategy": "epoch",
398
+ "group_by_length": False,
399
+ "length_column_name": "length",
400
+ "disable_tqdm": False,
401
+ "weight_decay": 0.001,
402
+ "load_best_model_at_end": load_best_model_at_end,
403
+ }
404
+ training_args.update(default_training_args)
405
+
406
+ return training_args, freeze_layers
geneformer/emb_extractor.py CHANGED
@@ -17,7 +17,6 @@ from pathlib import Path
17
 
18
  import anndata
19
  import matplotlib.pyplot as plt
20
- import numpy as np
21
  import pandas as pd
22
  import scanpy as sc
23
  import seaborn as sns
@@ -303,13 +302,6 @@ def make_colorbar(embs_df, label):
303
  cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
304
  label_colors = pd.DataFrame(cell_type_colors, columns=[label])
305
 
306
- for i, row in label_colors.iterrows():
307
- colors = row[0]
308
- if len(colors) != 3 or any(np.isnan(colors)):
309
- print(i, colors)
310
-
311
- label_colors.isna().sum()
312
-
313
  # create dictionary for colors and classes
314
  label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
315
  return label_colors, label_color_dict
@@ -565,7 +557,9 @@ class EmbExtractor:
565
  filtered_input_data, cell_state, self.nproc
566
  )
567
  downsampled_data = pu.downsample_and_sort(filtered_input_data, self.max_ncells)
568
- model = pu.load_model(self.model_type, self.num_classes, model_directory, mode = "eval")
 
 
569
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
570
  embs = get_embs(
571
  model,
 
17
 
18
  import anndata
19
  import matplotlib.pyplot as plt
 
20
  import pandas as pd
21
  import scanpy as sc
22
  import seaborn as sns
 
302
  cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
303
  label_colors = pd.DataFrame(cell_type_colors, columns=[label])
304
 
 
 
 
 
 
 
 
305
  # create dictionary for colors and classes
306
  label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
307
  return label_colors, label_color_dict
 
557
  filtered_input_data, cell_state, self.nproc
558
  )
559
  downsampled_data = pu.downsample_and_sort(filtered_input_data, self.max_ncells)
560
+ model = pu.load_model(
561
+ self.model_type, self.num_classes, model_directory, mode="eval"
562
+ )
563
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
564
  embs = get_embs(
565
  model,
geneformer/evaluation_utils.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import pickle
4
+ from pathlib import Path
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import pandas as pd
9
+ import seaborn as sns
10
+ import torch
11
+ from datasets.utils.logging import disable_progress_bar, enable_progress_bar
12
+ from sklearn import preprocessing
13
+ from sklearn.metrics import (
14
+ ConfusionMatrixDisplay,
15
+ accuracy_score,
16
+ auc,
17
+ confusion_matrix,
18
+ f1_score,
19
+ roc_curve,
20
+ )
21
+ from tqdm.auto import trange
22
+
23
+ from .emb_extractor import make_colorbar
24
+ from .tokenizer import TOKEN_DICTIONARY_FILE
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # load token dictionary (Ensembl IDs:token)
29
+ with open(TOKEN_DICTIONARY_FILE, "rb") as f:
30
+ gene_token_dict = pickle.load(f)
31
+
32
+
33
+ def preprocess_classifier_batch(cell_batch, max_len, label_name):
34
+ if max_len is None:
35
+ max_len = max([len(i) for i in cell_batch["input_ids"]])
36
+
37
+ def pad_label_example(example):
38
+ example[label_name] = np.pad(
39
+ example[label_name],
40
+ (0, max_len - len(example["input_ids"])),
41
+ mode="constant",
42
+ constant_values=-100,
43
+ )
44
+ example["input_ids"] = np.pad(
45
+ example["input_ids"],
46
+ (0, max_len - len(example["input_ids"])),
47
+ mode="constant",
48
+ constant_values=gene_token_dict.get("<pad>"),
49
+ )
50
+ example["attention_mask"] = (
51
+ example["input_ids"] != gene_token_dict.get("<pad>")
52
+ ).astype(int)
53
+ return example
54
+
55
+ padded_batch = cell_batch.map(pad_label_example)
56
+ return padded_batch
57
+
58
+
59
+ # Function to find the largest number smaller
60
+ # than or equal to N that is divisible by k
61
+ def find_largest_div(N, K):
62
+ rem = N % K
63
+ if rem == 0:
64
+ return N
65
+ else:
66
+ return N - rem
67
+
68
+
69
+ def vote(logit_list):
70
+ m = max(logit_list)
71
+ logit_list.index(m)
72
+ indices = [i for i, x in enumerate(logit_list) if x == m]
73
+ if len(indices) > 1:
74
+ return "tie"
75
+ else:
76
+ return indices[0]
77
+
78
+
79
+ def py_softmax(vector):
80
+ e = np.exp(vector)
81
+ return e / e.sum()
82
+
83
+
84
+ def classifier_predict(model, classifier_type, evalset, forward_batch_size):
85
+ if classifier_type == "gene":
86
+ label_name = "labels"
87
+ elif classifier_type == "cell":
88
+ label_name = "label"
89
+
90
+ predict_logits = []
91
+ predict_labels = []
92
+ model.eval()
93
+
94
+ # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
95
+ evalset_len = len(evalset)
96
+ max_divisible = find_largest_div(evalset_len, forward_batch_size)
97
+ if len(evalset) - max_divisible == 1:
98
+ evalset_len = max_divisible
99
+
100
+ max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
101
+
102
+ disable_progress_bar() # disable progress bar for preprocess_classifier_batch mapping
103
+ for i in trange(0, evalset_len, forward_batch_size):
104
+ max_range = min(i + forward_batch_size, evalset_len)
105
+ batch_evalset = evalset.select([i for i in range(i, max_range)])
106
+ padded_batch = preprocess_classifier_batch(
107
+ batch_evalset, max_evalset_len, label_name
108
+ )
109
+ padded_batch.set_format(type="torch")
110
+
111
+ input_data_batch = padded_batch["input_ids"]
112
+ attn_msk_batch = padded_batch["attention_mask"]
113
+ label_batch = padded_batch[label_name]
114
+ with torch.no_grad():
115
+ outputs = model(
116
+ input_ids=input_data_batch.to("cuda"),
117
+ attention_mask=attn_msk_batch.to("cuda"),
118
+ labels=label_batch.to("cuda"),
119
+ )
120
+ predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
121
+ predict_labels += [torch.squeeze(label_batch.to("cpu"))]
122
+
123
+ enable_progress_bar()
124
+ logits_by_cell = torch.cat(predict_logits)
125
+ last_dim = len(logits_by_cell.shape) - 1
126
+ all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[last_dim])
127
+ labels_by_cell = torch.cat(predict_labels)
128
+ all_labels = torch.flatten(labels_by_cell)
129
+ logit_label_paired = [
130
+ item
131
+ for item in list(zip(all_logits.tolist(), all_labels.tolist()))
132
+ if item[1] != -100
133
+ ]
134
+ y_pred = [vote(item[0]) for item in logit_label_paired]
135
+ y_true = [item[1] for item in logit_label_paired]
136
+ logits_list = [item[0] for item in logit_label_paired]
137
+ return y_pred, y_true, logits_list
138
+
139
+
140
+ def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
141
+ conf_mat = confusion_matrix(y_true, y_pred, labels=list(labels))
142
+ macro_f1 = f1_score(y_true, y_pred, average="macro")
143
+ acc = accuracy_score(y_true, y_pred)
144
+ roc_metrics = None # roc metrics not reported for multiclass
145
+ if num_classes == 2:
146
+ y_score = [py_softmax(item)[1] for item in logits_list]
147
+ fpr, tpr, _ = roc_curve(y_true, y_score)
148
+ mean_fpr = np.linspace(0, 1, 100)
149
+ interp_tpr = np.interp(mean_fpr, fpr, tpr)
150
+ interp_tpr[0] = 0.0
151
+ tpr_wt = len(tpr)
152
+ roc_auc = auc(fpr, tpr)
153
+ roc_metrics = {
154
+ "fpr": fpr,
155
+ "tpr": tpr,
156
+ "interp_tpr": interp_tpr,
157
+ "auc": roc_auc,
158
+ "tpr_wt": tpr_wt,
159
+ }
160
+ return conf_mat, macro_f1, acc, roc_metrics
161
+
162
+
163
+ # get cross-validated mean and sd metrics
164
+ def get_cross_valid_roc_metrics(all_tpr, all_roc_auc, all_tpr_wt):
165
+ wts = [count / sum(all_tpr_wt) for count in all_tpr_wt]
166
+ all_weighted_tpr = [a * b for a, b in zip(all_tpr, wts)]
167
+ mean_tpr = np.sum(all_weighted_tpr, axis=0)
168
+ mean_tpr[-1] = 1.0
169
+ all_weighted_roc_auc = [a * b for a, b in zip(all_roc_auc, wts)]
170
+ roc_auc = np.sum(all_weighted_roc_auc)
171
+ roc_auc_sd = math.sqrt(np.average((all_roc_auc - roc_auc) ** 2, weights=wts))
172
+ return mean_tpr, roc_auc, roc_auc_sd
173
+
174
+
175
+ # plot ROC curve
176
+ def plot_ROC(roc_metric_dict, model_style_dict, title, output_dir, output_prefix):
177
+ fig = plt.figure()
178
+ fig.set_size_inches(10, 8)
179
+ sns.set(font_scale=2)
180
+ sns.set_style("white")
181
+ lw = 3
182
+ for model_name in roc_metric_dict.keys():
183
+ mean_fpr = roc_metric_dict[model_name]["mean_fpr"]
184
+ mean_tpr = roc_metric_dict[model_name]["mean_tpr"]
185
+ roc_auc = roc_metric_dict[model_name]["roc_auc"]
186
+ roc_auc_sd = roc_metric_dict[model_name]["roc_auc_sd"]
187
+ color = model_style_dict[model_name]["color"]
188
+ linestyle = model_style_dict[model_name]["linestyle"]
189
+ if len(roc_metric_dict[model_name]["all_roc_auc"]) > 1:
190
+ label = f"{model_name} (AUC {roc_auc:0.2f} $\pm$ {roc_auc_sd:0.2f})"
191
+ else:
192
+ label = f"{model_name} (AUC {roc_auc:0.2f})"
193
+ plt.plot(
194
+ mean_fpr, mean_tpr, color=color, linestyle=linestyle, lw=lw, label=label
195
+ )
196
+
197
+ plt.plot([0, 1], [0, 1], color="black", lw=lw, linestyle="--")
198
+ plt.xlim([0.0, 1.0])
199
+ plt.ylim([0.0, 1.05])
200
+ plt.xlabel("False Positive Rate")
201
+ plt.ylabel("True Positive Rate")
202
+ plt.title(title)
203
+ plt.legend(loc="lower right")
204
+ plt.show()
205
+
206
+ output_file = (Path(output_dir) / f"{output_prefix}_roc").with_suffix(".pdf")
207
+ plt.savefig(output_file, bbox_inches="tight")
208
+
209
+
210
+ # plot confusion matrix
211
+ def plot_confusion_matrix(
212
+ conf_mat_df, title, output_dir, output_prefix, custom_class_order
213
+ ):
214
+ fig = plt.figure()
215
+ fig.set_size_inches(10, 10)
216
+ sns.set(font_scale=1)
217
+ sns.set_style("whitegrid", {"axes.grid": False})
218
+ if custom_class_order is not None:
219
+ conf_mat_df = conf_mat_df.reindex(
220
+ index=custom_class_order, columns=custom_class_order
221
+ )
222
+ display_labels = generate_display_labels(conf_mat_df)
223
+ conf_mat = preprocessing.normalize(conf_mat_df.to_numpy(), norm="l1")
224
+ display = ConfusionMatrixDisplay(
225
+ confusion_matrix=conf_mat, display_labels=display_labels
226
+ )
227
+ display.plot(cmap="Blues", values_format=".2g")
228
+ plt.title(title)
229
+ plt.show()
230
+
231
+ output_file = (Path(output_dir) / f"{output_prefix}_conf_mat").with_suffix(".pdf")
232
+ display.figure_.savefig(output_file, bbox_inches="tight")
233
+
234
+
235
+ def generate_display_labels(conf_mat_df):
236
+ display_labels = []
237
+ i = 0
238
+ for label in conf_mat_df.index:
239
+ display_labels += [f"{label}\nn={conf_mat_df.iloc[i,:].sum():.0f}"]
240
+ i = i + 1
241
+ return display_labels
242
+
243
+
244
+ def plot_predictions(predictions_df, title, output_dir, output_prefix, kwargs_dict):
245
+ sns.set(font_scale=2)
246
+ plt.figure(figsize=(10, 10), dpi=150)
247
+ label_colors, label_color_dict = make_colorbar(predictions_df, "true")
248
+ predictions_df = predictions_df.drop(columns=["true"])
249
+ predict_colors_list = [label_color_dict[label] for label in predictions_df.columns]
250
+ predict_label_list = [label for label in predictions_df.columns]
251
+ predict_colors = pd.DataFrame(
252
+ pd.Series(predict_colors_list, index=predict_label_list), columns=["predicted"]
253
+ )
254
+
255
+ default_kwargs_dict = {
256
+ "row_cluster": False,
257
+ "col_cluster": False,
258
+ "row_colors": label_colors,
259
+ "col_colors": predict_colors,
260
+ "linewidths": 0,
261
+ "xticklabels": False,
262
+ "yticklabels": False,
263
+ "center": 0,
264
+ "cmap": "vlag",
265
+ }
266
+
267
+ if kwargs_dict is not None:
268
+ default_kwargs_dict.update(kwargs_dict)
269
+ g = sns.clustermap(predictions_df, **default_kwargs_dict)
270
+
271
+ plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
272
+
273
+ for label_color in list(label_color_dict.keys()):
274
+ g.ax_col_dendrogram.bar(
275
+ 0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0
276
+ )
277
+
278
+ g.ax_col_dendrogram.legend(
279
+ title=f"{title}",
280
+ loc="lower center",
281
+ ncol=4,
282
+ bbox_to_anchor=(0.5, 1),
283
+ facecolor="white",
284
+ )
285
+
286
+ output_file = (Path(output_dir) / f"{output_prefix}_pred").with_suffix(".pdf")
287
+ plt.savefig(output_file, bbox_inches="tight")
geneformer/in_silico_perturber_stats.py CHANGED
@@ -801,6 +801,12 @@ class InSilicoPerturberStats:
801
  logger.error("All states must be unique.")
802
  raise
803
 
 
 
 
 
 
 
804
  else:
805
  logger.error(
806
  "cell_states_to_model must only have the following four keys: "
 
801
  logger.error("All states must be unique.")
802
  raise
803
 
804
+ elif set(self.cell_states_to_model.keys()) == {
805
+ "state_key",
806
+ "start_state",
807
+ "goal_state",
808
+ }:
809
+ self.cell_states_to_model["alt_states"] = []
810
  else:
811
  logger.error(
812
  "cell_states_to_model must only have the following four keys: "
geneformer/tokenizer.py CHANGED
@@ -43,12 +43,13 @@ from pathlib import Path
43
  from typing import Literal
44
 
45
  import anndata as ad
46
- import loompy as lp
47
  import numpy as np
48
  import scipy.sparse as sp
49
  from datasets import Dataset
50
 
51
- warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
 
 
52
  logger = logging.getLogger(__name__)
53
 
54
  GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
@@ -81,7 +82,7 @@ class TranscriptomeTokenizer:
81
  custom_attr_name_dict=None,
82
  nproc=1,
83
  chunk_size=512,
84
- input_size=2048,
85
  special_token=False,
86
  gene_median_file=GENE_MEDIAN_FILE,
87
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
@@ -95,10 +96,10 @@ class TranscriptomeTokenizer:
95
  | Values are the names of the attributes in the dataset.
96
  nproc : int
97
  | Number of processes to use for dataset mapping.
98
- chunk_size: int = 512
99
  | Chunk size for anndata tokenizer.
100
- input_size: int = 2048
101
- | Input size for tokenization
102
  special_token: bool = False
103
  | Option to add CLS and SEP tokens
104
  gene_median_file : Path
@@ -117,7 +118,7 @@ class TranscriptomeTokenizer:
117
  self.chunk_size = chunk_size
118
 
119
  # input size for tokenization
120
- self.input_size = input_size
121
 
122
  # add CLS and SEP tokens
123
  self.special_token = special_token
@@ -163,7 +164,9 @@ class TranscriptomeTokenizer:
163
  Path(data_directory), file_format
164
  )
165
  tokenized_dataset = self.create_dataset(
166
- tokenized_cells, cell_metadata, use_generator=use_generator
 
 
167
  )
168
 
169
  output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
@@ -332,7 +335,7 @@ class TranscriptomeTokenizer:
332
  file_cell_metadata[k] += subview.ca[k].tolist()
333
  else:
334
  file_cell_metadata = None
335
-
336
  return tokenized_cells, file_cell_metadata
337
 
338
  def create_dataset(
@@ -367,12 +370,20 @@ class TranscriptomeTokenizer:
367
 
368
  # Truncate/Crop input_ids to input size
369
  if self.special_token:
370
- example["input_ids"] = example["input_ids"][0:self.input_size-2] # truncate to leave space for CLS and SEP token
371
- example["input_ids"] = np.insert(example["input_ids"], 0, self.gene_token_dict.get("<cls>"))
372
- example["input_ids"] = np.insert(example["input_ids"], len(example["input_ids"]), self.gene_token_dict.get("<sep>"))
 
 
 
 
 
 
 
 
373
  else:
374
  # Truncate/Crop input_ids to input size
375
- example["input_ids"] = example["input_ids"][0:self.input_size]
376
  example["length"] = len(example["input_ids"])
377
 
378
  return example
@@ -380,4 +391,4 @@ class TranscriptomeTokenizer:
380
  output_dataset_truncated = output_dataset.map(
381
  format_cell_features, num_proc=self.nproc
382
  )
383
- return output_dataset_truncated
 
43
  from typing import Literal
44
 
45
  import anndata as ad
 
46
  import numpy as np
47
  import scipy.sparse as sp
48
  from datasets import Dataset
49
 
50
+ warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa
51
+ import loompy as lp # noqa
52
+
53
  logger = logging.getLogger(__name__)
54
 
55
  GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
 
82
  custom_attr_name_dict=None,
83
  nproc=1,
84
  chunk_size=512,
85
+ model_input_size=2048,
86
  special_token=False,
87
  gene_median_file=GENE_MEDIAN_FILE,
88
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
 
96
  | Values are the names of the attributes in the dataset.
97
  nproc : int
98
  | Number of processes to use for dataset mapping.
99
+ chunk_size : int = 512
100
  | Chunk size for anndata tokenizer.
101
+ model_input_size: int = 2048
102
+ | Max input size of model to truncate input to.
103
  special_token: bool = False
104
  | Option to add CLS and SEP tokens
105
  gene_median_file : Path
 
118
  self.chunk_size = chunk_size
119
 
120
  # input size for tokenization
121
+ self.model_input_size = model_input_size
122
 
123
  # add CLS and SEP tokens
124
  self.special_token = special_token
 
164
  Path(data_directory), file_format
165
  )
166
  tokenized_dataset = self.create_dataset(
167
+ tokenized_cells,
168
+ cell_metadata,
169
+ use_generator=use_generator,
170
  )
171
 
172
  output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
 
335
  file_cell_metadata[k] += subview.ca[k].tolist()
336
  else:
337
  file_cell_metadata = None
338
+
339
  return tokenized_cells, file_cell_metadata
340
 
341
  def create_dataset(
 
370
 
371
  # Truncate/Crop input_ids to input size
372
  if self.special_token:
373
+ example["input_ids"] = example["input_ids"][
374
+ 0 : self.model_input_size - 2
375
+ ] # truncate to leave space for CLS and SEP token
376
+ example["input_ids"] = np.insert(
377
+ example["input_ids"], 0, self.gene_token_dict.get("<cls>")
378
+ )
379
+ example["input_ids"] = np.insert(
380
+ example["input_ids"],
381
+ len(example["input_ids"]),
382
+ self.gene_token_dict.get("<sep>"),
383
+ )
384
  else:
385
  # Truncate/Crop input_ids to input size
386
+ example["input_ids"] = example["input_ids"][0 : self._model_input_size]
387
  example["length"] = len(example["input_ids"])
388
 
389
  return example
 
391
  output_dataset_truncated = output_dataset.map(
392
  format_cell_features, num_proc=self.nproc
393
  )
394
+ return output_dataset_truncated