Spaces:

aksell commited on
Commit
c8fb9e8
1 Parent(s): 5d3313d

Expander with pdb code, file or sequence input

Browse files

User picks one.
Still confusing with the caching of the files in session
storage. Need to figure out how to improve that.

hexviz/attention.py CHANGED
@@ -2,6 +2,7 @@ from io import StringIO
2
  from typing import List, Optional
3
  from urllib import request
4
 
 
5
  import streamlit as st
6
  import torch
7
  from Bio.PDB import PDBParser, Polypeptide, Structure
@@ -29,6 +30,15 @@ def get_pdb_file(pdb_code: str) -> Structure:
29
  file = StringIO(pdb_data)
30
  return file
31
 
 
 
 
 
 
 
 
 
 
32
 
33
  def get_chains(structure: Structure) -> List[str]:
34
  """
@@ -125,6 +135,7 @@ def get_attention_pairs(pdb_str: str, layer: int, head: int, chain_ids: Optional
125
  chains = list(structure.get_chains())
126
 
127
  attention_pairs = []
 
128
  for chain in chains:
129
  sequence = get_sequence(chain)
130
  attention = get_attention(sequence=sequence, model_type=model_type)
@@ -145,7 +156,6 @@ def get_attention_pairs(pdb_str: str, layer: int, head: int, chain_ids: Optional
145
 
146
  top_n_residues = sorted(residue_attention.items(), key=lambda x: x[1], reverse=True)[:top_n]
147
 
148
- top_residues = []
149
  for res, attn_sum in top_n_residues:
150
  coord = chain[res]["CA"].coord.tolist()
151
  top_residues.append((attn_sum, coord, chain.id, res))
 
2
  from typing import List, Optional
3
  from urllib import request
4
 
5
+ import requests
6
  import streamlit as st
7
  import torch
8
  from Bio.PDB import PDBParser, Polypeptide, Structure
 
30
  file = StringIO(pdb_data)
31
  return file
32
 
33
+ @st.cache
34
+ def get_pdb_from_seq(sequence: str) -> str:
35
+ """
36
+ Get structure from sequence
37
+ """
38
+ url = "https://api.esmatlas.com/foldSequence/v1/pdb/"
39
+ res = requests.post(url, data=sequence)
40
+ pdb_str = res.text
41
+ return pdb_str
42
 
43
  def get_chains(structure: Structure) -> List[str]:
44
  """
 
135
  chains = list(structure.get_chains())
136
 
137
  attention_pairs = []
138
+ top_residues = []
139
  for chain in chains:
140
  sequence = get_sequence(chain)
141
  attention = get_attention(sequence=sequence, model_type=model_type)
 
156
 
157
  top_n_residues = sorted(residue_attention.items(), key=lambda x: x[1], reverse=True)[:top_n]
158
 
 
159
  for res, attn_sum in top_n_residues:
160
  coord = chain[res]["CA"].coord.tolist()
161
  top_residues.append((attn_sum, coord, chain.id, res))
hexviz/pages/🗺️Identify_Interesting_Heads.py CHANGED
@@ -4,7 +4,7 @@ from hexviz.attention import get_attention, get_sequence, get_structure
4
  from hexviz.models import Model, ModelType
5
  from hexviz.plot import plot_tiled_heatmap
6
  from hexviz.view import (menu_items, select_heads_and_layers, select_model,
7
- select_pdb, select_sequence_slice)
8
 
9
  st.set_page_config(layout="wide", menu_items=menu_items)
10
  st.subheader("Find interesting heads and layers")
@@ -15,9 +15,16 @@ models = [
15
  Model(name=ModelType.ZymCTRL, layers=36, heads=16),
16
  ]
17
 
 
 
 
 
 
 
 
 
18
  selected_model = select_model(models)
19
 
20
- pdb_id = select_pdb()
21
 
22
  structure = get_structure(pdb_id)
23
 
 
4
  from hexviz.models import Model, ModelType
5
  from hexviz.plot import plot_tiled_heatmap
6
  from hexviz.view import (menu_items, select_heads_and_layers, select_model,
7
+ select_pdb, select_protein, select_sequence_slice)
8
 
9
  st.set_page_config(layout="wide", menu_items=menu_items)
10
  st.subheader("Find interesting heads and layers")
 
15
  Model(name=ModelType.ZymCTRL, layers=36, heads=16),
16
  ]
17
 
18
+ with st.expander("Input a PDB id, upload a PDB file or input a sequence"):
19
+ pdb_id = select_pdb()
20
+ uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
21
+ # TODO set max length of input sequence
22
+ input_sequence = st.text_area("3.Input sequence (Folded with ESMfold) Max 400 resis", "", max_chars=400)
23
+ pdb_str, structure, source = select_protein(pdb_id, uploaded_file, input_sequence)
24
+ st.write(f"Using: {source}")
25
+
26
  selected_model = select_model(models)
27
 
 
28
 
29
  structure = get_structure(pdb_id)
30
 
hexviz/view.py CHANGED
@@ -1,9 +1,10 @@
1
  from io import StringIO
2
 
 
3
  import streamlit as st
4
  from Bio.PDB import PDBParser
5
 
6
- from hexviz.attention import get_pdb_file
7
 
