Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files
app.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import AutoTokenizer, EsmForProteinFolding
|
3 |
+
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
|
4 |
+
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
|
5 |
+
import torch
|
6 |
+
from logging import getLogger
|
7 |
+
|
8 |
+
logger = getLogger(__name__)
|
9 |
+
|
10 |
+
def convert_outputs_to_pdb(outputs):
|
11 |
+
final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
|
12 |
+
outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
|
13 |
+
final_atom_positions = final_atom_positions.cpu().numpy()
|
14 |
+
final_atom_mask = outputs["atom37_atom_exists"]
|
15 |
+
pdbs = []
|
16 |
+
for i in range(outputs["aatype"].shape[0]):
|
17 |
+
aa = outputs["aatype"][i]
|
18 |
+
pred_pos = final_atom_positions[i]
|
19 |
+
mask = final_atom_mask[i]
|
20 |
+
resid = outputs["residue_index"][i] + 1
|
21 |
+
pred = OFProtein(
|
22 |
+
aatype=aa,
|
23 |
+
atom_positions=pred_pos,
|
24 |
+
atom_mask=mask,
|
25 |
+
residue_index=resid,
|
26 |
+
b_factors=outputs["plddt"][i],
|
27 |
+
chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
|
28 |
+
)
|
29 |
+
pdbs.append(to_pdb(pred))
|
30 |
+
return pdbs[0]
|
31 |
+
|
32 |
+
def fold_prot_locally(sequence):
|
33 |
+
logger.info("Folding: " + sequence)
|
34 |
+
tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
|
35 |
+
|
36 |
+
with torch.no_grad():
|
37 |
+
output = model(tokenized_input)
|
38 |
+
pdb = convert_outputs_to_pdb(output)
|
39 |
+
return pdb
|
40 |
+
|
41 |
+
def get_esm2_embeddings(sequence):
|
42 |
+
logger.info("Getting embeddings for: " + sequence)
|
43 |
+
tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
|
44 |
+
|
45 |
+
with torch.no_grad():
|
46 |
+
aa = tokenized_input
|
47 |
+
L = aa.shape[1]
|
48 |
+
device = tokenized_input.device
|
49 |
+
attention_mask = torch.ones_like(aa, device=device)
|
50 |
+
|
51 |
+
# === ESM ===
|
52 |
+
esmaa = model.af2_idx_to_esm_idx(aa, attention_mask)
|
53 |
+
esm_s = model.compute_language_model_representations(esmaa)
|
54 |
+
|
55 |
+
return {"res": esm_s.cpu().tolist()}
|
56 |
+
|
57 |
+
def get_esmfold_embeddings(sequence):
|
58 |
+
logger.info("Getting embeddings for: " + sequence)
|
59 |
+
tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
|
60 |
+
|
61 |
+
with torch.no_grad():
|
62 |
+
output = model(tokenized_input)
|
63 |
+
|
64 |
+
return {"res": output["s_s"].cpu().tolist()}
|
65 |
+
|
66 |
+
def suggest(option):
|
67 |
+
if option == "Plastic degradation protein":
|
68 |
+
suggestion = "MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ"
|
69 |
+
elif option == "Antifreeze protein":
|
70 |
+
suggestion = "QCTGGADCTSCTGACTGCGNCPNAVTCTNSQHCVKANTCTGSTDCNTAQTCTNSKDCFEANTCTDSTNCYKATACTNSSGCPGH"
|
71 |
+
elif option == "AI Generated protein":
|
72 |
+
suggestion = "MSGMKKLYEYTVTTLDEFLEKLKEFILNTSKDKIYKLTITNPKLIKDIGKAIAKAAEIADVDPKEIEEMIKAVEENELTKLVITIEQTDDKYVIKVELENEDGLVHSFEIYFKNKEEMEKFLELLEKLISKLSGS"
|
73 |
+
elif option == "7-bladed propeller fold":
|
74 |
+
suggestion = "VKLAGNSSLCPINGWAVYSKDNSIRIGSKGDVFVIREPFISCSHLECRTFFLTQGALLNDKHSNGTVKDRSPHRTLMSCPVGEAPSPYNSRFESVAWSASACHDGTSWLTIGISGPDNGAVAVLKYNGIITDTIKSWRNNILRTQESECACVNGSCFTVMTDGPSNGQASYKIFKMEKGKVVKSVELDAPNYHYEECSCYPNAGEITCVCRDNWHGSNRPWVSFNQNLEYQIGYICSGVFGDNPRPNDGTGSCGPVSSNGAYGVKGFSFKYGNGVWIGRTKSTNSRSGFEMIWDPNGWTETDSSFSVKQDIVAITDWSGYSGSFVQHPELTGLDCIRPCFWVELIRGRPKESTIWTSGSSISFCGVNSDTVGWSWPDGAELPFTIDK"
|
75 |
+
else:
|
76 |
+
suggestion = ""
|
77 |
+
return suggestion
|
78 |
+
|
79 |
+
|
80 |
+
def molecule(mol):
|
81 |
+
x = (
|
82 |
+
"""<!DOCTYPE html>
|
83 |
+
<html>
|
84 |
+
<head>
|
85 |
+
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
|
86 |
+
<style>
|
87 |
+
body{
|
88 |
+
font-family:sans-serif
|
89 |
+
}
|
90 |
+
.mol-container {
|
91 |
+
width: 100%;
|
92 |
+
height: 600px;
|
93 |
+
position: relative;
|
94 |
+
}
|
95 |
+
.mol-container select{
|
96 |
+
background-image:None;
|
97 |
+
}
|
98 |
+
</style>
|
99 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
|
100 |
+
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
|
101 |
+
</head>
|
102 |
+
<body>
|
103 |
+
<div id="container" class="mol-container"></div>
|
104 |
+
|
105 |
+
<script>
|
106 |
+
let pdb = `"""
|
107 |
+
+ mol
|
108 |
+
+ """`
|
109 |
+
|
110 |
+
$(document).ready(function () {
|
111 |
+
let element = $("#container");
|
112 |
+
let config = { backgroundColor: "white" };
|
113 |
+
let viewer = $3Dmol.createViewer(element, config);
|
114 |
+
viewer.addModel(pdb, "pdb");
|
115 |
+
viewer.getModel(0).setStyle({}, { cartoon: { colorscheme:"whiteCarbon" } });
|
116 |
+
viewer.zoomTo();
|
117 |
+
viewer.render();
|
118 |
+
viewer.zoom(0.8, 2000);
|
119 |
+
})
|
120 |
+
</script>
|
121 |
+
</body></html>"""
|
122 |
+
)
|
123 |
+
|
124 |
+
return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
|
125 |
+
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
126 |
+
allow-scripts allow-same-origin allow-popups
|
127 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
128 |
+
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
|
129 |
+
|
130 |
+
|
131 |
+
sample_code = """
|
132 |
+
from gradio_client import Client
|
133 |
+
client = Client("https://wwydmanski-esmfold.hf.space/")
|
134 |
+
def fold_huggingface(sequence, fname=None):
|
135 |
+
result = client.predict(
|
136 |
+
sequence, # str in 'sequence' Textbox component
|
137 |
+
api_name="/pdb")
|
138 |
+
if fname is None:
|
139 |
+
with tempfile.NamedTemporaryFile("w", delete=False, suffix=".pdb", prefix="esmfold_") as fp:
|
140 |
+
fp.write(result)
|
141 |
+
fp.flush()
|
142 |
+
return fp.name
|
143 |
+
else:
|
144 |
+
with open(fname, "w") as fp:
|
145 |
+
fp.write(result)
|
146 |
+
fp.flush()
|
147 |
+
return fname
|
148 |
+
pdb_fname = fold_huggingface("MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN")
|
149 |
+
"""
|
150 |
+
|
151 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
|
152 |
+
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True).cuda()
|
153 |
+
model.esm = model.esm.half()
|
154 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
155 |
+
|
156 |
+
with gr.Blocks() as demo:
|
157 |
+
gr.Markdown("# ESMFold")
|
158 |
+
with gr.Row():
|
159 |
+
with gr.Column():
|
160 |
+
inp = gr.Textbox(lines=1, label="Sequence")
|
161 |
+
name = gr.Dropdown(label="Choose a Sample Protein", value="Plastic degradation protein", choices=["Antifreeze protein", "Plastic degradation protein", "AI Generated protein", "7-bladed propeller fold", "custom"])
|
162 |
+
btn = gr.Button("🔬 Predict Structure ")
|
163 |
+
|
164 |
+
with gr.Row():
|
165 |
+
with gr.Column():
|
166 |
+
gr.Markdown("## Sample code")
|
167 |
+
gr.Code(sample_code, label="Sample usage", language="python", interactive=False)
|
168 |
+
|
169 |
+
with gr.Row():
|
170 |
+
gr.Markdown("## Output")
|
171 |
+
|
172 |
+
with gr.Row():
|
173 |
+
with gr.Column():
|
174 |
+
out = gr.Code(label="Output", interactive=False)
|
175 |
+
with gr.Column():
|
176 |
+
out_mol = gr.HTML(label="3D Structure")
|
177 |
+
|
178 |
+
with gr.Row(visible=False):
|
179 |
+
with gr.Column():
|
180 |
+
gr.Markdown("## Embeddings")
|
181 |
+
embs = gr.JSON(label="Embeddings")
|
182 |
+
|
183 |
+
name.change(fn=suggest, inputs=name, outputs=inp)
|
184 |
+
btn.click(fold_prot_locally, inputs=[inp], outputs=[out], api_name="pdb")
|
185 |
+
btn.click(get_esmfold_embeddings, inputs=[inp], outputs=[embs], api_name="embeddings")
|
186 |
+
btn.click(get_esm2_embeddings, inputs=[inp], outputs=[embs], api_name="esm2_embeddings")
|
187 |
+
out.change(fn=molecule, inputs=[out], outputs=[out_mol], api_name="3d_fold")
|
188 |
+
|
189 |
+
demo.launch()
|
client.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
from gradio_client import Client
|
3 |
+
|
4 |
+
#%%
|
5 |
+
# client = Client("https://huggingface.co/spaces/GaganaMD/Protein-Structure-Prediction")
|
6 |
+
client = Client("http://localhost:7860")
|
7 |
+
|
8 |
+
# %%
|
9 |
+
result = client.predict(
|
10 |
+
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN", # str in 'sequence' Textbox component
|
11 |
+
api_name="/esm2_embeddings")
|
12 |
+
|
13 |
+
# %%
|
14 |
+
result
|