Which Layer to use for Representations?
Hi there, I'm trying out your FastESM implementations and would like to know/understand which layer to use for protein representations? I understand, that in fact all layers could do the job - but which one is supposed to achieve the same results as Layer 33 in the original esm2_t33_650M_UR50D model. To me it is confusing - because the embedding shape, I retrieved in your implementation for esm2_650M is 1 x 220 x 1280.
Thanks in advance, Alex.
Ps:
How to handle this warning: Some weights of FastEsmModel were not initialized from the model checkpoint at Synthyra/ESM2-650M and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
I thought this model is already pretrained and ready to use
Hi there, I'm trying out your FastESM implementations and would like to know/understand which layer to use for protein representations? I understand, that in fact all layers could do the job - but which one is supposed to achieve the same results as Layer 33 in the original esm2_t33_650M_UR50D model. To me it is confusing - because the embedding shape, I retrieved in your implementation for esm2_650M is 1 x 220 x 1280.
Thanks for trying it out. The last hidden state is typically used for representations, this would be size (1, L, 1280)
for any protein of length L-2
(cls, eos tokens are added) for either this model or the original esm2_t33_650M_UR50D
. FastESM is trained a bit more than the original version, if you prefer that one you can use this link for our implementation which is exactly the same but faster.
How to handle this warning: Some weights of FastEsmModel were not initialized from the model checkpoint at Synthyra/ESM2-650M and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Those weights are newly initialized because they are not used in pretraining. The pooler weights weights multiply the cls token to get a pooled representation, for any protein length L
they output a fixed size (1, d)
, where d=1280
for this model. They need to be fine tuned for specific tasks. If you are using the model for mask filling or working with the last hidden state, this will not effect your results at all. By the way, the pooler layers are randomly initialized for any ESM model, not just this one.
Please let me know if you have any questions or features you'd like added.
Best,
Logan
Thanks for answering this quickly and your effort!
sequence = "MSEQWENCE"
tokenized = tokenizer(MSEQWENCE, padding=True, return_tensors='pt')
with torch.no_grad():
embedding = model(**tokenized).last_hidden_state
so did i understand that the embedding (which is to be used further for e.g. activity prediction or masking/demasking) of an 218 AA long protein sequence is a 2-dimensional tensor of shape 220 x 1280 for FAST-ESM650?
Or is my desired embedding at embedding.numpy()[0][10]?
I must admit, that I expected an 1-dimensional tensor of (1) x 1280 for every input sequence of every length.
For Reference, please find the Fair-ESM implementation Documentation (https://github.com/facebookresearch/esm) below
I understand this to obtain the representation from Layer 33 (-> results["representations"][33])
import torch
import esm
Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval() # disables dropout for deterministic results
Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
("protein2 with mask","KALTARQQEVFDLIRD
Extract per-residue representations (on CPU)
with torch.no_grad():
results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]
Thanks in Advance, Alex
Hey Alex,
I believe results["representations"][33]
would give you the last hidden state (batch_size, seq_len, 1280)
. The equivalent with huggingface based ESM models is the .last_hidden_state
mentioned in the Readme. If you want the cls token only (batch_size, 1280)
you can slice the last hidden state X
like this: cls = X[:, 0, :]
. You could also try mean pooling if you need fixed length representations from the model.
The pooler output discussed above X = model_output.pooler_output
is a learned linear projection of the cls token. By default, the weights are random, because you should fine tune these weights for a given downstream task.
So, if you need information for every residue, you can get it from the last hidden state by grabbing .last_hidden_state
from the models output. If you need layers that are not the last layer (ie, layer 31, or layer 2, whatever), you can get a tuple of all the hidden states by passing output_hidden_state=True
when you call the model and grabbing .hidden_states
from the model output. Then, if that is in X
you can get the 30th layer with X[30]
size (batch_size, seq_len, 1280)
. Note - this will actually return 34 hidden states because it stores the initial token embedding as well as the outputs of all 33 layers.
If you want a fixed length, 1D representation per protein you can pool the last hidden state. Popular methods include cls pooling or mean pooling. If you do mean pooling, make sure you take into account the attention mask if you input is batched. For example
def mean_pooling(x: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
input_mask_expanded = attention_mask.unsqueeze(-1).expand(x.size()).float()
return torch.sum(x * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
If you want to fine tune the model using a fixed length representation, you can train the model based on the outputs of pooler_output
. EsmForSequenceClassification
(or call AutoModelForSequenceClassification
in this case) does this automatically.
Please let me know if you have any other questions.
Best,
Logan
Hi Logan,
thanks again for your very fast answer. This is what I forgot and did not understand before.
Yeah, to obtain the whole sequence representation i needed to obtain all residue embeddings for each sequence first and then to create their mean.
This helped a lot and now I'm happy :)
Thanks for your efforts and best regards!
Alex