GaganaMD commited on
Commit
50fcaf6
·
verified ·
1 Parent(s): 576036c

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +189 -0
  2. client.py +14 -0
app.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, EsmForProteinFolding
3
+ from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
4
+ from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
5
+ import torch
6
+ from logging import getLogger
7
+
8
+ logger = getLogger(__name__)
9
+
10
+ def convert_outputs_to_pdb(outputs):
11
+ final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
12
+ outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
13
+ final_atom_positions = final_atom_positions.cpu().numpy()
14
+ final_atom_mask = outputs["atom37_atom_exists"]
15
+ pdbs = []
16
+ for i in range(outputs["aatype"].shape[0]):
17
+ aa = outputs["aatype"][i]
18
+ pred_pos = final_atom_positions[i]
19
+ mask = final_atom_mask[i]
20
+ resid = outputs["residue_index"][i] + 1
21
+ pred = OFProtein(
22
+ aatype=aa,
23
+ atom_positions=pred_pos,
24
+ atom_mask=mask,
25
+ residue_index=resid,
26
+ b_factors=outputs["plddt"][i],
27
+ chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
28
+ )
29
+ pdbs.append(to_pdb(pred))
30
+ return pdbs[0]
31
+
32
+ def fold_prot_locally(sequence):
33
+ logger.info("Folding: " + sequence)
34
+ tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
35
+
36
+ with torch.no_grad():
37
+ output = model(tokenized_input)
38
+ pdb = convert_outputs_to_pdb(output)
39
+ return pdb
40
+
41
+ def get_esm2_embeddings(sequence):
42
+ logger.info("Getting embeddings for: " + sequence)
43
+ tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
44
+
45
+ with torch.no_grad():
46
+ aa = tokenized_input
47
+ L = aa.shape[1]
48
+ device = tokenized_input.device
49
+ attention_mask = torch.ones_like(aa, device=device)
50
+
51
+ # === ESM ===
52
+ esmaa = model.af2_idx_to_esm_idx(aa, attention_mask)
53
+ esm_s = model.compute_language_model_representations(esmaa)
54
+
55
+ return {"res": esm_s.cpu().tolist()}
56
+
57
+ def get_esmfold_embeddings(sequence):
58
+ logger.info("Getting embeddings for: " + sequence)
59
+ tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
60
+
61
+ with torch.no_grad():
62
+ output = model(tokenized_input)
63
+
64
+ return {"res": output["s_s"].cpu().tolist()}
65
+
66
+ def suggest(option):
67
+ if option == "Plastic degradation protein":
68
+ suggestion = "MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ"
69
+ elif option == "Antifreeze protein":
70
+ suggestion = "QCTGGADCTSCTGACTGCGNCPNAVTCTNSQHCVKANTCTGSTDCNTAQTCTNSKDCFEANTCTDSTNCYKATACTNSSGCPGH"
71
+ elif option == "AI Generated protein":
72
+ suggestion = "MSGMKKLYEYTVTTLDEFLEKLKEFILNTSKDKIYKLTITNPKLIKDIGKAIAKAAEIADVDPKEIEEMIKAVEENELTKLVITIEQTDDKYVIKVELENEDGLVHSFEIYFKNKEEMEKFLELLEKLISKLSGS"
73
+ elif option == "7-bladed propeller fold":
74
+ suggestion = "VKLAGNSSLCPINGWAVYSKDNSIRIGSKGDVFVIREPFISCSHLECRTFFLTQGALLNDKHSNGTVKDRSPHRTLMSCPVGEAPSPYNSRFESVAWSASACHDGTSWLTIGISGPDNGAVAVLKYNGIITDTIKSWRNNILRTQESECACVNGSCFTVMTDGPSNGQASYKIFKMEKGKVVKSVELDAPNYHYEECSCYPNAGEITCVCRDNWHGSNRPWVSFNQNLEYQIGYICSGVFGDNPRPNDGTGSCGPVSSNGAYGVKGFSFKYGNGVWIGRTKSTNSRSGFEMIWDPNGWTETDSSFSVKQDIVAITDWSGYSGSFVQHPELTGLDCIRPCFWVELIRGRPKESTIWTSGSSISFCGVNSDTVGWSWPDGAELPFTIDK"
75
+ else:
76
+ suggestion = ""
77
+ return suggestion
78
+
79
+
80
+ def molecule(mol):
81
+ x = (
82
+ """<!DOCTYPE html>
83
+ <html>
84
+ <head>
85
+ <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
86
+ <style>
87
+ body{
88
+ font-family:sans-serif
89
+ }
90
+ .mol-container {
91
+ width: 100%;
92
+ height: 600px;
93
+ position: relative;
94
+ }
95
+ .mol-container select{
96
+ background-image:None;
97
+ }
98
+ </style>
99
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
100
+ <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
101
+ </head>
102
+ <body>
103
+ <div id="container" class="mol-container"></div>
104
+
105
+ <script>
106
+ let pdb = `"""
107
+ + mol
108
+ + """`
109
+
110
+ $(document).ready(function () {
111
+ let element = $("#container");
112
+ let config = { backgroundColor: "white" };
113
+ let viewer = $3Dmol.createViewer(element, config);
114
+ viewer.addModel(pdb, "pdb");
115
+ viewer.getModel(0).setStyle({}, { cartoon: { colorscheme:"whiteCarbon" } });
116
+ viewer.zoomTo();
117
+ viewer.render();
118
+ viewer.zoom(0.8, 2000);
119
+ })
120
+ </script>
121
+ </body></html>"""
122
+ )
123
+
124
+ return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
125
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
126
+ allow-scripts allow-same-origin allow-popups
127
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
128
+ allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
129
+
130
+
131
+ sample_code = """
132
+ from gradio_client import Client
133
+ client = Client("https://wwydmanski-esmfold.hf.space/")
134
+ def fold_huggingface(sequence, fname=None):
135
+ result = client.predict(
136
+ sequence, # str in 'sequence' Textbox component
137
+ api_name="/pdb")
138
+ if fname is None:
139
+ with tempfile.NamedTemporaryFile("w", delete=False, suffix=".pdb", prefix="esmfold_") as fp:
140
+ fp.write(result)
141
+ fp.flush()
142
+ return fp.name
143
+ else:
144
+ with open(fname, "w") as fp:
145
+ fp.write(result)
146
+ fp.flush()
147
+ return fname
148
+ pdb_fname = fold_huggingface("MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN")
149
+ """
150
+
151
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
152
+ model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True).cuda()
153
+ model.esm = model.esm.half()
154
+ torch.backends.cuda.matmul.allow_tf32 = True
155
+
156
+ with gr.Blocks() as demo:
157
+ gr.Markdown("# ESMFold")
158
+ with gr.Row():
159
+ with gr.Column():
160
+ inp = gr.Textbox(lines=1, label="Sequence")
161
+ name = gr.Dropdown(label="Choose a Sample Protein", value="Plastic degradation protein", choices=["Antifreeze protein", "Plastic degradation protein", "AI Generated protein", "7-bladed propeller fold", "custom"])
162
+ btn = gr.Button("🔬 Predict Structure ")
163
+
164
+ with gr.Row():
165
+ with gr.Column():
166
+ gr.Markdown("## Sample code")
167
+ gr.Code(sample_code, label="Sample usage", language="python", interactive=False)
168
+
169
+ with gr.Row():
170
+ gr.Markdown("## Output")
171
+
172
+ with gr.Row():
173
+ with gr.Column():
174
+ out = gr.Code(label="Output", interactive=False)
175
+ with gr.Column():
176
+ out_mol = gr.HTML(label="3D Structure")
177
+
178
+ with gr.Row(visible=False):
179
+ with gr.Column():
180
+ gr.Markdown("## Embeddings")
181
+ embs = gr.JSON(label="Embeddings")
182
+
183
+ name.change(fn=suggest, inputs=name, outputs=inp)
184
+ btn.click(fold_prot_locally, inputs=[inp], outputs=[out], api_name="pdb")
185
+ btn.click(get_esmfold_embeddings, inputs=[inp], outputs=[embs], api_name="embeddings")
186
+ btn.click(get_esm2_embeddings, inputs=[inp], outputs=[embs], api_name="esm2_embeddings")
187
+ out.change(fn=molecule, inputs=[out], outputs=[out_mol], api_name="3d_fold")
188
+
189
+ demo.launch()
client.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ from gradio_client import Client
3
+
4
+ #%%
5
+ # client = Client("https://huggingface.co/spaces/GaganaMD/Protein-Structure-Prediction")
6
+ client = Client("http://localhost:7860")
7
+
8
+ # %%
9
+ result = client.predict(
10
+ "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN", # str in 'sequence' Textbox component
11
+ api_name="/esm2_embeddings")
12
+
13
+ # %%
14
+ result