File size: 7,586 Bytes
3694457
7a1fb87
 
59c4ed1
7a1fb87
 
 
 
 
 
6187032
 
 
 
 
 
 
 
 
a29fcb1
7a1fb87
59c4ed1
0915bd1
49e961f
 
cd1f4b4
5e8cce3
49e961f
 
 
400bbb0
 
 
49e961f
3694457
6187032
 
 
49e961f
6187032
17b2397
d4048c3
cd1f4b4
 
 
 
 
d4048c3
cd1f4b4
 
 
 
 
a728b2c
3694457
d4048c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6187032
 
 
 
 
 
3694457
 
6187032
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
---
language:
- en
license: mit
library_name: peft
tags:
- ESM-2
- protein language model
- binding sites
- biology
datasets:
- AmelieSchreiber/binding_sites_random_split_by_family_550K
metrics:
- accuracy
- f1
- roc_auc
- precision
- recall
- matthews_correlation
pipeline_tag: token-classification
base_model: facebook/esm2_t12_35M_UR50D
---

# ESM-2 for Binding Site Prediction

**This model is overfit (see below).** This model is a finetuned version of the 35M parameter `esm2_t12_35M_UR50D` ([see here](https://huggingface.co./facebook/esm2_t12_35M_UR50D) 
and [here](https://huggingface.co./docs/transformers/model_doc/esm) for more details). The model was finetuned with LoRA for
the binay token classification task of predicting binding sites (and active sites) of protein sequences based on sequence alone. 
The model may be underfit and undertrained, however it still achieved better performance on the test set in terms of loss, accuracy, 
precision, recall, F1 score, ROC_AUC, and Matthews Correlation Coefficient (MCC) compared to the models trained on the smaller 
dataset [found here](https://huggingface.co./datasets/AmelieSchreiber/binding_sites_random_split_by_family) of ~209K protein sequences. Note, 
this model has a high recall, meaning it is likely to detect binding sites, but it has a low precision, meaning the model will likely return 
false positives as well. 

## Training procedure

This model was finetuned on ~549K protein sequences from the UniProt database. The dataset can be found 
[here](https://huggingface.co./datasets/AmelieSchreiber/binding_sites_random_split_by_family_550K). The model obtains 
the following test metrics:

```python
Train: ({'accuracy': 0.9905461579981686,
  'precision': 0.7695765003685506,
  'recall': 0.9841352974610041,
  'f1': 0.8637307441810476,
  'auc': 0.9874413786006525,
  'mcc': 0.8658850560635515},
Test: {'accuracy': 0.9394282959813123,
  'precision': 0.3662722265170941,
  'recall': 0.8330231316088238,
  'f1': 0.5088208423175958,
  'auc': 0.8883078682492643,
  'mcc': 0.5283098562376193})
```

To analyze the train and test metrics, we will consider each metric individually and then offer a comprehensive view of the 
model’s performance. Let's start:

### **1. Accuracy**
- **Train**: 99.05%
- **Test**: 93.94%

The accuracy is quite high in both the training and test datasets, indicating that the model is correctly identifying the positive 
and negative classes most of the time.

### **2. Precision**
- **Train**: 76.96%
- **Test**: 36.63%

The precision, which measures the proportion of true positive predictions among all positive predictions, drops significantly in 
the test set. This suggests that the model might be identifying too many false positives when generalized to unseen data.

### **3. Recall**
- **Train**: 98.41%
- **Test**: 83.30%

The recall, which indicates the proportion of actual positives correctly identified, remains quite high in the test set, although 
lower than in the training set. This suggests the model is quite sensitive and is able to identify most of the positive cases.

### **4. F1-Score**
- **Train**: 86.37%
- **Test**: 50.88%

The F1-score is the harmonic mean of precision and recall. The significant drop in the F1-score from training to testing indicates 
that the balance between precision and recall has worsened in the test set, which is primarily due to the lower precision.

### **5. AUC (Area Under the ROC Curve)**
- **Train**: 98.74%
- **Test**: 88.83%

The AUC is high in both training and testing, but it decreases in the test set. A high AUC indicates that the model has good measure 
of separability and is able to distinguish between the positive and negative classes well.

### **6. MCC (Matthews Correlation Coefficient)**
- **Train**: 86.59%
- **Test**: 52.83%

MCC is a balanced metric that considers true and false positives and negatives. The decline in MCC from training to testing indicates 
a decrease in the quality of binary classifications.

### **Overall Analysis**

- **Overfitting**: The significant drop in metrics such as precision, F1-score, and MCC from training to test set suggests that the model might be overfitting to the training data, i.e., it may not generalize well to unseen data.
  
- **High Recall, Low Precision**: The model has a high recall but low precision on the test set, indicating that it is identifying too many cases as positive, including those that are actually negative (false positives). This could be a reflection of a model that is biased towards predicting the positive class.

- **Improvement Suggestions**:
  - **Data Augmentation**: So, we might want to consider data augmentation strategies to make the model more robust.
  - **Class Weights**: If there is a class imbalance in the dataset, adjusting the class weights during training might help.
  - **Hyperparameter Tuning**: Experiment with different hyperparameters, including the learning rate, batch size, etc., to see if you can improve the model's performance on the test set.
  - **Feature Engineering**: Consider revisiting the features used to train the model. Sometimes, introducing new features or removing irrelevant ones can help improve performance.

In conclusion, while the model performs excellently on the training set, its performance drops in the test set, suggesting that there 
is room for improvement to make the model more generalizable to unseen data. It would be beneficial to look into strategies to reduce 
overfitting and improve precision without significantly sacrificing recall.

The dataset size increase from ~209K protein sequences to ~549K clearly improved performance in terms of test metric. 
We used Hugging Face's parameter efficient finetuning (PEFT) library to finetune with Low Rank Adaptation (LoRA). We decided 
to use a rank of 2 for the LoRA, as this was shown to slightly improve the test metrics compared to rank 8 and rank 16 on the 
same model trained on the smaller dataset. 

### Framework versions

- PEFT 0.5.0

## Using the model

To use the model on one of your protein sequences try running the following:

```python
from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch

# Path to the saved LoRA model
model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp1"
# ESM2 base model
base_model_path = "facebook/esm2_t12_35M_UR50D"

# Load the model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)

# Ensure the model is in evaluation mode
loaded_model.eval()

# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)

# Protein sequence for inference
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT"  # Replace with your actual sequence

# Tokenize the sequence
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')

# Run the model
with torch.no_grad():
    logits = loaded_model(**inputs).logits

# Get predictions
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])  # Convert input ids back to tokens
predictions = torch.argmax(logits, dim=2)

# Define labels
id2label = {
    0: "No binding site",
    1: "Binding site"
}

# Print the predicted labels for each token
for token, prediction in zip(tokens, predictions[0].numpy()):
    if token not in ['<pad>', '<cls>', '<eos>']:
        print((token, id2label[prediction]))
```