# Testing ESMB for Protein Binding Residue Prediction

This notebook is meant to test out ESM-2 LoRA models on the datasets found [here](https://github.com/hamzagamouh/pt-lm-gnn/tree/main/datasets/yu_merged) for the paper [Hybrid protein-ligand binding residue prediction with protein
language models: Does the structure matter?](https://www.biorxiv.org/content/10.1101/2023.08.11.553028v1). The models referenced in the paper are GCN, GAT, and ensemble structural models trained on PDB sequences to predict binding residues. They are the best performing models that could be found as of 17/09/23. You will need to download the datasets you want to test out from the github above and provide the file path in the code below.

## Mount Your Google Drive if Necessary

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install transformers -q
!pip install accelerate -q
!pip install peft -q
!pip install datasets -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m294.8/294.8 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m19.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m258.1/258.1 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.6/85.6 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.6/519.6 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import pandas as pd

# Load the dataset
data_df = pd.read_csv("/content/drive/MyDrive/esmb_testing/CA_Training.txt", delimiter=';')

# Display the first few rows of the dataframe to understand its structure
data_df.head()



Unnamed: 0,pdb_id,chain_id,binding_residues,sequence
0,1CB8,A,E380 D382 K383 D391 Y392,GTAELIMKRVMLDLKKPLRNMDKVAEKNLNTLQPDGSWKDVPYKDD...
1,3ALS,A,E112 N114 N115 D135,LTSCPPLWTGFNGKCFRLFHNHLNFDNAENACRQFGLASCSGDELA...
2,2X7Q,A,N52 D197 G71 E73,LPTLKVAYIPEHFSTPLFFAQQQGYYKAHDLSIEFVKVPEGSGRLI...
3,3BBY,A,D75 E77,KPAITLWSDAHFFSPYVLSAWVALQEKGLSFHIKTIDRVPLLQIDD...
4,1B2L,A,D2 T4,MDLTNKNVIFVAALGGIGLDTSRELVKRNLKNFVILDRVENPTALA...


In [None]:
# Define a function to convert binding residues to binary labels
def binding_residues_to_labels(row):
    sequence = row['sequence']
    binding_residues = row['binding_residues']

    # Initialize a list with zeros
    labels = [0] * len(sequence)

    # If binding_residues is not NaN, mark the binding residues in the labels list with 1
    if isinstance(binding_residues, str):
        # Get the indices of the binding residues
        binding_residues_indices = [int(residue[1:]) - 1 for residue in binding_residues.split()]

        # Mark the binding residues in the labels list with 1
        for idx in binding_residues_indices:
            if idx < len(labels):
                labels[idx] = 1

    return labels

# Apply the function to each row in the DataFrame to get the binary labels
data_df['binding_labels'] = data_df.apply(binding_residues_to_labels, axis=1)

# Display the first few rows of the DataFrame
data_df.head()



Unnamed: 0,pdb_id,chain_id,binding_residues,sequence,binding_labels
0,1CB8,A,E380 D382 K383 D391 Y392,GTAELIMKRVMLDLKKPLRNMDKVAEKNLNTLQPDGSWKDVPYKDD...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,3ALS,A,E112 N114 N115 D135,LTSCPPLWTGFNGKCFRLFHNHLNFDNAENACRQFGLASCSGDELA...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,2X7Q,A,N52 D197 G71 E73,LPTLKVAYIPEHFSTPLFFAQQQGYYKAHDLSIEFVKVPEGSGRLI...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,3BBY,A,D75 E77,KPAITLWSDAHFFSPYVLSAWVALQEKGLSFHIKTIDRVPLLQIDD...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,1B2L,A,D2 T4,MDLTNKNVIFVAALGGIGLDTSRELVKRNLKNFVILDRVENPTALA...,"[0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [None]:
# Define the maximum chunk size
MAX_CHUNK_SIZE = 900

# Function to segment sequences and labels into chunks of size <= 1022
def segment_into_chunks(row):
    sequence = row['sequence']
    labels = row['binding_labels']

    # Segment the sequence and labels into chunks of size <= 1022
    sequence_chunks = [sequence[i:i+MAX_CHUNK_SIZE] for i in range(0, len(sequence), MAX_CHUNK_SIZE)]
    label_chunks = [labels[i:i+MAX_CHUNK_SIZE] for i in range(0, len(labels), MAX_CHUNK_SIZE)]

    return sequence_chunks, label_chunks

# Apply the function to each row in the DataFrame to get the segmented sequences and labels
data_df['sequence_chunks'] = None
data_df['label_chunks'] = None
for idx, row in data_df.iterrows():
    data_df.at[idx, 'sequence_chunks'], data_df.at[idx, 'label_chunks'] = segment_into_chunks(row)

# Display the first few rows of the DataFrame
data_df[['pdb_id', 'chain_id', 'sequence_chunks', 'label_chunks']].head()



Unnamed: 0,pdb_id,chain_id,sequence_chunks,label_chunks
0,1CB8,A,[GTAELIMKRVMLDLKKPLRNMDKVAEKNLNTLQPDGSWKDVPYKD...,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
1,3ALS,A,[LTSCPPLWTGFNGKCFRLFHNHLNFDNAENACRQFGLASCSGDEL...,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
2,2X7Q,A,[LPTLKVAYIPEHFSTPLFFAQQQGYYKAHDLSIEFVKVPEGSGRL...,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
3,3BBY,A,[KPAITLWSDAHFFSPYVLSAWVALQEKGLSFHIKTIDRVPLLQID...,"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
4,1B2L,A,[MDLTNKNVIFVAALGGIGLDTSRELVKRNLKNFVILDRVENPTAL...,"[[0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."


In [None]:
from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch

def get_predictions(protein_sequence):
    # Path to the saved LoRA model
    model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"
    # ESM2 base model
    base_model_path = "facebook/esm2_t12_35M_UR50D"

    # Load the model
    base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
    loaded_model = PeftModel.from_pretrained(base_model, model_path)

    # Ensure the model is in evaluation mode
    loaded_model.eval()

    # Load the tokenizer
    loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)

    # Tokenize the sequence
    inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')

    # Run the model
    with torch.no_grad():
        logits = loaded_model(**inputs).logits

    # Get predictions
    tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])  # Convert input ids back to tokens
    predictions = torch.argmax(logits, dim=2)[0].numpy()

    # Define labels
    id2label = {
        0: "No binding site",
        1: "Binding site"
    }

    # Convert predictions to binary labels (1 for binding site, 0 otherwise)
    special_tokens = ['<cls>', '<pad>', '<eos>', '<unk>', '.', '-', '<null_1>', '<mask>']
    binary_predictions = [1 if id2label[pred] == "Binding site" else 0 for token, pred in zip(tokens, predictions) if token not in special_tokens]

    return binary_predictions

# Use the function to get predictions for a test sequence
test_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT"
print(get_predictions(test_sequence))



Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Downloading (…)/adapter_config.json:   0%|          | 0.00/456 [00:00<?, ?B/s]

Downloading adapter_model.bin:   0%|          | 0.00/307k [00:00<?, ?B/s]

[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [None]:
# Step 1: Modify the get_predictions function
def get_predictions(protein_sequence, loaded_model, loaded_tokenizer):
    # Tokenize the sequence
    inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1000, padding='max_length')

    # Run the model
    with torch.no_grad():
        logits = loaded_model(**inputs).logits

    # Get predictions
    tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])  # Convert input ids back to tokens
    predictions = torch.argmax(logits, dim=2)[0].numpy()

    # Define labels
    id2label = {
        0: "No binding site",
        1: "Binding site"
    }

    # Convert predictions to binary labels (1 for binding site, 0 otherwise)
    special_tokens = ['<cls>', '<pad>', '<eos>', '<unk>', '.', '-', '<null_1>', '<mask>']
    binary_predictions = [1 if id2label[pred] == "Binding site" else 0 for token, pred in zip(tokens, predictions) if token not in special_tokens]

    return binary_predictions

# Load the model and tokenizer
base_model_path = "facebook/esm2_t12_35M_UR50D"
model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"
loaded_model = PeftModel.from_pretrained(AutoModelForTokenClassification.from_pretrained(base_model_path), model_path)
loaded_model.eval()
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)

