moavs / app.py
tai-dang11
file
e5e01c1
import gradio as gr
import numpy as np
import random
import json
import os
from gradio_molecule3d import Molecule3D # Import Molecule3D component
from rdkit import Chem
from rdkit.Chem import Descriptors, Draw, QED
class VirtualScreeningBOApp:
def __init__(
self,
ligands,
initial_pairs,
protein_pdb_path, # Hardcoded PDB file path
max_iterations=3,
comparisons_per_iteration=2,
show_smiles=True # <--- Added argument
):
self.ligands = ligands
self.current_pairs = initial_pairs
self.completed_pairs = {}
self.comparison_results = []
self.bo_iteration = 0
self.is_completed = False
self.max_iterations = max_iterations
self.comparisons_per_iteration = comparisons_per_iteration
self.protein_pdb_path = protein_pdb_path # Store PDB path
self.protein_pdb_data = self._read_pdb_file() # Read PDB data
self.show_smiles = show_smiles # <--- Store argument
self.app = None
def _read_pdb_file(self):
"""Read the PDB file from the hardcoded path."""
try:
with open(self.protein_pdb_path, 'r') as f:
pdb_data = f.read()
return pdb_data
except FileNotFoundError:
print(f"Error: Protein PDB file not found at {self.protein_pdb_path}")
return None
def _iteration_status(self):
"""Return a string like 'Iteration 1/3' (1-based for user display)."""
return f"**Iteration**: {self.bo_iteration + 1}/{self.max_iterations}"
def compute_properties(self, smiles):
"""Compute basic properties from RDKit."""
if Chem is None or smiles is None:
return {
"SMILES": smiles if smiles else "N/A",
"MW": None,
"LogP": None,
"TPSA": None,
"QED": None,
}
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return {
"SMILES": "Invalid SMILES",
"MW": None,
"LogP": None,
"TPSA": None,
"QED": None,
}
return {
"SMILES": smiles,
"MW": round(Descriptors.MolWt(mol), 2),
"LogP": round(Descriptors.MolLogP(mol), 2),
"TPSA": round(Descriptors.TPSA(mol), 2),
"QED": round(QED.qed(mol), 2),
}
def _mol_to_image(self, ligand_name):
"""Create a 300x300 image from the SMILES. If RDKit not present, returns None."""
if Chem is None:
return None
smiles = self.ligands.get(ligand_name, "")
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
return np.array(Draw.MolToImage(mol, size=(300, 300)))
def _generate_new_pairs(self, n):
"""Randomly pick n pairs from the ligand dictionary."""
keys = list(self.ligands.keys())
pairs = []
for _ in range(n):
a = random.choice(keys)
b = random.choice(keys)
while b == a:
b = random.choice(keys)
pairs.append((a, b))
return pairs
def _save_results(self):
"""Saves iteration results to JSON."""
filename = f"comparison_results_iter_{self.bo_iteration}.json"
with open(filename, "w") as f:
json.dump(self.comparison_results, f, indent=4)
print(f"Results of iteration {self.bo_iteration} saved to {filename}")
def get_pair_index(self, pair_label):
"""Parse 'Pair X (Pending)' or 'Pair X ✔' => integer X-1 (zero-based)."""
try:
parts = pair_label.split()
idx = int(parts[1]) - 1
return idx
except (IndexError, ValueError):
return 0
# --------------------------------------------------------------------------
# Gradio event methods
# --------------------------------------------------------------------------
def show_initial(self):
iteration_str = self._iteration_status()
if self.bo_iteration >= self.max_iterations:
return self.finish_bo_process()
if not self.current_pairs:
return (
iteration_str,
"No pairs available.",
None,
None,
gr.update(value=[], headers=["Ligand", "MW", "LogP", "TPSA", "QED"]),
gr.update(choices=[], value=""), # Set dropdown to empty
gr.update(),
gr.update(),
)
# Build updated labels
updated_labels = []
for i, pair in enumerate(self.current_pairs):
if pair in self.completed_pairs:
updated_labels.append(f"Pair {i+1} ✔")
else:
updated_labels.append(f"Pair {i+1} (Pending)")
default_label = updated_labels[0]
ligandA_id, ligandB_id = self.current_pairs[0]
imgA = self._mol_to_image(ligandA_id)
imgB = self._mol_to_image(ligandB_id)
propsA = self.compute_properties(self.ligands[ligandA_id])
propsB = self.compute_properties(self.ligands[ligandB_id])
if self.show_smiles:
table_headers = ["Ligand", "SMILES", "MW", "LogP", "TPSA", "QED"]
table_data = [
["Ligand A", propsA["SMILES"], propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]],
["Ligand B", propsB["SMILES"], propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]],
]
else:
table_headers = ["Ligand", "MW", "LogP", "TPSA", "QED"]
table_data = [
["Ligand A", propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]],
["Ligand B", propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]],
]
if (ligandA_id, ligandB_id) in self.completed_pairs:
arrow = ">" if self.completed_pairs[(ligandA_id, ligandB_id)] == 1 else "<"
else:
arrow = "vs"
pair_label_str = default_label
current_selection_msg = (
f"**Currently selected**: {pair_label_str} => **Ligand A** {arrow} **Ligand B**"
)
return (
iteration_str,
current_selection_msg,
imgA,
imgB,
gr.update(value=table_data, headers=table_headers),
gr.update(choices=updated_labels, value=default_label),
gr.update(),
gr.update(),
)
def update_view_on_dropdown(self, pair_label):
iteration_str = self._iteration_status()
if self.bo_iteration >= self.max_iterations:
return self.finish_bo_process()
if not pair_label or pair_label.strip() == "":
return (
self._iteration_status(),
"Please select a valid pair",
self._mol_to_image(list(self.ligands.keys())[0]), # Show first ligand
self._mol_to_image(list(self.ligands.keys())[1]), # Show second ligand
gr.update(value=[], headers=["Ligand", "MW", "LogP", "TPSA", "QED"]),
)
# If pair_label is "" or None, handle gracefully
# if not pair_label:
# return (
# iteration_str,
# "No pair selected or invalid selection!",
# None,
# None,
# gr.update(value=[], headers=["Ligand", "MW", "LogP", "TPSA", "QED"]),
# )
idx = self.get_pair_index(pair_label)
if idx < 0 or idx >= len(self.current_pairs):
return (
iteration_str,
f"Invalid pair: {pair_label}",
None,
None,
gr.update(value=[], headers=["Ligand", "MW", "LogP", "TPSA", "QED"]),
)
ligandA_id, ligandB_id = self.current_pairs[idx]
pair_done = (ligandA_id, ligandB_id) in self.completed_pairs
print(f"completed_pairs: {self.completed_pairs}")
print(f"current_pairs: {self.current_pairs}")
if (ligandA_id, ligandB_id) in self.completed_pairs:
arrow = ">" if self.completed_pairs[(ligandA_id, ligandB_id)] == 1 else "<"
else:
arrow = "vs"
imgA = self._mol_to_image(ligandA_id)
imgB = self._mol_to_image(ligandB_id)
propsA = self.compute_properties(self.ligands[ligandA_id])
propsB = self.compute_properties(self.ligands[ligandB_id])
if self.show_smiles:
table_headers = ["Ligand", "SMILES", "MW", "LogP", "TPSA", "QED"]
table_data = [
["Ligand A", propsA["SMILES"], propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]],
["Ligand B", propsB["SMILES"], propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]],
]
else:
table_headers = ["Ligand", "MW", "LogP", "TPSA", "QED"]
table_data = [
["Ligand A", propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]],
["Ligand B", propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]],
]
label_symbol = "✔" if pair_done else "(Pending)"
pair_label_str = f"Pair {idx+1} {label_symbol}"
current_selection_msg = (
f"**Currently selected**: {pair_label_str} => **Ligand A** {arrow} **Ligand B**"
)
return (
iteration_str,
current_selection_msg,
imgA,
imgB,
gr.update(value=table_data, headers=table_headers),
)
def get_pair_index(self, pair_label):
"""Parse 'Pair X (Pending)' or 'Pair X ✔' => integer X-1 (zero-based)."""
if not pair_label:
return -1 # Return an invalid index for `None`
try:
parts = pair_label.split()
idx = int(parts[1]) - 1
return idx
except (IndexError, ValueError):
return -1
def show_pair(self, preference, pair_label):
iteration_str = self._iteration_status()
idx = self.get_pair_index(pair_label)
if idx < 0 or idx >= len(self.current_pairs):
idx = 0
ligandA_id, ligandB_id = self.current_pairs[idx]
self.comparison_results.append({
"Iteration": self.bo_iteration,
"Pair": (ligandA_id, ligandB_id),
"Preference": preference,
})
print(f"Logged preference: Iter={self.bo_iteration}, Pair=({ligandA_id}, {ligandB_id}), Choice={preference}")
if preference == "Ligand A":
self.completed_pairs[(ligandA_id, ligandB_id)] = 1
old_pair_str = "**Ligand A** > **Ligand B**"
else:
self.completed_pairs[(ligandA_id, ligandB_id)] = 0
old_pair_str = "**Ligand B** > **Ligand A**"
updated_labels = []
for i, p in enumerate(self.current_pairs):
if p in self.completed_pairs:
updated_labels.append(f"Pair {i+1} ✔")
else:
updated_labels.append(f"Pair {i+1} (Pending)")
next_idx = None
for i, p in enumerate(self.current_pairs):
if p not in self.completed_pairs:
next_idx = i
break
if next_idx is not None:
nextA_id, nextB_id = self.current_pairs[next_idx]
imgA = self._mol_to_image(nextA_id)
imgB = self._mol_to_image(nextB_id)
propsA = self.compute_properties(self.ligands[nextA_id])
propsB = self.compute_properties(self.ligands[nextB_id])
if self.show_smiles:
table_headers = ["Ligand", "SMILES", "MW", "LogP", "TPSA", "QED"]
table_data = [
["Ligand A", propsA["SMILES"], propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]],
["Ligand B", propsB["SMILES"], propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]],
]
else:
table_headers = ["Ligand", "MW", "LogP", "TPSA", "QED"]
table_data = [
["Ligand A", propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]],
["Ligand B", propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]],
]
next_label = updated_labels[next_idx]
current_selection_msg = (
f"**Currently selected**: {next_label} => **Ligand A** vs **Ligand B**"
)
bo_btn_state = False
dropdown_val = next_label
else:
current_selection_msg = (
f"**Currently selected**: {pair_label} => **Ligand A** vs **Ligand B**"
)
imgA = self._mol_to_image(ligandA_id)
imgB = self._mol_to_image(ligandB_id)
propsA = self.compute_properties(self.ligands[ligandA_id])
propsB = self.compute_properties(self.ligands[ligandB_id])
if self.show_smiles:
table_headers = ["Ligand", "SMILES", "MW", "LogP", "TPSA", "QED"]
table_data = [
["Ligand A", propsA["SMILES"], propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]],
["Ligand B", propsB["SMILES"], propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]],
]
else:
table_headers = ["Ligand", "MW", "LogP", "TPSA", "QED"]
table_data = [
["Ligand A", propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]],
["Ligand B", propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]],
]
bo_btn_state = True
dropdown_val = updated_labels[-1] if updated_labels else ""
selection_msg = (
f"You chose {old_pair_str} for {pair_label}.<br>"
f"{current_selection_msg}"
)
print(selection_msg)
return (
iteration_str,
selection_msg,
imgA,
imgB,
gr.update(value=table_data, headers=table_headers),
gr.update(choices=updated_labels, value=dropdown_val),
dropdown_val,
gr.update(interactive=bo_btn_state),
)
def start_bo_iteration(self):
iteration_str = self._iteration_status()
if self.bo_iteration >= self.max_iterations - 1:
self.is_completed = True
return self.finish_bo_process()
# Ensure all pairs are completed before proceeding
if len(self.completed_pairs) < len(self.current_pairs):
return (
iteration_str,
"Please complete all pairs first!", # Notify user
None,
None,
gr.update(value=[], headers=["Ligand", "MW", "LogP", "TPSA", "QED"]),
gr.update(),
gr.update(),
gr.update(interactive=False, value="Next BO Iteration"), # Keep button disabled
)
# Save results and increment the iteration
self._save_results()
self.bo_iteration += 1
iteration_str = self._iteration_status()
# Check if the BO process is complete
if self.bo_iteration == self.max_iterations:
final_message = """
<div style="text-align: center; font-size: 24px; margin-top: 50px;">
<strong>The BO process has been completed.</strong><br>
Thank you for your input!
</div>
"""
return (
None, # Clear iteration status
gr.update(value=final_message), # Display final message
None, # Hide Ligand A image
None, # Hide Ligand B image
None, # Clear the properties table
gr.update(choices=[], value=""), # Clear dropdown
None, # Clear dropdown value
gr.update(interactive=False),
)
# Generate new pairs for the next iteration
new_pairs = self._generate_new_pairs(self.comparisons_per_iteration)
self.current_pairs = new_pairs
self.completed_pairs = {}
updated_labels = [f"Pair {i+1} (Pending)" for i in range(len(new_pairs))]
default_val = updated_labels[0] if updated_labels else ""
if new_pairs:
ligandA_id, ligandB_id = new_pairs[0]
imgA = self._mol_to_image(ligandA_id)
imgB = self._mol_to_image(ligandB_id)
propsA = self.compute_properties(self.ligands[ligandA_id])
propsB = self.compute_properties(self.ligands[ligandB_id])
table_headers = ["Ligand", "MW", "LogP", "TPSA", "QED"]
table_data = [
["Ligand A", propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]],
["Ligand B", propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]],
]
msg = (
f"Starting iteration {self.bo_iteration}/{self.max_iterations} with {len(new_pairs)} new pairs.<br>"
f"**Currently selected**: {default_val} => **Ligand A** vs **Ligand B**"
)
else:
msg = f"Starting iteration {self.bo_iteration}/{self.max_iterations}, but no new pairs?"
imgA = None
imgB = None
table_data = []
return (
iteration_str,
msg,
imgA,
imgB,
gr.update(value=table_data, headers=table_headers),
gr.update(choices=updated_labels, value=default_val),
default_val,
gr.update(interactive=False),
gr.update(visible=True), # Submit button
gr.update(visible=True), # Preference radio
)
def finish_bo_process(self):
self.is_completed = True
final_message = """
<div style="text-align: center; font-size: 24px; margin-top: 50px;">
<strong>The BO process has been completed.</strong><br>
Thank you for your input!
</div>
"""
self.protein_view.delete()
self.protein_view.visible = False
self.protein_view.showviewer = False
del self.protein_view
return (
"", # Clear iteration status
final_message, # Show completion message
gr.update(visible=False), # Clear Ligand A image
gr.update(visible=False), # Clear Ligand B image
gr.update(visible=False), # Hide properties table
gr.update(visible=False), # Hide dropdown
None, # Clear dropdown value
gr.update(visible=False), # Hide BO button
gr.update(visible=False), # Hide submit button
gr.update(visible=False), # Hide preference radio
)
def build_app(self):
with gr.Blocks() as app:
gr.Markdown("## Virtual Screening BO App")
iteration_status_text = gr.Markdown(value="Loading...", label="Iteration Status")
current_selection_text = gr.HTML(value="Initializing...", label="Current Selection")
with gr.Row():
pair_dropdown = gr.Dropdown(
label="Select a Pair",
allow_custom_value=True,
interactive=True,
)
preference_radio = gr.Radio(
["Ligand A", "Ligand B"],
label="Your preference",
value="Ligand A"
)
bo_btn = gr.Button(value="Next BO Iteration", interactive=False)
with gr.Row():
with gr.Column():
out_imgA = gr.Image(label="Ligand A", width=650, height=325)
out_imgB = gr.Image(label="Ligand B", width=650, height=325)
with gr.Column():
if self.protein_pdb_data:
self.protein_view = Molecule3D(
label="Protein Structure",
reps=[
{
"model": 0,
"chain": "",
"resname": "",
"style": "cartoon",
"color": "spectrum",
"residue_range": "",
"around": 0,
"byres": False,
"visible": True
}
],
# Pass the PDB file path, not the data
value=self.protein_pdb_path,
)
else:
# If PDB data not found, display a message
self.protein_view = gr.Markdown("**Protein PDB file not found.**")
out_table = gr.Dataframe(
headers=["Ligand", "SMILES", "MW", "LogP", "TPSA", "QED"],
label="Properties"
)
submit_btn = gr.Button("Submit Preference")
# Event: When the app loads, show the initial view
app.load(
fn=self.show_initial,
inputs=None,
outputs=[
iteration_status_text,
current_selection_text,
out_imgA,
out_imgB,
out_table,
pair_dropdown,
preference_radio,
bo_btn,
# Protein visualization is static; no need to output
],
)
# Event: When the dropdown changes, update the view
pair_dropdown.change(
fn=self.update_view_on_dropdown,
inputs=pair_dropdown,
outputs=[
iteration_status_text,
current_selection_text,
out_imgA,
out_imgB,
out_table,
# Protein visualization is static; no need to output
]
)
bo_btn.click(
fn=self.start_bo_iteration,
inputs=[],
outputs=[
iteration_status_text,
current_selection_text,
out_imgA,
out_imgB,
out_table,
pair_dropdown,
pair_dropdown,
bo_btn,
submit_btn, # Add submit button control
preference_radio # Add preference radio control
]
)
# Event: When the submit button is clicked
submit_btn.click(
fn=self.show_pair,
inputs=[preference_radio, pair_dropdown],
outputs=[
iteration_status_text,
current_selection_text,
out_imgA,
out_imgB,
out_table,
pair_dropdown,
pair_dropdown,
bo_btn,
# Protein visualization is static; no need to output
],
)
self.app = app
def launch(self, **kwargs):
if self.app is None:
self.build_app()
self.app.launch(**kwargs)
# ---------------------------
# Example Usage
# ---------------------------
if __name__ == "__main__":
ligands = {
"L1": "CCN(CC)CC(=O)c1ccccc1N(C)C",
"L2": "CC(C)Cc1ccc(cc1)C(C)C(=O)O",
"L3": "Cn1c(=O)c2c(ncnc2N(C)C)n(C)c1=O",
"L4": "CC(=O)Oc1ccccc1C(=O)O",
"L5": "CCCC",
}
initial_pairs = [("L1", "L2"), ("L3", "L4")]
# Hardcoded PDB file path
protein_pdb = "1syn.pdb" # Update this path as needed
# Check if the PDB file exists
if not os.path.isfile(protein_pdb):
print(f"Error: Protein PDB file not found at {protein_pdb}")
exit(1)
app = VirtualScreeningBOApp(
ligands=ligands,
initial_pairs=initial_pairs,
protein_pdb_path=protein_pdb, # Provide the PDB file path
max_iterations=2,
comparisons_per_iteration=2,
show_smiles=False
)
app.launch(share=True, debug=True)