|
--- |
|
license: apache-2.0 |
|
--- |
|
# InterProt ESM2 SAE Models |
|
|
|
A set of SAE models trained on [ESM2-650](https://huggingface.co./facebook/esm2_t33_650M_UR50D) activations using protein sequences from [UniProt](https://www.uniprot.org/). The [InterProt website](https://interprot.com/) has an interactive visualizer of the SAE features. |
|
|
|
## Installation |
|
|
|
```bash |
|
pip install git+https://github.com/etowahadams/interprot.git |
|
``` |
|
|
|
## Usage |
|
|
|
Load the SAE |
|
```python |
|
from safetensors.torch import load_file |
|
from interprot.sae_model import SparseAutoencoder |
|
|
|
sae_model = SparseAutoencoder(1280, 4096) |
|
checkpoint_path = 'esm2_plm1280_l24_sae4096.safetensors' |
|
sae_model.load_state_dict(load_file(checkpoint_path)) |
|
``` |
|
|
|
Load ESM and run ESM inference -> SAE inference |
|
``` |
|
import torch |
|
from transformers import AutoTokenizer, EsmModel |
|
|
|
# Load ESM model and tokenizer |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
|
|
# Run ESM inference with some sequence and take layer 24 activations |
|
seq = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVVAAIVQDIAYLRSLGYNIVATPRGYVLAGG" |
|
esm_layer = 24 |
|
|
|
inputs = tokenizer([seq], padding=True, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = esm_model(**inputs, output_hidden_states=True) |
|
esm_layer_acts = outputs.hidden_states[esm_layer] # (1, sequence length + 2, 1280) |
|
|
|
# Run SAE inference with ESM activations as input |
|
sae_acts = sae_model.get_acts(esm_layer_acts) |
|
sae_acts # (1, sequence length + 2, 4096) |
|
``` |
|
|