# Step 2: Create a function to get predictions for each chunk and store them in a new column
def get_chunk_predictions(row):
    global loaded_model, loaded_tokenizer
    sequence_chunks = row['sequence_chunks']
    predictions = [get_predictions(chunk, loaded_model, loaded_tokenizer) for chunk in sequence_chunks]
    return predictions

data_df['predictions_chunks'] = data_df.apply(get_chunk_predictions, axis=1)

# Step 3: Flatten the predictions and true labels columns to calculate metrics
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, matthews_corrcoef

# Flatten the lists of labels and predictions to calculate metrics
true_labels_flat = [label for sublist in data_df['label_chunks'].tolist() for subsublist in sublist for label in subsublist]
predictions_flat = [label for sublist in data_df['predictions_chunks'].tolist() for subsublist in sublist for label in subsublist]

# Calculate the metrics
accuracy = accuracy_score(true_labels_flat, predictions_flat)
precision = precision_score(true_labels_flat, predictions_flat)
recall = recall_score(true_labels_flat, predictions_flat)
f1 = f1_score(true_labels_flat, predictions_flat)
auc = roc_auc_score(true_labels_flat, predictions_flat)
mcc = matthews_corrcoef(true_labels_flat, predictions_flat)

# Print the metrics
print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')
print(f'AUC: {auc:.4f}')
print(f'MCC: {mcc:.4f}')


Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Accuracy: 0.8673
Precision: 0.0408
Recall: 0.3071
F1 Score: 0.0721
AUC: 0.5920
MCC: 0.0712


