PLTNUM-ESM2-HeLa / README.md
sagawa's picture
Update README.md
2df4a99 verified
metadata
license: mit
tags:
  - biology
  - protein

PLTNUM-ESM2-HeLa

PLTNUM is a protein language model trained to predict protein half-lives based on their sequences.
This model was created based on facebook/esm2_t33_650M_UR50D and trained on protein half-life dataset of HeLa human cell line (paper link).

Model Sources

Uses

How to Get Started with the Model

Use the code below to get started with the model.

from torch import sigmoid
import torch.nn as nn
from transformers import AutoModel, AutoConfig, PreTrainedModel, AutoTokenizer


class PLTNUM_PreTrainedModel(PreTrainedModel):
    config_class = AutoConfig

    def __init__(self, config):
        super(PLTNUM_PreTrainedModel, self).__init__(config)
        self.model = AutoModel.from_pretrained(self.config._name_or_path)

        self.fc_dropout1 = nn.Dropout(0.8)
        self.fc_dropout2 = nn.Dropout(0.4)
        self.fc = nn.Linear(self.config.hidden_size, 1)
        self._init_weights(self.fc)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                nn.init.constant_(module.weight[module.padding_idx], 0.0)
        elif isinstance(module, nn.LayerNorm):
            nn.init.constant_(module.bias, 0)
            nn.init.constant_(module.weight, 1.0)

    def forward(self, inputs):
        outputs = self.model(**inputs)
        last_hidden_state = outputs.last_hidden_state[:, 0]
        output = (
            self.fc(self.fc_dropout1(last_hidden_state))
            + self.fc(self.fc_dropout2(last_hidden_state))
        ) / 2
        return output

    def create_embedding(self, inputs):
        outputs = self.model(**inputs)
        last_hidden_state = outputs.last_hidden_state[:, 0]
        return last_hidden_state


model = PLTNUM_PreTrainedModel.from_pretrained("sagawa/PLTNUM-ESM2-HeLa")
tokenizer = AutoTokenizer.from_pretrained("sagawa/PLTNUM-ESM2-HeLa")
seq = "MSGRGKQGGKARAKAKTRSSRAGLQFPVGRVHRLLRKGNYSERVGAGAPVYLAAVLEYLTAEILELAGNAARDNKKTRIIPRHLQLAIRNDEELNKLLGRVTIAQGGVLPNIQAVLLPKKTESHHKPKGK"
input = tokenizer(
    [seq],
    add_special_tokens=True,
    max_length=512,
    padding="max_length",
    truncation=True,
    return_offsets_mapping=False,
    return_attention_mask=True,
    return_tensors="pt",
)
print(sigmoid(model(input)))

Citation

Prediction of Protein Half-lives from Amino Acid Sequences by Protein Language Models
Tatsuya Sagawa, Eisuke Kanao, Kosuke Ogata, Koshi Imami, Yasushi Ishihama
bioRxiv 2024.09.10.612367; doi: https://doi.org/10.1101/2024.09.10.612367