import gradio as gr
import pandas as pd
from PIL import Image
from rdkit import RDLogger
from molecule_generation_helpers import *
from property_prediction_helpers import *
RDLogger.logger().setLevel(RDLogger.ERROR)
# Predefined dataset paths (these should be adjusted to your file paths)
predefined_datasets = {
" ": " ",
"BACE": f"./data/bace/train.csv, ./data/bace/test.csv, smiles, Class",
"ESOL": f"./data/esol/train.csv, ./data/esol/test.csv, smiles, prop",
}
# Models
models_enabled = ["SELFIES-TED", "MHG-GED", "MolFormer", "SMI-TED"]
# Fusion Types
fusion_available = ["Concat"]
# Function to load a predefined dataset from the local path
def load_predefined_dataset(dataset_name):
val = predefined_datasets.get(dataset_name)
if val:
df = pd.read_csv(val.split(",")[0])
return (
df.head(),
gr.update(choices=list(df.columns)),
gr.update(choices=list(df.columns)),
dataset_name.lower(),
)
else:
return (
pd.DataFrame(),
gr.update(choices=[]),
gr.update(choices=[]),
f"Dataset not found",
)
# Function to handle dataset selection (predefined or custom)
def handle_dataset_selection(selected_dataset):
if selected_dataset == "Custom Dataset":
# Show file upload fields for train and test datasets if "Custom Dataset" is selected
return (
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=True),
)
return (
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
)
# Dynamically show relevant hyperparameters based on selected model
def update_hyperparameters(model_name):
if model_name == "XGBClassifier":
return (
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
)
elif model_name == "SVR":
return (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=True),
)
elif model_name == "Kernel Ridge":
return (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
)
elif model_name == "Linear Regression":
return (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
)
elif model_name == "Default - Auto":
return (
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
)
# Function to select input and output columns and display a message
def select_columns(input_column, output_column, train_data, test_data, dataset_name):
if input_column and output_column:
return f"{train_data.name},{test_data.name},{input_column},{output_column},{dataset_name}"
return "Please select both input and output columns."
# Function to set Dataset Name
def set_dataname(dataset_name, dataset_selector):
return dataset_name if dataset_selector == "Custom Dataset" else dataset_selector
# Function to display the head of the uploaded CSV file
def display_csv_head(file):
if file is not None:
# Load the CSV file into a DataFrame
df = pd.read_csv(file.name)
return (
df.head(),
gr.update(choices=list(df.columns)),
gr.update(choices=list(df.columns)),
)
return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[])
# Dictionary for SMILES strings and corresponding images (you can replace with your actual image paths)
smiles_image_mapping = {
# Example SMILES for ethanol
"Mol 1": {
"smiles": "C=C(C)CC(=O)NC[C@H](CO)NC(=O)C=Cc1ccc(C)c(Cl)c1",
"image": "img/img1.png",
},
# Example SMILES for butane
"Mol 2": {
"smiles": "C=CC1(CC(=O)NC[C@@H](CCCC)NC(=O)c2cc(Cl)cc(Br)c2)CC1",
"image": "img/img2.png",
},
# Example SMILES for ethylamine
"Mol 3": {
"smiles": "C=C(C)C[C@H](NC(C)=O)C(=O)N1CC[C@H](NC(=O)[C@H]2C[C@@]2(C)Br)C(C)(C)C1",
"image": "img/img3.png",
},
# Example SMILES for diethyl ether
"Mol 4": {
"smiles": "C=C1CC(CC(=O)N[C@H]2CCN(C(=O)c3ncccc3SC)C23CC3)C1",
"image": "img/img4.png",
},
# Example SMILES for chloroethane
"Mol 5": {
"smiles": "C=CCS[C@@H](C)CC(=O)OCC",
"image": "img/img5.png",
},
}
# Load images for selection
def load_image(path):
try:
return Image.open(smiles_image_mapping[path]["image"])
except:
pass
# Function to handle image selection
def handle_image_selection(image_key):
if not image_key:
return None, None
smiles = smiles_image_mapping[image_key]["smiles"]
mol_image = smiles_to_image(smiles)
return smiles, mol_image
# Introduction
with open("INTRODUCTION.md") as f:
# introduction = gr.Markdown(f.read())
with gr.Blocks() as introduction:
gr.Markdown(f.read())
gr.Markdown("---\n# Debug")
gr.HTML("HTML text: ")
gr.Markdown("Markdown text: ![selfies-ted](file/img/selfies-ted.png)")
gr.HTML("HTML text: ")
gr.Markdown("Markdown text: ![Huggingface Logo](https://huggingface.co./front/assets/huggingface_logo-noborder.svg)")
# Property Prediction
with gr.Blocks() as property_prediction:
log_df = pd.DataFrame(
{"": [], 'Selected Models': [], 'Dataset': [], 'Task': [], 'Result': []}
)
state = gr.State({"log_df": log_df})
gr.HTML(
'''
Task : Property Prediction
Models are finetuned with different combination of modalities on the uploaded or selected built data set.
'''
)
with gr.Row():
with gr.Column():
# Dropdown menu for predefined datasets including "Custom Dataset" option
dataset_selector = gr.Dropdown(
label="Select Dataset",
choices=list(predefined_datasets.keys()) + ["Custom Dataset"],
)
# Display the message for selected columns
selected_columns_message = gr.Textbox(
label="Selected Columns Info", visible=False
)
with gr.Accordion("Dataset Settings", open=True):
# File upload options for custom dataset (train and test)
dataset_name = gr.Textbox(label="Dataset Name", visible=False)
train_file = gr.File(
label="Upload Custom Train Dataset",
file_types=[".csv"],
visible=False,
)
train_display = gr.Dataframe(
label="Train Dataset Preview (First 5 Rows)",
visible=False,
interactive=False,
)
test_file = gr.File(
label="Upload Custom Test Dataset",
file_types=[".csv"],
visible=False,
)
test_display = gr.Dataframe(
label="Test Dataset Preview (First 5 Rows)",
visible=False,
interactive=False,
)
# Predefined dataset displays
predefined_display = gr.Dataframe(
label="Predefined Dataset Preview (First 5 Rows)",
visible=False,
interactive=False,
)
# Dropdowns for selecting input and output columns for the custom dataset
input_column_selector = gr.Dropdown(
label="Select Input Column", choices=[], visible=False
)
output_column_selector = gr.Dropdown(
label="Select Output Column", choices=[], visible=False
)
# When a dataset is selected, show either file upload fields (for custom) or load predefined datasets
dataset_selector.change(
handle_dataset_selection,
inputs=dataset_selector,
outputs=[
dataset_name,
train_file,
train_display,
test_file,
test_display,
predefined_display,
input_column_selector,
output_column_selector,
],
)
# When a predefined dataset is selected, load its head and update column selectors
dataset_selector.change(
load_predefined_dataset,
inputs=dataset_selector,
outputs=[
predefined_display,
input_column_selector,
output_column_selector,
selected_columns_message,
],
)
# When a custom train file is uploaded, display its head and update column selectors
train_file.change(
display_csv_head,
inputs=train_file,
outputs=[
train_display,
input_column_selector,
output_column_selector,
],
)
# When a custom test file is uploaded, display its head
test_file.change(
display_csv_head,
inputs=test_file,
outputs=[
test_display,
input_column_selector,
output_column_selector,
],
)
dataset_selector.change(
set_dataname,
inputs=[dataset_name, dataset_selector],
outputs=dataset_name,
)
# Update the selected columns information when dropdown values are changed
input_column_selector.change(
select_columns,
inputs=[
input_column_selector,
output_column_selector,
train_file,
test_file,
dataset_name,
],
outputs=selected_columns_message,
)
output_column_selector.change(
select_columns,
inputs=[
input_column_selector,
output_column_selector,
train_file,
test_file,
dataset_name,
],
outputs=selected_columns_message,
)
model_checkbox = gr.CheckboxGroup(
choices=models_enabled, label="Select Model"
)
task_radiobutton = gr.Radio(
choices=["Classification", "Regression"], label="Task Type"
)
####### adding hyper parameter tuning ###########
model_name = gr.Dropdown(
[
"Default - Auto",
"XGBClassifier",
"SVR",
"Kernel Ridge",
"Linear Regression",
],
label="Select Downstream Model",
)
with gr.Accordion("Downstream Hyperparameter Settings", open=True):
# Create placeholders for hyperparameter components
max_depth = gr.Slider(1, 20, step=1, visible=False, label="max_depth")
n_estimators = gr.Slider(
100, 5000, step=100, visible=False, label="n_estimators"
)
alpha = gr.Slider(0.1, 10.0, step=0.1, visible=False, label="alpha")
degree = gr.Slider(1, 20, step=1, visible=False, label="degree")
kernel = gr.Dropdown(
choices=["rbf", "poly", "linear"], visible=False, label="kernel"
)
# Output textbox
output = gr.Textbox(label="Loaded Parameters")
# When model is selected, update which hyperparameters are visible
model_name.change(
update_hyperparameters,
inputs=[model_name],
outputs=[max_depth, n_estimators, alpha, degree, kernel],
)
# Submit button to create the model with selected hyperparameters
submit_button = gr.Button("Create Downstream Model")
# When the submit button is clicked, run the on_submit function
submit_button.click(
create_downstream_model,
inputs=[model_name, max_depth, n_estimators, alpha, degree, kernel],
outputs=output,
)
###### End of hyper param tuning #########
fusion_radiobutton = gr.Radio(choices=fusion_available, label="Fusion Type")
eval_button = gr.Button("Train downstream model")
# Right Column
with gr.Column():
eval_output = gr.Textbox(label="Train downstream model")
plot_radio = gr.Radio(
choices=["ROC-AUC", "Parity Plot", "Latent Space"],
label="Select Plot Type",
)
plot_output = gr.Plot(label="Visualization")
create_log = gr.Button("Store log")
log_table = gr.Dataframe(
value=log_df, label="Log of Selections and Results", interactive=False
)
eval_button.click(
display_eval,
inputs=[
model_checkbox,
selected_columns_message,
task_radiobutton,
output,
fusion_radiobutton,
state,
],
outputs=eval_output,
)
plot_radio.change(
display_plot, inputs=[plot_radio, state], outputs=plot_output
)
create_log.click(
evaluate_and_log,
inputs=[
model_checkbox,
dataset_name,
task_radiobutton,
eval_output,
state,
],
outputs=log_table,
)
# Molecule Generation
with gr.Blocks() as molecule_generation:
gr.HTML(
'''
Task : Molecule Generation
Generate a new molecule similar to the initial molecule with better drug-likeness and synthetic accessibility.
'''
)
with gr.Row():
with gr.Column():
smiles_input = gr.Textbox(label="Input SMILES String")
image_display = gr.Image(label="Molecule Image", height=250, width=250)
# Show images for selection
with gr.Accordion("Select from sample molecules", open=False):
image_selector = gr.Radio(
choices=list(smiles_image_mapping.keys()),
label="Select from sample molecules",
value=None,
)
image_selector.change(load_image, image_selector, image_display)
clear_button = gr.Button("Clear")
generate_button = gr.Button("Submit", variant="primary")
# Right Column
with gr.Column():
gen_image_display = gr.Image(
label="Generated Molecule Image", height=250, width=250
)
generated_output = gr.Textbox(label="Generated Output")
property_table = gr.Dataframe(label="Molecular Properties Comparison")
# Handle image selection
image_selector.change(
handle_image_selection,
inputs=image_selector,
outputs=[smiles_input, image_display],
)
smiles_input.change(
smiles_to_image, inputs=smiles_input, outputs=image_display
)
# Generate button to display canonical SMILES and molecule image
generate_button.click(
generate_canonical,
inputs=smiles_input,
outputs=[property_table, generated_output, gen_image_display],
)
clear_button.click(
lambda: (None, None, None, None, None, None),
outputs=[
smiles_input,
image_display,
image_selector,
gen_image_display,
generated_output,
property_table,
],
)
# Render with tabs
gr.TabbedInterface(
[introduction, property_prediction, molecule_generation],
["Introduction", "Property Prediction", "Molecule Generation"],
).launch(server_name="0.0.0.0", allowed_paths=["./"])