## Train/Test Metrics

Here you can get the train and test metrics the model was originally trained on. Perhaps you can figure out why they are so different from the metrics on the datasets above?!

### Loading and Tokenizing the Datasets

To use this notebook to run the model on the train/test split and get the various metrics (accuracy, precision, recall, F1 score, AUC, and MCC) you will need to download the pickle files [found on Hugging Face here](https://huggingface.co./datasets/AmelieSchreiber/binding_sites_random_split_by_family_550K). Navigate to the "Files and versions" and download the four pickle files (you can ignore the TSV files unless you want to preprocess the data in a different way yourself). Once you have downloaded the pickle files, change the four file pickle paths in the cell below to match the local paths of the pickle files on your machine, then run the cell.

In [None]:
from datasets import Dataset
from transformers import AutoTokenizer
import pickle

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")

# Function to truncate labels
def truncate_labels(labels, max_length):
    """Truncate labels to the specified max_length."""
    return [label[:max_length] for label in labels]

# Set the maximum sequence length
max_sequence_length = 1000

# Load the data from pickle files (change to match your local paths)
with open("train_sequences_chunked_by_family.pkl", "rb") as f:
    train_sequences = pickle.load(f)
with open("test_sequences_chunked_by_family.pkl", "rb") as f:
    test_sequences = pickle.load(f)
with open("train_labels_chunked_by_family.pkl", "rb") as f:
    train_labels = pickle.load(f)
with open("test_labels_chunked_by_family.pkl", "rb") as f:
    test_labels = pickle.load(f)

# Tokenize the sequences
train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)

# Truncate the labels to match the tokenized sequence lengths
train_labels = truncate_labels(train_labels, max_sequence_length)
test_labels = truncate_labels(test_labels, max_sequence_length)

# Create train and test datasets
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)

train_dataset, test_dataset


(Dataset({
     features: ['input_ids', 'attention_mask', 'labels'],
     num_rows: 450330
 }),
 Dataset({
     features: ['input_ids', 'attention_mask', 'labels'],
     num_rows: 113475
 }))

### Getting the Train/Test Metrics

Next, run the following cell. Depending on your hardware, this may take a while. There are ~549K protein sequences to process in total. The train dataset will obviously take much longer than the test dataset. Be patient and let both of them complete to see both the train and test metrics.

In [None]:
from sklearn.metrics import(
    matthews_corrcoef,
    accuracy_score,
    precision_recall_fscore_support,
    roc_auc_score
)
from peft import PeftModel
from transformers import DataCollatorForTokenClassification, AutoModelForTokenClassification
from transformers import Trainer
from accelerate import Accelerator

# Instantiate the accelerator
accelerator = Accelerator()

# Define paths to the LoRA and base models
base_model_path = "facebook/esm2_t12_35M_UR50D"
lora_model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3" # "path/to/your/lora/model"  # Replace with the correct path to your LoRA model

# Load the base model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)

# Load the LoRA model
model = PeftModel.from_pretrained(base_model, lora_model_path)
model = accelerator.prepare(model)  # Prepare the model using the accelerator

# Define label mappings
id2label = {0: "No binding site", 1: "Binding site"}
label2id = {v: k for k, v in id2label.items()}

# Create a data collator
data_collator = DataCollatorForTokenClassification(tokenizer)

# Define a function to compute the metrics
def compute_metrics(dataset):
    # Get the predictions using the trained model
    trainer = Trainer(model=model, data_collator=data_collator)
    predictions, labels, _ = trainer.predict(test_dataset=dataset)

    # Remove padding and special tokens
    mask = labels != -100
    true_labels = labels[mask].flatten()
    flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist()

    # Compute the metrics
    accuracy = accuracy_score(true_labels, flat_predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')
    auc = roc_auc_score(true_labels, flat_predictions)
    mcc = matthews_corrcoef(true_labels, flat_predictions)  # Compute the MCC

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc}  # Include the MCC in the returned dictionary

# Get the metrics for the training and test datasets
train_metrics = compute_metrics(train_dataset)
test_metrics = compute_metrics(test_dataset)

train_metrics, test_metrics

Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Downloading (…)/adapter_config.json:   0%|          | 0.00/457 [00:00<?, ?B/s]

Downloading adapter_model.bin:   0%|          | 0.00/307k [00:00<?, ?B/s]