Expander with pdb code, file or sequence input
Browse filesUser picks one.
Still confusing with the caching of the files in session
storage. Need to figure out how to improve that.
- hexviz/attention.py +11 -1
- hexviz/pages/🗺️Identify_Interesting_Heads.py +9 -2
- hexviz/view.py +19 -15
- hexviz/🧬Attention_Visualization.py +11 -4
hexviz/attention.py
CHANGED
@@ -2,6 +2,7 @@ from io import StringIO
|
|
2 |
from typing import List, Optional
|
3 |
from urllib import request
|
4 |
|
|
|
5 |
import streamlit as st
|
6 |
import torch
|
7 |
from Bio.PDB import PDBParser, Polypeptide, Structure
|
@@ -29,6 +30,15 @@ def get_pdb_file(pdb_code: str) -> Structure:
|
|
29 |
file = StringIO(pdb_data)
|
30 |
return file
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
def get_chains(structure: Structure) -> List[str]:
|
34 |
"""
|
@@ -125,6 +135,7 @@ def get_attention_pairs(pdb_str: str, layer: int, head: int, chain_ids: Optional
|
|
125 |
chains = list(structure.get_chains())
|
126 |
|
127 |
attention_pairs = []
|
|
|
128 |
for chain in chains:
|
129 |
sequence = get_sequence(chain)
|
130 |
attention = get_attention(sequence=sequence, model_type=model_type)
|
@@ -145,7 +156,6 @@ def get_attention_pairs(pdb_str: str, layer: int, head: int, chain_ids: Optional
|
|
145 |
|
146 |
top_n_residues = sorted(residue_attention.items(), key=lambda x: x[1], reverse=True)[:top_n]
|
147 |
|
148 |
-
top_residues = []
|
149 |
for res, attn_sum in top_n_residues:
|
150 |
coord = chain[res]["CA"].coord.tolist()
|
151 |
top_residues.append((attn_sum, coord, chain.id, res))
|
|
|
2 |
from typing import List, Optional
|
3 |
from urllib import request
|
4 |
|
5 |
+
import requests
|
6 |
import streamlit as st
|
7 |
import torch
|
8 |
from Bio.PDB import PDBParser, Polypeptide, Structure
|
|
|
30 |
file = StringIO(pdb_data)
|
31 |
return file
|
32 |
|
33 |
+
@st.cache
|
34 |
+
def get_pdb_from_seq(sequence: str) -> str:
|
35 |
+
"""
|
36 |
+
Get structure from sequence
|
37 |
+
"""
|
38 |
+
url = "https://api.esmatlas.com/foldSequence/v1/pdb/"
|
39 |
+
res = requests.post(url, data=sequence)
|
40 |
+
pdb_str = res.text
|
41 |
+
return pdb_str
|
42 |
|
43 |
def get_chains(structure: Structure) -> List[str]:
|
44 |
"""
|
|
|
135 |
chains = list(structure.get_chains())
|
136 |
|
137 |
attention_pairs = []
|
138 |
+
top_residues = []
|
139 |
for chain in chains:
|
140 |
sequence = get_sequence(chain)
|
141 |
attention = get_attention(sequence=sequence, model_type=model_type)
|
|
|
156 |
|
157 |
top_n_residues = sorted(residue_attention.items(), key=lambda x: x[1], reverse=True)[:top_n]
|
158 |
|
|
|
159 |
for res, attn_sum in top_n_residues:
|
160 |
coord = chain[res]["CA"].coord.tolist()
|
161 |
top_residues.append((attn_sum, coord, chain.id, res))
|
hexviz/pages/🗺️Identify_Interesting_Heads.py
CHANGED
@@ -4,7 +4,7 @@ from hexviz.attention import get_attention, get_sequence, get_structure
|
|
4 |
from hexviz.models import Model, ModelType
|
5 |
from hexviz.plot import plot_tiled_heatmap
|
6 |
from hexviz.view import (menu_items, select_heads_and_layers, select_model,
|
7 |
-
select_pdb, select_sequence_slice)
|
8 |
|
9 |
st.set_page_config(layout="wide", menu_items=menu_items)
|
10 |
st.subheader("Find interesting heads and layers")
|
@@ -15,9 +15,16 @@ models = [
|
|
15 |
Model(name=ModelType.ZymCTRL, layers=36, heads=16),
|
16 |
]
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
selected_model = select_model(models)
|
19 |
|
20 |
-
pdb_id = select_pdb()
|
21 |
|
22 |
structure = get_structure(pdb_id)
|
23 |
|
|
|
4 |
from hexviz.models import Model, ModelType
|
5 |
from hexviz.plot import plot_tiled_heatmap
|
6 |
from hexviz.view import (menu_items, select_heads_and_layers, select_model,
|
7 |
+
select_pdb, select_protein, select_sequence_slice)
|
8 |
|
9 |
st.set_page_config(layout="wide", menu_items=menu_items)
|
10 |
st.subheader("Find interesting heads and layers")
|
|
|
15 |
Model(name=ModelType.ZymCTRL, layers=36, heads=16),
|
16 |
]
|
17 |
|
18 |
+
with st.expander("Input a PDB id, upload a PDB file or input a sequence"):
|
19 |
+
pdb_id = select_pdb()
|
20 |
+
uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
|
21 |
+
# TODO set max length of input sequence
|
22 |
+
input_sequence = st.text_area("3.Input sequence (Folded with ESMfold) Max 400 resis", "", max_chars=400)
|
23 |
+
pdb_str, structure, source = select_protein(pdb_id, uploaded_file, input_sequence)
|
24 |
+
st.write(f"Using: {source}")
|
25 |
+
|
26 |
selected_model = select_model(models)
|
27 |
|
|
|
28 |
|
29 |
structure = get_structure(pdb_id)
|
30 |
|
hexviz/view.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
from io import StringIO
|
2 |
|
|
|
3 |
import streamlit as st
|
4 |
from Bio.PDB import PDBParser
|
5 |
|
6 |
-
from hexviz.attention import get_pdb_file
|
7 |
|
8 |
menu_items = {
|
9 |
"Get Help": "https://huggingface.co/spaces/aksell/hexviz/discussions/new",
|
@@ -41,19 +42,16 @@ def select_model(models):
|
|
41 |
return select_model
|
42 |
|
43 |
def select_pdb():
|
44 |
-
st.sidebar.markdown(
|
45 |
-
"""
|
46 |
-
Select Protein
|
47 |
-
---
|
48 |
-
""")
|
49 |
stored_pdb = st.session_state.get("pdb_id", None)
|
50 |
-
pdb_id = st.
|
51 |
-
label="PDB ID",
|
52 |
value=stored_pdb or "2FZ5")
|
53 |
pdb_changed = stored_pdb != pdb_id
|
54 |
if pdb_changed:
|
55 |
-
st.session_state
|
56 |
-
|
|
|
|
|
57 |
if "sequence_slice" in st.session_state:
|
58 |
del st.session_state.sequence_slice
|
59 |
if "uploaded_pdb_str" in st.session_state:
|
@@ -61,25 +59,31 @@ def select_pdb():
|
|
61 |
st.session_state.pdb_id = pdb_id
|
62 |
return pdb_id
|
63 |
|
64 |
-
def select_protein(pdb_code, uploaded_file):
|
65 |
# We get the pdb from 1 of 3 places:
|
66 |
# 1. Cached pdb from session storage
|
67 |
# 2. PDB file from uploaded file
|
68 |
# 3. PDB file fetched based on the pdb_code input
|
69 |
parser = PDBParser()
|
70 |
if uploaded_file is not None:
|
71 |
-
if "pdb_str" in st.session_state:
|
72 |
-
del st.session_state.pdb_str
|
73 |
pdb_str = uploaded_file.read().decode("utf-8")
|
74 |
st.session_state["uploaded_pdb_str"] = pdb_str
|
75 |
-
|
|
|
76 |
pdb_str = st.session_state.uploaded_pdb_str
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
else:
|
78 |
file = get_pdb_file(pdb_code)
|
79 |
pdb_str = file.read()
|
|
|
80 |
|
81 |
structure = parser.get_structure(pdb_code, StringIO(pdb_str))
|
82 |
-
return pdb_str, structure
|
83 |
|
84 |
def select_heads_and_layers(sidebar, model):
|
85 |
sidebar.markdown(
|
|
|
1 |
from io import StringIO
|
2 |
|
3 |
+
import requests
|
4 |
import streamlit as st
|
5 |
from Bio.PDB import PDBParser
|
6 |
|
7 |
+
from hexviz.attention import get_pdb_file, get_pdb_from_seq
|
8 |
|
9 |
menu_items = {
|
10 |
"Get Help": "https://huggingface.co/spaces/aksell/hexviz/discussions/new",
|
|
|
42 |
return select_model
|
43 |
|
44 |
def select_pdb():
|
|
|
|
|
|
|
|
|
|
|
45 |
stored_pdb = st.session_state.get("pdb_id", None)
|
46 |
+
pdb_id = st.text_input(
|
47 |
+
label="1.PDB ID",
|
48 |
value=stored_pdb or "2FZ5")
|
49 |
pdb_changed = stored_pdb != pdb_id
|
50 |
if pdb_changed:
|
51 |
+
if "selected_chains" in st.session_state:
|
52 |
+
del st.session_state.selected_chains
|
53 |
+
if "selected_chain_index" in st.session_state:
|
54 |
+
del st.session_state.selected_chain_index
|
55 |
if "sequence_slice" in st.session_state:
|
56 |
del st.session_state.sequence_slice
|
57 |
if "uploaded_pdb_str" in st.session_state:
|
|
|
59 |
st.session_state.pdb_id = pdb_id
|
60 |
return pdb_id
|
61 |
|
62 |
+
def select_protein(pdb_code, uploaded_file, input_sequence):
|
63 |
# We get the pdb from 1 of 3 places:
|
64 |
# 1. Cached pdb from session storage
|
65 |
# 2. PDB file from uploaded file
|
66 |
# 3. PDB file fetched based on the pdb_code input
|
67 |
parser = PDBParser()
|
68 |
if uploaded_file is not None:
|
|
|
|
|
69 |
pdb_str = uploaded_file.read().decode("utf-8")
|
70 |
st.session_state["uploaded_pdb_str"] = pdb_str
|
71 |
+
source = f"uploaded pdb file {uploaded_file.name}"
|
72 |
+
elif "uploaded_pdb_str" in st.session_state:
|
73 |
pdb_str = st.session_state.uploaded_pdb_str
|
74 |
+
source = f"Uploaded file stored in cache"
|
75 |
+
elif input_sequence:
|
76 |
+
pdb_str = get_pdb_from_seq(str(input_sequence))
|
77 |
+
if "selected_chains" in st.session_state:
|
78 |
+
del st.session_state.selected_chains
|
79 |
+
source = f"Input sequence + ESM-fold"
|
80 |
else:
|
81 |
file = get_pdb_file(pdb_code)
|
82 |
pdb_str = file.read()
|
83 |
+
source = f"PDB ID: {pdb_code}"
|
84 |
|
85 |
structure = parser.get_structure(pdb_code, StringIO(pdb_str))
|
86 |
+
return pdb_str, structure, source
|
87 |
|
88 |
def select_heads_and_layers(sidebar, model):
|
89 |
sidebar.markdown(
|
hexviz/🧬Attention_Visualization.py
CHANGED
@@ -20,11 +20,18 @@ models = [
|
|
20 |
Model(name=ModelType.ZymCTRL, layers=36, heads=16),
|
21 |
]
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
uploaded_file = st.file_uploader("Upload PDB", type=["pdb"])
|
|
|
|
|
|
|
26 |
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
chains = get_chains(structure)
|
29 |
|
30 |
selected_chains = st.sidebar.multiselect(label="Select Chain(s)", options=chains, default=st.session_state.get("selected_chains", None) or chains)
|
|
|
20 |
Model(name=ModelType.ZymCTRL, layers=36, heads=16),
|
21 |
]
|
22 |
|
23 |
+
with st.expander("Input a PDB id, upload a PDB file or input a sequence"):
|
24 |
+
pdb_id = select_pdb()
|
25 |
+
uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
|
26 |
+
input_sequence = st.text_area("3.Input sequence", "")
|
27 |
+
pdb_str, structure, source = select_protein(pdb_id, uploaded_file, input_sequence)
|
28 |
+
st.write(f"Visualizing: {source}")
|
29 |
|
30 |
+
st.sidebar.markdown(
|
31 |
+
"""
|
32 |
+
Configure visualization
|
33 |
+
---
|
34 |
+
""")
|
35 |
chains = get_chains(structure)
|
36 |
|
37 |
selected_chains = st.sidebar.multiselect(label="Select Chain(s)", options=chains, default=st.session_state.get("selected_chains", None) or chains)
|