Synteract freezes when running in loop after approx 5K pairwise PPI comparison

#15
by rohitsatyam - opened

Hi @colinhorger @gleghorn

I am trying to run synteract using a while loop as shown below where I fetch two protein sequence at a time and loop synteract.py over and over again. I have to do it for 6 million pairwise comparison. I do this by activating the conda environment and opening 7 separate tabs with same enviornment and run 2K jobs in each. These jobs doesn't overwhelm GPU memory since at any time 7 jobs are running in parallel. However I see that the jobs in all the terminals freeze after a total of nearly 5K jobs are finished. Can you help me why is that so and how can I prevent that?

Screenshot_2025-02-11-22-33-02-16_68ed9935803a63844709bbd59cb0bdde.jpg

IMG_20250211_224033.jpg

Gleghorn Lab org

Hi @rohitsatyam ,

I'm not sure of the exact issue, but the highest throughput and most reliable (no freezing, crashing, etc.) inference will be achieved by running one process on your system - otherwise the registry can get overwhelmed. The simple inference script pasted below should run faster and more reliably than your current method. Please use something similar to this so that the model is initialized once per inference run (otherwise you are severely inflating our download count). Hopefully it is helpful. Please let me know if you have any other questions.

import torch
import re
import argparse
import pandas as pd
from transformers import BertForSequenceClassification, BertTokenizer
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple, Dict
from tqdm.auto import tqdm


class PairDataset(Dataset):
    def __init__(self, sequences_a: List[str], sequences_b: List[str]):
        self.sequences_a = sequences_a
        self.sequences_b = sequences_b

    def __len__(self):
        return len(self.sequences_a)

    def __getitem__(self, idx: int) -> Tuple[str, str]:
        return self.sequences_a[idx], self.sequences_b[idx]
    

class PairCollator:
    def __init__(self, tokenizer, max_length=1024):
        self.tokenizer = tokenizer
        self.max_length = max_length

    def sanitize_seq(self, seq: str) -> str:
        seq = ' '.join(list(re.sub(r'[UZOB]', 'X', seq)))
        return seq

    def __call__(self, batch: List[Tuple[str, str]]) -> Dict[str, torch.Tensor]:
        seqs_a, seqs_b, = zip(*batch)
        seqs = []
        for a, b in zip(seqs_a, seqs_b):
            seq = self.sanitize_seq(a) + ' [SEP] ' + self.sanitize_seq(b)
            seqs.append(seq)
        seqs = self.tokenizer(seqs, padding='longest', truncation=True, max_length=self.max_length, return_tensors='pt')
        return {
            'input_ids': seqs['input_ids'],
            'attention_mask': seqs['attention_mask'],
        }


