File size: 1,722 Bytes
aa1790d 54f5bbf 3af2c7f 54f5bbf ed39825 8aea824 ed39825 8aea824 ed39825 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
---
license: apache-2.0
---
# InterProt ESM2 SAE Models
A set of SAE models trained on [ESM2-650](https://huggingface.co./facebook/esm2_t33_650M_UR50D) activations using 1M protein sequences from [UniProt](https://www.uniprot.org/). The SAE implementation mostly followed [Gao et al.](https://arxiv.org/abs/2406.04093) with Top-K activation function, though with much fewer latent dimensions.
Check out [https://interprot.com](https://interprot.com) for an interactive visualizer of the 4096-dimensional SAE on ESM layer 24.
## Installation
```bash
pip install git+https://github.com/etowahadams/interprot.git
```
## Usage
Load the SAE
```python
from safetensors.torch import load_file
from interprot.sae_model import SparseAutoencoder
sae_model = SparseAutoencoder(1280, 4096)
checkpoint_path = 'esm2_plm1280_l24_sae4096.safetensors'
sae_model.load_state_dict(load_file(checkpoint_path))
```
Load ESM and run ESM inference -> SAE inference
```
import torch
from transformers import AutoTokenizer, EsmModel
# Load ESM model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
# Run ESM inference with some sequence and take layer 24 activations
seq = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVVAAIVQDIAYLRSLGYNIVATPRGYVLAGG"
esm_layer = 24
inputs = tokenizer([seq], padding=True, return_tensors="pt")
with torch.no_grad():
outputs = esm_model(**inputs, output_hidden_states=True)
esm_layer_acts = outputs.hidden_states[esm_layer] # (1, sequence length + 2, 1280)
# Run SAE inference with ESM activations as input
sae_acts = sae_model.get_acts(esm_layer_acts)
sae_acts # (1, sequence length + 2, 4096)
```
|