Spaces:
Sleeping
Sleeping
Massimo G. Totaro
commited on
Commit
·
ddc1bd3
1
Parent(s):
fba8f5e
QOL and gradio upgrade
Browse files- .gitignore +2 -1
- README.md +1 -1
- app.py +14 -26
- data.py +22 -31
- instructions.md +58 -36
- model.py +11 -7
- requirements.txt +1 -0
.gitignore
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
Dockerfile
|
2 |
*.ipynb
|
3 |
-
|
|
|
|
1 |
Dockerfile
|
2 |
*.ipynb
|
3 |
+
out.*
|
4 |
+
*/
|
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: 📈
|
|
4 |
colorFrom: gray
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: bsd-2-clause
|
|
|
4 |
colorFrom: gray
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.5.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: bsd-2-clause
|
app.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
-
from
|
2 |
-
from gradio import Blocks, Button, Checkbox, Dropdown, Examples, File, HTML, Markdown, Textbox
|
3 |
|
4 |
from model import get_models
|
5 |
from data import Data
|
@@ -17,19 +16,14 @@ def app(*argv):
|
|
17 |
# Unpack the arguments
|
18 |
seq, trg, model_name, *_ = argv
|
19 |
scoring = SCORING[scoring_strategy.value]
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
except Exception as e:
|
24 |
-
# If an error occurs, return an HTML error message
|
25 |
-
return f'<!DOCTYPE html><html><body><h1 style="background-color:#F70D1A;text-align:center;">Error: {str(e)}</h1></body></html>', None
|
26 |
# If no error occurs, return the calculated data
|
27 |
-
return
|
28 |
|
29 |
# Create the Gradio interface
|
30 |
-
with open("instructions.md", "r", encoding="utf-8") as md
|
31 |
-
NamedTemporaryFile(mode='w+') as out_file,\
|
32 |
-
Blocks() as esm_scan:
|
33 |
|
34 |
# Define the interface components
|
35 |
Markdown(md.read())
|
@@ -46,20 +40,14 @@ with open("instructions.md", "r", encoding="utf-8") as md,\
|
|
46 |
value=""
|
47 |
)
|
48 |
model_name = Dropdown(MODELS, label="Model", value="facebook/esm2_t30_150M_UR50D")
|
49 |
-
scoring_strategy = Checkbox(value=True, label="Use
|
50 |
-
|
51 |
-
out =
|
52 |
-
|
53 |
-
value=out_file.name,
|
54 |
-
visible=False,
|
55 |
-
label="Download",
|
56 |
-
file_count='single',
|
57 |
-
interactive=False
|
58 |
-
)
|
59 |
btn.click(
|
60 |
fn=app,
|
61 |
inputs=[seq, trg, model_name],
|
62 |
-
outputs=[out,
|
63 |
)
|
64 |
ex = Examples(
|
65 |
examples=[
|
@@ -87,9 +75,9 @@ with open("instructions.md", "r", encoding="utf-8") as md,\
|
|
87 |
inputs=[seq,
|
88 |
trg,
|
89 |
model_name],
|
90 |
-
outputs=[out,
|
91 |
-
|
92 |
-
|
93 |
)
|
94 |
|
95 |
# Launch the Gradio interface
|
|
|
1 |
+
from gradio import Blocks, Button, Checkbox, DownloadButton, Dropdown, Examples, File, Image, Markdown, Textbox
|
|
|
2 |
|
3 |
from model import get_models
|
4 |
from data import Data
|
|
|
16 |
# Unpack the arguments
|
17 |
seq, trg, model_name, *_ = argv
|
18 |
scoring = SCORING[scoring_strategy.value]
|
19 |
+
# Calculate the data based on the input parameters
|
20 |
+
data = Data(seq, trg, model_name, scoring).calculate()
|
21 |
+
|
|
|
|
|
|
|
22 |
# If no error occurs, return the calculated data
|
23 |
+
return Image(value=data.image(), type='filepath', visible=True), DownloadButton(value=data.csv(), visible=True)
|
24 |
|
25 |
# Create the Gradio interface
|
26 |
+
with open("instructions.md", "r", encoding="utf-8") as md, Blocks() as esm_scan:
|
|
|
|
|
27 |
|
28 |
# Define the interface components
|
29 |
Markdown(md.read())
|
|
|
40 |
value=""
|
41 |
)
|
42 |
model_name = Dropdown(MODELS, label="Model", value="facebook/esm2_t30_150M_UR50D")
|
43 |
+
scoring_strategy = Checkbox(value=True, label="Use higher accuracy scoring", interactive=True)
|
44 |
+
dlb = DownloadButton(label="Download raw data", visible=False)
|
45 |
+
out = Image(visible=False)
|
46 |
+
btn = Button(value="Run", variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
btn.click(
|
48 |
fn=app,
|
49 |
inputs=[seq, trg, model_name],
|
50 |
+
outputs=[out, dlb]
|
51 |
)
|
52 |
ex = Examples(
|
53 |
examples=[
|
|
|
75 |
inputs=[seq,
|
76 |
trg,
|
77 |
model_name],
|
78 |
+
outputs=[out],
|
79 |
+
fn=app,
|
80 |
+
cache_examples=False
|
81 |
)
|
82 |
|
83 |
# Launch the Gradio interface
|
data.py
CHANGED
@@ -1,12 +1,8 @@
|
|
|
|
1 |
from math import ceil
|
2 |
-
from re import match
|
3 |
-
import seaborn as sns
|
4 |
-
|
5 |
-
from model import Model
|
6 |
-
|
7 |
-
|
8 |
import matplotlib.pyplot as plt
|
9 |
import pandas as pd
|
|
|
10 |
import seaborn as sns
|
11 |
|
12 |
from model import Model
|
@@ -26,19 +22,18 @@ class Data:
|
|
26 |
"""Parse input substitutions"""
|
27 |
self.mode = None
|
28 |
self.sub = list()
|
29 |
-
self.trg = trg.strip().upper()
|
30 |
self.resi = list()
|
31 |
|
32 |
# Identify running mode
|
33 |
-
if len(self.trg
|
34 |
# If single string of same length as sequence, seq vs seq mode
|
35 |
self.mode = 'MUT'
|
36 |
-
for resi, (src, trg) in enumerate(zip(self.seq, self.trg), 1):
|
37 |
if src != trg:
|
38 |
self.sub.append(f"{src}{resi}{trg}")
|
39 |
self.resi.append(resi)
|
40 |
else:
|
41 |
-
self.trg = self.trg.split()
|
42 |
if all(match(r'\d+', x) for x in self.trg):
|
43 |
# If all strings are numbers, deep mutational scanning mode
|
44 |
self.mode = 'DMS'
|
@@ -64,7 +59,7 @@ class Data:
|
|
64 |
|
65 |
self.sub = pd.DataFrame(self.sub, columns=['0'])
|
66 |
|
67 |
-
def __init__(self, src:str, trg:str, model_name:str='facebook/esm2_t33_650M_UR50D', scoring_strategy:str='masked-marginals', out_file=
|
68 |
"initialise data"
|
69 |
# if model has changed, load new model
|
70 |
if self.model.model_name != model_name:
|
@@ -76,13 +71,14 @@ class Data:
|
|
76 |
self.scoring_strategy = scoring_strategy
|
77 |
self.token_probs = None
|
78 |
self.out = pd.DataFrame(self.sub, columns=['0', self.model_name])
|
79 |
-
self.
|
80 |
-
self.
|
81 |
|
82 |
def parse_output(self) -> None:
|
83 |
"format output data for visualisation"
|
84 |
if self.mode == 'TMS':
|
85 |
self.process_tms_mode()
|
|
|
86 |
else:
|
87 |
if self.mode == 'DMS':
|
88 |
self.sort_by_residue_and_score()
|
@@ -90,14 +86,12 @@ class Data:
|
|
90 |
self.sort_by_score()
|
91 |
else:
|
92 |
raise RuntimeError(f"Unrecognised mode {self.mode}")
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
.background_gradient(cmap="RdYlGn", vmax=8, vmin=-8)
|
100 |
-
.to_html(justify='center'))
|
101 |
|
102 |
def sort_by_score(self):
|
103 |
self.out = self.out.sort_values(self.model_name, ascending=False)
|
@@ -155,10 +149,7 @@ class Data:
|
|
155 |
else:
|
156 |
self.plot_multiple_heatmaps(ncols, nrows)
|
157 |
|
158 |
-
|
159 |
-
plt.savefig(self.out_buffer, format='svg')
|
160 |
-
with open(self.out_buffer, 'r', encoding='utf-8') as f:
|
161 |
-
self.out_str = f.read()
|
162 |
|
163 |
def plot_single_heatmap(self):
|
164 |
fig = plt.figure(figsize=(12, 6))
|
@@ -200,10 +191,10 @@ class Data:
|
|
200 |
self.parse_output()
|
201 |
return self
|
202 |
|
203 |
-
def
|
204 |
-
"return output data
|
205 |
-
return
|
206 |
|
207 |
-
def
|
208 |
-
"return output data
|
209 |
-
return self.
|
|
|
1 |
+
import dataframe_image as dfi
|
2 |
from math import ceil
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import matplotlib.pyplot as plt
|
4 |
import pandas as pd
|
5 |
+
from re import match
|
6 |
import seaborn as sns
|
7 |
|
8 |
from model import Model
|
|
|
22 |
"""Parse input substitutions"""
|
23 |
self.mode = None
|
24 |
self.sub = list()
|
25 |
+
self.trg = trg.strip().upper().split()
|
26 |
self.resi = list()
|
27 |
|
28 |
# Identify running mode
|
29 |
+
if len(self.trg) == 1 and len(self.trg[0]) == len(self.seq) and match(r'^\w+$', self.trg[0]):
|
30 |
# If single string of same length as sequence, seq vs seq mode
|
31 |
self.mode = 'MUT'
|
32 |
+
for resi, (src, trg) in enumerate(zip(self.seq, self.trg[0]), 1):
|
33 |
if src != trg:
|
34 |
self.sub.append(f"{src}{resi}{trg}")
|
35 |
self.resi.append(resi)
|
36 |
else:
|
|
|
37 |
if all(match(r'\d+', x) for x in self.trg):
|
38 |
# If all strings are numbers, deep mutational scanning mode
|
39 |
self.mode = 'DMS'
|
|
|
59 |
|
60 |
self.sub = pd.DataFrame(self.sub, columns=['0'])
|
61 |
|
62 |
+
def __init__(self, src:str, trg:str, model_name:str='facebook/esm2_t33_650M_UR50D', scoring_strategy:str='masked-marginals', out_file='out'):
|
63 |
"initialise data"
|
64 |
# if model has changed, load new model
|
65 |
if self.model.model_name != model_name:
|
|
|
71 |
self.scoring_strategy = scoring_strategy
|
72 |
self.token_probs = None
|
73 |
self.out = pd.DataFrame(self.sub, columns=['0', self.model_name])
|
74 |
+
self.out_img = f'{out_file}.png'
|
75 |
+
self.out_csv = f'{out_file}.csv'
|
76 |
|
77 |
def parse_output(self) -> None:
|
78 |
"format output data for visualisation"
|
79 |
if self.mode == 'TMS':
|
80 |
self.process_tms_mode()
|
81 |
+
self.out.to_csv(self.out_csv, float_format='%.2f')
|
82 |
else:
|
83 |
if self.mode == 'DMS':
|
84 |
self.sort_by_residue_and_score()
|
|
|
86 |
self.sort_by_score()
|
87 |
else:
|
88 |
raise RuntimeError(f"Unrecognised mode {self.mode}")
|
89 |
+
out_df = (self.out.style
|
90 |
+
.format(lambda x: f'{x:.2f}' if isinstance(x, float) else x)
|
91 |
+
.hide(axis=0).hide(axis=1)
|
92 |
+
.background_gradient(cmap="RdYlGn", vmax=8, vmin=-8))
|
93 |
+
dfi.export(out_df, self.out_img, max_rows=-1, max_cols=-1, dpi=300)
|
94 |
+
self.out.to_csv(self.out_csv, float_format='%.2f', index=False, header=False)
|
|
|
|
|
95 |
|
96 |
def sort_by_score(self):
|
97 |
self.out = self.out.sort_values(self.model_name, ascending=False)
|
|
|
149 |
else:
|
150 |
self.plot_multiple_heatmaps(ncols, nrows)
|
151 |
|
152 |
+
plt.savefig(self.out_img, format='png', dpi=300)
|
|
|
|
|
|
|
153 |
|
154 |
def plot_single_heatmap(self):
|
155 |
fig = plt.figure(figsize=(12, 6))
|
|
|
191 |
self.parse_output()
|
192 |
return self
|
193 |
|
194 |
+
def csv(self):
|
195 |
+
"return output data"
|
196 |
+
return self.out_csv
|
197 |
|
198 |
+
def image(self):
|
199 |
+
"return output data"
|
200 |
+
return self.out_img
|
instructions.md
CHANGED
@@ -1,39 +1,61 @@
|
|
1 |
# **ESM-Scan**
|
|
|
2 |
Calculate the <u>fitness of single amino acid substitutions</u> on proteins, using a [zero-shot](https://doi.org/10.1101/2021.07.09.450648) [language model predictor](https://github.com/facebookresearch/esm)
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
Running a calculation
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
The
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# **ESM-Scan**
|
2 |
+
|
3 |
Calculate the <u>fitness of single amino acid substitutions</u> on proteins, using a [zero-shot](https://doi.org/10.1101/2021.07.09.450648) [language model predictor](https://github.com/facebookresearch/esm)
|
4 |
|
5 |
+
<details>
|
6 |
+
<summary> <b> USAGE INSTRUCTIONS </b> </summary>
|
7 |
+
|
8 |
+
## Setup
|
9 |
+
|
10 |
+
No setup is required. Simply fill in the input boxes with the necessary data and click the **Run** button.
|
11 |
+
You can find a list of examples at the bottom of the page; clicking on them will autofill the fields for you.
|
12 |
+
If the server remains idle for a period, it will enter standby mode. Running a calculation will wake the tool from standby, but note that the first run may take longer due to startup and model loading.
|
13 |
+
|
14 |
+
## Input
|
15 |
+
|
16 |
+
**Sequence**: Enter the full amino acid sequence to be analyzed in the **Sequence** text box.
|
17 |
+
Note: While jolly characters (e.g., `-X.B`) can be included, they currently cannot be visualised.
|
18 |
+
|
19 |
+
**Substitutions**: Specify the substitutions you wish to test in the **Substitutions** box. The tool supports three running modes based on your input:
|
20 |
+
|
21 |
+
- **Single Substitution**: Input one or more substitutions (e.g. `R218K R218W`) to score specific changes.
|
22 |
+
- **Residue Position**: Provide residue positions to evaluate all possible substitutions at those sites.
|
23 |
+
- **Same-Length Sequence**: Analyze differing amino acid substitutions one by one within sequences of equal length.
|
24 |
+
- **Different Inputs**: For any other input format, a deep mutational scan of the full sequence will be performed.
|
25 |
+
|
26 |
+
**Model Selection**: Choose an ESM model for calculations from those available on Hugging Face Model Hub.
|
27 |
+
The model `esm2_t33_650M_UR50D` offers an optimal balance between cost and accuracy [*](https://doi.org/10.1126/science.ade2574).
|
28 |
+
|
29 |
+
**Accuracy Option**: The **Use higher accuracy** option applies a masked-marginals scoring strategy, which considers sequence context during inference.
|
30 |
+
While this method is slower, it enhances accuracy. If you experience long runtimes, unchecking this option can significantly speed up calculations at the cost of some accuracy.
|
31 |
+
|
32 |
+
**Deep Mutational Scan Recommendations**: When performing a deep mutational scan, it is advisable to use smaller models (8M, 35M, or 150M parameters) due to significant runtime concerns—especially with longer sequences or during peak server usage times.
|
33 |
+
For example, calculating a 300-residue-long sequence with larger models may require over 30 minutes.
|
34 |
+
Generally, accuracy is more affected by the scoring strategy than by model size; therefore, prioritise reducing model size when optimizing for runtime.
|
35 |
+
The computational cost of the scoring strategy scales with the number of substitutions tested, while model cost scales with wild-type sequence length.
|
36 |
+
|
37 |
+
**Concurrent Substitutions**: To calculate the effect of multiple concurrent substitutions, you must manually change the input sequence and rerun the calculation. Accuracy is not guaranteed as this use case is yet untested.
|
38 |
+
|
39 |
+
## Output
|
40 |
+
|
41 |
+
Results are displayed in a color-coded table, except for deep mutational scans, which produce a heatmap.
|
42 |
+
In the table:
|
43 |
+
|
44 |
+
- Beneficial substitutions are highlighted in blue with positive values.
|
45 |
+
- Detrimental substitutions appear in red with negative values.
|
46 |
+
|
47 |
+
As a rule of thumb, score differences of *4* or more are considered significant. For instance:
|
48 |
+
|
49 |
+
- A substitution scoring *-6* is likely detrimental to protein functionality.
|
50 |
+
- A score of *+2* is generally regarded as neutral.
|
51 |
+
|
52 |
+
You can download the output raw data from the **button at the bottom of the page.
|
53 |
+
|
54 |
+
<b>
|
55 |
+
If you use this tool in your research, please cite:
|
56 |
+
|
57 |
+
- Totaro, M.G. (2023). “ESM-Scan - a tool to guide amino acid substitutions.” bioRxiv. [doi.org/10.1101/2023.12.12.571273](https://doi.org/10.1101/2023.12.12.571273)
|
58 |
+
- Meier, J. (2021). “Language Models Enable Zero-Shot Prediction of the Effects of Mutations on Protein Function.” bioRxiv (Cold Spring Harbor Laboratory), July. [doi.org/10.1101/2021.07.09.450648](https://doi.org/10.1101/2021.07.09.450648)
|
59 |
+
</b>
|
60 |
+
|
61 |
+
</details>
|
model.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
-
from huggingface_hub import HfApi
|
2 |
import torch
|
|
|
3 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
4 |
from transformers.tokenization_utils_base import BatchEncoding
|
5 |
from transformers.modeling_outputs import MaskedLMOutput
|
@@ -10,9 +11,9 @@ def get_models() -> list[None|str]:
|
|
10 |
if not any(
|
11 |
out := [
|
12 |
m.modelId for m in HfApi().list_models(
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
sort="lastModified",
|
17 |
direction=-1
|
18 |
)
|
@@ -34,6 +35,9 @@ class Model:
|
|
34 |
# Check if CUDA is available and if so, use it
|
35 |
if torch.cuda.is_available():
|
36 |
self.model = self.model.cuda()
|
|
|
|
|
|
|
37 |
|
38 |
def tokenise(self, input: str) -> BatchEncoding:
|
39 |
"""Convert input string to batch of tokens."""
|
@@ -41,7 +45,7 @@ class Model:
|
|
41 |
|
42 |
def __call__(self, batch_tokens: torch.Tensor, **kwargs) -> MaskedLMOutput:
|
43 |
"""Run model on batch of tokens."""
|
44 |
-
return self.model(batch_tokens, **kwargs)
|
45 |
|
46 |
def __getitem__(self, key: str) -> int:
|
47 |
"""Get token ID from character."""
|
@@ -70,7 +74,7 @@ class Model:
|
|
70 |
if data.scoring_strategy.startswith("masked-marginals"):
|
71 |
all_token_probs = []
|
72 |
# For each token in the batch
|
73 |
-
for i in range(batch_tokens.size()[1]):
|
74 |
# If the token is in the list of residues
|
75 |
if i in data.resi:
|
76 |
# Clone the batch tokens and mask the current token
|
@@ -96,4 +100,4 @@ class Model:
|
|
96 |
token_probs,
|
97 |
),
|
98 |
axis=1,
|
99 |
-
)
|
|
|
1 |
+
from huggingface_hub import HfApi
|
2 |
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
5 |
from transformers.tokenization_utils_base import BatchEncoding
|
6 |
from transformers.modeling_outputs import MaskedLMOutput
|
|
|
11 |
if not any(
|
12 |
out := [
|
13 |
m.modelId for m in HfApi().list_models(
|
14 |
+
author="facebook",
|
15 |
+
model_name="esm",
|
16 |
+
task="fill-mask",
|
17 |
sort="lastModified",
|
18 |
direction=-1
|
19 |
)
|
|
|
35 |
# Check if CUDA is available and if so, use it
|
36 |
if torch.cuda.is_available():
|
37 |
self.model = self.model.cuda()
|
38 |
+
self.device = torch.device("cuda")
|
39 |
+
else:
|
40 |
+
self.device = torch.device("cpu")
|
41 |
|
42 |
def tokenise(self, input: str) -> BatchEncoding:
|
43 |
"""Convert input string to batch of tokens."""
|
|
|
45 |
|
46 |
def __call__(self, batch_tokens: torch.Tensor, **kwargs) -> MaskedLMOutput:
|
47 |
"""Run model on batch of tokens."""
|
48 |
+
return self.model(batch_tokens.to(self.device), **kwargs)
|
49 |
|
50 |
def __getitem__(self, key: str) -> int:
|
51 |
"""Get token ID from character."""
|
|
|
74 |
if data.scoring_strategy.startswith("masked-marginals"):
|
75 |
all_token_probs = []
|
76 |
# For each token in the batch
|
77 |
+
for i in tqdm(range(batch_tokens.size()[1])):
|
78 |
# If the token is in the list of residues
|
79 |
if i in data.resi:
|
80 |
# Clone the batch tokens and mask the current token
|
|
|
100 |
token_probs,
|
101 |
),
|
102 |
axis=1,
|
103 |
+
)
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
gradio
|
2 |
pandas
|
3 |
seaborn
|
|
|
1 |
+
dataframe-image
|
2 |
gradio
|
3 |
pandas
|
4 |
seaborn
|