8
  menu_items = {
9
  "Get Help": "https://huggingface.co/spaces/aksell/hexviz/discussions/new",
@@ -41,19 +42,16 @@ def select_model(models):
41
  return select_model
42
 
43
  def select_pdb():
44
- st.sidebar.markdown(
45
- """
46
- Select Protein
47
- ---
48
- """)
49
  stored_pdb = st.session_state.get("pdb_id", None)
50
- pdb_id = st.sidebar.text_input(
51
- label="PDB ID",
52
  value=stored_pdb or "2FZ5")
53
  pdb_changed = stored_pdb != pdb_id
54
  if pdb_changed:
55
- st.session_state.selected_chains = None
56
- st.session_state.selected_chain_index = 0
 
 
57
  if "sequence_slice" in st.session_state:
58
  del st.session_state.sequence_slice
59
  if "uploaded_pdb_str" in st.session_state:
@@ -61,25 +59,31 @@ def select_pdb():
61
  st.session_state.pdb_id = pdb_id
62
  return pdb_id
63
 
64
- def select_protein(pdb_code, uploaded_file):
65
  # We get the pdb from 1 of 3 places:
66
  # 1. Cached pdb from session storage
67
  # 2. PDB file from uploaded file
68
  # 3. PDB file fetched based on the pdb_code input
69
  parser = PDBParser()
70
  if uploaded_file is not None:
71
- if "pdb_str" in st.session_state:
72
- del st.session_state.pdb_str
73
  pdb_str = uploaded_file.read().decode("utf-8")
74
  st.session_state["uploaded_pdb_str"] = pdb_str
75
- if "uploaded_pdb_str" in st.session_state:
 
76
  pdb_str = st.session_state.uploaded_pdb_str
 
 
 
 
 
 
77
  else:
78
  file = get_pdb_file(pdb_code)
79
  pdb_str = file.read()
 
80
 
81
  structure = parser.get_structure(pdb_code, StringIO(pdb_str))
82
- return pdb_str, structure
83
 
84
  def select_heads_and_layers(sidebar, model):
85
  sidebar.markdown(
 
1
  from io import StringIO
2
 
3
+ import requests
4
  import streamlit as st
5
  from Bio.PDB import PDBParser
6
 
7
+ from hexviz.attention import get_pdb_file, get_pdb_from_seq
8
 
9
  menu_items = {
10
  "Get Help": "https://huggingface.co/spaces/aksell/hexviz/discussions/new",
 
42
  return select_model
43
 
44
  def select_pdb():
 
 
 
 
 
45
  stored_pdb = st.session_state.get("pdb_id", None)
46
+ pdb_id = st.text_input(
47
+ label="1.PDB ID",
48
  value=stored_pdb or "2FZ5")
49
  pdb_changed = stored_pdb != pdb_id
50
  if pdb_changed:
51
+ if "selected_chains" in st.session_state:
52
+ del st.session_state.selected_chains
53
+ if "selected_chain_index" in st.session_state:
54
+ del st.session_state.selected_chain_index
55
  if "sequence_slice" in st.session_state:
56
  del st.session_state.sequence_slice
57
  if "uploaded_pdb_str" in st.session_state:
 
59
  st.session_state.pdb_id = pdb_id
60
  return pdb_id
61
 
62
+ def select_protein(pdb_code, uploaded_file, input_sequence):
63
  # We get the pdb from 1 of 3 places:
64
  # 1. Cached pdb from session storage
65
  # 2. PDB file from uploaded file
66
  # 3. PDB file fetched based on the pdb_code input
67
  parser = PDBParser()
68
  if uploaded_file is not None:
 
 
69
  pdb_str = uploaded_file.read().decode("utf-8")
70
  st.session_state["uploaded_pdb_str"] = pdb_str
71
+ source = f"uploaded pdb file {uploaded_file.name}"
72
+ elif "uploaded_pdb_str" in st.session_state:
73
  pdb_str = st.session_state.uploaded_pdb_str
74
+ source = f"Uploaded file stored in cache"
75
+ elif input_sequence:
76
+ pdb_str = get_pdb_from_seq(str(input_sequence))
77
+ if "selected_chains" in st.session_state:
78
+ del st.session_state.selected_chains
79
+ source = f"Input sequence + ESM-fold"
80
  else:
81
  file = get_pdb_file(pdb_code)
82
  pdb_str = file.read()
83
+ source = f"PDB ID: {pdb_code}"
84
 
85
  structure = parser.get_structure(pdb_code, StringIO(pdb_str))
86
+ return pdb_str, structure, source
87
 
88
  def select_heads_and_layers(sidebar, model):
89
  sidebar.markdown(
hexviz/🧬Attention_Visualization.py CHANGED
@@ -20,11 +20,18 @@ models = [
20
  Model(name=ModelType.ZymCTRL, layers=36, heads=16),
21
  ]
22
 
23
- pdb_id = select_pdb()
24
- with st.expander("Input sequence or upload PDB file"):
25
- uploaded_file = st.file_uploader("Upload PDB", type=["pdb"])
 
 
 
26
 
27
- pdb_str, structure = select_protein(pdb_id, uploaded_file)
 
 
 
 
28
  chains = get_chains(structure)
29
 
30
  selected_chains = st.sidebar.multiselect(label="Select Chain(s)", options=chains, default=st.session_state.get("selected_chains", None) or chains)
 
20
  Model(name=ModelType.ZymCTRL, layers=36, heads=16),
21
  ]
22
 
23
+ with st.expander("Input a PDB id, upload a PDB file or input a sequence"):
24
+ pdb_id = select_pdb()
25
+ uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
26
+ input_sequence = st.text_area("3.Input sequence", "")
27
+ pdb_str, structure, source = select_protein(pdb_id, uploaded_file, input_sequence)
28
+ st.write(f"Visualizing: {source}")
29
 
30
+ st.sidebar.markdown(
31
+ """
32
+ Configure visualization
33
+ ---
34
+ """)
35
  chains = get_chains(structure)
36
 
37
  selected_chains = st.sidebar.multiselect(label="Select Chain(s)", options=chains, default=st.session_state.get("selected_chains", None) or chains)