import gradio as gr
import numpy as np
import random
import json
import os
from gradio_molecule3d import Molecule3D # Import Molecule3D component
from rdkit import Chem
from rdkit.Chem import Descriptors, Draw, QED
class VirtualScreeningBOApp:
def __init__(
self,
ligands,
initial_pairs,
protein_pdb_path, # Hardcoded PDB file path
max_iterations=3,
comparisons_per_iteration=2,
show_smiles=True # <--- Added argument
):
self.ligands = ligands
self.current_pairs = initial_pairs
self.completed_pairs = {}
self.comparison_results = []
self.bo_iteration = 0
self.is_completed = False
self.max_iterations = max_iterations
self.comparisons_per_iteration = comparisons_per_iteration
self.protein_pdb_path = protein_pdb_path # Store PDB path
self.protein_pdb_data = self._read_pdb_file() # Read PDB data
self.show_smiles = show_smiles # <--- Store argument
self.app = None
def _read_pdb_file(self):
"""Read the PDB file from the hardcoded path."""
try:
with open(self.protein_pdb_path, 'r') as f:
pdb_data = f.read()
return pdb_data
except FileNotFoundError:
print(f"Error: Protein PDB file not found at {self.protein_pdb_path}")
return None
def _iteration_status(self):
"""Return a string like 'Iteration 1/3' (1-based for user display)."""
return f"**Iteration**: {self.bo_iteration + 1}/{self.max_iterations}"
def compute_properties(self, smiles):
"""Compute basic properties from RDKit."""
if Chem is None or smiles is None:
return {
"SMILES": smiles if smiles else "N/A",
"MW": None,
"LogP": None,
"TPSA": None,
"QED": None,
}
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return {
"SMILES": "Invalid SMILES",
"MW": None,
"LogP": None,
"TPSA": None,
"QED": None,
}
return {
"SMILES": smiles,
"MW": round(Descriptors.MolWt(mol), 2),
"LogP": round(Descriptors.MolLogP(mol), 2),
"TPSA": round(Descriptors.TPSA(mol), 2),
"QED": round(QED.qed(mol), 2),
}
def _mol_to_image(self, ligand_name):
"""Create a 300x300 image from the SMILES. If RDKit not present, returns None."""
if Chem is None:
return None
smiles = self.ligands.get(ligand_name, "")
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
return np.array(Draw.MolToImage(mol, size=(300, 300)))
def _generate_new_pairs(self, n):
"""Randomly pick n pairs from the ligand dictionary."""
keys = list(self.ligands.keys())
pairs = []
for _ in range(n):
a = random.choice(keys)
b = random.choice(keys)
while b == a:
b = random.choice(keys)
pairs.append((a, b))
return pairs
def _save_results(self):
"""Saves iteration results to JSON."""
filename = f"comparison_results_iter_{self.bo_iteration}.json"
with open(filename, "w") as f:
json.dump(self.comparison_results, f, indent=4)
print(f"Results of iteration {self.bo_iteration} saved to {filename}")
def get_pair_index(self, pair_label):
"""Parse 'Pair X (Pending)' or 'Pair X ✔' => integer X-1 (zero-based)."""
try:
parts = pair_label.split()
idx = int(parts[1]) - 1
return idx
except (IndexError, ValueError):
return 0
# --------------------------------------------------------------------------
# Gradio event methods
# --------------------------------------------------------------------------
def show_initial(self):
iteration_str = self._iteration_status()
if self.bo_iteration >= self.max_iterations:
return self.finish_bo_process()
if not self.current_pairs:
return (
iteration_str,
"No pairs available.",
None,
None,
gr.update(value=[], headers=["Ligand", "MW", "LogP", "TPSA", "QED"]),
gr.update(choices=[], value=""), # Set dropdown to empty
gr.update(),
gr.update(),
)
# Build updated labels
updated_labels = []
for i, pair in enumerate(self.current_pairs):
if pair in self.completed_pairs:
updated_labels.append(f"Pair {i+1} ✔")
else:
updated_labels.append(f"Pair {i+1} (Pending)")
default_label = updated_labels[0]
ligandA_id, ligandB_id = self.current_pairs[0]
imgA = self._mol_to_image(ligandA_id)
imgB = self._mol_to_image(ligandB_id)
propsA = self.compute_properties(self.ligands[ligandA_id])
propsB = self.compute_properties(self.ligands[ligandB_id])
if self.show_smiles:
table_headers = ["Ligand", "SMILES", "MW", "LogP", "TPSA", "QED"]
table_data = [
["Ligand A", propsA["SMILES"], propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]],
["Ligand B", propsB["SMILES"], propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]],
]
else:
table_headers = ["Ligand", "MW", "LogP", "TPSA", "QED"]
table_data = [
["Ligand A", propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]],
["Ligand B", propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]],
]
if (ligandA_id, ligandB_id) in self.completed_pairs:
arrow = ">" if self.completed_pairs[(ligandA_id, ligandB_id)] == 1 else "<"
else:
arrow = "vs"
pair_label_str = default_label
current_selection_msg = (
f"**Currently selected**: {pair_label_str} => **Ligand A** {arrow} **Ligand B**"
)
return (
iteration_str,
current_selection_msg,
imgA,
imgB,
gr.update(value=table_data, headers=table_headers),
gr.update(choices=updated_labels, value=default_label),
gr.update(),
gr.update(),
)
def update_view_on_dropdown(self, pair_label):
iteration_str = self._iteration_status()
if self.bo_iteration >= self.max_iterations:
return self.finish_bo_process()
if not pair_label or pair_label.strip() == "":
return (
self._iteration_status(),
"Please select a valid pair",
self._mol_to_image(list(self.ligands.keys())[0]), # Show first ligand
self._mol_to_image(list(self.ligands.keys())[1]), # Show second ligand
gr.update(value=[], headers=["Ligand", "MW", "LogP", "TPSA", "QED"]),
)
# If pair_label is "" or None, handle gracefully
# if not pair_label:
# return (
# iteration_str,
# "No pair selected or invalid selection!",
# None,
# None,
# gr.update(value=[], headers=["Ligand", "MW", "LogP", "TPSA", "QED"]),
# )
idx = self.get_pair_index(pair_label)
if idx < 0 or idx >= len(self.current_pairs):
return (
iteration_str,
f"Invalid pair: {pair_label}",
None,
None,
gr.update(value=[], headers=["Ligand", "MW", "LogP", "TPSA", "QED"]),
)
ligandA_id, ligandB_id = self.current_pairs[idx]
pair_done = (ligandA_id, ligandB_id) in self.completed_pairs
print(f"completed_pairs: {self.completed_pairs}")
print(f"current_pairs: {self.current_pairs}")
if (ligandA_id, ligandB_id) in self.completed_pairs:
arrow = ">" if self.completed_pairs[(ligandA_id, ligandB_id)] == 1 else "<"
else:
arrow = "vs"
imgA = self._mol_to_image(ligandA_id)
imgB = self._mol_to_image(ligandB_id)
propsA = self.compute_properties(self.ligands[ligandA_id])
propsB = self.compute_properties(self.ligands[ligandB_id])
if self.show_smiles:
table_headers = ["Ligand", "SMILES", "MW", "LogP", "TPSA", "QED"]
table_data = [
["Ligand A", propsA["SMILES"], propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]],
["Ligand B", propsB["SMILES"], propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]],
]
else:
table_headers = ["Ligand", "MW", "LogP", "TPSA", "QED"]
table_data = [
["Ligand A", propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]],
["Ligand B", propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]],
]
label_symbol = "✔" if pair_done else "(Pending)"
pair_label_str = f"Pair {idx+1} {label_symbol}"
current_selection_msg = (
f"**Currently selected**: {pair_label_str} => **Ligand A** {arrow} **Ligand B**"
)
return (
iteration_str,
current_selection_msg,
imgA,
imgB,
gr.update(value=table_data, headers=table_headers),
)
def get_pair_index(self, pair_label):
"""Parse 'Pair X (Pending)' or 'Pair X ✔' => integer X-1 (zero-based)."""
if not pair_label:
return -1 # Return an invalid index for `None`
try:
parts = pair_label.split()
idx = int(parts[1]) - 1
return idx
except (IndexError, ValueError):
return -1
def show_pair(self, preference, pair_label):
iteration_str = self._iteration_status()
idx = self.get_pair_index(pair_label)
if idx < 0 or idx >= len(self.current_pairs):
idx = 0
ligandA_id, ligandB_id = self.current_pairs[idx]
self.comparison_results.append({
"Iteration": self.bo_iteration,
"Pair": (ligandA_id, ligandB_id),
"Preference": preference,
})
print(f"Logged preference: Iter={self.bo_iteration}, Pair=({ligandA_id}, {ligandB_id}), Choice={preference}")
if preference == "Ligand A":
self.completed_pairs[(ligandA_id, ligandB_id)] = 1
old_pair_str = "**Ligand A** > **Ligand B**"
else:
self.completed_pairs[(ligandA_id, ligandB_id)] = 0
old_pair_str = "**Ligand B** > **Ligand A**"
updated_labels = []
for i, p in enumerate(self.current_pairs):
if p in self.completed_pairs:
updated_labels.append(f"Pair {i+1} ✔")
else:
updated_labels.append(f"Pair {i+1} (Pending)")
next_idx = None
for i, p in enumerate(self.current_pairs):
if p not in self.completed_pairs:
next_idx = i
break
if next_idx is not None:
nextA_id, nextB_id = self.current_pairs[next_idx]
imgA = self._mol_to_image(nextA_id)
imgB = self._mol_to_image(nextB_id)
propsA = self.compute_properties(self.ligands[nextA_id])
propsB = self.compute_properties(self.ligands[nextB_id])
if self.show_smiles:
table_headers = ["Ligand", "SMILES", "MW", "LogP", "TPSA", "QED"]
table_data = [
["Ligand A", propsA["SMILES"], propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]],
["Ligand B", propsB["SMILES"], propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]],
]
else:
table_headers = ["Ligand", "MW", "LogP", "TPSA", "QED"]
table_data = [
["Ligand A", propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]],
["Ligand B", propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]],
]
next_label = updated_labels[next_idx]
current_selection_msg = (
f"**Currently selected**: {next_label} => **Ligand A** vs **Ligand B**"
)
bo_btn_state = False
dropdown_val = next_label
else:
current_selection_msg = (
f"**Currently selected**: {pair_label} => **Ligand A** vs **Ligand B**"
)
imgA = self._mol_to_image(ligandA_id)
imgB = self._mol_to_image(ligandB_id)
propsA = self.compute_properties(self.ligands[ligandA_id])
propsB = self.compute_properties(self.ligands[ligandB_id])
if self.show_smiles:
table_headers = ["Ligand", "SMILES", "MW", "LogP", "TPSA", "QED"]
table_data = [
["Ligand A", propsA["SMILES"], propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]],
["Ligand B", propsB["SMILES"], propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]],
]
else:
table_headers = ["Ligand", "MW", "LogP", "TPSA", "QED"]
table_data = [
["Ligand A", propsA["MW"], propsA["LogP"], propsA["TPSA"], propsA["QED"]],
["Ligand B", propsB["MW"], propsB["LogP"], propsB["TPSA"], propsB["QED"]],
]
bo_btn_state = True
dropdown_val = updated_labels[-1] if updated_labels else ""
selection_msg = (
f"You chose {old_pair_str} for {pair_label}.
"
f"{current_selection_msg}"
)
print(selection_msg)
return (
iteration_str,
selection_msg,
imgA,
imgB,
gr.update(value=table_data, headers=table_headers),
gr.update(choices=updated_labels, value=dropdown_val),
dropdown_val,
gr.update(interactive=bo_btn_state),
)
def start_bo_iteration(self):
iteration_str = self._iteration_status()
if self.bo_iteration >= self.max_iterations - 1:
self.is_completed = True
return self.finish_bo_process()
# Ensure all pairs are completed before proceeding
if len(self.completed_pairs) < len(self.current_pairs):
return (
iteration_str,
"Please complete all pairs first!", # Notify user
None,
None,
gr.update(value=[], headers=["Ligand", "MW", "LogP", "TPSA", "QED"]),
gr.update(),
gr.update(),
gr.update(interactive=False, value="Next BO Iteration"), # Keep button disabled
)
# Save results and increment the iteration
self._save_results()
self.bo_iteration += 1
iteration_str = self._iteration_status()
# Check if the BO process is complete
if self.bo_iteration == self.max_iterations:
final_message = """