def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    print(f"Loading model from {args.model_path}")
    model = BertForSequenceClassification.from_pretrained(args.model_path, attn_implementation="sdpa").eval().to(device)
    # When using PyTorch >= 2.5.1 on a linux machine, spda attention will greatly speed up inference
    tokenizer = BertTokenizer.from_pretrained(args.model_path)
    print(f"Tokenizer loaded")

    """
    Load your data into two lists of sequences, where you want the PPI for each pair sequences_a[i], sequences_b[i]
    We recommend trimmed sequence pairs that sum over 1022 tokens (for the 1024 max length limit of SYNTERACT)
    We also recommend sorting the sequences by length in descending order, as this will speed up inference by reducing padding

    Example:
        from datasets import load_dataset
        data = load_dataset('Synthyra/NEGATOME', split='combined')
        # Filter out examples where the total length exceeds 1022
        data = data.filter(lambda x: len(x['SeqA']) + len(x['SeqB']) <= 1022)
        # Add a new column 'total_length' that is the sum of lengths of SeqA and SeqB
        data = data.map(lambda x: {"total_length": len(x['SeqA']) + len(x['SeqB'])})
        # Sort the dataset by 'total_length' in descending order (longest sequences first)
        data = data.sort("total_length", reverse=True)
        # Now retrieve the sorted sequences
        sequences_a = data['SeqA']
        sequences_b = data['SeqB']
    """
    print("Loading data...")
    sequences_a = []
    sequences_b = []

    print("Creating torch dataset...")
    pair_dataset = PairDataset(sequences_a, sequences_b)
    pair_collator = PairCollator(tokenizer, max_length=1024)
    data_loader = DataLoader(pair_dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=pair_collator)

    all_seqs_a = []
    all_seqs_b = []
    all_probs = []
    all_preds = []

    print("Starting inference...")
    with torch.no_grad():
        for i, batch in enumerate(tqdm(data_loader, total=len(data_loader), desc="Batches processed")):
            # Because sequences are sorted, the initial estimate for time will be much longer than the actual time it will take
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            logits = model(input_ids, attention_mask=attention_mask).logits.detach().cpu()

            prob_of_interaction = torch.softmax(logits, dim=1)[:, 1] # can do 1 - this for no interaction prob
            pred = torch.argmax(logits, dim=1)

            # Store results
            batch_start = i * args.batch_size
            batch_end = min((i + 1) * args.batch_size, len(sequences_a))
            all_seqs_a.extend(sequences_a[batch_start:batch_end])
            all_seqs_b.extend(sequences_b[batch_start:batch_end])
            all_probs.extend(prob_of_interaction.tolist())
            all_preds.extend(pred.tolist())

    # round to 5 decimal places
    all_probs = [round(prob, 5) for prob in all_probs]

    # Create dataframe and save to CSV
    results_df = pd.DataFrame({
        'sequence_a': all_seqs_a,
        'sequence_b': all_seqs_b,
        'probabilities': all_probs,
        'prediction': all_preds
    })
    print(f"Saving results to {args.save_path}")
    results_df.to_csv(args.save_path, index=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default='GleghornLab/SYNTERACT')
    parser.add_argument('--save_path', type=str, default='ppi_predictions.csv')
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--num_workers', type=int, default=0) # can increase to use multiprocessing for dataloader, 4 is a good value usually
    args = parser.parse_args()

    main(args)

The script produces a nice csv like this:

image.png

Gleghorn Lab org

Synthyra will have a version of Synteract coming out soon, internally we are calling it SynteractTurbo. 6 million pairwise comparisons will take quite a while with Synteract1.0. If you would like to collaborate, we could run some inference for you with SyteractTurbo to see if the outputs are helpful. Reach out at [email protected] if you are interested.

Apologies for overwhelming your download counts. That wasn't my intention. Just wanted to perform quick PPI inference. Regarding the code above, I don't see argument that accepts multifasta file to perform all-vs-all PPI inference. Is that possible? Apologies for I am not very proficient in Python. Besides, can the code above be revised so that instead of providing sequences in columns (which will increase file size), it outputs only protein header? So far I was using the following code

import argparse
import re
import csv
import torch
import torch.nn.functional as F
from transformers import BertForSequenceClassification, BertTokenizer
from Bio import SeqIO

def load_model():
    # Load model and tokenizer
    model = BertForSequenceClassification.from_pretrained('GleghornLab/SYNTERACT')
    tokenizer = BertTokenizer.from_pretrained('GleghornLab/SYNTERACT')
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)
    model.eval()
    return model, tokenizer, device

def preprocess_sequence(sequence):
    # Replace rare amino acids with X and add spaces between characters
    return ' '.join(list(re.sub(r'[UZOB]', 'X', sequence)))

def predict_interaction(model, tokenizer, device, seq_a, seq_b):
    # Preprocess sequences
    sequence_a = preprocess_sequence(seq_a)
    sequence_b = preprocess_sequence(seq_b)
    example = sequence_a + ' [SEP] ' + sequence_b
    
    # Tokenize example
    example = tokenizer(example, return_tensors='pt', padding=False).to(device)
    
    # Predict interaction
    with torch.no_grad():
         logits = model(**example).logits.cpu().detach()
    
    probability = F.softmax(logits, dim=-1)
    prediction = probability.argmax(dim=-1).item()  # 0 for no interaction, 1 for interaction
    confidence = probability[0, prediction].item()
    
    return prediction, confidence

def read_fasta(file_path):
    # Read sequences from a FASTA file
    records = []
    for record in SeqIO.parse(file_path, "fasta"):
        records.append(record)
    return records

