Spaces:

aksell commited on
Commit
ebbe380
1 Parent(s): 466a8f2

Add ModelType enum and Model class to hold layers and head count

Browse files
protention/attention.py CHANGED
@@ -2,14 +2,24 @@ from enum import Enum
2
  from io import StringIO
3
  from urllib import request
4
 
 
5
  import torch
6
  from Bio.PDB import PDBParser, Polypeptide, Structure
7
  from tape import ProteinBertModel, TAPETokenizer
8
  from transformers import T5EncoderModel, T5Tokenizer
9
 
10
 
11
- class Model(str, Enum):
12
- tape_bert = "bert-base"
 
 
 
 
 
 
 
 
 
13
 
14
  def get_structure(pdb_code: str) -> Structure:
15
  """
@@ -56,9 +66,9 @@ def get_tape_bert() -> tuple[TAPETokenizer, ProteinBertModel]:
56
  model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
57
  return tokenizer, model
58
 
59
-
60
  def get_attention(
61
- pdb_code: str, model: Model = Model.tape_bert
62
  ):
63
  """
64
  Get attention from T5
@@ -70,8 +80,8 @@ def get_attention(
70
  # TODO handle multiple sequences
71
  sequence = sequences[0]
72
 
73
- match model:
74
- case model.tape_bert:
75
  tokenizer, model = get_tape_bert()
76
  token_idxs = tokenizer.encode(sequence).tolist()
77
  inputs = torch.tensor(token_idxs).unsqueeze(0)
@@ -80,9 +90,10 @@ def get_attention(
80
  # Remove attention from <CLS> (first) and <SEP> (last) token
81
  attns = [attn[:, :, 1:-1, 1:-1] for attn in attns]
82
  attns = torch.stack([attn.squeeze(0) for attn in attns])
83
- case model.prot_T5:
84
  # Space separate sequences
85
  sequences = [" ".join(sequence) for sequence in sequences]
86
  tokenizer, model = get_protT5()
87
 
88
- return attns
 
 
2
  from io import StringIO
3
  from urllib import request
4
 
5
+ import streamlit as st
6
  import torch
7
  from Bio.PDB import PDBParser, Polypeptide, Structure
8
  from tape import ProteinBertModel, TAPETokenizer
9
  from transformers import T5EncoderModel, T5Tokenizer
10
 
11
 
12
+ class ModelType(str, Enum):
13
+ TAPE_BERT = "bert-base"
14
+ PROT_T5 = "prot_t5_xl_half_uniref50-enc"
15
+
16
+
17
+ class Model:
18
+ def __init__(self, name, layers, heads):
19
+ self.name: ModelType = name
20
+ self.layers: int = layers
21
+ self.heads: int = heads
22
+
23
 
24
  def get_structure(pdb_code: str) -> Structure:
25
  """
 
66
  model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
67
  return tokenizer, model
68
 
69
+ @st.cache
70
  def get_attention(
71
+ pdb_code: str, model: ModelType = ModelType.TAPE_BERT
72
  ):
73
  """
74
  Get attention from T5
 
80
  # TODO handle multiple sequences
81
  sequence = sequences[0]
82
 
83
+ match model.name:
84
+ case ModelType.TAPE_BERT:
85
  tokenizer, model = get_tape_bert()
86
  token_idxs = tokenizer.encode(sequence).tolist()
87
  inputs = torch.tensor(token_idxs).unsqueeze(0)
 
90
  # Remove attention from <CLS> (first) and <SEP> (last) token
91
  attns = [attn[:, :, 1:-1, 1:-1] for attn in attns]
92
  attns = torch.stack([attn.squeeze(0) for attn in attns])
93
+ case ModelType.PROT_T5:
94
  # Space separate sequences
95
  sequences = [" ".join(sequence) for sequence in sequences]
96
  tokenizer, model = get_protT5()
97
 
98
+ return attns
99
+
protention/streamlit/Attention_On_Structure.py CHANGED
@@ -3,21 +3,31 @@ import stmol
3
  import streamlit as st
4
  from stmol import showmol
5
 
 
 
6
  st.sidebar.title("pLM Attention Visualization")
7
 
8
  st.title("pLM Attention Visualization")
9
 
 
 
 
 
 
 
 
 
10
  pdb_id = st.text_input("PDB ID", "4RW0")
11
- chain_id = None
12
 
13
  left, right = st.columns(2)
14
  with left:
15
- layer = st.number_input("Layer", value=8)
16
  with right:
17
- head = st.number_input("Head", value=5)
18
 
19
  min_attn = st.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.15)
20
 
 
21
 
22
  def get_3dview(pdb):
23
  xyzview = py3Dmol.view(query=f"pdb:{pdb}")
 
3
  import streamlit as st
4
  from stmol import showmol
5
 
6
+ from protention.attention import Model, ModelType, get_attention
7
+
8
  st.sidebar.title("pLM Attention Visualization")
9
 
10
  st.title("pLM Attention Visualization")
11
 
12
+ # Define list of model types
13
+ models = [
14
+ Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
15
+ ]
16
+
17
+ selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=0)
18
+ selected_model = next((model for model in models if model.name.value == selected_model_name), None)
19
+
20
  pdb_id = st.text_input("PDB ID", "4RW0")
 
21
 
22
  left, right = st.columns(2)
23
  with left:
24
+ layer = st.number_input("Layer", value=1, min_value=1, max_value=selected_model.layers)
25
  with right:
26
+ head = st.number_input("Head", value=1, min_value=1, max_value=selected_model.heads)
27
 
28
  min_attn = st.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.15)
29
 
30
+ attention = get_attention(pdb_id, model=selected_model.name)
31
 
32
  def get_3dview(pdb):
33
  xyzview = py3Dmol.view(query=f"pdb:{pdb}")
tests/test_attention.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  from Bio.PDB.Structure import Structure
3
  from transformers import T5EncoderModel, T5Tokenizer
4
 
5
- from protention.attention import (Model, get_attention, get_protT5,
6
  get_sequences, get_structure)
7
 
8
 
@@ -38,7 +38,7 @@ def test_get_protT5():
38
 
39
  def test_get_attention_tape():
40
 
41
- result = get_attention("1AKE", model=Model.tape_bert)
42
 
43
  assert result is not None
44
  assert result.shape == torch.Size([12,12,456,456])
 
2
  from Bio.PDB.Structure import Structure
3
  from transformers import T5EncoderModel, T5Tokenizer
4
 
5
+ from protention.attention import (ModelType, get_attention, get_protT5,
6
  get_sequences, get_structure)
7
 
8
 
 
38
 
39
  def test_get_attention_tape():
40
 
41
+ result = get_attention("1AKE", model=ModelType.tape_bert)
42
 
43
  assert result is not None
44
  assert result.shape == torch.Size([12,12,456,456])