Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import random | |
import json | |
import os | |
from gradio_molecule3d import Molecule3D # Import Molecule3D component | |
from rdkit import Chem | |
from rdkit.Chem import Descriptors, Draw, QED | |
class VirtualScreeningBOApp: | |
def __init__( | |
self, | |
ligands, | |
initial_pairs, | |
protein_pdb_path, # Hardcoded PDB file path | |
max_iterations=3, | |
comparisons_per_iteration=2, | |
show_smiles=True # <--- Added argument | |
): | |
self.ligands = ligands | |
self.current_pairs = initial_pairs | |
self.completed_pairs = {} | |
self.comparison_results = [] | |
self.bo_iteration = 0 | |
self.is_completed = False | |
self.max_iterations = max_iterations | |
self.comparisons_per_iteration = comparisons_per_iteration | |
self.protein_pdb_path = protein_pdb_path # Store PDB path | |
self.protein_pdb_data = self._read_pdb_file() # Read PDB data | |
self.show_smiles = show_smiles # <--- Store argument | |
self.app = None | |
def _read_pdb_file(self): | |
"""Read the PDB file from the hardcoded path.""" | |
try: | |
with open(self.protein_pdb_path, 'r') as f: | |
pdb_data = f.read() | |
return pdb_data | |
except FileNotFoundError: | |
print(f"Error: Protein PDB file not found at {self.protein_pdb_path}") | |
return None | |
def _iteration_status(self): | |
"""Return a string like 'Iteration 1/3' (1-based for user display).""" | |
return f"**Iteration**: {self.bo_iteration + 1}/{self.max_iterations}" | |
def compute_properties(self, smiles): | |
"""Compute basic properties from RDKit.""" | |
if Chem is None or smiles is None: | |
return { | |
"SMILES": smiles if smiles else "N/A", | |
"MW": None, | |
"LogP": None, | |
"TPSA": None, | |
"QED": None, | |
} | |
mol = Chem.MolFromSmiles(smiles) | |
if mol is None: | |
return { | |
"SMILES": "Invalid SMILES", | |
"MW": None, | |
"LogP": None, | |
"TPSA": None, | |
"QED": None, | |
} | |
return { | |
"SMILES": smiles, | |
"MW": round(Descriptors.MolWt(mol), 2), | |
"LogP": round(Descriptors.MolLogP(mol), 2), | |
"TPSA": round(Descriptors.TPSA(mol), 2), | |
"QED": round(QED.qed(mol), 2), | |
} | |
def _mol_to_image(self, ligand_name): | |
"""Create a 300x300 image from the SMILES. If RDKit not present, returns None.""" | |
if Chem is None: | |
return None | |
smiles = self.ligands.get(ligand_name, "") | |
mol = Chem.MolFromSmiles(smiles) | |
if mol is None: | |
return None | |
return np.array(Draw.MolToImage(mol, size=(300, 300))) | |
def _generate_new_pairs(self, n): | |
"""Randomly pick n pairs from the ligand dictionary.""" | |
keys = list(self.ligands.keys()) | |
pairs = [] | |
for _ in range(n): | |
a = random.choice(keys) | |
b = random.choice(keys) | |
while b == a: | |
b = random.choice(keys) | |
pairs.append((a, b)) | |
return pairs | |
def _save_results(self): | |
"""Saves iteration results to JSON.""" | |
filename = f"comparison_results_iter_{self.bo_iteration}.json" | |
with open(filename, "w") as f: | |
json.dump(self.comparison_results, f, indent=4) | |
print(f"Results of iteration {self.bo_iteration} saved to {filename}") | |
def get_pair_index(self, pair_label): | |
"""Parse 'Pair X (Pending)' or 'Pair X ✔' => integer X-1 (zero-based).""" | |
try: | |
parts = pair_label.split() | |
idx = int(parts[1]) - 1 | |
return idx | |
except (IndexError, ValueError): | |
return 0 | |
# -------------------------------------------------------------------------- | |
# Gradio event methods | |
# -------------------------------------------------------------------------- | |
def show_initial(self): | |
iteration_str = self._iteration_status() | |
if self.bo_iteration >= self.max_iterations: | |
return self.finish_bo_process() | |
if not self.current_pairs: | |
return ( | |
iteration_str, | |
"No pairs available.", | |
None, | |
None, | |
gr.update(value=[], headers=["Ligand", "MW", "LogP", "TPSA", "QED"]), | |
gr.update(choices=[], value=""), # Set dropdown to empty | |
gr.update(), | |
gr.update(), | |
) | |
# Build updated labels | |
updated_labels = [] | |
for i, pair in enumerate(self.current_pairs): | |
if pair in self.completed_pairs: | |
updated_labels.append(f"Pair {i+1} ✔") | |
else: | |
updated_labels.append(f"Pair {i+1} (Pending)") | |
default_label = updated_labels[0] | |
ligandA_id, ligandB_id = self.current_pairs[0] | |
imgA = self._mol_to_image(ligandA_id) | |
imgB = self._mol_to_image(ligandB_id) | |
propsA = self.compute_properties(self.ligands[ligandA_id]) | |
propsB = self.compute_properties(self.ligands[ligandB_id]) | |
if self.show_smiles: | |
table_headers = ["Ligand", "SMILES", "MW", "LogP", "TPSA", "QED"] | |
table_data = [ | |
["Ligand A", propsA["SMILES"], propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]], | |
["Ligand B", propsB["SMILES"], propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]], | |
] | |
else: | |
table_headers = ["Ligand", "MW", "LogP", "TPSA", "QED"] | |
table_data = [ | |
["Ligand A", propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]], | |
["Ligand B", propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]], | |
] | |
if (ligandA_id, ligandB_id) in self.completed_pairs: | |
arrow = ">" if self.completed_pairs[(ligandA_id, ligandB_id)] == 1 else "<" | |
else: | |
arrow = "vs" | |
pair_label_str = default_label | |
current_selection_msg = ( | |
f"**Currently selected**: {pair_label_str} => **Ligand A** {arrow} **Ligand B**" | |
) | |
return ( | |
iteration_str, | |
current_selection_msg, | |
imgA, | |
imgB, | |
gr.update(value=table_data, headers=table_headers), | |
gr.update(choices=updated_labels, value=default_label), | |
gr.update(), | |
gr.update(), | |
) | |
def update_view_on_dropdown(self, pair_label): | |
iteration_str = self._iteration_status() | |
if self.bo_iteration >= self.max_iterations: | |
return self.finish_bo_process() | |
if not pair_label or pair_label.strip() == "": | |
return ( | |
self._iteration_status(), | |
"Please select a valid pair", | |
self._mol_to_image(list(self.ligands.keys())[0]), # Show first ligand | |
self._mol_to_image(list(self.ligands.keys())[1]), # Show second ligand | |
gr.update(value=[], headers=["Ligand", "MW", "LogP", "TPSA", "QED"]), | |
) | |
# If pair_label is "" or None, handle gracefully | |
# if not pair_label: | |
# return ( | |
# iteration_str, | |
# "No pair selected or invalid selection!", | |
# None, | |
# None, | |
# gr.update(value=[], headers=["Ligand", "MW", "LogP", "TPSA", "QED"]), | |
# ) | |
idx = self.get_pair_index(pair_label) | |
if idx < 0 or idx >= len(self.current_pairs): | |
return ( | |
iteration_str, | |
f"Invalid pair: {pair_label}", | |
None, | |
None, | |
gr.update(value=[], headers=["Ligand", "MW", "LogP", "TPSA", "QED"]), | |
) | |
ligandA_id, ligandB_id = self.current_pairs[idx] | |
pair_done = (ligandA_id, ligandB_id) in self.completed_pairs | |
print(f"completed_pairs: {self.completed_pairs}") | |
print(f"current_pairs: {self.current_pairs}") | |
if (ligandA_id, ligandB_id) in self.completed_pairs: | |
arrow = ">" if self.completed_pairs[(ligandA_id, ligandB_id)] == 1 else "<" | |
else: | |
arrow = "vs" | |
imgA = self._mol_to_image(ligandA_id) | |
imgB = self._mol_to_image(ligandB_id) | |
propsA = self.compute_properties(self.ligands[ligandA_id]) | |
propsB = self.compute_properties(self.ligands[ligandB_id]) | |
if self.show_smiles: | |
table_headers = ["Ligand", "SMILES", "MW", "LogP", "TPSA", "QED"] | |
table_data = [ | |
["Ligand A", propsA["SMILES"], propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]], | |
["Ligand B", propsB["SMILES"], propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]], | |
] | |
else: | |
table_headers = ["Ligand", "MW", "LogP", "TPSA", "QED"] | |
table_data = [ | |
["Ligand A", propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]], | |
["Ligand B", propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]], | |
] | |
label_symbol = "✔" if pair_done else "(Pending)" | |
pair_label_str = f"Pair {idx+1} {label_symbol}" | |
current_selection_msg = ( | |
f"**Currently selected**: {pair_label_str} => **Ligand A** {arrow} **Ligand B**" | |
) | |
return ( | |
iteration_str, | |
current_selection_msg, | |
imgA, | |
imgB, | |
gr.update(value=table_data, headers=table_headers), | |
) | |
def get_pair_index(self, pair_label): | |
"""Parse 'Pair X (Pending)' or 'Pair X ✔' => integer X-1 (zero-based).""" | |
if not pair_label: | |
return -1 # Return an invalid index for `None` | |
try: | |
parts = pair_label.split() | |
idx = int(parts[1]) - 1 | |
return idx | |
except (IndexError, ValueError): | |
return -1 | |
def show_pair(self, preference, pair_label): | |
iteration_str = self._iteration_status() | |
idx = self.get_pair_index(pair_label) | |
if idx < 0 or idx >= len(self.current_pairs): | |
idx = 0 | |
ligandA_id, ligandB_id = self.current_pairs[idx] | |
self.comparison_results.append({ | |
"Iteration": self.bo_iteration, | |
"Pair": (ligandA_id, ligandB_id), | |
"Preference": preference, | |
}) | |
print(f"Logged preference: Iter={self.bo_iteration}, Pair=({ligandA_id}, {ligandB_id}), Choice={preference}") | |
if preference == "Ligand A": | |
self.completed_pairs[(ligandA_id, ligandB_id)] = 1 | |
old_pair_str = "**Ligand A** > **Ligand B**" | |
else: | |
self.completed_pairs[(ligandA_id, ligandB_id)] = 0 | |
old_pair_str = "**Ligand B** > **Ligand A**" | |
updated_labels = [] | |
for i, p in enumerate(self.current_pairs): | |
if p in self.completed_pairs: | |
updated_labels.append(f"Pair {i+1} ✔") | |
else: | |
updated_labels.append(f"Pair {i+1} (Pending)") | |
next_idx = None | |
for i, p in enumerate(self.current_pairs): | |
if p not in self.completed_pairs: | |
next_idx = i | |
break | |
if next_idx is not None: | |
nextA_id, nextB_id = self.current_pairs[next_idx] | |
imgA = self._mol_to_image(nextA_id) | |
imgB = self._mol_to_image(nextB_id) | |
propsA = self.compute_properties(self.ligands[nextA_id]) | |
propsB = self.compute_properties(self.ligands[nextB_id]) | |
if self.show_smiles: | |
table_headers = ["Ligand", "SMILES", "MW", "LogP", "TPSA", "QED"] | |
table_data = [ | |
["Ligand A", propsA["SMILES"], propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]], | |
["Ligand B", propsB["SMILES"], propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]], | |
] | |
else: | |
table_headers = ["Ligand", "MW", "LogP", "TPSA", "QED"] | |
table_data = [ | |
["Ligand A", propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]], | |
["Ligand B", propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]], | |
] | |
next_label = updated_labels[next_idx] | |
current_selection_msg = ( | |
f"**Currently selected**: {next_label} => **Ligand A** vs **Ligand B**" | |
) | |
bo_btn_state = False | |
dropdown_val = next_label | |
else: | |
current_selection_msg = ( | |
f"**Currently selected**: {pair_label} => **Ligand A** vs **Ligand B**" | |
) | |
imgA = self._mol_to_image(ligandA_id) | |
imgB = self._mol_to_image(ligandB_id) | |
propsA = self.compute_properties(self.ligands[ligandA_id]) | |
propsB = self.compute_properties(self.ligands[ligandB_id]) | |
if self.show_smiles: | |
table_headers = ["Ligand", "SMILES", "MW", "LogP", "TPSA", "QED"] | |
table_data = [ | |
["Ligand A", propsA["SMILES"], propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]], | |
["Ligand B", propsB["SMILES"], propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]], | |
] | |
else: | |
table_headers = ["Ligand", "MW", "LogP", "TPSA", "QED"] | |
table_data = [ | |
["Ligand A", propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]], | |
["Ligand B", propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]], | |
] | |
bo_btn_state = True | |
dropdown_val = updated_labels[-1] if updated_labels else "" | |
selection_msg = ( | |
f"You chose {old_pair_str} for {pair_label}.<br>" | |
f"{current_selection_msg}" | |
) | |
print(selection_msg) | |
return ( | |
iteration_str, | |
selection_msg, | |
imgA, | |
imgB, | |
gr.update(value=table_data, headers=table_headers), | |
gr.update(choices=updated_labels, value=dropdown_val), | |
dropdown_val, | |
gr.update(interactive=bo_btn_state), | |
) | |
def start_bo_iteration(self): | |
iteration_str = self._iteration_status() | |
if self.bo_iteration >= self.max_iterations - 1: | |
self.is_completed = True | |
return self.finish_bo_process() | |
# Ensure all pairs are completed before proceeding | |
if len(self.completed_pairs) < len(self.current_pairs): | |
return ( | |
iteration_str, | |
"Please complete all pairs first!", # Notify user | |
None, | |
None, | |
gr.update(value=[], headers=["Ligand", "MW", "LogP", "TPSA", "QED"]), | |
gr.update(), | |
gr.update(), | |
gr.update(interactive=False, value="Next BO Iteration"), # Keep button disabled | |
) | |
# Save results and increment the iteration | |
self._save_results() | |
self.bo_iteration += 1 | |
iteration_str = self._iteration_status() | |
# Check if the BO process is complete | |
if self.bo_iteration == self.max_iterations: | |
final_message = """ | |
<div style="text-align: center; font-size: 24px; margin-top: 50px;"> | |
<strong>The BO process has been completed.</strong><br> | |
Thank you for your input! | |
</div> | |
""" | |
return ( | |
None, # Clear iteration status | |
gr.update(value=final_message), # Display final message | |
None, # Hide Ligand A image | |
None, # Hide Ligand B image | |
None, # Clear the properties table | |
gr.update(choices=[], value=""), # Clear dropdown | |
None, # Clear dropdown value | |
gr.update(interactive=False), | |
) | |
# Generate new pairs for the next iteration | |
new_pairs = self._generate_new_pairs(self.comparisons_per_iteration) | |
self.current_pairs = new_pairs | |
self.completed_pairs = {} | |
updated_labels = [f"Pair {i+1} (Pending)" for i in range(len(new_pairs))] | |
default_val = updated_labels[0] if updated_labels else "" | |
if new_pairs: | |
ligandA_id, ligandB_id = new_pairs[0] | |
imgA = self._mol_to_image(ligandA_id) | |
imgB = self._mol_to_image(ligandB_id) | |
propsA = self.compute_properties(self.ligands[ligandA_id]) | |
propsB = self.compute_properties(self.ligands[ligandB_id]) | |
table_headers = ["Ligand", "MW", "LogP", "TPSA", "QED"] | |
table_data = [ | |
["Ligand A", propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]], | |
["Ligand B", propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]], | |
] | |
msg = ( | |
f"Starting iteration {self.bo_iteration}/{self.max_iterations} with {len(new_pairs)} new pairs.<br>" | |
f"**Currently selected**: {default_val} => **Ligand A** vs **Ligand B**" | |
) | |
else: | |
msg = f"Starting iteration {self.bo_iteration}/{self.max_iterations}, but no new pairs?" | |
imgA = None | |
imgB = None | |
table_data = [] | |
return ( | |
iteration_str, | |
msg, | |
imgA, | |
imgB, | |
gr.update(value=table_data, headers=table_headers), | |
gr.update(choices=updated_labels, value=default_val), | |
default_val, | |
gr.update(interactive=False), | |
gr.update(visible=True), # Submit button | |
gr.update(visible=True), # Preference radio | |
) | |
def finish_bo_process(self): | |
self.is_completed = True | |
final_message = """ | |
<div style="text-align: center; font-size: 24px; margin-top: 50px;"> | |
<strong>The BO process has been completed.</strong><br> | |
Thank you for your input! | |
</div> | |
""" | |
self.protein_view.delete() | |
self.protein_view.visible = False | |
self.protein_view.showviewer = False | |
del self.protein_view | |
return ( | |
"", # Clear iteration status | |
final_message, # Show completion message | |
gr.update(visible=False), # Clear Ligand A image | |
gr.update(visible=False), # Clear Ligand B image | |
gr.update(visible=False), # Hide properties table | |
gr.update(visible=False), # Hide dropdown | |
None, # Clear dropdown value | |
gr.update(visible=False), # Hide BO button | |
gr.update(visible=False), # Hide submit button | |
gr.update(visible=False), # Hide preference radio | |
) | |
def build_app(self): | |
with gr.Blocks() as app: | |
gr.Markdown("## Virtual Screening BO App") | |
iteration_status_text = gr.Markdown(value="Loading...", label="Iteration Status") | |
current_selection_text = gr.HTML(value="Initializing...", label="Current Selection") | |
with gr.Row(): | |
pair_dropdown = gr.Dropdown( | |
label="Select a Pair", | |
allow_custom_value=True, | |
interactive=True, | |
) | |
preference_radio = gr.Radio( | |
["Ligand A", "Ligand B"], | |
label="Your preference", | |
value="Ligand A" | |
) | |
bo_btn = gr.Button(value="Next BO Iteration", interactive=False) | |
with gr.Row(): | |
with gr.Column(): | |
out_imgA = gr.Image(label="Ligand A", width=650, height=325) | |
out_imgB = gr.Image(label="Ligand B", width=650, height=325) | |
with gr.Column(): | |
if self.protein_pdb_data: | |
self.protein_view = Molecule3D( | |
label="Protein Structure", | |
reps=[ | |
{ | |
"model": 0, | |
"chain": "", | |
"resname": "", | |
"style": "cartoon", | |
"color": "spectrum", | |
"residue_range": "", | |
"around": 0, | |
"byres": False, | |
"visible": True | |
} | |
], | |
# Pass the PDB file path, not the data | |
value=self.protein_pdb_path, | |
) | |
else: | |
# If PDB data not found, display a message | |
self.protein_view = gr.Markdown("**Protein PDB file not found.**") | |
out_table = gr.Dataframe( | |
headers=["Ligand", "SMILES", "MW", "LogP", "TPSA", "QED"], | |
label="Properties" | |
) | |
submit_btn = gr.Button("Submit Preference") | |
# Event: When the app loads, show the initial view | |
app.load( | |
fn=self.show_initial, | |
inputs=None, | |
outputs=[ | |
iteration_status_text, | |
current_selection_text, | |
out_imgA, | |
out_imgB, | |
out_table, | |
pair_dropdown, | |
preference_radio, | |
bo_btn, | |
# Protein visualization is static; no need to output | |
], | |
) | |
# Event: When the dropdown changes, update the view | |
pair_dropdown.change( | |
fn=self.update_view_on_dropdown, | |
inputs=pair_dropdown, | |
outputs=[ | |
iteration_status_text, | |
current_selection_text, | |
out_imgA, | |
out_imgB, | |
out_table, | |
# Protein visualization is static; no need to output | |
] | |
) | |
bo_btn.click( | |
fn=self.start_bo_iteration, | |
inputs=[], | |
outputs=[ | |
iteration_status_text, | |
current_selection_text, | |
out_imgA, | |
out_imgB, | |
out_table, | |
pair_dropdown, | |
pair_dropdown, | |
bo_btn, | |
submit_btn, # Add submit button control | |
preference_radio # Add preference radio control | |
] | |
) | |
# Event: When the submit button is clicked | |
submit_btn.click( | |
fn=self.show_pair, | |
inputs=[preference_radio, pair_dropdown], | |
outputs=[ | |
iteration_status_text, | |
current_selection_text, | |
out_imgA, | |
out_imgB, | |
out_table, | |
pair_dropdown, | |
pair_dropdown, | |
bo_btn, | |
# Protein visualization is static; no need to output | |
], | |
) | |
self.app = app | |
def launch(self, **kwargs): | |
if self.app is None: | |
self.build_app() | |
self.app.launch(**kwargs) | |
# --------------------------- | |
# Example Usage | |
# --------------------------- | |
if __name__ == "__main__": | |
ligands = { | |
"L1": "CCN(CC)CC(=O)c1ccccc1N(C)C", | |
"L2": "CC(C)Cc1ccc(cc1)C(C)C(=O)O", | |
"L3": "Cn1c(=O)c2c(ncnc2N(C)C)n(C)c1=O", | |
"L4": "CC(=O)Oc1ccccc1C(=O)O", | |
"L5": "CCCC", | |
} | |
initial_pairs = [("L1", "L2"), ("L3", "L4")] | |
# Hardcoded PDB file path | |
protein_pdb = "1syn.pdb" # Update this path as needed | |
# Check if the PDB file exists | |
if not os.path.isfile(protein_pdb): | |
print(f"Error: Protein PDB file not found at {protein_pdb}") | |
exit(1) | |
app = VirtualScreeningBOApp( | |
ligands=ligands, | |
initial_pairs=initial_pairs, | |
protein_pdb_path=protein_pdb, # Provide the PDB file path | |
max_iterations=2, | |
comparisons_per_iteration=2, | |
show_smiles=False | |
) | |
app.launch(share=True, debug=True) |