def main():
    parser = argparse.ArgumentParser(description="Predict interactions between protein sequences.")
    parser.add_argument("--fasta1", required=True, help="Path to the first FASTA file.")
    parser.add_argument("--fasta2", required=True, help="Path to the second FASTA file.")
    parser.add_argument("--output", required=True, help="Path to the output CSV file.")
    args = parser.parse_args()

    # Load model and tokenizer
    model, tokenizer, device = load_model()

    # Read sequences from FASTA files
    records_a = read_fasta(args.fasta1)
    records_b = read_fasta(args.fasta2)

  # Ensure the files have the same number of sequences
    if len(records_a) != len(records_b):
        raise ValueError("The two FASTA files must contain the same number of sequences.")

    # Open CSV file for writing
    with open(args.output, "w", newline="") as csvfile:
        csvwriter = csv.writer(csvfile)
        csvwriter.writerow(["Sequence_A", "Sequence_B", "Prediction", "Confidence"])

        # Predict interactions for each pair of sequences
        for record_a, record_b in zip(records_a, records_b):
            prediction, confidence = predict_interaction(model, tokenizer, device, str(record_a.seq), str(record_b.seq))
            csvwriter.writerow([record_a.id, record_b.id, prediction, confidence])

if __name__ == "__main__":
    main()

The reason why i did it this way was to avoid symmetrical pairs (A-B and B-A). But I couldn't find a way to do that in python. So I remove such pairs outside python and used two sequence at a time. Is there a way to download the model once and reuse it again to decrease the overhead?
Regarding the collaboration, I need to concur with my PI first. Kindly allow me some time.

Gleghorn Lab org
edited 12 days ago

No worries @rohitsatyam ! Hopefully the script below solves your issues. It can do all-vs-all or one file vs. another. If you build a matrix of all-vs-all proteins D x D where D is the dataset size, you can get rid of duplicates (A-B and B-A) by looking at entries where i > j or j > i. This script handles that so there is no extra computation. Please let me know if you have additional questions.

import argparse
import re
import torch
import torch.nn.functional as F
import pandas as pd
from transformers import BertForSequenceClassification, BertTokenizer
from torch.utils.data import Dataset, DataLoader
from Bio import SeqIO
from typing import List, Tuple, Dict
from tqdm.auto import tqdm


class ProteinPairDataset(Dataset):
    """
    Expects a list of tuples:
       (protein_header_A, protein_header_B, sequence_A, sequence_B)
    """
    def __init__(self, pairs: List[Tuple[str, str, str, str]]):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx: int) -> Tuple[str, str, str, str]:
        return self.pairs[idx]


