Synteract freezes when running in loop after approx 5K pairwise PPI comparison
I am trying to run synteract using a while loop as shown below where I fetch two protein sequence at a time and loop synteract.py over and over again. I have to do it for 6 million pairwise comparison. I do this by activating the conda environment and opening 7 separate tabs with same enviornment and run 2K jobs in each. These jobs doesn't overwhelm GPU memory since at any time 7 jobs are running in parallel. However I see that the jobs in all the terminals freeze after a total of nearly 5K jobs are finished. Can you help me why is that so and how can I prevent that?
Hi @rohitsatyam ,
I'm not sure of the exact issue, but the highest throughput and most reliable (no freezing, crashing, etc.) inference will be achieved by running one process on your system - otherwise the registry can get overwhelmed. The simple inference script pasted below should run faster and more reliably than your current method. Please use something similar to this so that the model is initialized once per inference run (otherwise you are severely inflating our download count). Hopefully it is helpful. Please let me know if you have any other questions.
import torch
import re
import argparse
import pandas as pd
from transformers import BertForSequenceClassification, BertTokenizer
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple, Dict
from tqdm.auto import tqdm
class PairDataset(Dataset):
def __init__(self, sequences_a: List[str], sequences_b: List[str]):
self.sequences_a = sequences_a
self.sequences_b = sequences_b
def __len__(self):
return len(self.sequences_a)
def __getitem__(self, idx: int) -> Tuple[str, str]:
return self.sequences_a[idx], self.sequences_b[idx]
class PairCollator:
def __init__(self, tokenizer, max_length=1024):
self.tokenizer = tokenizer
self.max_length = max_length
def sanitize_seq(self, seq: str) -> str:
seq = ' '.join(list(re.sub(r'[UZOB]', 'X', seq)))
return seq
def __call__(self, batch: List[Tuple[str, str]]) -> Dict[str, torch.Tensor]:
seqs_a, seqs_b, = zip(*batch)
seqs = []
for a, b in zip(seqs_a, seqs_b):
seq = self.sanitize_seq(a) + ' [SEP] ' + self.sanitize_seq(b)
seqs.append(seq)
seqs = self.tokenizer(seqs, padding='longest', truncation=True, max_length=self.max_length, return_tensors='pt')
return {
'input_ids': seqs['input_ids'],
'attention_mask': seqs['attention_mask'],
}
def main(args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"Loading model from {args.model_path}")
model = BertForSequenceClassification.from_pretrained(args.model_path, attn_implementation="sdpa").eval().to(device)
# When using PyTorch >= 2.5.1 on a linux machine, spda attention will greatly speed up inference
tokenizer = BertTokenizer.from_pretrained(args.model_path)
print(f"Tokenizer loaded")
"""
Load your data into two lists of sequences, where you want the PPI for each pair sequences_a[i], sequences_b[i]
We recommend trimmed sequence pairs that sum over 1022 tokens (for the 1024 max length limit of SYNTERACT)
We also recommend sorting the sequences by length in descending order, as this will speed up inference by reducing padding
Example:
from datasets import load_dataset
data = load_dataset('Synthyra/NEGATOME', split='combined')
# Filter out examples where the total length exceeds 1022
data = data.filter(lambda x: len(x['SeqA']) + len(x['SeqB']) <= 1022)
# Add a new column 'total_length' that is the sum of lengths of SeqA and SeqB
data = data.map(lambda x: {"total_length": len(x['SeqA']) + len(x['SeqB'])})
# Sort the dataset by 'total_length' in descending order (longest sequences first)
data = data.sort("total_length", reverse=True)
# Now retrieve the sorted sequences
sequences_a = data['SeqA']
sequences_b = data['SeqB']
"""
print("Loading data...")
sequences_a = []
sequences_b = []
print("Creating torch dataset...")
pair_dataset = PairDataset(sequences_a, sequences_b)
pair_collator = PairCollator(tokenizer, max_length=1024)
data_loader = DataLoader(pair_dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=pair_collator)
all_seqs_a = []
all_seqs_b = []
all_probs = []
all_preds = []
print("Starting inference...")
with torch.no_grad():
for i, batch in enumerate(tqdm(data_loader, total=len(data_loader), desc="Batches processed")):
# Because sequences are sorted, the initial estimate for time will be much longer than the actual time it will take
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
logits = model(input_ids, attention_mask=attention_mask).logits.detach().cpu()
prob_of_interaction = torch.softmax(logits, dim=1)[:, 1] # can do 1 - this for no interaction prob
pred = torch.argmax(logits, dim=1)
# Store results
batch_start = i * args.batch_size
batch_end = min((i + 1) * args.batch_size, len(sequences_a))
all_seqs_a.extend(sequences_a[batch_start:batch_end])
all_seqs_b.extend(sequences_b[batch_start:batch_end])
all_probs.extend(prob_of_interaction.tolist())
all_preds.extend(pred.tolist())
# round to 5 decimal places
all_probs = [round(prob, 5) for prob in all_probs]
# Create dataframe and save to CSV
results_df = pd.DataFrame({
'sequence_a': all_seqs_a,
'sequence_b': all_seqs_b,
'probabilities': all_probs,
'prediction': all_preds
})
print(f"Saving results to {args.save_path}")
results_df.to_csv(args.save_path, index=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default='GleghornLab/SYNTERACT')
parser.add_argument('--save_path', type=str, default='ppi_predictions.csv')
parser.add_argument('--batch_size', type=int, default=2)
parser.add_argument('--num_workers', type=int, default=0) # can increase to use multiprocessing for dataloader, 4 is a good value usually
args = parser.parse_args()
main(args)
The script produces a nice csv like this:
Synthyra will have a version of Synteract coming out soon, internally we are calling it SynteractTurbo. 6 million pairwise comparisons will take quite a while with Synteract1.0. If you would like to collaborate, we could run some inference for you with SyteractTurbo to see if the outputs are helpful. Reach out at [email protected] if you are interested.
Apologies for overwhelming your download counts. That wasn't my intention. Just wanted to perform quick PPI inference. Regarding the code above, I don't see argument that accepts multifasta file to perform all-vs-all PPI inference. Is that possible? Apologies for I am not very proficient in Python. Besides, can the code above be revised so that instead of providing sequences in columns (which will increase file size), it outputs only protein header? So far I was using the following code
import argparse
import re
import csv
import torch
import torch.nn.functional as F
from transformers import BertForSequenceClassification, BertTokenizer
from Bio import SeqIO
def load_model():
# Load model and tokenizer
model = BertForSequenceClassification.from_pretrained('GleghornLab/SYNTERACT')
tokenizer = BertTokenizer.from_pretrained('GleghornLab/SYNTERACT')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model.eval()
return model, tokenizer, device
def preprocess_sequence(sequence):
# Replace rare amino acids with X and add spaces between characters
return ' '.join(list(re.sub(r'[UZOB]', 'X', sequence)))
def predict_interaction(model, tokenizer, device, seq_a, seq_b):
# Preprocess sequences
sequence_a = preprocess_sequence(seq_a)
sequence_b = preprocess_sequence(seq_b)
example = sequence_a + ' [SEP] ' + sequence_b
# Tokenize example
example = tokenizer(example, return_tensors='pt', padding=False).to(device)
# Predict interaction
with torch.no_grad():
logits = model(**example).logits.cpu().detach()
probability = F.softmax(logits, dim=-1)
prediction = probability.argmax(dim=-1).item() # 0 for no interaction, 1 for interaction
confidence = probability[0, prediction].item()
return prediction, confidence
def read_fasta(file_path):
# Read sequences from a FASTA file
records = []
for record in SeqIO.parse(file_path, "fasta"):
records.append(record)
return records
def main():
parser = argparse.ArgumentParser(description="Predict interactions between protein sequences.")
parser.add_argument("--fasta1", required=True, help="Path to the first FASTA file.")
parser.add_argument("--fasta2", required=True, help="Path to the second FASTA file.")
parser.add_argument("--output", required=True, help="Path to the output CSV file.")
args = parser.parse_args()
# Load model and tokenizer
model, tokenizer, device = load_model()
# Read sequences from FASTA files
records_a = read_fasta(args.fasta1)
records_b = read_fasta(args.fasta2)
# Ensure the files have the same number of sequences
if len(records_a) != len(records_b):
raise ValueError("The two FASTA files must contain the same number of sequences.")
# Open CSV file for writing
with open(args.output, "w", newline="") as csvfile:
csvwriter = csv.writer(csvfile)
csvwriter.writerow(["Sequence_A", "Sequence_B", "Prediction", "Confidence"])
# Predict interactions for each pair of sequences
for record_a, record_b in zip(records_a, records_b):
prediction, confidence = predict_interaction(model, tokenizer, device, str(record_a.seq), str(record_b.seq))
csvwriter.writerow([record_a.id, record_b.id, prediction, confidence])
if __name__ == "__main__":
main()
The reason why i did it this way was to avoid symmetrical pairs (A-B and B-A). But I couldn't find a way to do that in python. So I remove such pairs outside python and used two sequence at a time. Is there a way to download the model once and reuse it again to decrease the overhead?
Regarding the collaboration, I need to concur with my PI first. Kindly allow me some time.
No worries
@rohitsatyam
! Hopefully the script below solves your issues. It can do all-vs-all or one file vs. another. If you build a matrix of all-vs-all proteins D x D where D is the dataset size, you can get rid of duplicates (A-B and B-A) by looking at entries where i > j
or j > i
. This script handles that so there is no extra computation. Please let me know if you have additional questions.
import argparse
import re
import torch
import torch.nn.functional as F
import pandas as pd
from transformers import BertForSequenceClassification, BertTokenizer
from torch.utils.data import Dataset, DataLoader
from Bio import SeqIO
from typing import List, Tuple, Dict
from tqdm.auto import tqdm
class ProteinPairDataset(Dataset):
"""
Expects a list of tuples:
(protein_header_A, protein_header_B, sequence_A, sequence_B)
"""
def __init__(self, pairs: List[Tuple[str, str, str, str]]):
self.pairs = pairs
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx: int) -> Tuple[str, str, str, str]:
return self.pairs[idx]
class PairCollator:
def __init__(self, tokenizer, max_length=1024):
self.tokenizer = tokenizer
self.max_length = max_length
def sanitize_seq(self, seq: str) -> str:
# Replace any ambiguous amino acids with 'X' and add spaces between characters
return ' '.join(list(re.sub(r'[UZOB]', 'X', seq)))
def __call__(self, batch: List[Tuple[str, str, str, str]]) -> Dict:
# Unpack headers and sequences
headers_a, headers_b, seqs_a, seqs_b = zip(*batch)
sequences = []
for seq_a, seq_b in zip(seqs_a, seqs_b):
# Tokenizer expects a single string; combine the sanitized sequences with a [SEP]
combined = self.sanitize_seq(seq_a) + " [SEP] " + self.sanitize_seq(seq_b)
sequences.append(combined)
# Tokenize the combined sequences
tokenized = self.tokenizer(
sequences,
padding='longest',
truncation=True,
max_length=self.max_length,
return_tensors='pt'
)
# Also return the headers so we can write them out later
tokenized["headers_a"] = headers_a # tuple of strings
tokenized["headers_b"] = headers_b
return tokenized
def read_fasta(file_path: str):
"""Read sequences from a FASTA file using Biopython."""
return list(SeqIO.parse(file_path, "fasta"))
def generate_pairs_single(records: List) -> List[Tuple[str, str, str, str]]:
"""
For all-vs-all inference: generate unique pairs from a single FASTA.
Only one ordering is kept (i.e. if record A is paired with B, then B–A is omitted).
"""
pairs = []
n = len(records)
for i in range(n):
for j in range(i + 1, n):
pairs.append((records[i].id, records[j].id, str(records[i].seq), str(records[j].seq)))
return pairs
def generate_pairs_two(records1: List, records2: List) -> List[Tuple[str, str, str, str]]:
"""
For paired inference: generate pairs by zipping the two FASTA files.
The two files must contain the same number of sequences.
"""
if len(records1) != len(records2):
raise ValueError("For paired inference, the two FASTA files must contain the same number of sequences.")
pairs = []
for rec1, rec2 in zip(records1, records2):
pairs.append((rec1.id, rec2.id, str(rec1.seq), str(rec2.seq)))
return pairs
def main(args):
# Decide on device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load model and tokenizer only once
print(f"Loading model from {args.model_path}")
# The parameter `attn_implementation="sdpa"` may speed up inference with PyTorch>=2.5.1 on Linux
model = BertForSequenceClassification.from_pretrained(args.model_path, attn_implementation="sdpa")
model.eval().to(device)
tokenizer = BertTokenizer.from_pretrained(args.model_path)
print("Model and tokenizer loaded.\n")
# Determine input mode: all-vs-all (single FASTA) or paired (two FASTA files)
if args.fasta:
print(f"Reading FASTA from {args.fasta} for all-vs-all inference.")
records = read_fasta(args.fasta)
pairs = generate_pairs_single(records)
elif args.fasta1 and args.fasta2:
print(f"Reading FASTA files:\n FASTA1: {args.fasta1}\n FASTA2: {args.fasta2}")
records1 = read_fasta(args.fasta1)
records2 = read_fasta(args.fasta2)
pairs = generate_pairs_two(records1, records2)
else:
raise ValueError("Please provide either --fasta for all-vs-all mode or both --fasta1 and --fasta2 for paired inference.")
if not pairs:
print("No pairs to process. Exiting.")
return
print(f"Total pairs to process: {len(pairs)}")
# Create the dataset and dataloader
dataset = ProteinPairDataset(pairs)
collator = PairCollator(tokenizer, max_length=args.max_length)
data_loader = DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
collate_fn=collator,
shuffle=False
)
# Prepare lists to store results
all_headers_a = []
all_headers_b = []
all_predictions = []
all_confidences = []
print("\nStarting batched inference...")
with torch.no_grad():
for batch in tqdm(data_loader, desc="Batches processed", leave=True):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
# Forward pass
logits = model(input_ids, attention_mask=attention_mask).logits.cpu()
# Compute prediction and confidence for each pair
pred = torch.argmax(logits, dim=1)
probs = torch.softmax(logits, dim=1)
# Confidence: probability for the predicted class
confidence = probs.gather(1, pred.unsqueeze(1)).squeeze(1)
# Append batch results; note that headers are passed through by the collator
all_headers_a.extend(batch['headers_a'])
all_headers_b.extend(batch['headers_b'])
all_predictions.extend(pred.tolist())
# Round confidences to 5 decimal places
all_confidences.extend([round(conf.item(), 5) for conf in confidence])
# Create a DataFrame with the headers and prediction info
results_df = pd.DataFrame({
"Protein_A": all_headers_a,
"Protein_B": all_headers_b,
"Prediction": all_predictions,
"Confidence": all_confidences
})
print(f"\nSaving results to {args.output}")
results_df.to_csv(args.output, index=False)
print("Done.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="PPI Inference Script for SYNTERACT Model: "
"Use a single FASTA file for all-vs-all (unique pairs) or two FASTA files for paired inference."
)
# Input file options: either a single FASTA or a pair of FASTA files
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--fasta", type=str, help="Path to a FASTA file for all-vs-all inference.")
group.add_argument("--fasta1", type=str, help="Path to the first FASTA file for paired inference.")
parser.add_argument("--fasta2", type=str, help="Path to the second FASTA file for paired inference. Required if --fasta1 is provided.")
parser.add_argument("--model_path", type=str, default="GleghornLab/SYNTERACT", help="Path or identifier of the pretrained model.")
parser.add_argument("--output", type=str, default="ppi_predictions.csv", help="Path to output CSV file.")
parser.add_argument("--batch_size", type=int, default=2, help="Batch size for inference.")
parser.add_argument("--num_workers", type=int, default=0, help="Number of worker processes for the dataloader.")
parser.add_argument("--max_length", type=int, default=1024, help="Maximum sequence length for tokenization.")
args = parser.parse_args()
# In paired mode, ensure both FASTA files are provided
if args.fasta1 and not args.fasta2:
parser.error("--fasta1 requires --fasta2 for paired inference.")
main(args)