class PairCollator:
    def __init__(self, tokenizer, max_length=1024):
        self.tokenizer = tokenizer
        self.max_length = max_length

    def sanitize_seq(self, seq: str) -> str:
        # Replace any ambiguous amino acids with 'X' and add spaces between characters
        return ' '.join(list(re.sub(r'[UZOB]', 'X', seq)))

    def __call__(self, batch: List[Tuple[str, str, str, str]]) -> Dict:
        # Unpack headers and sequences
        headers_a, headers_b, seqs_a, seqs_b = zip(*batch)
        sequences = []
        for seq_a, seq_b in zip(seqs_a, seqs_b):
            # Tokenizer expects a single string; combine the sanitized sequences with a [SEP]
            combined = self.sanitize_seq(seq_a) + " [SEP] " + self.sanitize_seq(seq_b)
            sequences.append(combined)
        # Tokenize the combined sequences
        tokenized = self.tokenizer(
            sequences,
            padding='longest',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        # Also return the headers so we can write them out later
        tokenized["headers_a"] = headers_a  # tuple of strings
        tokenized["headers_b"] = headers_b
        return tokenized

def read_fasta(file_path: str):
    """Read sequences from a FASTA file using Biopython."""
    return list(SeqIO.parse(file_path, "fasta"))

def generate_pairs_single(records: List) -> List[Tuple[str, str, str, str]]:
    """
    For all-vs-all inference: generate unique pairs from a single FASTA.
    Only one ordering is kept (i.e. if record A is paired with B, then B–A is omitted).
    """
    pairs = []
    n = len(records)
    for i in range(n):
        for j in range(i + 1, n):
            pairs.append((records[i].id, records[j].id, str(records[i].seq), str(records[j].seq)))
    return pairs

def generate_pairs_two(records1: List, records2: List) -> List[Tuple[str, str, str, str]]:
    """
    For paired inference: generate pairs by zipping the two FASTA files.
    The two files must contain the same number of sequences.
    """
    if len(records1) != len(records2):
        raise ValueError("For paired inference, the two FASTA files must contain the same number of sequences.")
    pairs = []
    for rec1, rec2 in zip(records1, records2):
        pairs.append((rec1.id, rec2.id, str(rec1.seq), str(rec2.seq)))
    return pairs

def main(args):
    # Decide on device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load model and tokenizer only once
    print(f"Loading model from {args.model_path}")
    # The parameter `attn_implementation="sdpa"` may speed up inference with PyTorch>=2.5.1 on Linux
    model = BertForSequenceClassification.from_pretrained(args.model_path, attn_implementation="sdpa")
    model.eval().to(device)
    tokenizer = BertTokenizer.from_pretrained(args.model_path)
    print("Model and tokenizer loaded.\n")

    # Determine input mode: all-vs-all (single FASTA) or paired (two FASTA files)
    if args.fasta:
        print(f"Reading FASTA from {args.fasta} for all-vs-all inference.")
        records = read_fasta(args.fasta)
        pairs = generate_pairs_single(records)
    elif args.fasta1 and args.fasta2:
        print(f"Reading FASTA files:\n  FASTA1: {args.fasta1}\n  FASTA2: {args.fasta2}")
        records1 = read_fasta(args.fasta1)
        records2 = read_fasta(args.fasta2)
        pairs = generate_pairs_two(records1, records2)
    else:
        raise ValueError("Please provide either --fasta for all-vs-all mode or both --fasta1 and --fasta2 for paired inference.")

    if not pairs:
        print("No pairs to process. Exiting.")
        return

    print(f"Total pairs to process: {len(pairs)}")

    # Create the dataset and dataloader
    dataset = ProteinPairDataset(pairs)
    collator = PairCollator(tokenizer, max_length=args.max_length)
    data_loader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        collate_fn=collator,
        shuffle=False
    )

    # Prepare lists to store results
    all_headers_a = []
    all_headers_b = []
    all_predictions = []
    all_confidences = []

    print("\nStarting batched inference...")
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Batches processed", leave=True):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            # Forward pass
            logits = model(input_ids, attention_mask=attention_mask).logits.cpu()
            # Compute prediction and confidence for each pair
            pred = torch.argmax(logits, dim=1)
            probs = torch.softmax(logits, dim=1)
            # Confidence: probability for the predicted class
            confidence = probs.gather(1, pred.unsqueeze(1)).squeeze(1)
            # Append batch results; note that headers are passed through by the collator
            all_headers_a.extend(batch['headers_a'])
            all_headers_b.extend(batch['headers_b'])
            all_predictions.extend(pred.tolist())
            # Round confidences to 5 decimal places
            all_confidences.extend([round(conf.item(), 5) for conf in confidence])

    # Create a DataFrame with the headers and prediction info
    results_df = pd.DataFrame({
        "Protein_A": all_headers_a,
        "Protein_B": all_headers_b,
        "Prediction": all_predictions,
        "Confidence": all_confidences
    })

    print(f"\nSaving results to {args.output}")
    results_df.to_csv(args.output, index=False)
    print("Done.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="PPI Inference Script for SYNTERACT Model: "
                    "Use a single FASTA file for all-vs-all (unique pairs) or two FASTA files for paired inference."
    )
    # Input file options: either a single FASTA or a pair of FASTA files
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("--fasta", type=str, help="Path to a FASTA file for all-vs-all inference.")
    group.add_argument("--fasta1", type=str, help="Path to the first FASTA file for paired inference.")
    parser.add_argument("--fasta2", type=str, help="Path to the second FASTA file for paired inference. Required if --fasta1 is provided.")
    
    parser.add_argument("--model_path", type=str, default="GleghornLab/SYNTERACT", help="Path or identifier of the pretrained model.")
    parser.add_argument("--output", type=str, default="ppi_predictions.csv", help="Path to output CSV file.")
    parser.add_argument("--batch_size", type=int, default=2, help="Batch size for inference.")
    parser.add_argument("--num_workers", type=int, default=0, help="Number of worker processes for the dataloader.")
    parser.add_argument("--max_length", type=int, default=1024, help="Maximum sequence length for tokenization.")
    
    args = parser.parse_args()

    # In paired mode, ensure both FASTA files are provided
    if args.fasta1 and not args.fasta2:
        parser.error("--fasta1 requires --fasta2 for paired inference.")
    
    main(args)

Sign up or log in to comment