ipd commited on
Commit
79d2379
·
1 Parent(s): b22c916
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. Dockerfile +13 -0
  3. Dockerfile-conda +13 -0
  4. INTRODUCTION.md +43 -0
  5. README.md +5 -3
  6. app.py +514 -0
  7. data/.DS_Store +0 -0
  8. data/bace/test.csv +0 -0
  9. data/bace/train.csv +0 -0
  10. data/bace/valid.csv +0 -0
  11. data/esol/test.csv +109 -0
  12. data/esol/train.csv +0 -0
  13. img/img1.png +0 -0
  14. img/img2.png +0 -0
  15. img/img3.png +0 -0
  16. img/img4.png +0 -0
  17. img/img5.png +0 -0
  18. img/selfies-ted.png +0 -0
  19. models/.DS_Store +0 -0
  20. models/fm4m.py +670 -0
  21. models/mhg_model/.DS_Store +0 -0
  22. models/mhg_model/README.md +75 -0
  23. models/mhg_model/__init__.py +5 -0
  24. models/mhg_model/graph_grammar/__init__.py +19 -0
  25. models/mhg_model/graph_grammar/algo/__init__.py +20 -0
  26. models/mhg_model/graph_grammar/algo/tree_decomposition.py +821 -0
  27. models/mhg_model/graph_grammar/graph_grammar/__init__.py +20 -0
  28. models/mhg_model/graph_grammar/graph_grammar/base.py +30 -0
  29. models/mhg_model/graph_grammar/graph_grammar/corpus.py +152 -0
  30. models/mhg_model/graph_grammar/graph_grammar/hrg.py +1065 -0
  31. models/mhg_model/graph_grammar/graph_grammar/symbols.py +180 -0
  32. models/mhg_model/graph_grammar/graph_grammar/utils.py +130 -0
  33. models/mhg_model/graph_grammar/hypergraph.py +544 -0
  34. models/mhg_model/graph_grammar/io/__init__.py +20 -0
  35. models/mhg_model/graph_grammar/io/smi.py +559 -0
  36. models/mhg_model/graph_grammar/nn/__init__.py +11 -0
  37. models/mhg_model/graph_grammar/nn/dataset.py +121 -0
  38. models/mhg_model/graph_grammar/nn/decoder.py +158 -0
  39. models/mhg_model/graph_grammar/nn/encoder.py +199 -0
  40. models/mhg_model/graph_grammar/nn/graph.py +313 -0
  41. models/mhg_model/images/mhg_example.png +0 -0
  42. models/mhg_model/images/mhg_example1.png +0 -0
  43. models/mhg_model/images/mhg_example2.png +0 -0
  44. models/mhg_model/load.py +84 -0
  45. models/mhg_model/mhg_gnn.egg-info/PKG-INFO +102 -0
  46. models/mhg_model/mhg_gnn.egg-info/SOURCES.txt +46 -0
  47. models/mhg_model/mhg_gnn.egg-info/dependency_links.txt +1 -0
  48. models/mhg_model/mhg_gnn.egg-info/requires.txt +7 -0
  49. models/mhg_model/mhg_gnn.egg-info/top_level.txt +2 -0
  50. models/mhg_model/models/__init__.py +5 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9.7
2
+
3
+ WORKDIR /app
4
+ COPY requirements.txt .
5
+ RUN pip install -r requirements.txt
6
+ # preload models
7
+ RUN python -c '\
8
+ from transformers import BartForConditionalGeneration, AutoTokenizer;\
9
+ AutoTokenizer.from_pretrained("ibm/materials.selfies-ted");\
10
+ BartForConditionalGeneration.from_pretrained("ibm/materials.selfies-ted")'
11
+ COPY . .
12
+
13
+ CMD ["python", "app.py"]
Dockerfile-conda ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM condaforge/miniforge3
2
+
3
+ WORKDIR /app
4
+ SHELL ["/bin/bash", "-i", "-c"]
5
+ RUN apt-get update && \
6
+ apt-get install -y build-essential libxrender1 libxext-dev
7
+ RUN conda create --name fm4m python=3.9.7
8
+ RUN conda activate fm4m
9
+ COPY requirements.txt .
10
+ RUN pip install -r requirements.txt
11
+ COPY . .
12
+
13
+ CMD ["python", "app.py"]
INTRODUCTION.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # selfies-ted
2
+
3
+ selfies-ted is an transformer based encoder decoder model for molecular representations using SELFIES.
4
+
5
+ ![selfies-ted](file/img/selfies-ted.png)
6
+
7
+
8
+ ## Usage
9
+
10
+ ### Import
11
+
12
+ ```
13
+ from transformers import AutoTokenizer, AutoModel
14
+ import selfies as sf
15
+ import torch
16
+ ```
17
+
18
+ ### Load the model and tokenizer
19
+ ```
20
+ tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted")
21
+ model = AutoModel.from_pretrained("ibm/materials.selfies-ted")
22
+ ```
23
+ ### Encode SMILES strings to selfies
24
+ ```
25
+ smiles = "c1ccccc1"
26
+ selfies = sf.encoder(smiles)
27
+ selfies = selfies.replace("][", "] [")
28
+
29
+ ```
30
+ ### Get embedding
31
+ ```
32
+ token = tokenizer(selfies, return_tensors='pt', max_length=128, truncation=True, padding='max_length')
33
+ input_ids = token['input_ids']
34
+ attention_mask = token['attention_mask']
35
+ outputs = model.encoder(input_ids=input_ids, attention_mask=attention_mask)
36
+ model_output = outputs.last_hidden_state
37
+
38
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(model_output.size()).float()
39
+ sum_embeddings = torch.sum(model_output * input_mask_expanded, 1)
40
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
41
+ model_output = sum_embeddings / sum_mask
42
+ ```
43
+
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
- title: Test
3
- emoji: 📈
4
- colorFrom: yellow
5
  colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.6.0
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Fm4m Demo Simplified
3
+ emoji: 🐨
4
+ colorFrom: green
5
  colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.6.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
+ short_description: FM4M Demo Simplified
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from PIL import Image
4
+ from rdkit import RDLogger
5
+ from molecule_generation_helpers import *
6
+ from property_prediction_helpers import *
7
+
8
+ RDLogger.logger().setLevel(RDLogger.ERROR)
9
+
10
+ # Predefined dataset paths (these should be adjusted to your file paths)
11
+ predefined_datasets = {
12
+ " ": " ",
13
+ "BACE": f"./data/bace/train.csv, ./data/bace/test.csv, smiles, Class",
14
+ "ESOL": f"./data/esol/train.csv, ./data/esol/test.csv, smiles, prop",
15
+ }
16
+
17
+ # Models
18
+ models_enabled = ["SELFIES-TED", "MHG-GED", "MolFormer", "SMI-TED"]
19
+
20
+ # Fusion Types
21
+ fusion_available = ["Concat"]
22
+
23
+
24
+ # Function to load a predefined dataset from the local path
25
+ def load_predefined_dataset(dataset_name):
26
+ val = predefined_datasets.get(dataset_name)
27
+ if val:
28
+ df = pd.read_csv(val.split(",")[0])
29
+ return (
30
+ df.head(),
31
+ gr.update(choices=list(df.columns)),
32
+ gr.update(choices=list(df.columns)),
33
+ dataset_name.lower(),
34
+ )
35
+ else:
36
+ return (
37
+ pd.DataFrame(),
38
+ gr.update(choices=[]),
39
+ gr.update(choices=[]),
40
+ f"Dataset not found",
41
+ )
42
+
43
+
44
+ # Function to handle dataset selection (predefined or custom)
45
+ def handle_dataset_selection(selected_dataset):
46
+ if selected_dataset == "Custom Dataset":
47
+ # Show file upload fields for train and test datasets if "Custom Dataset" is selected
48
+ return (
49
+ gr.update(visible=True),
50
+ gr.update(visible=True),
51
+ gr.update(visible=True),
52
+ gr.update(visible=True),
53
+ gr.update(visible=True),
54
+ gr.update(visible=False),
55
+ gr.update(visible=True),
56
+ gr.update(visible=True),
57
+ )
58
+ return (
59
+ gr.update(visible=True),
60
+ gr.update(visible=False),
61
+ gr.update(visible=False),
62
+ gr.update(visible=False),
63
+ gr.update(visible=False),
64
+ gr.update(visible=False),
65
+ gr.update(visible=False),
66
+ gr.update(visible=False),
67
+ )
68
+
69
+
70
+ # Dynamically show relevant hyperparameters based on selected model
71
+ def update_hyperparameters(model_name):
72
+ if model_name == "XGBClassifier":
73
+ return (
74
+ gr.update(visible=True),
75
+ gr.update(visible=True),
76
+ gr.update(visible=True),
77
+ gr.update(visible=False),
78
+ gr.update(visible=False),
79
+ )
80
+ elif model_name == "SVR":
81
+ return (
82
+ gr.update(visible=False),
83
+ gr.update(visible=False),
84
+ gr.update(visible=False),
85
+ gr.update(visible=True),
86
+ gr.update(visible=True),
87
+ )
88
+ elif model_name == "Kernel Ridge":
89
+ return (
90
+ gr.update(visible=False),
91
+ gr.update(visible=False),
92
+ gr.update(visible=True),
93
+ gr.update(visible=True),
94
+ gr.update(visible=True),
95
+ )
96
+ elif model_name == "Linear Regression":
97
+ return (
98
+ gr.update(visible=False),
99
+ gr.update(visible=False),
100
+ gr.update(visible=False),
101
+ gr.update(visible=False),
102
+ gr.update(visible=False),
103
+ )
104
+ elif model_name == "Default - Auto":
105
+ return (
106
+ gr.update(visible=False),
107
+ gr.update(visible=False),
108
+ gr.update(visible=False),
109
+ gr.update(visible=False),
110
+ gr.update(visible=False),
111
+ )
112
+
113
+
114
+ # Function to select input and output columns and display a message
115
+ def select_columns(input_column, output_column, train_data, test_data, dataset_name):
116
+ if input_column and output_column:
117
+ return f"{train_data.name},{test_data.name},{input_column},{output_column},{dataset_name}"
118
+ return "Please select both input and output columns."
119
+
120
+
121
+ # Function to set Dataset Name
122
+ def set_dataname(dataset_name, dataset_selector):
123
+ return dataset_name if dataset_selector == "Custom Dataset" else dataset_selector
124
+
125
+
126
+ # Function to display the head of the uploaded CSV file
127
+ def display_csv_head(file):
128
+ if file is not None:
129
+ # Load the CSV file into a DataFrame
130
+ df = pd.read_csv(file.name)
131
+ return (
132
+ df.head(),
133
+ gr.update(choices=list(df.columns)),
134
+ gr.update(choices=list(df.columns)),
135
+ )
136
+ return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[])
137
+
138
+
139
+ # Dictionary for SMILES strings and corresponding images (you can replace with your actual image paths)
140
+ smiles_image_mapping = {
141
+ # Example SMILES for ethanol
142
+ "Mol 1": {
143
+ "smiles": "C=C(C)CC(=O)NC[C@H](CO)NC(=O)C=Cc1ccc(C)c(Cl)c1",
144
+ "image": "img/img1.png",
145
+ },
146
+ # Example SMILES for butane
147
+ "Mol 2": {
148
+ "smiles": "C=CC1(CC(=O)NC[C@@H](CCCC)NC(=O)c2cc(Cl)cc(Br)c2)CC1",
149
+ "image": "img/img2.png",
150
+ },
151
+ # Example SMILES for ethylamine
152
+ "Mol 3": {
153
+ "smiles": "C=C(C)C[C@H](NC(C)=O)C(=O)N1CC[C@H](NC(=O)[C@H]2C[C@@]2(C)Br)C(C)(C)C1",
154
+ "image": "img/img3.png",
155
+ },
156
+ # Example SMILES for diethyl ether
157
+ "Mol 4": {
158
+ "smiles": "C=C1CC(CC(=O)N[C@H]2CCN(C(=O)c3ncccc3SC)C23CC3)C1",
159
+ "image": "img/img4.png",
160
+ },
161
+ # Example SMILES for chloroethane
162
+ "Mol 5": {
163
+ "smiles": "C=CCS[C@@H](C)CC(=O)OCC",
164
+ "image": "img/img5.png",
165
+ },
166
+ }
167
+
168
+
169
+ # Load images for selection
170
+ def load_image(path):
171
+ try:
172
+ return Image.open(smiles_image_mapping[path]["image"])
173
+ except:
174
+ pass
175
+
176
+
177
+ # Function to handle image selection
178
+ def handle_image_selection(image_key):
179
+ if not image_key:
180
+ return None, None
181
+ smiles = smiles_image_mapping[image_key]["smiles"]
182
+ mol_image = smiles_to_image(smiles)
183
+ return smiles, mol_image
184
+
185
+
186
+ # Introduction
187
+ with open("INTRODUCTION.md") as f:
188
+ # introduction = gr.Markdown(f.read())
189
+ with gr.Blocks() as introduction:
190
+ gr.Markdown(f.read())
191
+ gr.Markdown("---\n# Debug")
192
+ gr.HTML("HTML text: <img src='file/img/selfies-ted.png'>")
193
+ gr.Markdown("Markdown text: ![selfies-ted](file/img/selfies-ted.png)")
194
+ gr.HTML("HTML text: <img src='https://huggingface.co/front/assets/huggingface_logo-noborder.svg'>")
195
+ gr.Markdown("Markdown text: ![Huggingface Logo](https://huggingface.co/front/assets/huggingface_logo-noborder.svg)")
196
+
197
+ # Property Prediction
198
+ with gr.Blocks() as property_prediction:
199
+ log_df = pd.DataFrame(
200
+ {"": [], 'Selected Models': [], 'Dataset': [], 'Task': [], 'Result': []}
201
+ )
202
+ state = gr.State({"log_df": log_df})
203
+ gr.HTML(
204
+ '''
205
+ <p style="text-align: center">
206
+ Task : Property Prediction
207
+ <br>
208
+ Models are finetuned with different combination of modalities on the uploaded or selected built data set.
209
+ </p>
210
+ '''
211
+ )
212
+ with gr.Row():
213
+ with gr.Column():
214
+ # Dropdown menu for predefined datasets including "Custom Dataset" option
215
+ dataset_selector = gr.Dropdown(
216
+ label="Select Dataset",
217
+ choices=list(predefined_datasets.keys()) + ["Custom Dataset"],
218
+ )
219
+ # Display the message for selected columns
220
+ selected_columns_message = gr.Textbox(
221
+ label="Selected Columns Info", visible=False
222
+ )
223
+
224
+ with gr.Accordion("Dataset Settings", open=True):
225
+ # File upload options for custom dataset (train and test)
226
+ dataset_name = gr.Textbox(label="Dataset Name", visible=False)
227
+ train_file = gr.File(
228
+ label="Upload Custom Train Dataset",
229
+ file_types=[".csv"],
230
+ visible=False,
231
+ )
232
+ train_display = gr.Dataframe(
233
+ label="Train Dataset Preview (First 5 Rows)",
234
+ visible=False,
235
+ interactive=False,
236
+ )
237
+
238
+ test_file = gr.File(
239
+ label="Upload Custom Test Dataset",
240
+ file_types=[".csv"],
241
+ visible=False,
242
+ )
243
+ test_display = gr.Dataframe(
244
+ label="Test Dataset Preview (First 5 Rows)",
245
+ visible=False,
246
+ interactive=False,
247
+ )
248
+
249
+ # Predefined dataset displays
250
+ predefined_display = gr.Dataframe(
251
+ label="Predefined Dataset Preview (First 5 Rows)",
252
+ visible=False,
253
+ interactive=False,
254
+ )
255
+
256
+ # Dropdowns for selecting input and output columns for the custom dataset
257
+ input_column_selector = gr.Dropdown(
258
+ label="Select Input Column", choices=[], visible=False
259
+ )
260
+ output_column_selector = gr.Dropdown(
261
+ label="Select Output Column", choices=[], visible=False
262
+ )
263
+
264
+ # When a dataset is selected, show either file upload fields (for custom) or load predefined datasets
265
+ dataset_selector.change(
266
+ handle_dataset_selection,
267
+ inputs=dataset_selector,
268
+ outputs=[
269
+ dataset_name,
270
+ train_file,
271
+ train_display,
272
+ test_file,
273
+ test_display,
274
+ predefined_display,
275
+ input_column_selector,
276
+ output_column_selector,
277
+ ],
278
+ )
279
+
280
+ # When a predefined dataset is selected, load its head and update column selectors
281
+ dataset_selector.change(
282
+ load_predefined_dataset,
283
+ inputs=dataset_selector,
284
+ outputs=[
285
+ predefined_display,
286
+ input_column_selector,
287
+ output_column_selector,
288
+ selected_columns_message,
289
+ ],
290
+ )
291
+
292
+ # When a custom train file is uploaded, display its head and update column selectors
293
+ train_file.change(
294
+ display_csv_head,
295
+ inputs=train_file,
296
+ outputs=[
297
+ train_display,
298
+ input_column_selector,
299
+ output_column_selector,
300
+ ],
301
+ )
302
+
303
+ # When a custom test file is uploaded, display its head
304
+ test_file.change(
305
+ display_csv_head,
306
+ inputs=test_file,
307
+ outputs=[
308
+ test_display,
309
+ input_column_selector,
310
+ output_column_selector,
311
+ ],
312
+ )
313
+
314
+ dataset_selector.change(
315
+ set_dataname,
316
+ inputs=[dataset_name, dataset_selector],
317
+ outputs=dataset_name,
318
+ )
319
+
320
+ # Update the selected columns information when dropdown values are changed
321
+ input_column_selector.change(
322
+ select_columns,
323
+ inputs=[
324
+ input_column_selector,
325
+ output_column_selector,
326
+ train_file,
327
+ test_file,
328
+ dataset_name,
329
+ ],
330
+ outputs=selected_columns_message,
331
+ )
332
+
333
+ output_column_selector.change(
334
+ select_columns,
335
+ inputs=[
336
+ input_column_selector,
337
+ output_column_selector,
338
+ train_file,
339
+ test_file,
340
+ dataset_name,
341
+ ],
342
+ outputs=selected_columns_message,
343
+ )
344
+
345
+ model_checkbox = gr.CheckboxGroup(
346
+ choices=models_enabled, label="Select Model"
347
+ )
348
+
349
+ task_radiobutton = gr.Radio(
350
+ choices=["Classification", "Regression"], label="Task Type"
351
+ )
352
+
353
+ ####### adding hyper parameter tuning ###########
354
+ model_name = gr.Dropdown(
355
+ [
356
+ "Default - Auto",
357
+ "XGBClassifier",
358
+ "SVR",
359
+ "Kernel Ridge",
360
+ "Linear Regression",
361
+ ],
362
+ label="Select Downstream Model",
363
+ )
364
+ with gr.Accordion("Downstream Hyperparameter Settings", open=True):
365
+ # Create placeholders for hyperparameter components
366
+ max_depth = gr.Slider(1, 20, step=1, visible=False, label="max_depth")
367
+ n_estimators = gr.Slider(
368
+ 100, 5000, step=100, visible=False, label="n_estimators"
369
+ )
370
+ alpha = gr.Slider(0.1, 10.0, step=0.1, visible=False, label="alpha")
371
+ degree = gr.Slider(1, 20, step=1, visible=False, label="degree")
372
+ kernel = gr.Dropdown(
373
+ choices=["rbf", "poly", "linear"], visible=False, label="kernel"
374
+ )
375
+
376
+ # Output textbox
377
+ output = gr.Textbox(label="Loaded Parameters")
378
+
379
+ # When model is selected, update which hyperparameters are visible
380
+ model_name.change(
381
+ update_hyperparameters,
382
+ inputs=[model_name],
383
+ outputs=[max_depth, n_estimators, alpha, degree, kernel],
384
+ )
385
+
386
+ # Submit button to create the model with selected hyperparameters
387
+ submit_button = gr.Button("Create Downstream Model")
388
+
389
+ # When the submit button is clicked, run the on_submit function
390
+ submit_button.click(
391
+ create_downstream_model,
392
+ inputs=[model_name, max_depth, n_estimators, alpha, degree, kernel],
393
+ outputs=output,
394
+ )
395
+ ###### End of hyper param tuning #########
396
+
397
+ fusion_radiobutton = gr.Radio(choices=fusion_available, label="Fusion Type")
398
+
399
+ eval_button = gr.Button("Train downstream model")
400
+
401
+ # Right Column
402
+ with gr.Column():
403
+ eval_output = gr.Textbox(label="Train downstream model")
404
+
405
+ plot_radio = gr.Radio(
406
+ choices=["ROC-AUC", "Parity Plot", "Latent Space"],
407
+ label="Select Plot Type",
408
+ )
409
+ plot_output = gr.Plot(label="Visualization")
410
+
411
+ create_log = gr.Button("Store log")
412
+
413
+ log_table = gr.Dataframe(
414
+ value=log_df, label="Log of Selections and Results", interactive=False
415
+ )
416
+
417
+ eval_button.click(
418
+ display_eval,
419
+ inputs=[
420
+ model_checkbox,
421
+ selected_columns_message,
422
+ task_radiobutton,
423
+ output,
424
+ fusion_radiobutton,
425
+ state,
426
+ ],
427
+ outputs=eval_output,
428
+ )
429
+
430
+ plot_radio.change(
431
+ display_plot, inputs=[plot_radio, state], outputs=plot_output
432
+ )
433
+
434
+ create_log.click(
435
+ evaluate_and_log,
436
+ inputs=[
437
+ model_checkbox,
438
+ dataset_name,
439
+ task_radiobutton,
440
+ eval_output,
441
+ state,
442
+ ],
443
+ outputs=log_table,
444
+ )
445
+
446
+
447
+ # Molecule Generation
448
+ with gr.Blocks() as molecule_generation:
449
+ gr.HTML(
450
+ '''
451
+ <p style="text-align: center">
452
+ Task : Molecule Generation
453
+ <br>
454
+ Generate a new molecule similar to the initial molecule with better drug-likeness and synthetic accessibility.
455
+ </p>
456
+ '''
457
+ )
458
+ with gr.Row():
459
+ with gr.Column():
460
+ smiles_input = gr.Textbox(label="Input SMILES String")
461
+ image_display = gr.Image(label="Molecule Image", height=250, width=250)
462
+ # Show images for selection
463
+ with gr.Accordion("Select from sample molecules", open=False):
464
+ image_selector = gr.Radio(
465
+ choices=list(smiles_image_mapping.keys()),
466
+ label="Select from sample molecules",
467
+ value=None,
468
+ )
469
+ image_selector.change(load_image, image_selector, image_display)
470
+ clear_button = gr.Button("Clear")
471
+ generate_button = gr.Button("Submit", variant="primary")
472
+
473
+ # Right Column
474
+ with gr.Column():
475
+ gen_image_display = gr.Image(
476
+ label="Generated Molecule Image", height=250, width=250
477
+ )
478
+ generated_output = gr.Textbox(label="Generated Output")
479
+ property_table = gr.Dataframe(label="Molecular Properties Comparison")
480
+
481
+ # Handle image selection
482
+ image_selector.change(
483
+ handle_image_selection,
484
+ inputs=image_selector,
485
+ outputs=[smiles_input, image_display],
486
+ )
487
+ smiles_input.change(
488
+ smiles_to_image, inputs=smiles_input, outputs=image_display
489
+ )
490
+
491
+ # Generate button to display canonical SMILES and molecule image
492
+ generate_button.click(
493
+ generate_canonical,
494
+ inputs=smiles_input,
495
+ outputs=[property_table, generated_output, gen_image_display],
496
+ )
497
+ clear_button.click(
498
+ lambda: (None, None, None, None, None, None),
499
+ outputs=[
500
+ smiles_input,
501
+ image_display,
502
+ image_selector,
503
+ gen_image_display,
504
+ generated_output,
505
+ property_table,
506
+ ],
507
+ )
508
+
509
+
510
+ # Render with tabs
511
+ gr.TabbedInterface(
512
+ [introduction, property_prediction, molecule_generation],
513
+ ["Introduction", "Property Prediction", "Molecule Generation"],
514
+ ).launch(server_name="0.0.0.0", allowed_paths=["./"])
data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data/bace/test.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/bace/train.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/bace/valid.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/esol/test.csv ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,selfies,prop,smiles
2
+ 0,[Cl] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [C] [C] [C] [C] [Branch1] [Branch2] [C] [O] [C] [Ring1] [=Branch1] [Ring1] [Ring1] [C] [Ring1] [Branch2] [C] [Ring1] [=C] [Branch1] [C] [Cl] [C] [Ring1] [=N] [Branch1] [C] [Cl] [Cl],-4.533,ClC4=C(Cl)C5(Cl)C3C1CC(C2OC12)C3C4(Cl)C5(Cl)Cl
3
+ 1,[C] [C] [C] [C] [C] [=O],-1.103,CCCCC=O
4
+ 2,[O] [C] [C] [C] [C] [=C],-0.7909999999999999,OCCCC=C
5
+ 3,[C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [N] [N] [=C] [C] [Branch1] [C] [N] [=C] [Branch1] [C] [Br] [C] [Ring1] [Branch2] [=O],-3.005,c1ccccc1n2ncc(N)c(Br)c2(=O)
6
+ 4,[N] [C] [=C] [C] [=C] [Branch1] [C] [O] [C] [=C] [Ring1] [#Branch1],-1.231,Nc1ccc(O)cc1
7
+ 5,[C] [C] [Branch1] [C] [C] [C] [C] [O] [C] [=Branch1] [C] [=O] [C],-1.817,CC(C)CCOC(=O)C
8
+ 6,[C] [O] [P] [=Branch1] [C] [=S] [Branch1] [Ring1] [O] [C] [S] [C] [C] [=Branch1] [C] [=O] [N] [Branch1] [C] [C] [C] [=O],-2.087,COP(=S)(OC)SCC(=O)N(C)C=O
9
+ 7,[Cl] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=Branch1] [Ring2] [=C] [Ring1] [#Branch1] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-6.312,Clc1ccc(Cl)c(c1)c2ccc(Cl)c(Cl)c2
10
+ 8,[C] [Branch1] [C] [Cl] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [N] [=C] [C] [=N] [C] [Ring1] [=Branch1] [=C] [Ring1] [=N] [Cl],-4.438,c2(Cl)c(Cl)c(Cl)c1nccnc1c2(Cl)
11
+ 9,[C] [C] [C] [C] [C] [=C] [Branch1] [C] [C] [N] [=C] [Branch1] [=Branch1] [N] [=C] [Ring1] [#Branch1] [O] [N] [Branch1] [C] [C] [C],-3.57,CCCCc1c(C)nc(nc1O)N(C)C
12
+ 10,[C] [C] [O] [C] [=Branch1] [C] [=O] [C] [C] [=Branch1] [C] [=O] [O] [C] [C],-1.413,CCOC(=O)CC(=O)OCC
13
+ 11,[C] [C] [Branch1] [C] [C] [Branch1] [C] [C] [C] [=C] [C] [=C] [Branch1] [C] [O] [C] [=C] [Ring1] [#Branch1],-3.192,CC(C)(C)c1ccc(O)cc1
14
+ 12,[C] [C] [=C] [C] [=C] [C] [Branch1] [C] [C] [=C] [Ring1] [#Branch1],-3.035,Cc1cccc(C)c1
15
+ 13,[C] [C] [C] [O] [C] [=Branch1] [C] [=O] [C],-1.125,CCCOC(=O)C
16
+ 14,[C] [S] [C] [=N] [N] [=C] [Branch1] [=Branch2] [C] [=Branch1] [C] [=O] [N] [Ring1] [#Branch1] [N] [C] [Branch1] [C] [C] [Branch1] [C] [C] [C],-2.324,CSc1nnc(c(=O)n1N)C(C)(C)C
17
+ 15,[Cl] [C] [=C] [C] [=C] [Branch1] [Branch1] [C] [=C] [Ring1] [=Branch1] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [Cl],-5.142,Clc1ccc(cc1)c2ccccc2Cl
18
+ 16,[C] [C] [C] [C] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [C] [Branch1] [Ring2] [C] [Ring1] [Branch2] [C] [Branch1] [C] [O] [C] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [Branch2],-1.5319999999999998,CC1CC(C)C(=O)C(C1)C(O)CC2CC(=O)NC(=O)C2
19
+ 17,[C] [N] [C] [=Branch1] [C] [=O] [O] [C] [=C] [C] [=C] [C] [Branch1] [Branch2] [N] [=C] [N] [Branch1] [C] [C] [C] [=C] [Ring1] [O],-1.846,CNC(=O)Oc1cccc(N=CN(C)C)c1
20
+ 18,[C] [C] [=C] [C] [=N] [C] [N] [Branch1] [=Branch1] [C] [C] [C] [Ring1] [Ring1] [C] [=N] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [=Branch1] [C] [=O] [N] [C] [Ring2] [Ring1] [Ring1] [=Ring1] [#C],-3.397,Cc3ccnc4N(C1CC1)c2ncccc2C(=O)Nc34
21
+ 19,[C] [C] [N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1],-2.389,CCNc1ccccc1
22
+ 20,[C] [C] [=C] [C] [=C] [C] [=C] [C] [Ring1] [=Branch1] [=C] [Branch1] [C] [C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [Ring1] [=Branch1] [=C] [Ring2] [Ring1] [Ring1] [Ring1] [#Branch2],-6.297000000000001,Cc1c2ccccc2c(C)c3ccc4ccccc4c13
23
+ 21,[F] [C] [=C] [C] [=C] [C] [Branch1] [C] [F] [=C] [Ring1] [#Branch1] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [Branch1] [C] [Cl] [=C] [Branch1] [C] [F] [C] [Branch1] [C] [Cl] [=C] [Ring1] [=Branch2] [F],-5.462000000000001,Fc1cccc(F)c1C(=O)NC(=O)Nc2cc(Cl)c(F)c(Cl)c2F
24
+ 22,[C] [O] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1],-3.057,COc1ccc(Cl)cc1
25
+ 23,[O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=N] [Ring1] [=Branch1],-4.2010000000000005,o1c2ccccc2c3ccccc13
26
+ 24,[C] [=C] [C] [=C] [N] [=C] [C] [=C] [C] [=C] [C] [Ring1] [=Branch1] [=C] [C] [Ring1] [#Branch2] [=C] [Ring1] [=C],-3.846,c3ccc2nc1ccccc1cc2c3
27
+ 25,[C] [C] [C] [C] [=Branch1] [C] [=O] [C] [C] [Branch1] [P] [C] [C] [C] [=C] [C] [=Branch1] [C] [=O] [C] [C] [C] [Ring1] [O] [Ring1] [#Branch1] [C] [C] [Ring1] [P] [C] [C] [C] [Ring2] [Ring1] [Ring2] [Branch1] [C] [O] [C] [=Branch1] [C] [=O] [C] [O],-2.893,CC12CC(=O)C3C(CCC4=CC(=O)CCC34C)C2CCC1(O)C(=O)CO
28
+ 26,[C] [C] [C] [=C] [C] [=C] [C] [Branch1] [Ring1] [C] [C] [=C] [Ring1] [Branch2] [N] [Branch1] [Ring2] [C] [O] [C] [C] [=Branch1] [C] [=O] [C] [Cl],-3.319,CCc1cccc(CC)c1N(COC)C(=O)CCl
29
+ 27,[C] [C] [C] [C] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-4.157,CCCCN(C)C(=O)Nc1ccc(Cl)c(Cl)c1
30
+ 28,[C] [S] [C] [=Branch1] [C] [=S] [N] [C] [Ring1] [=Branch1] [=O],-0.396,C1SC(=S)NC1(=O)
31
+ 29,[O] [C] [=C] [C] [=C] [Branch1] [Branch2] [C] [Branch1] [C] [O] [=C] [Ring1] [#Branch1] [C] [O] [C] [=C] [C] [Branch1] [C] [O] [=C] [C] [Branch1] [C] [O] [=C] [Ring1] [Branch2] [C] [=Branch1] [C] [=O] [C] [=Ring1] [=N] [O],-2.7310000000000003,Oc1ccc(c(O)c1)c3oc2cc(O)cc(O)c2c(=O)c3O
32
+ 30,[C] [N] [Branch1] [C] [C] [C] [=N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1] [C],-3.164,CN(C)C=Nc1ccc(Cl)cc1C
33
+ 31,[N] [C] [=Branch1] [C] [=O] [N] [C] [N] [C] [=Branch1] [C] [=O] [N] [C] [Ring1] [=Branch1] [=O],0.652,NC(=O)NC1NC(=O)NC1=O
34
+ 32,[Cl] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [#Branch2] [Ring1] [=Branch1],-4.063,Clc1cccc2ccccc12
35
+ 33,[O] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-3.352,Oc1ccc(Cl)c(Cl)c1
36
+ 34,[C] [C] [Branch1] [C] [C] [C] [Branch1] [#Branch1] [C] [=C] [Branch1] [C] [Cl] [Cl] [C] [Ring1] [Branch2] [C] [=Branch1] [C] [=O] [O] [C] [Branch1] [Ring1] [C] [#N] [C] [=C] [C] [=C] [C] [Branch1] [#Branch2] [O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [=C] [Ring1] [=N],-6.775,CC1(C)C(C=C(Cl)Cl)C1C(=O)OC(C#N)c2cccc(Oc3ccccc3)c2
37
+ 35,[C] [=C] [C] [=C] [NH1] [N] [=N] [C] [Ring1] [Branch1] [=C] [Ring1] [=Branch2],-2.21,c2ccc1[nH]nnc1c2
38
+ 36,[C] [C] [Branch1] [C] [C] [C] [Branch2] [Ring1] [Branch1] [N] [C] [=C] [C] [=C] [Branch1] [=Branch1] [C] [=C] [Ring1] [=Branch1] [Cl] [C] [Branch1] [C] [F] [Branch1] [C] [F] [F] [C] [=Branch1] [C] [=O] [O] [C] [Branch1] [Ring1] [C] [#N] [C] [=C] [C] [=C] [C] [Branch1] [#Branch2] [O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [=C] [Ring1] [=N],-8.057,CC(C)C(Nc1ccc(cc1Cl)C(F)(F)F)C(=O)OC(C#N)c2cccc(Oc3ccccc3)c2
39
+ 37,[C] [C] [C],-1.5530000000000002,CCC
40
+ 38,[C] [C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [Ring1] [O] [=C] [Ring1] [#Branch2] [Ring1] [=Branch1],-3.792,C1Cc2cccc3cccc1c23
41
+ 39,[C] [C] [C] [#C],-1.092,CCC#C
42
+ 40,[Cl] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1],-3.5580000000000003,Clc1ccc(Cl)cc1
43
+ 41,[C] [C] [=C] [NH1] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch2] [Ring1] [=Branch1],-2.9810000000000003,Cc1c[nH]c2ccccc12
44
+ 42,[C] [C] [#N],0.152,CC#N
45
+ 43,[C] [C] [C] [C] [O],-0.688,CCCCO
46
+ 44,[C] [C] [=Branch1] [C] [=C] [C] [=Branch1] [C] [=C] [C],-2.052,CC(=C)C(=C)C
47
+ 45,[C] [C] [C] [Branch1] [C] [C] [C] [C] [O],-1.308,CCC(C)CCO
48
+ 46,[Cl] [C] [=C] [C] [=C] [Branch1] [=Branch2] [C] [Branch1] [C] [Cl] [=C] [Ring1] [#Branch1] [Cl] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2] [Cl],-7.192,Clc1ccc(c(Cl)c1Cl)c2ccc(Cl)c(Cl)c2Cl
49
+ 47,[C] [C] [=C] [C] [=Branch2] [Ring1] [=Branch1] [=C] [C] [=C] [Ring1] [=Branch1] [N] [S] [=Branch1] [C] [=O] [=Branch1] [C] [=O] [C] [Branch1] [C] [F] [Branch1] [C] [F] [F] [S] [=Branch1] [C] [=O] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1],-4.945,Cc1cc(ccc1NS(=O)(=O)C(F)(F)F)S(=O)(=O)c2ccccc2
50
+ 48,[O] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1] [Cl],-3.22,Oc1ccc(Cl)cc1Cl
51
+ 49,[C] [N] [C] [=Branch2] [Ring1] [Ring2] [=C] [Branch1] [C] [O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [S] [Ring1] [O] [=Branch1] [C] [=O] [=O] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [C] [=N] [Ring1] [=Branch1],-3.4730000000000003,CN2C(=C(O)c1ccccc1S2(=O)=O)C(=O)Nc3ccccn3
52
+ 50,[C] [C] [C] [C] [C] [C] [Branch1] [S] [C] [C] [C] [=C] [C] [Branch1] [C] [O] [=C] [C] [=C] [Ring1] [O] [Ring1] [#Branch1] [C] [Ring1] [#C] [C] [C] [C] [Ring2] [Ring1] [C] [=O],-3.872,CC12CCC3C(CCc4cc(O)ccc34)C2CCC1=O
53
+ 51,[C] [C] [=C] [C] [=C] [C] [=C] [Branch1] [C] [C] [C] [=C] [C] [=C] [Ring1] [O] [Ring1] [#Branch1],-4.147,Cc1cccc2c(C)cccc12
54
+ 52,[N] [S] [=Branch1] [C] [=O] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [Branch1] [O] [N] [C] [N] [S] [Ring1] [=Branch1] [=Branch1] [C] [=O] [=O] [C] [=C] [Ring1] [N] [Cl],-1.72,NS(=O)(=O)c2cc1c(NCNS1(=O)=O)cc2Cl
55
+ 53,[O] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [N] [=C] [Ring1] [#Branch2] [Ring1] [=Branch1],-2.725,Oc1cccc2cccnc12
56
+ 54,[C] [C] [C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [Ring1] [#Branch2],-3.447,C1CCc2ccccc2C1
57
+ 55,[C] [C] [O] [C] [Branch1] [C] [C] [O] [C] [C],-0.899,CCOC(C)OCC
58
+ 56,[C] [C] [C] [C] [Ring1] [Ring1] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [Branch1] [Branch1] [C] [Ring1] [Branch2] [=O] [C] [=C] [C] [Branch1] [C] [Cl] [=C] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-3.464,CC12CC2(C)C(=O)N(C1=O)c3cc(Cl)cc(Cl)c3
59
+ 57,[C] [C] [=C] [C] [=C] [C] [=C] [C] [Ring1] [=Branch1] [=C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=C] [Ring1] [=Branch1],-4.87,Cc1c2ccccc2cc3ccccc13
60
+ 58,[C] [C] [C] [C] [O] [C],-1.072,CCCCOC
61
+ 59,[C] [C] [C] [C] [C] [=Branch1] [C] [=O] [C] [=C] [Ring1] [#Branch1] [C] [C] [C] [C] [C] [C] [C] [Branch1] [#Branch1] [C] [=Branch1] [C] [=O] [C] [O] [C] [Ring1] [=Branch2] [Branch1] [N] [C] [C] [Branch1] [C] [O] [C] [Ring2] [Ring1] [#Branch1] [Ring1] [=C] [C] [=O],-3.0660000000000003,CC13CCC(=O)C=C1CCC4C2CCC(C(=O)CO)C2(CC(O)C34)C=O
62
+ 60,[C] [C] [C] [Branch1] [=Branch1] [C] [Branch1] [C] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [N] [C] [Ring1] [O] [=O],-1.6030000000000002,CCC1(C(C)C)C(=O)NC(=O)NC1=O
63
+ 61,[C] [C] [O] [C] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [Branch1] [C] [O] [C] [=C] [Ring1] [#Branch1],-2.761,CCOC(=O)c1ccc(O)cc1
64
+ 62,[C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [Ring2] [Ring1] [C] [C] [=C] [Ring2] [Ring1] [C] [C] [Ring1] [S] [=C] [Ring1] [=C] [C] [Ring1] [N] [=C] [Ring1] [#Branch2] [Ring1] [=Branch1],-6.885,c1cc2ccc3ccc4ccc5ccc6ccc1c7c2c3c4c5c67
65
+ 63,[C] [C] [N] [C] [=C] [C] [Branch1] [=Branch1] [N] [Branch1] [C] [C] [C] [=C] [C] [Branch1] [C] [C] [=C] [Ring1] [#Branch2] [N] [C] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [N] [=C] [Ring2] [Ring1] [Ring2] [Ring1] [=Branch1],-4.408,CCN2c1cc(N(C)C)cc(C)c1NC(=O)c3cccnc23
66
+ 64,[C] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-3.301,CN(C)C(=O)Nc1ccc(Cl)c(Cl)c1
67
+ 65,[C] [C] [C] [C] [C] [C] [Branch1] [C] [C] [C],-3.3080000000000003,CCCCCC(C)C
68
+ 66,[C] [O] [C] [=C] [C] [=C] [Branch1] [C] [N] [N] [=C] [Branch1] [#C] [N] [=C] [Ring1] [#Branch1] [C] [Branch1] [Ring1] [O] [C] [=C] [Ring1] [=N] [O] [C] [N] [C] [C] [N] [Branch1] [Branch1] [C] [C] [Ring1] [=Branch1] [C] [=Branch1] [C] [=O] [O] [C] [C] [Branch1] [C] [C] [Branch1] [C] [C] [O],-3.958,COc2cc1c(N)nc(nc1c(OC)c2OC)N3CCN(CC3)C(=O)OCC(C)(C)O
69
+ 67,[C] [=C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [#Branch1] [C] [=C] [Ring1] [O],-0.636,c1cC2C(=O)NC(=O)C2cc1
70
+ 68,[C] [C] [C] [=O],-0.3939999999999999,CCC=O
71
+ 69,[Cl] [C] [=C] [C] [=C] [Branch2] [Ring1] [=Branch2] [C] [N] [Branch1] [Branch2] [C] [C] [C] [C] [C] [Ring1] [Branch1] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [=C] [Ring2] [Ring1] [=Branch1],-5.126,Clc1ccc(CN(C2CCCC2)C(=O)Nc3ccccc3)cc1
72
+ 70,[C] [C] [C] [C] [C] [Branch1] [Ring1] [C] [C] [C] [=O],-2.232,CCCCC(CC)C=O
73
+ 71,[O] [=C] [N] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [Branch2] [Branch1] [Ring1] [C] [C] [C] [C] [C] [Branch1] [C] [C] [C],-2.312,O=C1NC(=O)NC(=O)C1(CC)CCC(C)C
74
+ 72,[C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1],-1.857,CC(=O)Nc1ccccc1
75
+ 73,[C] [=N] [C] [=C] [C] [Branch1] [#Branch1] [C] [=Branch1] [C] [=O] [N] [N] [=C] [Ring1] [#Branch2],-0.7170000000000001,c1nccc(C(=O)NN)c1
76
+ 74,[C] [C] [Branch1] [C] [C] [C] [C] [C] [C] [Branch1] [C] [C] [Branch1] [Ring2] [C] [Ring1] [=Branch1] [C] [Ring1] [=Branch2] [=O],-2.158,CC2(C)C1CCC(C)(C1)C2=O
77
+ 75,[C] [O] [C] [=C] [N] [=C] [C] [=N] [C] [=N] [C] [Ring1] [=Branch1] [=N] [Ring1] [#Branch2],-1.589,COc2cnc1cncnc1n2
78
+ 76,[C] [N] [C] [=Branch1] [C] [=O] [C] [=C] [Branch1] [C] [C] [O] [P] [=Branch1] [C] [=O] [Branch1] [Ring1] [O] [C] [O] [C],-0.949,CNC(=O)C=C(C)OP(=O)(OC)OC
79
+ 77,[O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [N] [Branch1] [Ring1] [C] [C] [C] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [C] [=C] [Ring2] [Ring1] [C] [Ring1] [=Branch1],-3.784,O2c1ccccc1N(CC)C(=O)c3ccccc23
80
+ 78,[C] [=C] [C] [=C] [C] [=C] [Branch1] [Ring1] [O] [C] [C] [Branch1] [Branch2] [C] [C] [=C] [Branch1] [C] [C] [C] [=C] [Ring1] [=N] [O] [C] [Ring1] [P] [=O],-4.0760000000000005,c1cc2ccc(OC)c(CC=C(C)(C))c2oc1=O
81
+ 79,[C] [C] [C] [S] [C] [C] [C],-2.307,CCCSCCC
82
+ 80,[C] [O] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1],-2.948,CON(C)C(=O)Nc1ccc(Cl)cc1
83
+ 81,[C] [C] [O] [C] [C],-0.718,CCOCC
84
+ 82,[C] [C] [C] [C] [C] [C] [Branch1] [S] [C] [C] [C] [=C] [C] [Branch1] [C] [O] [=C] [C] [=C] [Ring1] [O] [Ring1] [#Branch1] [C] [Ring1] [#C] [C] [C] [Branch1] [C] [O] [C] [Ring2] [Ring1] [Ring1] [O],-3.858,CC34CCC1C(CCc2cc(O)ccc12)C3CC(O)C4O
85
+ 83,[C] [C] [N] [C] [=N] [C] [Branch1] [C] [Cl] [=N] [C] [Branch1] [O] [N] [C] [Branch1] [C] [C] [Branch1] [C] [C] [C] [#N] [=N] [Ring1] [=N],-2.49,CCNc1nc(Cl)nc(NC(C)(C)C#N)n1
86
+ 84,[C] [C] [Branch1] [C] [C] [C] [C] [Branch1] [C] [C] [Branch1] [C] [C] [O],-1.6469999999999998,CC(C)CC(C)(C)O
87
+ 85,[Cl] [C] [=C] [C] [=C] [C] [Branch1] [C] [Br] [=C] [Ring1] [#Branch1],-3.928,Clc1cccc(Br)c1
88
+ 86,[C] [C] [C] [C] [C] [C] [Branch1] [C] [O] [C] [C],-2.033,CCCCCC(O)CC
89
+ 87,[O] [=C] [N] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [Branch2] [Branch1] [Ring1] [C] [C] [C] [C] [=C] [Branch1] [C] [C] [C],-2.126,O=C1NC(=O)NC(=O)C1(CC)CC=C(C)C
90
+ 88,[C] [C] [C] [Branch1] [C] [C] [C] [Branch1] [#Branch1] [C] [C] [Branch1] [C] [Br] [=C] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [N] [C] [Ring1] [N] [=O],-2.766,CCC(C)C1(CC(Br)=C)C(=O)NC(=O)NC1=O
91
+ 89,[C] [O] [C] [=Branch1] [C] [=O] [C],-0.416,COC(=O)C
92
+ 90,[C] [C] [Branch1] [C] [C] [C] [=C] [C] [=C] [Branch1] [C] [C] [C] [=C] [Ring1] [#Branch1] [O],-3.129,CC(C)c1ccc(C)cc1O
93
+ 91,[C],-0.636,C
94
+ 92,[N] [C] [=N] [C] [Branch1] [C] [O] [=N] [C] [N] [=C] [NH1] [C] [Ring1] [#Branch2] [=Ring1] [Branch1],-1.74,Nc1nc(O)nc2nc[nH]c12
95
+ 93,[F] [C] [=C] [C] [=C] [C] [Branch1] [C] [F] [=C] [Ring1] [#Branch1] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1],-4.692,Fc1cccc(F)c1C(=O)NC(=O)Nc2ccc(Cl)cc2
96
+ 94,[C] [C] [C] [C] [C] [Branch1] [Branch1] [C] [C] [Ring1] [=Branch1] [C] [Branch1] [C] [C] [Branch1] [C] [C] [O] [Ring1] [#Branch2],-2.579,CC12CCC(CC1)C(C)(C)O2
97
+ 95,[C] [C] [O],0.02,CCO
98
+ 96,[C] [=C] [Branch2] [Ring1] [C] [N] [C] [=Branch1] [C] [=O] [O] [C] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [C] [C] [=C] [C] [=C] [Ring1] [P],-2.29,c1c(NC(=O)OC(C)C(=O)NCC)cccc1
99
+ 97,[C] [C] [Branch1] [C] [C] [=C] [C] [C] [Branch2] [Ring1] [#Branch2] [C] [=Branch1] [C] [=O] [O] [C] [C] [=C] [C] [=C] [C] [Branch1] [#Branch2] [O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [=C] [Ring1] [=N] [C] [Ring2] [Ring1] [Ring2] [Branch1] [C] [C] [C],-6.763,CC(C)=CC3C(C(=O)OCc2cccc(Oc1ccccc1)c2)C3(C)C
100
+ 98,[C] [C] [C] [C] [N] [C] [=Branch1] [C] [=O] [N] [C] [Branch1] [Branch2] [N] [C] [=Branch1] [C] [=O] [O] [C] [=N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=C] [Ring1] [=Branch1],-2.902,CCCCNC(=O)n1c(NC(=O)OC)nc2ccccc12
101
+ 99,[C] [N] [Branch1] [C] [C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1],-2.542,CN(C)c1ccccc1
102
+ 100,[C] [O] [C] [=Branch1] [C] [=O] [C] [=C],-0.878,COC(=O)C=C
103
+ 101,[C] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [=N] [O] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1] [C] [=C] [Ring1] [=C],-4.477,CN(C)C(=O)Nc2ccc(Oc1ccc(Cl)cc1)cc2
104
+ 102,[O] [=C] [N] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [Branch2] [Branch1] [=Branch1] [C] [Branch1] [C] [C] [C] [C] [C] [=C] [Branch1] [C] [C] [C],-2.465,O=C1NC(=O)NC(=O)C1(C(C)C)CC=C(C)C
105
+ 103,[C] [C] [=C] [C] [=C] [Branch1] [C] [O] [C] [=C] [Ring1] [#Branch1] [C],-2.6210000000000004,Cc1ccc(O)cc1C
106
+ 104,[Cl] [C] [=C] [C] [=C] [C] [=Branch1] [Ring2] [=N] [Ring1] [=Branch1] [C] [Branch1] [C] [Cl] [Branch1] [C] [Cl] [Cl],-3.833,Clc1cccc(n1)C(Cl)(Cl)Cl
107
+ 105,[C] [C] [=Branch1] [C] [=O] [O] [C] [Branch2] [Ring1] [=C] [C] [C] [C] [C] [C] [C] [C] [=C] [C] [=Branch1] [C] [=O] [C] [C] [C] [Ring1] [#Branch1] [C] [Ring1] [O] [C] [C] [C] [Ring2] [Ring1] [C] [Ring1] [#C] [C] [C] [#C],-4.2410000000000005,CC(=O)OC3(CCC4C2CCC1=CC(=O)CCC1C2CCC34C)C#C
108
+ 106,[C] [N] [C] [=Branch1] [C] [=O] [O] [N] [=C] [Branch1] [Ring2] [C] [S] [C] [C] [Branch1] [C] [C] [Branch1] [C] [C] [C],-2.7,CNC(=O)ON=C(CSC)C(C)(C)C
109
+ 107,[C] [C] [C] [C] [C] [C] [C] [Branch1] [C] [C] [O],-2.033,CCCCCCC(C)O
data/esol/train.csv ADDED
The diff for this file is too large to render. See raw diff
 
img/img1.png ADDED
img/img2.png ADDED
img/img3.png ADDED
img/img4.png ADDED
img/img5.png ADDED
img/selfies-ted.png ADDED
models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/fm4m.py ADDED
@@ -0,0 +1,670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.metrics import roc_auc_score, roc_curve
2
+
3
+ import datetime
4
+ import os
5
+ import umap
6
+ import numpy as np
7
+
8
+ import matplotlib.pyplot as plt
9
+ import pandas as pd
10
+ import pickle
11
+ import json
12
+
13
+ from xgboost import XGBClassifier, XGBRegressor
14
+ import xgboost as xgb
15
+ from sklearn.metrics import roc_auc_score, mean_squared_error
16
+ import xgboost as xgb
17
+ from sklearn.svm import SVR
18
+ from sklearn.linear_model import LinearRegression
19
+ from sklearn.kernel_ridge import KernelRidge
20
+ import json
21
+ from sklearn.compose import TransformedTargetRegressor
22
+ from sklearn.preprocessing import MinMaxScaler
23
+
24
+
25
+ import torch
26
+ from transformers import AutoTokenizer, AutoModel
27
+
28
+ from .selfies_model.load import SELFIES as bart
29
+ from .mhg_model import load as mhg
30
+ from .smi_ted.smi_ted_light.load import load_smi_ted
31
+
32
+ datasets = {}
33
+ models = {}
34
+ downstream_models ={}
35
+
36
+
37
+ def avail_models_data():
38
+ global datasets
39
+ global models
40
+
41
+ datasets = [{"Dataset": "hiv", "Input": "smiles", "Output": "HIV_active", "Path": "data/hiv", "Timestamp": "2024-06-26 11:27:37"},
42
+ {"Dataset": "esol", "Input": "smiles", "Output": "ESOL predicted log solubility in mols per litre", "Path": "data/esol", "Timestamp": "2024-06-26 11:31:46"},
43
+ {"Dataset": "freesolv", "Input": "smiles", "Output": "expt", "Path": "data/freesolv", "Timestamp": "2024-06-26 11:33:47"},
44
+ {"Dataset": "lipo", "Input": "smiles", "Output": "y", "Path": "data/lipo", "Timestamp": "2024-06-26 11:34:37"},
45
+ {"Dataset": "bace", "Input": "smiles", "Output": "Class", "Path": "data/bace", "Timestamp": "2024-06-26 11:36:40"},
46
+ {"Dataset": "bbbp", "Input": "smiles", "Output": "p_np", "Path": "data/bbbp", "Timestamp": "2024-06-26 11:39:23"},
47
+ {"Dataset": "clintox", "Input": "smiles", "Output": "CT_TOX", "Path": "data/clintox", "Timestamp": "2024-06-26 11:42:43"}]
48
+
49
+
50
+ models = [{"Name": "bart","Model Name": "SELFIES-TED","Description": "BART model for string based SELFIES modality", "Timestamp": "2024-06-21 12:32:20"},
51
+ {"Name": "mol-xl","Model Name": "Molformer", "Description": "MolFormer model for string based SMILES modality", "Timestamp": "2024-06-21 12:35:56"},
52
+ {"Name": "mhg", "Model Name": "MHG-GED","Description": "Molecular hypergraph model", "Timestamp": "2024-07-10 00:09:42"},
53
+ {"Name": "smi-ted", "Model Name": "SMI-TED","Description": "SMILES based encoder decoder model", "Timestamp": "2024-07-10 00:09:42"}]
54
+
55
+
56
+ def avail_models(raw=False):
57
+ global models
58
+
59
+ models = [{"Name": "smi-ted", "Model Name": "SMI-TED","Description": "SMILES based encoder decoder model"},
60
+ {"Name": "bart","Model Name": "SELFIES-TED","Description": "BART model for string based SELFIES modality"},
61
+ {"Name": "mol-xl","Model Name": "Molformer", "Description": "MolFormer model for string based SMILES modality"},
62
+ {"Name": "mhg", "Model Name": "MHG-GED","Description": "Molecular hypergraph model"},
63
+ ]
64
+
65
+
66
+
67
+ if raw: return models
68
+ else:
69
+ return pd.DataFrame(models).drop('Name', axis=1)
70
+
71
+ return models
72
+
73
+ def avail_downstream_models():
74
+ global downstream_models
75
+
76
+ with open("downstream_models.json", "r") as outfile:
77
+ downstream_models = json.load(outfile)
78
+ return downstream_models
79
+
80
+ def avail_datasets():
81
+ global datasets
82
+
83
+ datasets = [{"Dataset": "hiv", "Input": "smiles", "Output": "HIV_active", "Path": "data/hiv",
84
+ "Timestamp": "2024-06-26 11:27:37"},
85
+ {"Dataset": "esol", "Input": "smiles", "Output": "ESOL predicted log solubility in mols per litre",
86
+ "Path": "data/esol", "Timestamp": "2024-06-26 11:31:46"},
87
+ {"Dataset": "freesolv", "Input": "smiles", "Output": "expt", "Path": "data/freesolv",
88
+ "Timestamp": "2024-06-26 11:33:47"},
89
+ {"Dataset": "lipo", "Input": "smiles", "Output": "y", "Path": "data/lipo",
90
+ "Timestamp": "2024-06-26 11:34:37"},
91
+ {"Dataset": "bace", "Input": "smiles", "Output": "Class", "Path": "data/bace",
92
+ "Timestamp": "2024-06-26 11:36:40"},
93
+ {"Dataset": "bbbp", "Input": "smiles", "Output": "p_np", "Path": "data/bbbp",
94
+ "Timestamp": "2024-06-26 11:39:23"},
95
+ {"Dataset": "clintox", "Input": "smiles", "Output": "CT_TOX", "Path": "data/clintox",
96
+ "Timestamp": "2024-06-26 11:42:43"}]
97
+
98
+ return datasets
99
+
100
+ def reset():
101
+
102
+ """datasets = {"esol": ["smiles", "ESOL predicted log solubility in mols per litre", "data/esol", "2024-06-26 11:36:46.509324"],
103
+ "freesolv": ["smiles", "expt", "data/freesolv", "2024-06-26 11:37:37.393273"],
104
+ "lipo": ["smiles", "y", "data/lipo", "2024-06-26 11:37:37.393273"],
105
+ "hiv": ["smiles", "HIV_active", "data/hiv", "2024-06-26 11:37:37.393273"],
106
+ "bace": ["smiles", "Class", "data/bace", "2024-06-26 11:38:40.058354"],
107
+ "bbbp": ["smiles", "p_np", "data/bbbp","2024-06-26 11:38:40.058354"],
108
+ "clintox": ["smiles", "CT_TOX", "data/clintox","2024-06-26 11:38:40.058354"],
109
+ "sider": ["smiles","1:", "data/sider","2024-06-26 11:38:40.058354"],
110
+ "tox21": ["smiles",":-2", "data/tox21","2024-06-26 11:38:40.058354"]
111
+ }"""
112
+
113
+ datasets = [
114
+ {"Dataset": "hiv", "Input": "smiles", "Output": "HIV_active", "Path": "data/hiv", "Timestamp": "2024-06-26 11:27:37"},
115
+ {"Dataset": "esol", "Input": "smiles", "Output": "ESOL predicted log solubility in mols per litre", "Path": "data/esol", "Timestamp": "2024-06-26 11:31:46"},
116
+ {"Dataset": "freesolv", "Input": "smiles", "Output": "expt", "Path": "data/freesolv", "Timestamp": "2024-06-26 11:33:47"},
117
+ {"Dataset": "lipo", "Input": "smiles", "Output": "y", "Path": "data/lipo", "Timestamp": "2024-06-26 11:34:37"},
118
+ {"Dataset": "bace", "Input": "smiles", "Output": "Class", "Path": "data/bace", "Timestamp": "2024-06-26 11:36:40"},
119
+ {"Dataset": "bbbp", "Input": "smiles", "Output": "p_np", "Path": "data/bbbp", "Timestamp": "2024-06-26 11:39:23"},
120
+ {"Dataset": "clintox", "Input": "smiles", "Output": "CT_TOX", "Path": "data/clintox", "Timestamp": "2024-06-26 11:42:43"},
121
+ #{"Dataset": "sider", "Input": "smiles", "Output": "1:", "path": "data/sider", "Timestamp": "2024-06-26 11:38:40.058354"},
122
+ #{"Dataset": "tox21", "Input": "smiles", "Output": ":-2", "path": "data/tox21", "Timestamp": "2024-06-26 11:38:40.058354"}
123
+ ]
124
+
125
+ models = [{"Name": "bart", "Description": "BART model for string based SELFIES modality",
126
+ "Timestamp": "2024-06-21 12:32:20"},
127
+ {"Name": "mol-xl", "Description": "MolFormer model for string based SMILES modality",
128
+ "Timestamp": "2024-06-21 12:35:56"},
129
+ {"Name": "mhg", "Description": "MHG", "Timestamp": "2024-07-10 00:09:42"},
130
+ {"Name": "spec-gru", "Description": "Spectrum modality with GRU", "Timestamp": "2024-07-10 00:09:42"},
131
+ {"Name": "spec-lstm", "Description": "Spectrum modality with LSTM", "Timestamp": "2024-07-10 00:09:54"},
132
+ {"Name": "3d-vae", "Description": "VAE model for 3D atom positions", "Timestamp": "2024-07-10 00:10:08"}]
133
+
134
+
135
+ downstream_models = [
136
+ {"Name": "XGBClassifier", "Description": "XG Boost Classifier",
137
+ "Timestamp": "2024-06-21 12:31:20"},
138
+ {"Name": "XGBRegressor", "Description": "XG Boost Regressor",
139
+ "Timestamp": "2024-06-21 12:32:56"},
140
+ {"Name": "2-FNN", "Description": "A two layer feedforward network",
141
+ "Timestamp": "2024-06-24 14:34:16"},
142
+ {"Name": "3-FNN", "Description": "A three layer feedforward network",
143
+ "Timestamp": "2024-06-24 14:38:37"},
144
+ ]
145
+
146
+ with open("datasets.json", "w") as outfile:
147
+ json.dump(datasets, outfile)
148
+
149
+ with open("models.json", "w") as outfile:
150
+ json.dump(models, outfile)
151
+
152
+ with open("downstream_models.json", "w") as outfile:
153
+ json.dump(downstream_models, outfile)
154
+
155
+ def update_data_list(list_data):
156
+ #datasets[list_data[0]] = list_data[1:]
157
+
158
+ with open("datasets.json", "w") as outfile:
159
+ json.dump(datasets, outfile)
160
+
161
+ avail_models_data()
162
+
163
+ def update_model_list(list_model):
164
+ #models[list_model[0]] = list_model[1]
165
+
166
+ with open("models.json", "w") as outfile:
167
+ json.dump(list_model, outfile)
168
+
169
+ avail_models_data()
170
+
171
+ def update_downstream_model_list(list_model):
172
+ #models[list_model[0]] = list_model[1]
173
+
174
+ with open("downstream_models.json", "w") as outfile:
175
+ json.dump(list_model, outfile)
176
+
177
+ avail_models_data()
178
+
179
+ avail_models_data()
180
+
181
+ def get_representation(train_data,test_data,model_type, return_tensor=True):
182
+ alias = {"MHG-GED": "mhg", "SELFIES-TED": "bart", "MolFormer": "mol-xl", "Molformer": "mol-xl", "SMI-TED": "smi-ted"}
183
+ if model_type in alias.keys():
184
+ model_type = alias[model_type]
185
+
186
+ if model_type == "mhg":
187
+ model = mhg.load("models/mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle")
188
+ with torch.no_grad():
189
+ train_emb = model.encode(train_data)
190
+ x_batch = torch.stack(train_emb)
191
+
192
+ test_emb = model.encode(test_data)
193
+ x_batch_test = torch.stack(test_emb)
194
+ if not return_tensor:
195
+ x_batch = pd.DataFrame(x_batch)
196
+ x_batch_test = pd.DataFrame(x_batch_test)
197
+
198
+
199
+
200
+ elif model_type == "bart":
201
+ model = bart()
202
+ model.load()
203
+ x_batch = model.encode(train_data, return_tensor=return_tensor)
204
+ x_batch_test = model.encode(test_data, return_tensor=return_tensor)
205
+
206
+ elif model_type == "smi-ted":
207
+ model = load_smi_ted(folder='./models/smi_ted/smi_ted_light', ckpt_filename='smi-ted-Light_40.pt')
208
+ with torch.no_grad():
209
+ x_batch = model.encode(train_data, return_torch=return_tensor)
210
+ x_batch_test = model.encode(test_data, return_torch=return_tensor)
211
+
212
+ elif model_type == "mol-xl":
213
+ model = AutoModel.from_pretrained("ibm/MoLFormer-XL-both-10pct", deterministic_eval=True,
214
+ trust_remote_code=True)
215
+ tokenizer = AutoTokenizer.from_pretrained("ibm/MoLFormer-XL-both-10pct", trust_remote_code=True)
216
+
217
+ if type(train_data) == list:
218
+ inputs = tokenizer(train_data, padding=True, return_tensors="pt")
219
+ else:
220
+ inputs = tokenizer(list(train_data.values), padding=True, return_tensors="pt")
221
+
222
+ with torch.no_grad():
223
+ outputs = model(**inputs)
224
+
225
+ x_batch = outputs.pooler_output
226
+
227
+ if type(test_data) == list:
228
+ inputs = tokenizer(test_data, padding=True, return_tensors="pt")
229
+ else:
230
+ inputs = tokenizer(list(test_data.values), padding=True, return_tensors="pt")
231
+
232
+ with torch.no_grad():
233
+ outputs = model(**inputs)
234
+
235
+ x_batch_test = outputs.pooler_output
236
+
237
+ if not return_tensor:
238
+ x_batch = pd.DataFrame(x_batch)
239
+ x_batch_test = pd.DataFrame(x_batch_test)
240
+
241
+
242
+ return x_batch, x_batch_test
243
+
244
+ def single_modal(model,dataset, downstream_model,params):
245
+ print(model)
246
+ alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "SMI-TED": "smi-ted"}
247
+ data = avail_models(raw=True)
248
+ df = pd.DataFrame(data)
249
+ print(list(df["Name"].values))
250
+ if alias[model] in list(df["Name"].values):
251
+ if model in alias.keys():
252
+ model_type = alias[model]
253
+ else:
254
+ model_type = model
255
+ else:
256
+ print("Model not available")
257
+ return
258
+
259
+ data = avail_datasets()
260
+ df = pd.DataFrame(data)
261
+ print(list(df["Dataset"].values))
262
+
263
+ if dataset in list(df["Dataset"].values):
264
+ task = dataset
265
+ with open(f"./representation/{task}_{model_type}.pkl", "rb") as f1:
266
+ x_batch, y_batch, x_batch_test, y_batch_test = pickle.load(f1)
267
+ print(f" Representation loaded successfully")
268
+ else:
269
+
270
+ print("Custom Dataset")
271
+ #return
272
+ components = dataset.split(",")
273
+ train_data = pd.read_csv(components[0])[components[2]]
274
+ test_data = pd.read_csv(components[1])[components[2]]
275
+
276
+ y_batch = pd.read_csv(components[0])[components[3]]
277
+ y_batch_test = pd.read_csv(components[1])[components[3]]
278
+
279
+
280
+ x_batch, x_batch_test = get_representation(train_data,test_data,model_type)
281
+
282
+
283
+
284
+ print(f" Representation loaded successfully")
285
+
286
+
287
+
288
+
289
+
290
+ print(f" Calculating ROC AUC Score ...")
291
+
292
+ if downstream_model == "XGBClassifier":
293
+ xgb_predict_concat = XGBClassifier(**params) # n_estimators=5000, learning_rate=0.01, max_depth=10
294
+ xgb_predict_concat.fit(x_batch, y_batch)
295
+
296
+ y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
297
+
298
+ roc_auc = roc_auc_score(y_batch_test, y_prob)
299
+ fpr, tpr, _ = roc_curve(y_batch_test, y_prob)
300
+ print(f"ROC-AUC Score: {roc_auc:.4f}")
301
+
302
+ try:
303
+ with open(f"./plot_emb/{task}_{model_type}.pkl", "rb") as f1:
304
+ class_0,class_1 = pickle.load(f1)
305
+ except:
306
+ print("Generating latent plots")
307
+ reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
308
+ verbose=False)
309
+ n_samples = np.minimum(1000, len(x_batch))
310
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
311
+ try:x = y_batch.values[:n_samples]
312
+ except: x = y_batch[:n_samples]
313
+ index_0 = [index for index in range(len(x)) if x[index] == 0]
314
+ index_1 = [index for index in range(len(x)) if x[index] == 1]
315
+
316
+ class_0 = features_umap[index_0]
317
+ class_1 = features_umap[index_1]
318
+ print("Generating latent plots : Done")
319
+
320
+ #vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
321
+
322
+ result = f"ROC-AUC Score: {roc_auc:.4f}"
323
+
324
+ return result, roc_auc,fpr, tpr, class_0, class_1
325
+
326
+ elif downstream_model == "DefaultClassifier":
327
+ xgb_predict_concat = XGBClassifier() # n_estimators=5000, learning_rate=0.01, max_depth=10
328
+ xgb_predict_concat.fit(x_batch, y_batch)
329
+
330
+ y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
331
+
332
+ roc_auc = roc_auc_score(y_batch_test, y_prob)
333
+ fpr, tpr, _ = roc_curve(y_batch_test, y_prob)
334
+ print(f"ROC-AUC Score: {roc_auc:.4f}")
335
+
336
+ try:
337
+ with open(f"./plot_emb/{task}_{model_type}.pkl", "rb") as f1:
338
+ class_0,class_1 = pickle.load(f1)
339
+ except:
340
+ print("Generating latent plots")
341
+ reducer = umap.UMAP(metric='euclidean', n_neighbors= 10, n_components=2, low_memory=True, min_dist=0.1, verbose=False)
342
+ n_samples = np.minimum(1000,len(x_batch))
343
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
344
+ try:x = y_batch.values[:n_samples]
345
+ except:x = y_batch[:n_samples]
346
+ index_0 = [index for index in range(len(x)) if x[index] == 0]
347
+ index_1 = [index for index in range(len(x)) if x[index] == 1]
348
+
349
+ class_0 = features_umap[index_0]
350
+ class_1 = features_umap[index_1]
351
+ print("Generating latent plots : Done")
352
+
353
+ #vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
354
+
355
+ result = f"ROC-AUC Score: {roc_auc:.4f}"
356
+
357
+ return result, roc_auc,fpr, tpr, class_0, class_1
358
+
359
+ elif downstream_model == "SVR":
360
+ regressor = SVR(**params)
361
+ model = TransformedTargetRegressor(regressor= regressor,
362
+ transformer = MinMaxScaler(feature_range=(-1, 1))
363
+ ).fit(x_batch,y_batch)
364
+
365
+ y_prob = model.predict(x_batch_test)
366
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
367
+
368
+ print(f"RMSE Score: {RMSE_score:.4f}")
369
+ result = f"RMSE Score: {RMSE_score:.4f}"
370
+
371
+ print("Generating latent plots")
372
+ reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
373
+ verbose=False)
374
+ n_samples = np.minimum(1000, len(x_batch))
375
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
376
+ try:x = y_batch.values[:n_samples]
377
+ except:x = y_batch[:n_samples]
378
+ #index_0 = [index for index in range(len(x)) if x[index] == 0]
379
+ #index_1 = [index for index in range(len(x)) if x[index] == 1]
380
+
381
+ class_0 = features_umap#[index_0]
382
+ class_1 = features_umap#[index_1]
383
+ print("Generating latent plots : Done")
384
+
385
+ return result, RMSE_score,y_batch_test, y_prob, class_0, class_1
386
+
387
+ elif downstream_model == "Kernel Ridge":
388
+ regressor = KernelRidge(**params)
389
+ model = TransformedTargetRegressor(regressor=regressor,
390
+ transformer=MinMaxScaler(feature_range=(-1, 1))
391
+ ).fit(x_batch, y_batch)
392
+
393
+ y_prob = model.predict(x_batch_test)
394
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
395
+
396
+ print(f"RMSE Score: {RMSE_score:.4f}")
397
+ result = f"RMSE Score: {RMSE_score:.4f}"
398
+
399
+ print("Generating latent plots")
400
+ reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
401
+ verbose=False)
402
+ n_samples = np.minimum(1000, len(x_batch))
403
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
404
+ try:x = y_batch.values[:n_samples]
405
+ except:x = y_batch[:n_samples]
406
+ # index_0 = [index for index in range(len(x)) if x[index] == 0]
407
+ # index_1 = [index for index in range(len(x)) if x[index] == 1]
408
+
409
+ class_0 = features_umap#[index_0]
410
+ class_1 = features_umap#[index_1]
411
+ print("Generating latent plots : Done")
412
+
413
+ return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
414
+
415
+
416
+ elif downstream_model == "Linear Regression":
417
+ regressor = LinearRegression(**params)
418
+ model = TransformedTargetRegressor(regressor=regressor,
419
+ transformer=MinMaxScaler(feature_range=(-1, 1))
420
+ ).fit(x_batch, y_batch)
421
+
422
+ y_prob = model.predict(x_batch_test)
423
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
424
+
425
+ print(f"RMSE Score: {RMSE_score:.4f}")
426
+ result = f"RMSE Score: {RMSE_score:.4f}"
427
+
428
+ print("Generating latent plots")
429
+ reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
430
+ verbose=False)
431
+ n_samples = np.minimum(1000, len(x_batch))
432
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
433
+ try:x = y_batch.values[:n_samples]
434
+ except:x = y_batch[:n_samples]
435
+ # index_0 = [index for index in range(len(x)) if x[index] == 0]
436
+ # index_1 = [index for index in range(len(x)) if x[index] == 1]
437
+
438
+ class_0 = features_umap#[index_0]
439
+ class_1 = features_umap#[index_1]
440
+ print("Generating latent plots : Done")
441
+
442
+ return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
443
+
444
+
445
+ elif downstream_model == "DefaultRegressor":
446
+ regressor = SVR(kernel="rbf", degree=3, C=5, gamma="scale", epsilon=0.01)
447
+ model = TransformedTargetRegressor(regressor=regressor,
448
+ transformer=MinMaxScaler(feature_range=(-1, 1))
449
+ ).fit(x_batch, y_batch)
450
+
451
+ y_prob = model.predict(x_batch_test)
452
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
453
+
454
+ print(f"RMSE Score: {RMSE_score:.4f}")
455
+ result = f"RMSE Score: {RMSE_score:.4f}"
456
+
457
+ print("Generating latent plots")
458
+ reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
459
+ verbose=False)
460
+ n_samples = np.minimum(1000, len(x_batch))
461
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
462
+ try:x = y_batch.values[:n_samples]
463
+ except:x = y_batch[:n_samples]
464
+ # index_0 = [index for index in range(len(x)) if x[index] == 0]
465
+ # index_1 = [index for index in range(len(x)) if x[index] == 1]
466
+
467
+ class_0 = features_umap#[index_0]
468
+ class_1 = features_umap#[index_1]
469
+ print("Generating latent plots : Done")
470
+
471
+ return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
472
+
473
+
474
+ def multi_modal(model_list,dataset, downstream_model,params):
475
+ print(model_list)
476
+ data = avail_datasets()
477
+ df = pd.DataFrame(data)
478
+ list(df["Dataset"].values)
479
+
480
+ if dataset in list(df["Dataset"].values):
481
+ task = dataset
482
+ predefined = True
483
+ else:
484
+ predefined = False
485
+ components = dataset.split(",")
486
+ train_data = pd.read_csv(components[0])[components[2]]
487
+ test_data = pd.read_csv(components[1])[components[2]]
488
+
489
+ y_batch = pd.read_csv(components[0])[components[3]]
490
+ y_batch_test = pd.read_csv(components[1])[components[3]]
491
+
492
+ print("Custom Dataset loaded")
493
+
494
+
495
+ data = avail_models(raw=True)
496
+ df = pd.DataFrame(data)
497
+ list(df["Name"].values)
498
+
499
+ alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "SMI-TED":"smi-ted"}
500
+ #if set(model_list).issubset(list(df["Name"].values)):
501
+ if set(model_list).issubset(list(alias.keys())):
502
+ for i, model in enumerate(model_list):
503
+ if model in alias.keys():
504
+ model_type = alias[model]
505
+ else:
506
+ model_type = model
507
+
508
+ if i == 0:
509
+ if predefined:
510
+ with open(f"./representation/{task}_{model_type}.pkl", "rb") as f1:
511
+ x_batch, y_batch, x_batch_test, y_batch_test = pickle.load(f1)
512
+ print(f" Loaded representation/{task}_{model_type}.pkl")
513
+ else:
514
+ x_batch, x_batch_test = get_representation(train_data, test_data, model_type)
515
+ x_batch = pd.DataFrame(x_batch)
516
+ x_batch_test = pd.DataFrame(x_batch_test)
517
+
518
+ else:
519
+ if predefined:
520
+ with open(f"./representation/{task}_{model_type}.pkl", "rb") as f1:
521
+ x_batch_1, y_batch_1, x_batch_test_1, y_batch_test_1 = pickle.load(f1)
522
+ print(f" Loaded representation/{task}_{model_type}.pkl")
523
+ else:
524
+ x_batch_1, x_batch_test_1 = get_representation(train_data, test_data, model_type)
525
+ x_batch_1 = pd.DataFrame(x_batch_1)
526
+ x_batch_test_1 = pd.DataFrame(x_batch_test_1)
527
+
528
+ x_batch = pd.concat([x_batch, x_batch_1], axis=1)
529
+ x_batch_test = pd.concat([x_batch_test, x_batch_test_1], axis=1)
530
+
531
+
532
+ else:
533
+ print("Model not available")
534
+ return
535
+
536
+ num_columns = x_batch_test.shape[1]
537
+ x_batch_test.columns = [f'{i + 1}' for i in range(num_columns)]
538
+
539
+ num_columns = x_batch.shape[1]
540
+ x_batch.columns = [f'{i + 1}' for i in range(num_columns)]
541
+
542
+
543
+ print(f"Representations loaded successfully")
544
+ try:
545
+ with open(f"./plot_emb/{task}_multi.pkl", "rb") as f1:
546
+ class_0, class_1 = pickle.load(f1)
547
+ except:
548
+ print("Generating latent plots")
549
+ reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
550
+ verbose=False)
551
+ n_samples = np.minimum(1000, len(x_batch))
552
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
553
+
554
+ if "Classifier" in downstream_model:
555
+ try:x = y_batch.values[:n_samples]
556
+ except:x = y_batch[:n_samples]
557
+ index_0 = [index for index in range(len(x)) if x[index] == 0]
558
+ index_1 = [index for index in range(len(x)) if x[index] == 1]
559
+
560
+ class_0 = features_umap[index_0]
561
+ class_1 = features_umap[index_1]
562
+
563
+ else:
564
+ class_0 = features_umap
565
+ class_1 = features_umap
566
+
567
+ print("Generating latent plots : Done")
568
+
569
+ print(f" Calculating ROC AUC Score ...")
570
+
571
+
572
+ if downstream_model == "XGBClassifier":
573
+ xgb_predict_concat = XGBClassifier(**params)#n_estimators=5000, learning_rate=0.01, max_depth=10)
574
+ xgb_predict_concat.fit(x_batch, y_batch)
575
+
576
+ y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
577
+
578
+
579
+ roc_auc = roc_auc_score(y_batch_test, y_prob)
580
+ fpr, tpr, _ = roc_curve(y_batch_test, y_prob)
581
+ print(f"ROC-AUC Score: {roc_auc:.4f}")
582
+
583
+ #vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
584
+
585
+ #vizualize(x_batch_test, y_batch_test)
586
+ print(f"ROC-AUC Score: {roc_auc:.4f}")
587
+ result = f"ROC-AUC Score: {roc_auc:.4f}"
588
+
589
+ return result, roc_auc,fpr, tpr, class_0, class_1
590
+
591
+ elif downstream_model == "DefaultClassifier":
592
+ xgb_predict_concat = XGBClassifier()#n_estimators=5000, learning_rate=0.01, max_depth=10)
593
+ xgb_predict_concat.fit(x_batch, y_batch)
594
+
595
+ y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
596
+
597
+
598
+ roc_auc = roc_auc_score(y_batch_test, y_prob)
599
+ fpr, tpr, _ = roc_curve(y_batch_test, y_prob)
600
+ print(f"ROC-AUC Score: {roc_auc:.4f}")
601
+
602
+ #vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
603
+
604
+ #vizualize(x_batch_test, y_batch_test)
605
+ print(f"ROC-AUC Score: {roc_auc:.4f}")
606
+ result = f"ROC-AUC Score: {roc_auc:.4f}"
607
+
608
+ return result, roc_auc,fpr, tpr, class_0, class_1
609
+
610
+ elif downstream_model == "SVR":
611
+ regressor = SVR(**params)
612
+ model = TransformedTargetRegressor(regressor= regressor,
613
+ transformer = MinMaxScaler(feature_range=(-1, 1))
614
+ ).fit(x_batch,y_batch)
615
+
616
+ y_prob = model.predict(x_batch_test)
617
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
618
+
619
+ print(f"RMSE Score: {RMSE_score:.4f}")
620
+ result = f"RMSE Score: {RMSE_score:.4f}"
621
+
622
+ return result, RMSE_score,y_batch_test, y_prob, class_0, class_1
623
+
624
+ elif downstream_model == "Linear Regression":
625
+ regressor = LinearRegression(**params)
626
+ model = TransformedTargetRegressor(regressor=regressor,
627
+ transformer=MinMaxScaler(feature_range=(-1, 1))
628
+ ).fit(x_batch, y_batch)
629
+
630
+ y_prob = model.predict(x_batch_test)
631
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
632
+
633
+ print(f"RMSE Score: {RMSE_score:.4f}")
634
+ result = f"RMSE Score: {RMSE_score:.4f}"
635
+
636
+ return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
637
+
638
+ elif downstream_model == "Kernel Ridge":
639
+ regressor = KernelRidge(**params)
640
+ model = TransformedTargetRegressor(regressor=regressor,
641
+ transformer=MinMaxScaler(feature_range=(-1, 1))
642
+ ).fit(x_batch, y_batch)
643
+
644
+ y_prob = model.predict(x_batch_test)
645
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
646
+
647
+ print(f"RMSE Score: {RMSE_score:.4f}")
648
+ result = f"RMSE Score: {RMSE_score:.4f}"
649
+
650
+ return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
651
+
652
+ elif downstream_model == "DefaultRegressor":
653
+ regressor = SVR(kernel="rbf", degree=3, C=5, gamma="scale", epsilon=0.01)
654
+ model = TransformedTargetRegressor(regressor=regressor,
655
+ transformer=MinMaxScaler(feature_range=(-1, 1))
656
+ ).fit(x_batch, y_batch)
657
+
658
+ y_prob = model.predict(x_batch_test)
659
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
660
+
661
+ print(f"RMSE Score: {RMSE_score:.4f}")
662
+ result = f"RMSE Score: {RMSE_score:.4f}"
663
+
664
+ return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
665
+
666
+
667
+
668
+
669
+
670
+
models/mhg_model/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/mhg_model/README.md ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mhg-gnn
2
+
3
+ This repository provides PyTorch source code assosiated with our publication, "MHG-GNN: Combination of Molecular Hypergraph Grammar with Graph Neural Network"
4
+
5
+ **Paper:** [Arxiv Link](https://arxiv.org/pdf/2309.16374)
6
+
7
+ ![mhg-gnn](images/mhg_example1.png)
8
+
9
+ ## Introduction
10
+
11
+ We present MHG-GNN, an autoencoder architecture
12
+ that has an encoder based on GNN and a decoder based on a sequential model with MHG.
13
+ Since the encoder is a GNN variant, MHG-GNN can accept any molecule as input, and
14
+ demonstrate high predictive performance on molecular graph data.
15
+ In addition, the decoder inherits the theoretical guarantee of MHG on always generating a structurally valid molecule as output.
16
+
17
+ ## Table of Contents
18
+
19
+ 1. [Getting Started](#getting-started)
20
+ 1. [Pretrained Models and Training Logs](#pretrained-models-and-training-logs)
21
+ 2. [Installation](#installation)
22
+ 2. [Feature Extraction](#feature-extraction)
23
+
24
+ ## Getting Started
25
+
26
+ **This code and environment have been tested on Intel E5-2667 CPUs at 3.30GHz and NVIDIA A100 Tensor Core GPUs.**
27
+
28
+ ### Pretrained Models and Training Logs
29
+
30
+ We provide checkpoints of the MHG-GNN model pre-trained on a dataset of ~1.34M molecules curated from PubChem. (later) For model weights: [HuggingFace Link]()
31
+
32
+ Add the MHG-GNN `pre-trained weights.pt` to the `models/` directory according to your needs.
33
+
34
+ ### Installation
35
+
36
+ We recommend to create a virtual environment. For example:
37
+
38
+ ```
39
+ python3 -m venv .venv
40
+ . .venv/bin/activate
41
+ ```
42
+
43
+ Type the following command once the virtual environment is activated:
44
+
45
+ ```
46
+ git clone [email protected]:CMD-TRL/mhg-gnn.git
47
+ cd ./mhg-gnn
48
+ pip install .
49
+ ```
50
+
51
+ ## Feature Extraction
52
+
53
+ The example notebook [mhg-gnn_encoder_decoder_example.ipynb](notebooks/mhg-gnn_encoder_decoder_example.ipynb) contains code to load checkpoint files and use the pre-trained model for encoder and decoder tasks.
54
+
55
+ To load mhg-gnn, you can simply use:
56
+
57
+ ```python
58
+ import torch
59
+ import load
60
+
61
+ model = load.load()
62
+ ```
63
+
64
+ To encode SMILES into embeddings, you can use:
65
+
66
+ ```python
67
+ with torch.no_grad():
68
+ repr = model.encode(["CCO", "O=C=O", "OC(=O)c1ccccc1C(=O)O"])
69
+ ```
70
+
71
+ For decoder, you can use the function, so you can return from embeddings to SMILES strings:
72
+
73
+ ```python
74
+ orig = model.decode(repr)
75
+ ```
models/mhg_model/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # Rhizome
3
+ # Version beta 0.0, August 2023
4
+ # Property of IBM Research, Accelerated Discovery
5
+ #
models/mhg_model/graph_grammar/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+ """
8
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
9
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
10
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
11
+ """
12
+
13
+ """ Title """
14
+
15
+ __author__ = "Hiroshi Kajino <[email protected]>"
16
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
17
+ __version__ = "0.1"
18
+ __date__ = "Jan 1 2018"
19
+
models/mhg_model/graph_grammar/algo/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding:utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 1 2018"
20
+
models/mhg_model/graph_grammar/algo/tree_decomposition.py ADDED
@@ -0,0 +1,821 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2017"
18
+ __version__ = "0.1"
19
+ __date__ = "Dec 11 2017"
20
+
21
+ from copy import deepcopy
22
+ from itertools import combinations
23
+ from ..hypergraph import Hypergraph
24
+ import networkx as nx
25
+ import numpy as np
26
+
27
+
28
+ class CliqueTree(nx.Graph):
29
+ ''' clique tree object
30
+
31
+ Attributes
32
+ ----------
33
+ hg : Hypergraph
34
+ This hypergraph will be decomposed.
35
+ root_hg : Hypergraph
36
+ Hypergraph on the root node.
37
+ ident_node_dict : dict
38
+ ident_node_dict[key_node] gives a list of nodes that are identical (i.e., the adjacent hyperedges are common)
39
+ '''
40
+ def __init__(self, hg=None, **kwargs):
41
+ self.hg = deepcopy(hg)
42
+ if self.hg is not None:
43
+ self.ident_node_dict = self.hg.get_identical_node_dict()
44
+ else:
45
+ self.ident_node_dict = {}
46
+ super().__init__(**kwargs)
47
+
48
+ @property
49
+ def root_hg(self):
50
+ ''' return the hypergraph on the root node
51
+ '''
52
+ return self.nodes[0]['subhg']
53
+
54
+ @root_hg.setter
55
+ def root_hg(self, hypergraph):
56
+ ''' set the hypergraph on the root node
57
+ '''
58
+ self.nodes[0]['subhg'] = hypergraph
59
+
60
+ def insert_subhg(self, subhypergraph: Hypergraph) -> None:
61
+ ''' insert a subhypergraph, which is extracted from a root hypergraph, into the tree.
62
+
63
+ Parameters
64
+ ----------
65
+ subhg : Hypergraph
66
+ '''
67
+ num_nodes = self.number_of_nodes()
68
+ self.add_node(num_nodes, subhg=subhypergraph)
69
+ self.add_edge(num_nodes, 0)
70
+ adj_nodes = deepcopy(list(self.adj[0].keys()))
71
+ for each_node in adj_nodes:
72
+ if len(self.nodes[each_node]["subhg"].nodes.intersection(
73
+ self.nodes[num_nodes]["subhg"].nodes)\
74
+ - self.root_hg.nodes) != 0 and each_node != num_nodes:
75
+ self.remove_edge(0, each_node)
76
+ self.add_edge(each_node, num_nodes)
77
+
78
+ def to_irredundant(self) -> None:
79
+ ''' convert the clique tree to be irredundant
80
+ '''
81
+ for each_node in self.hg.nodes:
82
+ subtree = self.subgraph([
83
+ each_tree_node for each_tree_node in self.nodes()\
84
+ if each_node in self.nodes[each_tree_node]["subhg"].nodes]).copy()
85
+ leaf_node_list = [x for x in subtree.nodes() if subtree.degree(x)==1]
86
+ redundant_leaf_node_list = []
87
+ for each_leaf_node in leaf_node_list:
88
+ if len(self.nodes[each_leaf_node]["subhg"].adj_edges(each_node)) == 0:
89
+ redundant_leaf_node_list.append(each_leaf_node)
90
+ for each_red_leaf_node in redundant_leaf_node_list:
91
+ current_node = each_red_leaf_node
92
+ while subtree.degree(current_node) == 1 \
93
+ and len(subtree.nodes[current_node]["subhg"].adj_edges(each_node)) == 0:
94
+ self.nodes[current_node]["subhg"].remove_node(each_node)
95
+ remove_node = current_node
96
+ current_node = list(dict(subtree[remove_node]).keys())[0]
97
+ subtree.remove_node(remove_node)
98
+
99
+ fixed_node_set = deepcopy(self.nodes)
100
+ for each_node in fixed_node_set:
101
+ if self.nodes[each_node]["subhg"].num_edges == 0:
102
+ if len(self[each_node]) == 1:
103
+ self.remove_node(each_node)
104
+ elif len(self[each_node]) == 2:
105
+ self.add_edge(*self[each_node])
106
+ self.remove_node(each_node)
107
+ else:
108
+ pass
109
+ else:
110
+ pass
111
+
112
+ redundant = True
113
+ while redundant:
114
+ redundant = False
115
+ fixed_edge_set = deepcopy(self.edges)
116
+ remove_node_set = set()
117
+ for node_1, node_2 in fixed_edge_set:
118
+ if node_1 in remove_node_set or node_2 in remove_node_set:
119
+ pass
120
+ else:
121
+ if self.nodes[node_1]['subhg'].is_subhg(self.nodes[node_2]['subhg']):
122
+ redundant = True
123
+ adj_node_list = set(self.adj[node_1]) - {node_2}
124
+ self.remove_node(node_1)
125
+ remove_node_set.add(node_1)
126
+ for each_node in adj_node_list:
127
+ self.add_edge(node_2, each_node)
128
+
129
+ elif self.nodes[node_2]['subhg'].is_subhg(self.nodes[node_1]['subhg']):
130
+ redundant = True
131
+ adj_node_list = set(self.adj[node_2]) - {node_1}
132
+ self.remove_node(node_2)
133
+ remove_node_set.add(node_2)
134
+ for each_node in adj_node_list:
135
+ self.add_edge(node_1, each_node)
136
+
137
+ def node_update(self, key_node: str, subhg) -> None:
138
+ """ given a pair of a hypergraph, H, and its subhypergraph, sH, return a hypergraph H\sH.
139
+
140
+ Parameters
141
+ ----------
142
+ key_node : str
143
+ key node that must be removed.
144
+ subhg : Hypegraph
145
+ """
146
+ for each_edge in subhg.edges:
147
+ self.root_hg.remove_edge(each_edge)
148
+ self.root_hg.remove_nodes(self.ident_node_dict[key_node])
149
+
150
+ adj_node_list = list(subhg.nodes)
151
+ for each_node in subhg.nodes:
152
+ if each_node not in self.ident_node_dict[key_node]:
153
+ if set(self.root_hg.adj_edges(each_node)).issubset(subhg.edges):
154
+ self.root_hg.remove_node(each_node)
155
+ adj_node_list.remove(each_node)
156
+ else:
157
+ adj_node_list.remove(each_node)
158
+
159
+ for each_node_1, each_node_2 in combinations(adj_node_list, 2):
160
+ if not self.root_hg.is_adj(each_node_1, each_node_2):
161
+ self.root_hg.add_edge(set([each_node_1, each_node_2]), attr_dict=dict(tmp=True))
162
+
163
+ subhg.remove_edges_with_attr({'tmp' : True})
164
+ self.insert_subhg(subhg)
165
+
166
+ def update(self, subhg, remove_nodes=False):
167
+ """ given a pair of a hypergraph, H, and its subhypergraph, sH, return a hypergraph H\sH.
168
+
169
+ Parameters
170
+ ----------
171
+ subhg : Hypegraph
172
+ """
173
+ for each_edge in subhg.edges:
174
+ self.root_hg.remove_edge(each_edge)
175
+ if remove_nodes:
176
+ remove_edge_list = []
177
+ for each_edge in self.root_hg.edges:
178
+ if set(self.root_hg.nodes_in_edge(each_edge)).issubset(subhg.nodes)\
179
+ and self.root_hg.edge_attr(each_edge).get('tmp', False):
180
+ remove_edge_list.append(each_edge)
181
+ self.root_hg.remove_edges(remove_edge_list)
182
+
183
+ adj_node_list = list(subhg.nodes)
184
+ for each_node in subhg.nodes:
185
+ if self.root_hg.degree(each_node) == 0:
186
+ self.root_hg.remove_node(each_node)
187
+ adj_node_list.remove(each_node)
188
+
189
+ if len(adj_node_list) != 1 and not remove_nodes:
190
+ self.root_hg.add_edge(set(adj_node_list), attr_dict=dict(tmp=True))
191
+ '''
192
+ else:
193
+ for each_node_1, each_node_2 in combinations(adj_node_list, 2):
194
+ if not self.root_hg.is_adj(each_node_1, each_node_2):
195
+ self.root_hg.add_edge(
196
+ [each_node_1, each_node_2], attr_dict=dict(tmp=True))
197
+ '''
198
+ subhg.remove_edges_with_attr({'tmp':True})
199
+ self.insert_subhg(subhg)
200
+
201
+
202
+ def _get_min_deg_node(hg, ident_node_dict: dict, mode='mol'):
203
+ if mode == 'standard':
204
+ degree_dict = hg.degrees()
205
+ min_deg_node = min(degree_dict, key=degree_dict.get)
206
+ min_deg_subhg = hg.adj_subhg(min_deg_node, ident_node_dict)
207
+ return min_deg_node, min_deg_subhg
208
+ elif mode == 'mol':
209
+ degree_dict = hg.degrees()
210
+ min_deg = min(degree_dict.values())
211
+ min_deg_node_list = [each_node for each_node in hg.nodes if degree_dict[each_node]==min_deg]
212
+ min_deg_subhg_list = [hg.adj_subhg(each_min_deg_node, ident_node_dict)
213
+ for each_min_deg_node in min_deg_node_list]
214
+ best_score = np.inf
215
+ best_idx = -1
216
+ for each_idx in range(len(min_deg_subhg_list)):
217
+ if min_deg_subhg_list[each_idx].num_nodes < best_score:
218
+ best_idx = each_idx
219
+ return min_deg_node_list[each_idx], min_deg_subhg_list[each_idx]
220
+ else:
221
+ raise ValueError
222
+
223
+
224
+ def tree_decomposition(hg, irredundant=True):
225
+ """ compute a tree decomposition of the input hypergraph
226
+
227
+ Parameters
228
+ ----------
229
+ hg : Hypergraph
230
+ hypergraph to be decomposed
231
+ irredundant : bool
232
+ if True, irredundant tree decomposition will be computed.
233
+
234
+ Returns
235
+ -------
236
+ clique_tree : nx.Graph
237
+ each node contains a subhypergraph of `hg`
238
+ """
239
+ org_hg = hg.copy()
240
+ ident_node_dict = hg.get_identical_node_dict()
241
+ clique_tree = CliqueTree(org_hg)
242
+ clique_tree.add_node(0, subhg=org_hg)
243
+ while True:
244
+ degree_dict = org_hg.degrees()
245
+ min_deg_node = min(degree_dict, key=degree_dict.get)
246
+ min_deg_subhg = org_hg.adj_subhg(min_deg_node, ident_node_dict)
247
+ if org_hg.nodes == min_deg_subhg.nodes:
248
+ break
249
+
250
+ # org_hg and min_deg_subhg are divided
251
+ clique_tree.node_update(min_deg_node, min_deg_subhg)
252
+
253
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
254
+
255
+ if irredundant:
256
+ clique_tree.to_irredundant()
257
+ return clique_tree
258
+
259
+
260
+ def tree_decomposition_with_hrg(hg, hrg, irredundant=True, return_root=False):
261
+ ''' compute a tree decomposition given a hyperedge replacement grammar.
262
+ the resultant clique tree should induce a less compact HRG.
263
+
264
+ Parameters
265
+ ----------
266
+ hg : Hypergraph
267
+ hypergraph to be decomposed
268
+ hrg : HyperedgeReplacementGrammar
269
+ current HRG
270
+ irredundant : bool
271
+ if True, irredundant tree decomposition will be computed.
272
+
273
+ Returns
274
+ -------
275
+ clique_tree : nx.Graph
276
+ each node contains a subhypergraph of `hg`
277
+ '''
278
+ org_hg = hg.copy()
279
+ ident_node_dict = hg.get_identical_node_dict()
280
+ clique_tree = CliqueTree(org_hg)
281
+ clique_tree.add_node(0, subhg=org_hg)
282
+ root_node = 0
283
+
284
+ # construct a clique tree using HRG
285
+ success_any = True
286
+ while success_any:
287
+ success_any = False
288
+ for each_prod_rule in hrg.prod_rule_list:
289
+ org_hg, success, subhg = each_prod_rule.revert(org_hg, True)
290
+ if success:
291
+ if each_prod_rule.is_start_rule: root_node = clique_tree.number_of_nodes()
292
+ success_any = True
293
+ subhg.remove_edges_with_attr({'terminal' : False})
294
+ clique_tree.root_hg = org_hg
295
+ clique_tree.insert_subhg(subhg)
296
+
297
+ clique_tree.root_hg = org_hg
298
+
299
+ for each_edge in deepcopy(org_hg.edges):
300
+ if not org_hg.edge_attr(each_edge)['terminal']:
301
+ node_list = org_hg.nodes_in_edge(each_edge)
302
+ org_hg.remove_edge(each_edge)
303
+
304
+ for each_node_1, each_node_2 in combinations(node_list, 2):
305
+ if not org_hg.is_adj(each_node_1, each_node_2):
306
+ org_hg.add_edge([each_node_1, each_node_2], attr_dict=dict(tmp=True))
307
+
308
+ # construct a clique tree using the existing algorithm
309
+ degree_dict = org_hg.degrees()
310
+ if degree_dict:
311
+ while True:
312
+ min_deg_node, min_deg_subhg = _get_min_deg_node(org_hg, ident_node_dict)
313
+ if org_hg.nodes == min_deg_subhg.nodes: break
314
+
315
+ # org_hg and min_deg_subhg are divided
316
+ clique_tree.node_update(min_deg_node, min_deg_subhg)
317
+
318
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
319
+ if irredundant:
320
+ clique_tree.to_irredundant()
321
+
322
+ if return_root:
323
+ if root_node == 0 and 0 not in clique_tree.nodes:
324
+ root_node = clique_tree.number_of_nodes()
325
+ while root_node not in clique_tree.nodes:
326
+ root_node -= 1
327
+ elif root_node not in clique_tree.nodes:
328
+ while root_node not in clique_tree.nodes:
329
+ root_node -= 1
330
+ else:
331
+ pass
332
+ return clique_tree, root_node
333
+ else:
334
+ return clique_tree
335
+
336
+
337
+ def tree_decomposition_from_leaf(hg, irredundant=True):
338
+ """ compute a tree decomposition of the input hypergraph
339
+
340
+ Parameters
341
+ ----------
342
+ hg : Hypergraph
343
+ hypergraph to be decomposed
344
+ irredundant : bool
345
+ if True, irredundant tree decomposition will be computed.
346
+
347
+ Returns
348
+ -------
349
+ clique_tree : nx.Graph
350
+ each node contains a subhypergraph of `hg`
351
+ """
352
+ def apply_normal_decomposition(clique_tree):
353
+ degree_dict = clique_tree.root_hg.degrees()
354
+ min_deg_node = min(degree_dict, key=degree_dict.get)
355
+ min_deg_subhg = clique_tree.root_hg.adj_subhg(min_deg_node, clique_tree.ident_node_dict)
356
+ if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
357
+ return clique_tree, False
358
+ clique_tree.node_update(min_deg_node, min_deg_subhg)
359
+ return clique_tree, True
360
+
361
+ def apply_min_edge_deg_decomposition(clique_tree):
362
+ edge_degree_dict = clique_tree.root_hg.edge_degrees()
363
+ non_tmp_edge_list = [each_edge for each_edge in clique_tree.root_hg.edges \
364
+ if not clique_tree.root_hg.edge_attr(each_edge).get('tmp')]
365
+ if not non_tmp_edge_list:
366
+ return clique_tree, False
367
+ min_deg_edge = None
368
+ min_deg = np.inf
369
+ for each_edge in non_tmp_edge_list:
370
+ if min_deg > edge_degree_dict[each_edge]:
371
+ min_deg_edge = each_edge
372
+ min_deg = edge_degree_dict[each_edge]
373
+ node_list = clique_tree.root_hg.nodes_in_edge(min_deg_edge)
374
+ min_deg_subhg = clique_tree.root_hg.get_subhg(
375
+ node_list, [min_deg_edge], clique_tree.ident_node_dict)
376
+ if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
377
+ return clique_tree, False
378
+ clique_tree.update(min_deg_subhg)
379
+ return clique_tree, True
380
+
381
+ org_hg = hg.copy()
382
+ clique_tree = CliqueTree(org_hg)
383
+ clique_tree.add_node(0, subhg=org_hg)
384
+
385
+ success = True
386
+ while success:
387
+ clique_tree, success = apply_min_edge_deg_decomposition(clique_tree)
388
+ if not success:
389
+ clique_tree, success = apply_normal_decomposition(clique_tree)
390
+
391
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
392
+ if irredundant:
393
+ clique_tree.to_irredundant()
394
+ return clique_tree
395
+
396
+ def topological_tree_decomposition(
397
+ hg, irredundant=True, rip_labels=True, shrink_cycle=False, contract_cycles=False):
398
+ ''' compute a tree decomposition of the input hypergraph
399
+
400
+ Parameters
401
+ ----------
402
+ hg : Hypergraph
403
+ hypergraph to be decomposed
404
+ irredundant : bool
405
+ if True, irredundant tree decomposition will be computed.
406
+
407
+ Returns
408
+ -------
409
+ clique_tree : CliqueTree
410
+ each node contains a subhypergraph of `hg`
411
+ '''
412
+ def _contract_tree(clique_tree):
413
+ ''' contract a single leaf
414
+
415
+ Parameters
416
+ ----------
417
+ clique_tree : CliqueTree
418
+
419
+ Returns
420
+ -------
421
+ CliqueTree, bool
422
+ bool represents whether this operation succeeds or not.
423
+ '''
424
+ edge_degree_dict = clique_tree.root_hg.edge_degrees()
425
+ leaf_edge_list = [each_edge for each_edge in clique_tree.root_hg.edges \
426
+ if (not clique_tree.root_hg.edge_attr(each_edge).get('tmp'))\
427
+ and edge_degree_dict[each_edge] == 1]
428
+ if not leaf_edge_list:
429
+ return clique_tree, False
430
+ min_deg_edge = leaf_edge_list[0]
431
+ node_list = clique_tree.root_hg.nodes_in_edge(min_deg_edge)
432
+ min_deg_subhg = clique_tree.root_hg.get_subhg(
433
+ node_list, [min_deg_edge], clique_tree.ident_node_dict)
434
+ if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
435
+ return clique_tree, False
436
+ clique_tree.update(min_deg_subhg)
437
+ return clique_tree, True
438
+
439
+ def _rip_labels_from_cycles(clique_tree, org_hg):
440
+ ''' rip hyperedge-labels off
441
+
442
+ Parameters
443
+ ----------
444
+ clique_tree : CliqueTree
445
+ org_hg : Hypergraph
446
+
447
+ Returns
448
+ -------
449
+ CliqueTree, bool
450
+ bool represents whether this operation succeeds or not.
451
+ '''
452
+ ident_node_dict = clique_tree.ident_node_dict #hg.get_identical_node_dict()
453
+ for each_edge in clique_tree.root_hg.edges:
454
+ if each_edge in org_hg.edges:
455
+ if org_hg.in_cycle(each_edge):
456
+ node_list = clique_tree.root_hg.nodes_in_edge(each_edge)
457
+ subhg = clique_tree.root_hg.get_subhg(
458
+ node_list, [each_edge], ident_node_dict)
459
+ if clique_tree.root_hg.nodes == subhg.nodes:
460
+ return clique_tree, False
461
+ clique_tree.update(subhg)
462
+ '''
463
+ in_cycle_dict = {each_node: org_hg.node_attr(each_node)['is_in_ring'] for each_node in node_list}
464
+ if not all(in_cycle_dict.values()):
465
+ node_not_in_cycle = [each_node for each_node in in_cycle_dict.keys() if not in_cycle_dict[each_node]][0]
466
+ node_list = [node_not_in_cycle]
467
+ node_list.extend(clique_tree.root_hg.adj_nodes(node_not_in_cycle))
468
+ edge_list = clique_tree.root_hg.adj_edges(node_not_in_cycle)
469
+ import pdb; pdb.set_trace()
470
+ subhg = clique_tree.root_hg.get_subhg(
471
+ node_list, edge_list, ident_node_dict)
472
+
473
+ clique_tree.update(subhg)
474
+ '''
475
+ return clique_tree, True
476
+ return clique_tree, False
477
+
478
+ def _shrink_cycle(clique_tree):
479
+ ''' shrink a cycle
480
+
481
+ Parameters
482
+ ----------
483
+ clique_tree : CliqueTree
484
+
485
+ Returns
486
+ -------
487
+ CliqueTree, bool
488
+ bool represents whether this operation succeeds or not.
489
+ '''
490
+ def filter_subhg(subhg, hg, key_node):
491
+ num_nodes_cycle = 0
492
+ nodes_in_cycle_list = []
493
+ for each_node in subhg.nodes:
494
+ if hg.in_cycle(each_node):
495
+ num_nodes_cycle += 1
496
+ if each_node != key_node:
497
+ nodes_in_cycle_list.append(each_node)
498
+ if num_nodes_cycle > 3:
499
+ break
500
+ if num_nodes_cycle != 3:
501
+ return False
502
+ else:
503
+ for each_edge in hg.edges:
504
+ if set(nodes_in_cycle_list).issubset(hg.nodes_in_edge(each_edge)):
505
+ return False
506
+ return True
507
+
508
+ #ident_node_dict = hg.get_identical_node_dict()
509
+ ident_node_dict = clique_tree.ident_node_dict
510
+ for each_node in clique_tree.root_hg.nodes:
511
+ if clique_tree.root_hg.in_cycle(each_node)\
512
+ and filter_subhg(clique_tree.root_hg.adj_subhg(each_node, ident_node_dict),
513
+ clique_tree.root_hg,
514
+ each_node):
515
+ target_node = each_node
516
+ target_subhg = clique_tree.root_hg.adj_subhg(target_node, ident_node_dict)
517
+ if clique_tree.root_hg.nodes == target_subhg.nodes:
518
+ return clique_tree, False
519
+ clique_tree.update(target_subhg)
520
+ return clique_tree, True
521
+ return clique_tree, False
522
+
523
+ def _contract_cycles(clique_tree):
524
+ '''
525
+ remove a subhypergraph that looks like a cycle on a leaf.
526
+
527
+ Parameters
528
+ ----------
529
+ clique_tree : CliqueTree
530
+
531
+ Returns
532
+ -------
533
+ CliqueTree, bool
534
+ bool represents whether this operation succeeds or not.
535
+ '''
536
+ def _divide_hg(hg):
537
+ ''' divide a hypergraph into subhypergraphs such that
538
+ each subhypergraph is connected to each other in a tree-like way.
539
+
540
+ Parameters
541
+ ----------
542
+ hg : Hypergraph
543
+
544
+ Returns
545
+ -------
546
+ list of Hypergraphs
547
+ each element corresponds to a subhypergraph of `hg`
548
+ '''
549
+ for each_node in hg.nodes:
550
+ if hg.is_dividable(each_node):
551
+ adj_edges_dict = {each_edge: hg.in_cycle(each_edge) for each_edge in hg.adj_edges(each_node)}
552
+ '''
553
+ if any(adj_edges_dict.values()):
554
+ import pdb; pdb.set_trace()
555
+ edge_in_cycle = [each_key for each_key, each_val in adj_edges_dict.items() if each_val][0]
556
+ subhg1, subhg2, subhg3 = hg.divide(each_node, edge_in_cycle)
557
+ return _divide_hg(subhg1) + _divide_hg(subhg2) + _divide_hg(subhg3)
558
+ else:
559
+ '''
560
+ subhg1, subhg2 = hg.divide(each_node)
561
+ return _divide_hg(subhg1) + _divide_hg(subhg2)
562
+ return [hg]
563
+
564
+ def _is_leaf(hg, divided_subhg) -> bool:
565
+ ''' judge whether subhg is a leaf-like in the original hypergraph
566
+
567
+ Parameters
568
+ ----------
569
+ hg : Hypergraph
570
+ divided_subhg : Hypergraph
571
+ `divided_subhg` is a subhypergraph of `hg`
572
+
573
+ Returns
574
+ -------
575
+ bool
576
+ '''
577
+ '''
578
+ adj_edges_set = set([])
579
+ for each_node in divided_subhg.nodes:
580
+ adj_edges_set.update(set(hg.adj_edges(each_node)))
581
+
582
+
583
+ _hg = deepcopy(hg)
584
+ _hg.remove_subhg(divided_subhg)
585
+ if nx.is_connected(_hg.hg) != (len(adj_edges_set - divided_subhg.edges) == 1):
586
+ import pdb; pdb.set_trace()
587
+ return len(adj_edges_set - divided_subhg.edges) == 1
588
+ '''
589
+ _hg = deepcopy(hg)
590
+ _hg.remove_subhg(divided_subhg)
591
+ return nx.is_connected(_hg.hg)
592
+
593
+ subhg_list = _divide_hg(clique_tree.root_hg)
594
+ if len(subhg_list) == 1:
595
+ return clique_tree, False
596
+ else:
597
+ while len(subhg_list) > 1:
598
+ max_leaf_subhg = None
599
+ for each_subhg in subhg_list:
600
+ if _is_leaf(clique_tree.root_hg, each_subhg):
601
+ if max_leaf_subhg is None:
602
+ max_leaf_subhg = each_subhg
603
+ elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
604
+ max_leaf_subhg = each_subhg
605
+ clique_tree.update(max_leaf_subhg)
606
+ subhg_list.remove(max_leaf_subhg)
607
+ return clique_tree, True
608
+
609
+ org_hg = hg.copy()
610
+ clique_tree = CliqueTree(org_hg)
611
+ clique_tree.add_node(0, subhg=org_hg)
612
+
613
+ success = True
614
+ while success:
615
+ '''
616
+ clique_tree, success = _rip_labels_from_cycles(clique_tree, hg)
617
+ if not success:
618
+ clique_tree, success = _contract_cycles(clique_tree)
619
+ '''
620
+ clique_tree, success = _contract_tree(clique_tree)
621
+ if not success:
622
+ if rip_labels:
623
+ clique_tree, success = _rip_labels_from_cycles(clique_tree, hg)
624
+ if not success:
625
+ if shrink_cycle:
626
+ clique_tree, success = _shrink_cycle(clique_tree)
627
+ if not success:
628
+ if contract_cycles:
629
+ clique_tree, success = _contract_cycles(clique_tree)
630
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
631
+ if irredundant:
632
+ clique_tree.to_irredundant()
633
+ return clique_tree
634
+
635
+ def molecular_tree_decomposition(hg, irredundant=True):
636
+ """ compute a tree decomposition of the input molecular hypergraph
637
+
638
+ Parameters
639
+ ----------
640
+ hg : Hypergraph
641
+ molecular hypergraph to be decomposed
642
+ irredundant : bool
643
+ if True, irredundant tree decomposition will be computed.
644
+
645
+ Returns
646
+ -------
647
+ clique_tree : CliqueTree
648
+ each node contains a subhypergraph of `hg`
649
+ """
650
+ def _divide_hg(hg):
651
+ ''' divide a hypergraph into subhypergraphs such that
652
+ each subhypergraph is connected to each other in a tree-like way.
653
+
654
+ Parameters
655
+ ----------
656
+ hg : Hypergraph
657
+
658
+ Returns
659
+ -------
660
+ list of Hypergraphs
661
+ each element corresponds to a subhypergraph of `hg`
662
+ '''
663
+ is_ring = False
664
+ for each_node in hg.nodes:
665
+ if hg.node_attr(each_node)['is_in_ring']:
666
+ is_ring = True
667
+ if not hg.node_attr(each_node)['is_in_ring'] \
668
+ and hg.degree(each_node) == 2:
669
+ subhg1, subhg2 = hg.divide(each_node)
670
+ return _divide_hg(subhg1) + _divide_hg(subhg2)
671
+
672
+ if is_ring:
673
+ subhg_list = []
674
+ remove_edge_list = []
675
+ remove_node_list = []
676
+ for each_edge in hg.edges:
677
+ node_list = hg.nodes_in_edge(each_edge)
678
+ subhg = hg.get_subhg(node_list, [each_edge], hg.get_identical_node_dict())
679
+ subhg_list.append(subhg)
680
+ remove_edge_list.append(each_edge)
681
+ for each_node in node_list:
682
+ if not hg.node_attr(each_node)['is_in_ring']:
683
+ remove_node_list.append(each_node)
684
+ hg.remove_edges(remove_edge_list)
685
+ hg.remove_nodes(remove_node_list, False)
686
+ return subhg_list + [hg]
687
+ else:
688
+ return [hg]
689
+
690
+ org_hg = hg.copy()
691
+ clique_tree = CliqueTree(org_hg)
692
+ clique_tree.add_node(0, subhg=org_hg)
693
+
694
+ subhg_list = _divide_hg(deepcopy(clique_tree.root_hg))
695
+ #_subhg_list = deepcopy(subhg_list)
696
+ if len(subhg_list) == 1:
697
+ pass
698
+ else:
699
+ while len(subhg_list) > 1:
700
+ max_leaf_subhg = None
701
+ for each_subhg in subhg_list:
702
+ if _is_leaf(clique_tree.root_hg, each_subhg) and not _is_ring(each_subhg):
703
+ if max_leaf_subhg is None:
704
+ max_leaf_subhg = each_subhg
705
+ elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
706
+ max_leaf_subhg = each_subhg
707
+
708
+ if max_leaf_subhg is None:
709
+ for each_subhg in subhg_list:
710
+ if _is_ring_label(clique_tree.root_hg, each_subhg):
711
+ if max_leaf_subhg is None:
712
+ max_leaf_subhg = each_subhg
713
+ elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
714
+ max_leaf_subhg = each_subhg
715
+ if max_leaf_subhg is not None:
716
+ clique_tree.update(max_leaf_subhg)
717
+ subhg_list.remove(max_leaf_subhg)
718
+ else:
719
+ for each_subhg in subhg_list:
720
+ if _is_leaf(clique_tree.root_hg, each_subhg):
721
+ if max_leaf_subhg is None:
722
+ max_leaf_subhg = each_subhg
723
+ elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
724
+ max_leaf_subhg = each_subhg
725
+ if max_leaf_subhg is not None:
726
+ clique_tree.update(max_leaf_subhg, True)
727
+ subhg_list.remove(max_leaf_subhg)
728
+ else:
729
+ break
730
+ if len(subhg_list) > 1:
731
+ '''
732
+ for each_idx, each_subhg in enumerate(subhg_list):
733
+ each_subhg.draw(f'{each_idx}', True)
734
+ clique_tree.root_hg.draw('root', True)
735
+ import pickle
736
+ with open('buggy_hg.pkl', 'wb') as f:
737
+ pickle.dump(hg, f)
738
+ return clique_tree, subhg_list, _subhg_list
739
+ '''
740
+ raise RuntimeError('bug in tree decomposition algorithm')
741
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
742
+
743
+ '''
744
+ for each_tree_node in clique_tree.adj[0]:
745
+ subhg = clique_tree.nodes[each_tree_node]['subhg']
746
+ for each_edge in subhg.edges:
747
+ if set(subhg.nodes_in_edge(each_edge)).issubset(clique_tree.root_hg.nodes):
748
+ clique_tree.root_hg.add_edge(set(subhg.nodes_in_edge(each_edge)), attr_dict=dict(tmp=True))
749
+ '''
750
+ if irredundant:
751
+ clique_tree.to_irredundant()
752
+ return clique_tree #, _subhg_list
753
+
754
+ def _is_leaf(hg, subhg) -> bool:
755
+ ''' judge whether subhg is a leaf-like in the original hypergraph
756
+
757
+ Parameters
758
+ ----------
759
+ hg : Hypergraph
760
+ subhg : Hypergraph
761
+ `subhg` is a subhypergraph of `hg`
762
+
763
+ Returns
764
+ -------
765
+ bool
766
+ '''
767
+ if len(subhg.edges) == 0:
768
+ adj_edge_set = set([])
769
+ subhg_edge_set = set([])
770
+ for each_edge in hg.edges:
771
+ if set(hg.nodes_in_edge(each_edge)).issubset(subhg.nodes) and hg.edge_attr(each_edge).get('tmp', False):
772
+ subhg_edge_set.add(each_edge)
773
+ for each_node in subhg.nodes:
774
+ adj_edge_set.update(set(hg.adj_edges(each_node)))
775
+ if subhg_edge_set.issubset(adj_edge_set) and len(adj_edge_set.difference(subhg_edge_set)) == 1:
776
+ return True
777
+ else:
778
+ return False
779
+ elif len(subhg.edges) == 1:
780
+ adj_edge_set = set([])
781
+ subhg_edge_set = subhg.edges
782
+ for each_node in subhg.nodes:
783
+ for each_adj_edge in hg.adj_edges(each_node):
784
+ adj_edge_set.add(each_adj_edge)
785
+ if subhg_edge_set.issubset(adj_edge_set) and len(adj_edge_set.difference(subhg_edge_set)) == 1:
786
+ return True
787
+ else:
788
+ return False
789
+ else:
790
+ raise ValueError('subhg should be nodes only or one-edge hypergraph.')
791
+
792
+ def _is_ring_label(hg, subhg):
793
+ if len(subhg.edges) != 1:
794
+ return False
795
+ edge_name = list(subhg.edges)[0]
796
+ #assert edge_name in hg.edges, f'{edge_name}'
797
+ is_in_ring = False
798
+ for each_node in subhg.nodes:
799
+ if subhg.node_attr(each_node)['is_in_ring']:
800
+ is_in_ring = True
801
+ else:
802
+ adj_edge_list = list(hg.adj_edges(each_node))
803
+ adj_edge_list.remove(edge_name)
804
+ if len(adj_edge_list) == 1:
805
+ if not hg.edge_attr(adj_edge_list[0]).get('tmp', False):
806
+ return False
807
+ elif len(adj_edge_list) == 0:
808
+ pass
809
+ else:
810
+ raise ValueError
811
+ if is_in_ring:
812
+ return True
813
+ else:
814
+ return False
815
+
816
+ def _is_ring(hg):
817
+ for each_node in hg.nodes:
818
+ if not hg.node_attr(each_node)['is_in_ring']:
819
+ return False
820
+ return True
821
+
models/mhg_model/graph_grammar/graph_grammar/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 1 2018"
20
+
models/mhg_model/graph_grammar/graph_grammar/base.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2017"
18
+ __version__ = "0.1"
19
+ __date__ = "Dec 11 2017"
20
+
21
+ from abc import ABCMeta, abstractmethod
22
+
23
+ class GraphGrammarBase(metaclass=ABCMeta):
24
+ @abstractmethod
25
+ def learn(self):
26
+ pass
27
+
28
+ @abstractmethod
29
+ def sample(self):
30
+ pass
models/mhg_model/graph_grammar/graph_grammar/corpus.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jun 4 2018"
20
+
21
+ from collections import Counter
22
+ from functools import partial
23
+ from .utils import _easy_node_match, _edge_match, _node_match, common_node_list, _node_match_prod_rule
24
+ from networkx.algorithms.isomorphism import GraphMatcher
25
+ import os
26
+
27
+
28
+ class CliqueTreeCorpus(object):
29
+
30
+ ''' clique tree corpus
31
+
32
+ Attributes
33
+ ----------
34
+ clique_tree_list : list of CliqueTree
35
+ subhg_list : list of Hypergraph
36
+ '''
37
+
38
+ def __init__(self):
39
+ self.clique_tree_list = []
40
+ self.subhg_list = []
41
+
42
+ @property
43
+ def size(self):
44
+ return len(self.subhg_list)
45
+
46
+ def add_clique_tree(self, clique_tree):
47
+ for each_node in clique_tree.nodes:
48
+ subhg = clique_tree.nodes[each_node]['subhg']
49
+ subhg_idx = self.add_subhg(subhg)
50
+ clique_tree.nodes[each_node]['subhg_idx'] = subhg_idx
51
+ self.clique_tree_list.append(clique_tree)
52
+
53
+ def add_to_subhg_list(self, clique_tree, root_node):
54
+ parent_node_dict = {}
55
+ current_node = None
56
+ parent_node_dict[root_node] = None
57
+ stack = [root_node]
58
+ while stack:
59
+ current_node = stack.pop()
60
+ current_subhg = clique_tree.nodes[current_node]['subhg']
61
+ for each_child in clique_tree.adj[current_node]:
62
+ if each_child != parent_node_dict[current_node]:
63
+ stack.append(each_child)
64
+ parent_node_dict[each_child] = current_node
65
+ if parent_node_dict[current_node] is not None:
66
+ parent_subhg = clique_tree.nodes[parent_node_dict[current_node]]['subhg']
67
+ common, _ = common_node_list(parent_subhg, current_subhg)
68
+ parent_subhg.add_edge(set(common), attr_dict={'tmp': True})
69
+
70
+ parent_node_dict = {}
71
+ current_node = None
72
+ parent_node_dict[root_node] = None
73
+ stack = [root_node]
74
+ while stack:
75
+ current_node = stack.pop()
76
+ current_subhg = clique_tree.nodes[current_node]['subhg']
77
+ for each_child in clique_tree.adj[current_node]:
78
+ if each_child != parent_node_dict[current_node]:
79
+ stack.append(each_child)
80
+ parent_node_dict[each_child] = current_node
81
+ if parent_node_dict[current_node] is not None:
82
+ parent_subhg = clique_tree.nodes[parent_node_dict[current_node]]['subhg']
83
+ common, _ = common_node_list(parent_subhg, current_subhg)
84
+ for each_idx, each_node in enumerate(common):
85
+ current_subhg.set_node_attr(each_node, {'ext_id': each_idx})
86
+
87
+ subhg_idx, is_new = self.add_subhg(current_subhg)
88
+ clique_tree.nodes[current_node]['subhg_idx'] = subhg_idx
89
+ return clique_tree
90
+
91
+ def add_subhg(self, subhg):
92
+ if len(self.subhg_list) == 0:
93
+ node_dict = {}
94
+ for each_node in subhg.nodes:
95
+ node_dict[each_node] = subhg.node_attr(each_node)['symbol'].__hash__()
96
+ node_list = []
97
+ for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
98
+ node_list.append(each_key)
99
+ for each_idx, each_node in enumerate(node_list):
100
+ subhg.node_attr(each_node)['order4hrg'] = each_idx
101
+ self.subhg_list.append(subhg)
102
+ return 0, True
103
+ else:
104
+ match = False
105
+ subhg_bond_symbol_counter \
106
+ = Counter([subhg.node_attr(each_node)['symbol'] \
107
+ for each_node in subhg.nodes])
108
+ subhg_atom_symbol_counter \
109
+ = Counter([subhg.edge_attr(each_edge).get('symbol', None) \
110
+ for each_edge in subhg.edges])
111
+ for each_idx, each_subhg in enumerate(self.subhg_list):
112
+ each_bond_symbol_counter \
113
+ = Counter([each_subhg.node_attr(each_node)['symbol'] \
114
+ for each_node in each_subhg.nodes])
115
+ each_atom_symbol_counter \
116
+ = Counter([each_subhg.edge_attr(each_edge).get('symbol', None) \
117
+ for each_edge in each_subhg.edges])
118
+ if not match \
119
+ and (subhg.num_nodes == each_subhg.num_nodes
120
+ and subhg.num_edges == each_subhg.num_edges
121
+ and subhg_bond_symbol_counter == each_bond_symbol_counter
122
+ and subhg_atom_symbol_counter == each_atom_symbol_counter):
123
+ gm = GraphMatcher(each_subhg.hg,
124
+ subhg.hg,
125
+ node_match=_easy_node_match,
126
+ edge_match=_edge_match)
127
+ try:
128
+ isomap = next(gm.isomorphisms_iter())
129
+ match = True
130
+ for each_node in each_subhg.nodes:
131
+ subhg.node_attr(isomap[each_node])['order4hrg'] \
132
+ = each_subhg.node_attr(each_node)['order4hrg']
133
+ if 'ext_id' in each_subhg.node_attr(each_node):
134
+ subhg.node_attr(isomap[each_node])['ext_id'] \
135
+ = each_subhg.node_attr(each_node)['ext_id']
136
+ return each_idx, False
137
+ except StopIteration:
138
+ match = False
139
+ if not match:
140
+ node_dict = {}
141
+ for each_node in subhg.nodes:
142
+ node_dict[each_node] = subhg.node_attr(each_node)['symbol'].__hash__()
143
+ node_list = []
144
+ for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
145
+ node_list.append(each_key)
146
+ for each_idx, each_node in enumerate(node_list):
147
+ subhg.node_attr(each_node)['order4hrg'] = each_idx
148
+
149
+ #for each_idx, each_node in enumerate(subhg.nodes):
150
+ # subhg.node_attr(each_node)['order4hrg'] = each_idx
151
+ self.subhg_list.append(subhg)
152
+ return len(self.subhg_list) - 1, True
models/mhg_model/graph_grammar/graph_grammar/hrg.py ADDED
@@ -0,0 +1,1065 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2017"
18
+ __version__ = "0.1"
19
+ __date__ = "Dec 11 2017"
20
+
21
+ from .corpus import CliqueTreeCorpus
22
+ from .base import GraphGrammarBase
23
+ from .symbols import TSymbol, NTSymbol, BondSymbol
24
+ from .utils import _node_match, _node_match_prod_rule, _edge_match, masked_softmax, common_node_list
25
+ from ..hypergraph import Hypergraph
26
+ from collections import Counter
27
+ from copy import deepcopy
28
+ from ..algo.tree_decomposition import (
29
+ tree_decomposition,
30
+ tree_decomposition_with_hrg,
31
+ tree_decomposition_from_leaf,
32
+ topological_tree_decomposition,
33
+ molecular_tree_decomposition)
34
+ from functools import partial
35
+ from networkx.algorithms.isomorphism import GraphMatcher
36
+ from typing import List, Dict, Tuple
37
+ import networkx as nx
38
+ import numpy as np
39
+ import torch
40
+ import os
41
+ import random
42
+
43
+ DEBUG = False
44
+
45
+
46
+ class ProductionRule(object):
47
+ """ A class of a production rule
48
+
49
+ Attributes
50
+ ----------
51
+ lhs : Hypergraph or None
52
+ the left hand side of the production rule.
53
+ if None, the rule is a starting rule.
54
+ rhs : Hypergraph
55
+ the right hand side of the production rule.
56
+ """
57
+ def __init__(self, lhs, rhs):
58
+ self.lhs = lhs
59
+ self.rhs = rhs
60
+
61
+ @property
62
+ def is_start_rule(self) -> bool:
63
+ return self.lhs.num_nodes == 0
64
+
65
+ @property
66
+ def ext_node(self) -> Dict[int, str]:
67
+ """ return a dict of external nodes
68
+ """
69
+ if self.is_start_rule:
70
+ return {}
71
+ else:
72
+ ext_node_dict = {}
73
+ for each_node in self.lhs.nodes:
74
+ ext_node_dict[self.lhs.node_attr(each_node)["ext_id"]] = each_node
75
+ return ext_node_dict
76
+
77
+ @property
78
+ def lhs_nt_symbol(self) -> NTSymbol:
79
+ if self.is_start_rule:
80
+ return NTSymbol(degree=0, is_aromatic=False, bond_symbol_list=[])
81
+ else:
82
+ return self.lhs.edge_attr(list(self.lhs.edges)[0])['symbol']
83
+
84
+ def rhs_adj_mat(self, node_edge_list):
85
+ ''' return the adjacency matrix of rhs of the production rule
86
+ '''
87
+ return nx.adjacency_matrix(self.rhs.hg, node_edge_list)
88
+
89
+ def draw(self, file_path=None):
90
+ return self.rhs.draw(file_path)
91
+
92
+ def is_same(self, prod_rule, ignore_order=False):
93
+ """ judge whether this production rule is
94
+ the same as the input one, `prod_rule`
95
+
96
+ Parameters
97
+ ----------
98
+ prod_rule : ProductionRule
99
+ production rule to be compared
100
+
101
+ Returns
102
+ -------
103
+ is_same : bool
104
+ isomap : dict
105
+ isomorphism of nodes and hyperedges.
106
+ ex) {'bond_42': 'bond_37', 'bond_2': 'bond_1',
107
+ 'e36': 'e11', 'e16': 'e12', 'e25': 'e18',
108
+ 'bond_40': 'bond_38', 'e26': 'e21', 'bond_41': 'bond_39'}.
109
+ key comes from `prod_rule`, value comes from `self`.
110
+ """
111
+ if self.is_start_rule:
112
+ if not prod_rule.is_start_rule:
113
+ return False, {}
114
+ else:
115
+ if prod_rule.is_start_rule:
116
+ return False, {}
117
+ else:
118
+ if prod_rule.lhs.num_nodes != self.lhs.num_nodes:
119
+ return False, {}
120
+
121
+ if prod_rule.rhs.num_nodes != self.rhs.num_nodes:
122
+ return False, {}
123
+ if prod_rule.rhs.num_edges != self.rhs.num_edges:
124
+ return False, {}
125
+
126
+ subhg_bond_symbol_counter \
127
+ = Counter([prod_rule.rhs.node_attr(each_node)['symbol'] \
128
+ for each_node in prod_rule.rhs.nodes])
129
+ each_bond_symbol_counter \
130
+ = Counter([self.rhs.node_attr(each_node)['symbol'] \
131
+ for each_node in self.rhs.nodes])
132
+ if subhg_bond_symbol_counter != each_bond_symbol_counter:
133
+ return False, {}
134
+
135
+ subhg_atom_symbol_counter \
136
+ = Counter([prod_rule.rhs.edge_attr(each_edge)['symbol'] \
137
+ for each_edge in prod_rule.rhs.edges])
138
+ each_atom_symbol_counter \
139
+ = Counter([self.rhs.edge_attr(each_edge)['symbol'] \
140
+ for each_edge in self.rhs.edges])
141
+ if subhg_atom_symbol_counter != each_atom_symbol_counter:
142
+ return False, {}
143
+
144
+ gm = GraphMatcher(prod_rule.rhs.hg,
145
+ self.rhs.hg,
146
+ partial(_node_match_prod_rule,
147
+ ignore_order=ignore_order),
148
+ partial(_edge_match,
149
+ ignore_order=ignore_order))
150
+ try:
151
+ return True, next(gm.isomorphisms_iter())
152
+ except StopIteration:
153
+ return False, {}
154
+
155
+ def applied_to(self,
156
+ hg: Hypergraph,
157
+ edge: str) -> Tuple[Hypergraph, List[str]]:
158
+ """ augment `hg` by replacing `edge` with `self.rhs`.
159
+
160
+ Parameters
161
+ ----------
162
+ hg : Hypergraph
163
+ edge : str
164
+ `edge` must belong to `hg`
165
+
166
+ Returns
167
+ -------
168
+ hg : Hypergraph
169
+ resultant hypergraph
170
+ nt_edge_list : list
171
+ list of non-terminal edges
172
+ """
173
+ nt_edge_dict = {}
174
+ if self.is_start_rule:
175
+ if (edge is not None) or (hg is not None):
176
+ ValueError("edge and hg must be None for this prod rule.")
177
+ hg = Hypergraph()
178
+ node_map_rhs = {} # node id in rhs -> node id in hg, where rhs is augmented.
179
+ for num_idx, each_node in enumerate(self.rhs.nodes):
180
+ hg.add_node(f"bond_{num_idx}",
181
+ #attr_dict=deepcopy(self.rhs.node_attr(each_node)))
182
+ attr_dict=self.rhs.node_attr(each_node))
183
+ node_map_rhs[each_node] = f"bond_{num_idx}"
184
+ for each_edge in self.rhs.edges:
185
+ node_list = []
186
+ for each_node in self.rhs.nodes_in_edge(each_edge):
187
+ node_list.append(node_map_rhs[each_node])
188
+ if isinstance(self.rhs.nodes_in_edge(each_edge), set):
189
+ node_list = set(node_list)
190
+ edge_id = hg.add_edge(
191
+ node_list,
192
+ #attr_dict=deepcopy(self.rhs.edge_attr(each_edge)))
193
+ attr_dict=self.rhs.edge_attr(each_edge))
194
+ if "nt_idx" in hg.edge_attr(edge_id):
195
+ nt_edge_dict[hg.edge_attr(edge_id)["nt_idx"]] = edge_id
196
+ nt_edge_list = [nt_edge_dict[key] for key in range(len(nt_edge_dict))]
197
+ return hg, nt_edge_list
198
+ else:
199
+ if edge not in hg.edges:
200
+ raise ValueError("the input hyperedge does not exist.")
201
+ if hg.edge_attr(edge)["terminal"]:
202
+ raise ValueError("the input hyperedge is terminal.")
203
+ if hg.edge_attr(edge)['symbol'] != self.lhs_nt_symbol:
204
+ print(hg.edge_attr(edge)['symbol'], self.lhs_nt_symbol)
205
+ raise ValueError("the input hyperedge and lhs have inconsistent number of nodes.")
206
+ if DEBUG:
207
+ for node_idx, each_node in enumerate(hg.nodes_in_edge(edge)):
208
+ other_node = self.lhs.nodes_in_edge(list(self.lhs.edges)[0])[node_idx]
209
+ attr = deepcopy(self.lhs.node_attr(other_node))
210
+ attr.pop('ext_id')
211
+ if hg.node_attr(each_node) != attr:
212
+ raise ValueError('node attributes are inconsistent.')
213
+
214
+ # order of nodes that belong to the non-terminal edge in hg
215
+ nt_order_dict = {} # hg_node -> order ("bond_17" : 1)
216
+ nt_order_dict_inv = {} # order -> hg_node
217
+ for each_idx, each_node in enumerate(hg.nodes_in_edge(edge)):
218
+ nt_order_dict[each_node] = each_idx
219
+ nt_order_dict_inv[each_idx] = each_node
220
+
221
+ # construct a node_map_rhs: rhs -> new hg
222
+ node_map_rhs = {} # node id in rhs -> node id in hg, where rhs is augmented.
223
+ node_idx = hg.num_nodes
224
+ for each_node in self.rhs.nodes:
225
+ if "ext_id" in self.rhs.node_attr(each_node):
226
+ node_map_rhs[each_node] \
227
+ = nt_order_dict_inv[
228
+ self.rhs.node_attr(each_node)["ext_id"]]
229
+ else:
230
+ node_map_rhs[each_node] = f"bond_{node_idx}"
231
+ node_idx += 1
232
+
233
+ # delete non-terminal
234
+ hg.remove_edge(edge)
235
+
236
+ # add nodes to hg
237
+ for each_node in self.rhs.nodes:
238
+ hg.add_node(node_map_rhs[each_node],
239
+ attr_dict=self.rhs.node_attr(each_node))
240
+
241
+ # add hyperedges to hg
242
+ for each_edge in self.rhs.edges:
243
+ node_list_hg = []
244
+ for each_node in self.rhs.nodes_in_edge(each_edge):
245
+ node_list_hg.append(node_map_rhs[each_node])
246
+ edge_id = hg.add_edge(
247
+ node_list_hg,
248
+ attr_dict=self.rhs.edge_attr(each_edge))#deepcopy(self.rhs.edge_attr(each_edge)))
249
+ if "nt_idx" in hg.edge_attr(edge_id):
250
+ nt_edge_dict[hg.edge_attr(edge_id)["nt_idx"]] = edge_id
251
+ nt_edge_list = [nt_edge_dict[key] for key in range(len(nt_edge_dict))]
252
+ return hg, nt_edge_list
253
+
254
+ def revert(self, hg: Hypergraph, return_subhg=False):
255
+ ''' revert applying this production rule.
256
+ i.e., if there exists a subhypergraph that matches the r.h.s. of this production rule,
257
+ this method replaces the subhypergraph with a non-terminal hyperedge.
258
+
259
+ Parameters
260
+ ----------
261
+ hg : Hypergraph
262
+ hypergraph to be reverted
263
+ return_subhg : bool
264
+ if True, the removed subhypergraph will be returned.
265
+
266
+ Returns
267
+ -------
268
+ hg : Hypergraph
269
+ the resultant hypergraph. if it cannot be reverted, the original one is returned without any replacement.
270
+ success : bool
271
+ this indicates whether reverting is successed or not.
272
+ '''
273
+ gm = GraphMatcher(hg.hg, self.rhs.hg, node_match=_node_match_prod_rule,
274
+ edge_match=_edge_match)
275
+ try:
276
+ # in case when the matched subhg is connected to the other part via external nodes and more.
277
+ not_iso = True
278
+ while not_iso:
279
+ isomap = next(gm.subgraph_isomorphisms_iter())
280
+ adj_node_set = set([]) # reachable nodes from the internal nodes
281
+ subhg_node_set = set(isomap.keys()) # nodes in subhg
282
+ for each_node in subhg_node_set:
283
+ adj_node_set.add(each_node)
284
+ if isomap[each_node] not in self.ext_node.values():
285
+ adj_node_set.update(hg.hg.adj[each_node])
286
+ if adj_node_set == subhg_node_set:
287
+ not_iso = False
288
+ else:
289
+ if return_subhg:
290
+ return hg, False, Hypergraph()
291
+ else:
292
+ return hg, False
293
+ inv_isomap = {v: k for k, v in isomap.items()}
294
+ '''
295
+ isomap = {'e35': 'e8', 'bond_13': 'bond_18', 'bond_14': 'bond_19',
296
+ 'bond_15': 'bond_17', 'e29': 'e23', 'bond_12': 'bond_20'}
297
+ where keys come from `hg` and values come from `self.rhs`
298
+ '''
299
+ except StopIteration:
300
+ if return_subhg:
301
+ return hg, False, Hypergraph()
302
+ else:
303
+ return hg, False
304
+
305
+ if return_subhg:
306
+ subhg = Hypergraph()
307
+ for each_node in hg.nodes:
308
+ if each_node in isomap:
309
+ subhg.add_node(each_node, attr_dict=hg.node_attr(each_node))
310
+ for each_edge in hg.edges:
311
+ if each_edge in isomap:
312
+ subhg.add_edge(hg.nodes_in_edge(each_edge),
313
+ attr_dict=hg.edge_attr(each_edge),
314
+ edge_name=each_edge)
315
+ subhg.edge_idx = hg.edge_idx
316
+
317
+ # remove subhg except for the externael nodes
318
+ for each_key, each_val in isomap.items():
319
+ if each_key.startswith('e'):
320
+ hg.remove_edge(each_key)
321
+ for each_key, each_val in isomap.items():
322
+ if each_key.startswith('bond_'):
323
+ if each_val not in self.ext_node.values():
324
+ hg.remove_node(each_key)
325
+
326
+ # add non-terminal hyperedge
327
+ nt_node_list = []
328
+ for each_ext_id in self.ext_node.keys():
329
+ nt_node_list.append(inv_isomap[self.ext_node[each_ext_id]])
330
+
331
+ hg.add_edge(nt_node_list,
332
+ attr_dict=dict(
333
+ terminal=False,
334
+ symbol=self.lhs_nt_symbol))
335
+ if return_subhg:
336
+ return hg, True, subhg
337
+ else:
338
+ return hg, True
339
+
340
+
341
+ class ProductionRuleCorpus(object):
342
+
343
+ '''
344
+ A corpus of production rules.
345
+ This class maintains
346
+ (i) list of unique production rules,
347
+ (ii) list of unique edge symbols (both terminal and non-terminal), and
348
+ (iii) list of unique node symbols.
349
+
350
+ Attributes
351
+ ----------
352
+ prod_rule_list : list
353
+ list of unique production rules
354
+ edge_symbol_list : list
355
+ list of unique symbols (including both terminal and non-terminal)
356
+ node_symbol_list : list
357
+ list of node symbols
358
+ nt_symbol_list : list
359
+ list of unique lhs symbols
360
+ ext_id_list : list
361
+ list of ext_ids
362
+ lhs_in_prod_rule : array
363
+ a matrix of lhs vs prod_rule (= lhs_in_prod_rule)
364
+ '''
365
+
366
+ def __init__(self):
367
+ self.prod_rule_list = []
368
+ self.edge_symbol_list = []
369
+ self.edge_symbol_dict = {}
370
+ self.node_symbol_list = []
371
+ self.node_symbol_dict = {}
372
+ self.nt_symbol_list = []
373
+ self.ext_id_list = []
374
+ self._lhs_in_prod_rule = None
375
+ self.lhs_in_prod_rule_row_list = []
376
+ self.lhs_in_prod_rule_col_list = []
377
+
378
+ @property
379
+ def lhs_in_prod_rule(self):
380
+ if self._lhs_in_prod_rule is None:
381
+ self._lhs_in_prod_rule = torch.sparse.FloatTensor(
382
+ torch.LongTensor(list(zip(self.lhs_in_prod_rule_row_list, self.lhs_in_prod_rule_col_list))).t(),
383
+ torch.FloatTensor([1.0]*len(self.lhs_in_prod_rule_col_list)),
384
+ torch.Size([len(self.nt_symbol_list), len(self.prod_rule_list)])
385
+ ).to_dense()
386
+ return self._lhs_in_prod_rule
387
+
388
+ @property
389
+ def num_prod_rule(self):
390
+ ''' return the number of production rules
391
+
392
+ Returns
393
+ -------
394
+ int : the number of unique production rules
395
+ '''
396
+ return len(self.prod_rule_list)
397
+
398
+ @property
399
+ def start_rule_list(self):
400
+ ''' return a list of start rules
401
+
402
+ Returns
403
+ -------
404
+ list : list of start rules
405
+ '''
406
+ start_rule_list = []
407
+ for each_prod_rule in self.prod_rule_list:
408
+ if each_prod_rule.is_start_rule:
409
+ start_rule_list.append(each_prod_rule)
410
+ return start_rule_list
411
+
412
+ @property
413
+ def num_edge_symbol(self):
414
+ return len(self.edge_symbol_list)
415
+
416
+ @property
417
+ def num_node_symbol(self):
418
+ return len(self.node_symbol_list)
419
+
420
+ @property
421
+ def num_ext_id(self):
422
+ return len(self.ext_id_list)
423
+
424
+ def construct_feature_vectors(self):
425
+ ''' this method constructs feature vectors for the production rules collected so far.
426
+ currently, NTSymbol and TSymbol are treated in the same manner.
427
+ '''
428
+ feature_id_dict = {}
429
+ feature_id_dict['TSymbol'] = 0
430
+ feature_id_dict['NTSymbol'] = 1
431
+ feature_id_dict['BondSymbol'] = 2
432
+ for each_edge_symbol in self.edge_symbol_list:
433
+ for each_attr in each_edge_symbol.__dict__.keys():
434
+ each_val = each_edge_symbol.__dict__[each_attr]
435
+ if isinstance(each_val, list):
436
+ each_val = tuple(each_val)
437
+ if (each_attr, each_val) not in feature_id_dict:
438
+ feature_id_dict[(each_attr, each_val)] = len(feature_id_dict)
439
+
440
+ for each_node_symbol in self.node_symbol_list:
441
+ for each_attr in each_node_symbol.__dict__.keys():
442
+ each_val = each_node_symbol.__dict__[each_attr]
443
+ if isinstance(each_val, list):
444
+ each_val = tuple(each_val)
445
+ if (each_attr, each_val) not in feature_id_dict:
446
+ feature_id_dict[(each_attr, each_val)] = len(feature_id_dict)
447
+ for each_ext_id in self.ext_id_list:
448
+ feature_id_dict[('ext_id', each_ext_id)] = len(feature_id_dict)
449
+ dim = len(feature_id_dict)
450
+
451
+ feature_dict = {}
452
+ for each_edge_symbol in self.edge_symbol_list:
453
+ idx_list = []
454
+ idx_list.append(feature_id_dict[each_edge_symbol.__class__.__name__])
455
+ for each_attr in each_edge_symbol.__dict__.keys():
456
+ each_val = each_edge_symbol.__dict__[each_attr]
457
+ if isinstance(each_val, list):
458
+ each_val = tuple(each_val)
459
+ idx_list.append(feature_id_dict[(each_attr, each_val)])
460
+ feature = torch.sparse.LongTensor(
461
+ torch.LongTensor([idx_list]),
462
+ torch.ones(len(idx_list)),
463
+ torch.Size([len(feature_id_dict)])
464
+ )
465
+ feature_dict[each_edge_symbol] = feature
466
+
467
+ for each_node_symbol in self.node_symbol_list:
468
+ idx_list = []
469
+ idx_list.append(feature_id_dict[each_node_symbol.__class__.__name__])
470
+ for each_attr in each_node_symbol.__dict__.keys():
471
+ each_val = each_node_symbol.__dict__[each_attr]
472
+ if isinstance(each_val, list):
473
+ each_val = tuple(each_val)
474
+ idx_list.append(feature_id_dict[(each_attr, each_val)])
475
+ feature = torch.sparse.LongTensor(
476
+ torch.LongTensor([idx_list]),
477
+ torch.ones(len(idx_list)),
478
+ torch.Size([len(feature_id_dict)])
479
+ )
480
+ feature_dict[each_node_symbol] = feature
481
+ for each_ext_id in self.ext_id_list:
482
+ idx_list = [feature_id_dict[('ext_id', each_ext_id)]]
483
+ feature_dict[('ext_id', each_ext_id)] \
484
+ = torch.sparse.LongTensor(
485
+ torch.LongTensor([idx_list]),
486
+ torch.ones(len(idx_list)),
487
+ torch.Size([len(feature_id_dict)])
488
+ )
489
+ return feature_dict, dim
490
+
491
+ def edge_symbol_idx(self, symbol):
492
+ return self.edge_symbol_dict[symbol]
493
+
494
+ def node_symbol_idx(self, symbol):
495
+ return self.node_symbol_dict[symbol]
496
+
497
+ def append(self, prod_rule: ProductionRule) -> Tuple[int, ProductionRule]:
498
+ """ return whether the input production rule is new or not, and its production rule id.
499
+ Production rules are regarded as the same if
500
+ i) there exists a one-to-one mapping of nodes and edges, and
501
+ ii) all the attributes associated with nodes and hyperedges are the same.
502
+
503
+ Parameters
504
+ ----------
505
+ prod_rule : ProductionRule
506
+
507
+ Returns
508
+ -------
509
+ prod_rule_id : int
510
+ production rule index. if new, a new index will be assigned.
511
+ prod_rule : ProductionRule
512
+ """
513
+ num_lhs = len(self.nt_symbol_list)
514
+ for each_idx, each_prod_rule in enumerate(self.prod_rule_list):
515
+ is_same, isomap = prod_rule.is_same(each_prod_rule)
516
+ if is_same:
517
+ # we do not care about edge and node names, but care about the order of non-terminal edges.
518
+ for key, val in isomap.items(): # key : edges & nodes in each_prod_rule.rhs , val : those in prod_rule.rhs
519
+ if key.startswith("bond_"):
520
+ continue
521
+
522
+ # rewrite `nt_idx` in `prod_rule` for further processing
523
+ if "nt_idx" in prod_rule.rhs.edge_attr(val).keys():
524
+ if "nt_idx" not in each_prod_rule.rhs.edge_attr(key).keys():
525
+ raise ValueError
526
+ prod_rule.rhs.set_edge_attr(
527
+ val,
528
+ {'nt_idx': each_prod_rule.rhs.edge_attr(key)["nt_idx"]})
529
+ return each_idx, prod_rule
530
+ self.prod_rule_list.append(prod_rule)
531
+ self._update_edge_symbol_list(prod_rule)
532
+ self._update_node_symbol_list(prod_rule)
533
+ self._update_ext_id_list(prod_rule)
534
+
535
+ lhs_idx = self.nt_symbol_list.index(prod_rule.lhs_nt_symbol)
536
+ self.lhs_in_prod_rule_row_list.append(lhs_idx)
537
+ self.lhs_in_prod_rule_col_list.append(len(self.prod_rule_list)-1)
538
+ self._lhs_in_prod_rule = None
539
+ return len(self.prod_rule_list)-1, prod_rule
540
+
541
+ def get_prod_rule(self, prod_rule_idx: int) -> ProductionRule:
542
+ return self.prod_rule_list[prod_rule_idx]
543
+
544
+ def sample(self, unmasked_logit_array, nt_symbol, deterministic=False):
545
+ ''' sample a production rule whose lhs is `nt_symbol`, followihng `unmasked_logit_array`.
546
+
547
+ Parameters
548
+ ----------
549
+ unmasked_logit_array : array-like, length `num_prod_rule`
550
+ nt_symbol : NTSymbol
551
+ '''
552
+ if not isinstance(unmasked_logit_array, np.ndarray):
553
+ unmasked_logit_array = unmasked_logit_array.numpy().astype(np.float64)
554
+ if deterministic:
555
+ prob = masked_softmax(unmasked_logit_array,
556
+ self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64))
557
+ return self.prod_rule_list[np.argmax(prob)]
558
+ else:
559
+ return np.random.choice(
560
+ self.prod_rule_list, 1,
561
+ p=masked_softmax(unmasked_logit_array,
562
+ self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64)))[0]
563
+
564
+ def masked_logprob(self, unmasked_logit_array, nt_symbol):
565
+ if not isinstance(unmasked_logit_array, np.ndarray):
566
+ unmasked_logit_array = unmasked_logit_array.numpy().astype(np.float64)
567
+ prob = masked_softmax(unmasked_logit_array,
568
+ self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64))
569
+ return np.log(prob)
570
+
571
+ def _update_edge_symbol_list(self, prod_rule: ProductionRule):
572
+ ''' update edge symbol list
573
+
574
+ Parameters
575
+ ----------
576
+ prod_rule : ProductionRule
577
+ '''
578
+ if prod_rule.lhs_nt_symbol not in self.nt_symbol_list:
579
+ self.nt_symbol_list.append(prod_rule.lhs_nt_symbol)
580
+
581
+ for each_edge in prod_rule.rhs.edges:
582
+ if prod_rule.rhs.edge_attr(each_edge)['symbol'] not in self.edge_symbol_dict:
583
+ edge_symbol_idx = len(self.edge_symbol_list)
584
+ self.edge_symbol_list.append(prod_rule.rhs.edge_attr(each_edge)['symbol'])
585
+ self.edge_symbol_dict[prod_rule.rhs.edge_attr(each_edge)['symbol']] = edge_symbol_idx
586
+ else:
587
+ edge_symbol_idx = self.edge_symbol_dict[prod_rule.rhs.edge_attr(each_edge)['symbol']]
588
+ prod_rule.rhs.edge_attr(each_edge)['symbol_idx'] = edge_symbol_idx
589
+ pass
590
+
591
+ def _update_node_symbol_list(self, prod_rule: ProductionRule):
592
+ ''' update node symbol list
593
+
594
+ Parameters
595
+ ----------
596
+ prod_rule : ProductionRule
597
+ '''
598
+ for each_node in prod_rule.rhs.nodes:
599
+ if prod_rule.rhs.node_attr(each_node)['symbol'] not in self.node_symbol_dict:
600
+ node_symbol_idx = len(self.node_symbol_list)
601
+ self.node_symbol_list.append(prod_rule.rhs.node_attr(each_node)['symbol'])
602
+ self.node_symbol_dict[prod_rule.rhs.node_attr(each_node)['symbol']] = node_symbol_idx
603
+ else:
604
+ node_symbol_idx = self.node_symbol_dict[prod_rule.rhs.node_attr(each_node)['symbol']]
605
+ prod_rule.rhs.node_attr(each_node)['symbol_idx'] = node_symbol_idx
606
+
607
+ def _update_ext_id_list(self, prod_rule: ProductionRule):
608
+ for each_node in prod_rule.rhs.nodes:
609
+ if 'ext_id' in prod_rule.rhs.node_attr(each_node):
610
+ if prod_rule.rhs.node_attr(each_node)['ext_id'] not in self.ext_id_list:
611
+ self.ext_id_list.append(prod_rule.rhs.node_attr(each_node)['ext_id'])
612
+
613
+
614
+ class HyperedgeReplacementGrammar(GraphGrammarBase):
615
+ """
616
+ Learn a hyperedge replacement grammar from a set of hypergraphs.
617
+
618
+ Attributes
619
+ ----------
620
+ prod_rule_list : list of ProductionRule
621
+ production rules learned from the input hypergraphs
622
+ """
623
+ def __init__(self,
624
+ tree_decomposition=molecular_tree_decomposition,
625
+ ignore_order=False, **kwargs):
626
+ from functools import partial
627
+ self.prod_rule_corpus = ProductionRuleCorpus()
628
+ self.clique_tree_corpus = CliqueTreeCorpus()
629
+ self.ignore_order = ignore_order
630
+ self.tree_decomposition = partial(tree_decomposition, **kwargs)
631
+
632
+ @property
633
+ def num_prod_rule(self):
634
+ ''' return the number of production rules
635
+
636
+ Returns
637
+ -------
638
+ int : the number of unique production rules
639
+ '''
640
+ return self.prod_rule_corpus.num_prod_rule
641
+
642
+ @property
643
+ def start_rule_list(self):
644
+ ''' return a list of start rules
645
+
646
+ Returns
647
+ -------
648
+ list : list of start rules
649
+ '''
650
+ return self.prod_rule_corpus.start_rule_list
651
+
652
+ @property
653
+ def prod_rule_list(self):
654
+ return self.prod_rule_corpus.prod_rule_list
655
+
656
+ def learn(self, hg_list, logger=print, max_mol=np.inf, print_freq=500):
657
+ """ learn from a list of hypergraphs
658
+
659
+ Parameters
660
+ ----------
661
+ hg_list : list of Hypergraph
662
+
663
+ Returns
664
+ -------
665
+ prod_rule_seq_list : list of integers
666
+ each element corresponds to a sequence of production rules to generate each hypergraph.
667
+ """
668
+ prod_rule_seq_list = []
669
+ idx = 0
670
+ for each_idx, each_hg in enumerate(hg_list):
671
+ clique_tree = self.tree_decomposition(each_hg)
672
+
673
+ # get a pair of myself and children
674
+ root_node = _find_root(clique_tree)
675
+ clique_tree = self.clique_tree_corpus.add_to_subhg_list(clique_tree, root_node)
676
+ prod_rule_seq = []
677
+ stack = []
678
+
679
+ children = sorted(list(clique_tree[root_node].keys()))
680
+
681
+ # extract a temporary production rule
682
+ prod_rule = extract_prod_rule(
683
+ None,
684
+ clique_tree.nodes[root_node]["subhg"],
685
+ [clique_tree.nodes[each_child]["subhg"]
686
+ for each_child in children],
687
+ clique_tree.nodes[root_node].get('subhg_idx', None))
688
+
689
+ # update the production rule list
690
+ prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
691
+ children = reorder_children(root_node,
692
+ children,
693
+ prod_rule,
694
+ clique_tree)
695
+ stack.extend([(root_node, each_child) for each_child in children[::-1]])
696
+ prod_rule_seq.append(prod_rule_id)
697
+
698
+ while len(stack) != 0:
699
+ # get a triple of parent, myself, and children
700
+ parent, myself = stack.pop()
701
+ children = sorted(list(dict(clique_tree[myself]).keys()))
702
+ children.remove(parent)
703
+
704
+ # extract a temp prod rule
705
+ prod_rule = extract_prod_rule(
706
+ clique_tree.nodes[parent]["subhg"],
707
+ clique_tree.nodes[myself]["subhg"],
708
+ [clique_tree.nodes[each_child]["subhg"]
709
+ for each_child in children],
710
+ clique_tree.nodes[myself].get('subhg_idx', None))
711
+
712
+ # update the prod rule list
713
+ prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
714
+ children = reorder_children(myself,
715
+ children,
716
+ prod_rule,
717
+ clique_tree)
718
+ stack.extend([(myself, each_child)
719
+ for each_child in children[::-1]])
720
+ prod_rule_seq.append(prod_rule_id)
721
+ prod_rule_seq_list.append(prod_rule_seq)
722
+ if (each_idx+1) % print_freq == 0:
723
+ msg = f'#(molecules processed)={each_idx+1}\t'\
724
+ f'#(production rules)={self.prod_rule_corpus.num_prod_rule}\t#(subhg in corpus)={self.clique_tree_corpus.size}'
725
+ logger(msg)
726
+ if each_idx > max_mol:
727
+ break
728
+
729
+ print(f'corpus_size = {self.clique_tree_corpus.size}')
730
+ return prod_rule_seq_list
731
+
732
+ def sample(self, z, deterministic=False):
733
+ """ sample a new hypergraph from HRG.
734
+
735
+ Parameters
736
+ ----------
737
+ z : array-like, shape (len, num_prod_rule)
738
+ logit
739
+ deterministic : bool
740
+ if True, deterministic sampling
741
+
742
+ Returns
743
+ -------
744
+ Hypergraph
745
+ """
746
+ seq_idx = 0
747
+ stack = []
748
+ z = z[:, :-1]
749
+ init_prod_rule = self.prod_rule_corpus.sample(z[0], NTSymbol(degree=0,
750
+ is_aromatic=False,
751
+ bond_symbol_list=[]),
752
+ deterministic=deterministic)
753
+ hg, nt_edge_list = init_prod_rule.applied_to(None, None)
754
+ stack = deepcopy(nt_edge_list[::-1])
755
+ while len(stack) != 0 and seq_idx < z.shape[0]-1:
756
+ seq_idx += 1
757
+ nt_edge = stack.pop()
758
+ nt_symbol = hg.edge_attr(nt_edge)['symbol']
759
+ prod_rule = self.prod_rule_corpus.sample(z[seq_idx], nt_symbol, deterministic=deterministic)
760
+ hg, nt_edge_list = prod_rule.applied_to(hg, nt_edge)
761
+ stack.extend(nt_edge_list[::-1])
762
+ if len(stack) != 0:
763
+ raise RuntimeError(f'{len(stack)} non-terminals are left.')
764
+ return hg
765
+
766
+ def construct(self, prod_rule_seq):
767
+ """ construct a hypergraph following `prod_rule_seq`
768
+
769
+ Parameters
770
+ ----------
771
+ prod_rule_seq : list of integers
772
+ a sequence of production rules.
773
+
774
+ Returns
775
+ -------
776
+ UndirectedHypergraph
777
+ """
778
+ seq_idx = 0
779
+ init_prod_rule = self.prod_rule_corpus.get_prod_rule(prod_rule_seq[seq_idx])
780
+ hg, nt_edge_list = init_prod_rule.applied_to(None, None)
781
+ stack = deepcopy(nt_edge_list[::-1])
782
+ while len(stack) != 0:
783
+ seq_idx += 1
784
+ nt_edge = stack.pop()
785
+ hg, nt_edge_list = self.prod_rule_corpus.get_prod_rule(prod_rule_seq[seq_idx]).applied_to(hg, nt_edge)
786
+ stack.extend(nt_edge_list[::-1])
787
+ return hg
788
+
789
+ def update_prod_rule_list(self, prod_rule):
790
+ """ return whether the input production rule is new or not, and its production rule id.
791
+ Production rules are regarded as the same if
792
+ i) there exists a one-to-one mapping of nodes and edges, and
793
+ ii) all the attributes associated with nodes and hyperedges are the same.
794
+
795
+ Parameters
796
+ ----------
797
+ prod_rule : ProductionRule
798
+
799
+ Returns
800
+ -------
801
+ is_new : bool
802
+ if True, this production rule is new
803
+ prod_rule_id : int
804
+ production rule index. if new, a new index will be assigned.
805
+ """
806
+ return self.prod_rule_corpus.append(prod_rule)
807
+
808
+
809
+ class IncrementalHyperedgeReplacementGrammar(HyperedgeReplacementGrammar):
810
+ '''
811
+ This class learns HRG incrementally leveraging the previously obtained production rules.
812
+ '''
813
+ def __init__(self, tree_decomposition=tree_decomposition_with_hrg, ignore_order=False):
814
+ self.prod_rule_list = []
815
+ self.tree_decomposition = tree_decomposition
816
+ self.ignore_order = ignore_order
817
+
818
+ def learn(self, hg_list):
819
+ """ learn from a list of hypergraphs
820
+
821
+ Parameters
822
+ ----------
823
+ hg_list : list of UndirectedHypergraph
824
+
825
+ Returns
826
+ -------
827
+ prod_rule_seq_list : list of integers
828
+ each element corresponds to a sequence of production rules to generate each hypergraph.
829
+ """
830
+ prod_rule_seq_list = []
831
+ for each_hg in hg_list:
832
+ clique_tree, root_node = tree_decomposition_with_hrg(each_hg, self, return_root=True)
833
+
834
+ prod_rule_seq = []
835
+ stack = []
836
+
837
+ # get a pair of myself and children
838
+ children = sorted(list(clique_tree[root_node].keys()))
839
+
840
+ # extract a temporary production rule
841
+ prod_rule = extract_prod_rule(None, clique_tree.nodes[root_node]["subhg"],
842
+ [clique_tree.nodes[each_child]["subhg"] for each_child in children])
843
+
844
+ # update the production rule list
845
+ prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
846
+ children = reorder_children(root_node, children, prod_rule, clique_tree)
847
+ stack.extend([(root_node, each_child) for each_child in children[::-1]])
848
+ prod_rule_seq.append(prod_rule_id)
849
+
850
+ while len(stack) != 0:
851
+ # get a triple of parent, myself, and children
852
+ parent, myself = stack.pop()
853
+ children = sorted(list(dict(clique_tree[myself]).keys()))
854
+ children.remove(parent)
855
+
856
+ # extract a temp prod rule
857
+ prod_rule = extract_prod_rule(
858
+ clique_tree.nodes[parent]["subhg"], clique_tree.nodes[myself]["subhg"],
859
+ [clique_tree.nodes[each_child]["subhg"] for each_child in children])
860
+
861
+ # update the prod rule list
862
+ prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
863
+ children = reorder_children(myself, children, prod_rule, clique_tree)
864
+ stack.extend([(myself, each_child) for each_child in children[::-1]])
865
+ prod_rule_seq.append(prod_rule_id)
866
+ prod_rule_seq_list.append(prod_rule_seq)
867
+ self._compute_stats()
868
+ return prod_rule_seq_list
869
+
870
+
871
+ def reorder_children(myself, children, prod_rule, clique_tree):
872
+ """ reorder children so that they match the order in `prod_rule`.
873
+
874
+ Parameters
875
+ ----------
876
+ myself : int
877
+ children : list of int
878
+ prod_rule : ProductionRule
879
+ clique_tree : nx.Graph
880
+
881
+ Returns
882
+ -------
883
+ new_children : list of str
884
+ reordered children
885
+ """
886
+ perm = {} # key : `nt_idx`, val : child
887
+ for each_edge in prod_rule.rhs.edges:
888
+ if "nt_idx" in prod_rule.rhs.edge_attr(each_edge).keys():
889
+ for each_child in children:
890
+ common_node_set = set(
891
+ common_node_list(clique_tree.nodes[myself]["subhg"],
892
+ clique_tree.nodes[each_child]["subhg"])[0])
893
+ if set(prod_rule.rhs.nodes_in_edge(each_edge)) == common_node_set:
894
+ assert prod_rule.rhs.edge_attr(each_edge)["nt_idx"] not in perm
895
+ perm[prod_rule.rhs.edge_attr(each_edge)["nt_idx"]] = each_child
896
+ new_children = []
897
+ assert len(perm) == len(children)
898
+ for i in range(len(perm)):
899
+ new_children.append(perm[i])
900
+ return new_children
901
+
902
+
903
+ def extract_prod_rule(parent_hg, myself_hg, children_hg_list, subhg_idx=None):
904
+ """ extract a production rule from a triple of `parent_hg`, `myself_hg`, and `children_hg_list`.
905
+
906
+ Parameters
907
+ ----------
908
+ parent_hg : Hypergraph
909
+ myself_hg : Hypergraph
910
+ children_hg_list : list of Hypergraph
911
+
912
+ Returns
913
+ -------
914
+ ProductionRule, consisting of
915
+ lhs : Hypergraph or None
916
+ rhs : Hypergraph
917
+ """
918
+ def _add_ext_node(hg, ext_nodes):
919
+ """ mark nodes to be external (ordered ids are assigned)
920
+
921
+ Parameters
922
+ ----------
923
+ hg : UndirectedHypergraph
924
+ ext_nodes : list of str
925
+ list of external nodes
926
+
927
+ Returns
928
+ -------
929
+ hg : Hypergraph
930
+ nodes in `ext_nodes` are marked to be external
931
+ """
932
+ ext_id = 0
933
+ ext_id_exists = []
934
+ for each_node in ext_nodes:
935
+ ext_id_exists.append('ext_id' in hg.node_attr(each_node))
936
+ if ext_id_exists and any(ext_id_exists) != all(ext_id_exists):
937
+ raise ValueError
938
+ if not all(ext_id_exists):
939
+ for each_node in ext_nodes:
940
+ hg.node_attr(each_node)['ext_id'] = ext_id
941
+ ext_id += 1
942
+ return hg
943
+
944
+ def _check_aromatic(hg, node_list):
945
+ is_aromatic = False
946
+ node_aromatic_list = []
947
+ for each_node in node_list:
948
+ if hg.node_attr(each_node)['symbol'].is_aromatic:
949
+ is_aromatic = True
950
+ node_aromatic_list.append(True)
951
+ else:
952
+ node_aromatic_list.append(False)
953
+ return is_aromatic, node_aromatic_list
954
+
955
+ def _check_ring(hg):
956
+ for each_edge in hg.edges:
957
+ if not ('tmp' in hg.edge_attr(each_edge) or (not hg.edge_attr(each_edge)['terminal'])):
958
+ return False
959
+ return True
960
+
961
+ if parent_hg is None:
962
+ lhs = Hypergraph()
963
+ node_list = []
964
+ else:
965
+ lhs = Hypergraph()
966
+ node_list, edge_exists = common_node_list(parent_hg, myself_hg)
967
+ for each_node in node_list:
968
+ lhs.add_node(each_node,
969
+ deepcopy(myself_hg.node_attr(each_node)))
970
+ is_aromatic, _ = _check_aromatic(parent_hg, node_list)
971
+ for_ring = _check_ring(myself_hg)
972
+ bond_symbol_list = []
973
+ for each_node in node_list:
974
+ bond_symbol_list.append(parent_hg.node_attr(each_node)['symbol'])
975
+ lhs.add_edge(
976
+ node_list,
977
+ attr_dict=dict(
978
+ terminal=False,
979
+ edge_exists=edge_exists,
980
+ symbol=NTSymbol(
981
+ degree=len(node_list),
982
+ is_aromatic=is_aromatic,
983
+ bond_symbol_list=bond_symbol_list,
984
+ for_ring=for_ring)))
985
+ try:
986
+ lhs = _add_ext_node(lhs, node_list)
987
+ except ValueError:
988
+ import pdb; pdb.set_trace()
989
+
990
+ rhs = remove_tmp_edge(deepcopy(myself_hg))
991
+ #rhs = remove_ext_node(rhs)
992
+ #rhs = remove_nt_edge(rhs)
993
+ try:
994
+ rhs = _add_ext_node(rhs, node_list)
995
+ except ValueError:
996
+ import pdb; pdb.set_trace()
997
+
998
+ nt_idx = 0
999
+ if children_hg_list is not None:
1000
+ for each_child_hg in children_hg_list:
1001
+ node_list, edge_exists = common_node_list(myself_hg, each_child_hg)
1002
+ is_aromatic, _ = _check_aromatic(myself_hg, node_list)
1003
+ for_ring = _check_ring(each_child_hg)
1004
+ bond_symbol_list = []
1005
+ for each_node in node_list:
1006
+ bond_symbol_list.append(myself_hg.node_attr(each_node)['symbol'])
1007
+ rhs.add_edge(
1008
+ node_list,
1009
+ attr_dict=dict(
1010
+ terminal=False,
1011
+ nt_idx=nt_idx,
1012
+ edge_exists=edge_exists,
1013
+ symbol=NTSymbol(degree=len(node_list),
1014
+ is_aromatic=is_aromatic,
1015
+ bond_symbol_list=bond_symbol_list,
1016
+ for_ring=for_ring)))
1017
+ nt_idx += 1
1018
+ prod_rule = ProductionRule(lhs, rhs)
1019
+ prod_rule.subhg_idx = subhg_idx
1020
+ if DEBUG:
1021
+ if sorted(list(prod_rule.ext_node.keys())) \
1022
+ != list(np.arange(len(prod_rule.ext_node))):
1023
+ raise RuntimeError('ext_id is not continuous')
1024
+ return prod_rule
1025
+
1026
+
1027
+ def _find_root(clique_tree):
1028
+ max_node = None
1029
+ num_nodes_max = -np.inf
1030
+ for each_node in clique_tree.nodes:
1031
+ if clique_tree.nodes[each_node]['subhg'].num_nodes > num_nodes_max:
1032
+ max_node = each_node
1033
+ num_nodes_max = clique_tree.nodes[each_node]['subhg'].num_nodes
1034
+ '''
1035
+ children = sorted(list(clique_tree[each_node].keys()))
1036
+ prod_rule = extract_prod_rule(None,
1037
+ clique_tree.nodes[each_node]["subhg"],
1038
+ [clique_tree.nodes[each_child]["subhg"]
1039
+ for each_child in children])
1040
+ for each_start_rule in start_rule_list:
1041
+ if prod_rule.is_same(each_start_rule):
1042
+ return each_node
1043
+ '''
1044
+ return max_node
1045
+
1046
+ def remove_ext_node(hg):
1047
+ for each_node in hg.nodes:
1048
+ hg.node_attr(each_node).pop('ext_id', None)
1049
+ return hg
1050
+
1051
+ def remove_nt_edge(hg):
1052
+ remove_edge_list = []
1053
+ for each_edge in hg.edges:
1054
+ if not hg.edge_attr(each_edge)['terminal']:
1055
+ remove_edge_list.append(each_edge)
1056
+ hg.remove_edges(remove_edge_list)
1057
+ return hg
1058
+
1059
+ def remove_tmp_edge(hg):
1060
+ remove_edge_list = []
1061
+ for each_edge in hg.edges:
1062
+ if hg.edge_attr(each_edge).get('tmp', False):
1063
+ remove_edge_list.append(each_edge)
1064
+ hg.remove_edges(remove_edge_list)
1065
+ return hg
models/mhg_model/graph_grammar/graph_grammar/symbols.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+
15
+ """ Title """
16
+
17
+ __author__ = "Hiroshi Kajino <[email protected]>"
18
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
19
+ __version__ = "0.1"
20
+ __date__ = "Jan 1 2018"
21
+
22
+ from typing import List
23
+
24
+ class TSymbol(object):
25
+
26
+ ''' terminal symbol
27
+
28
+ Attributes
29
+ ----------
30
+ degree : int
31
+ the number of nodes in a hyperedge
32
+ is_aromatic : bool
33
+ whether or not the hyperedge is in an aromatic ring
34
+ symbol : str
35
+ atomic symbol
36
+ num_explicit_Hs : int
37
+ the number of hydrogens associated to this hyperedge
38
+ formal_charge : int
39
+ charge
40
+ chirality : int
41
+ chirality
42
+ '''
43
+
44
+ def __init__(self, degree, is_aromatic,
45
+ symbol, num_explicit_Hs, formal_charge, chirality):
46
+ self.degree = degree
47
+ self.is_aromatic = is_aromatic
48
+ self.symbol = symbol
49
+ self.num_explicit_Hs = num_explicit_Hs
50
+ self.formal_charge = formal_charge
51
+ self.chirality = chirality
52
+
53
+ @property
54
+ def terminal(self):
55
+ return True
56
+
57
+ def __eq__(self, other):
58
+ if not isinstance(other, TSymbol):
59
+ return False
60
+ if self.degree != other.degree:
61
+ return False
62
+ if self.is_aromatic != other.is_aromatic:
63
+ return False
64
+ if self.symbol != other.symbol:
65
+ return False
66
+ if self.num_explicit_Hs != other.num_explicit_Hs:
67
+ return False
68
+ if self.formal_charge != other.formal_charge:
69
+ return False
70
+ if self.chirality != other.chirality:
71
+ return False
72
+ return True
73
+
74
+ def __hash__(self):
75
+ return self.__str__().__hash__()
76
+
77
+ def __str__(self):
78
+ return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\
79
+ f'symbol={self.symbol}, '\
80
+ f'num_explicit_Hs={self.num_explicit_Hs}, '\
81
+ f'formal_charge={self.formal_charge}, chirality={self.chirality}'
82
+
83
+
84
+ class NTSymbol(object):
85
+
86
+ ''' non-terminal symbol
87
+
88
+ Attributes
89
+ ----------
90
+ degree : int
91
+ degree of the hyperedge
92
+ is_aromatic : bool
93
+ if True, at least one of the associated bonds must be aromatic.
94
+ node_aromatic_list : list of bool
95
+ indicate whether each of the nodes is aromatic or not.
96
+ bond_type_list : list of int
97
+ bond type of each node"
98
+ '''
99
+
100
+ def __init__(self, degree: int, is_aromatic: bool,
101
+ bond_symbol_list: list,
102
+ for_ring=False):
103
+ self.degree = degree
104
+ self.is_aromatic = is_aromatic
105
+ self.for_ring = for_ring
106
+ self.bond_symbol_list = bond_symbol_list
107
+
108
+ @property
109
+ def terminal(self) -> bool:
110
+ return False
111
+
112
+ @property
113
+ def symbol(self):
114
+ return f'NT{self.degree}'
115
+
116
+ def __eq__(self, other) -> bool:
117
+ if not isinstance(other, NTSymbol):
118
+ return False
119
+
120
+ if self.degree != other.degree:
121
+ return False
122
+ if self.is_aromatic != other.is_aromatic:
123
+ return False
124
+ if self.for_ring != other.for_ring:
125
+ return False
126
+ if len(self.bond_symbol_list) != len(other.bond_symbol_list):
127
+ return False
128
+ for each_idx in range(len(self.bond_symbol_list)):
129
+ if self.bond_symbol_list[each_idx] != other.bond_symbol_list[each_idx]:
130
+ return False
131
+ return True
132
+
133
+ def __hash__(self):
134
+ return self.__str__().__hash__()
135
+
136
+ def __str__(self) -> str:
137
+ return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\
138
+ f'bond_symbol_list={[str(each_symbol) for each_symbol in self.bond_symbol_list]}'\
139
+ f'for_ring={self.for_ring}'
140
+
141
+
142
+ class BondSymbol(object):
143
+
144
+
145
+ ''' Bond symbol
146
+
147
+ Attributes
148
+ ----------
149
+ is_aromatic : bool
150
+ if True, at least one of the associated bonds must be aromatic.
151
+ bond_type : int
152
+ bond type of each node"
153
+ '''
154
+
155
+ def __init__(self, is_aromatic: bool,
156
+ bond_type: int,
157
+ stereo: int):
158
+ self.is_aromatic = is_aromatic
159
+ self.bond_type = bond_type
160
+ self.stereo = stereo
161
+
162
+ def __eq__(self, other) -> bool:
163
+ if not isinstance(other, BondSymbol):
164
+ return False
165
+
166
+ if self.is_aromatic != other.is_aromatic:
167
+ return False
168
+ if self.bond_type != other.bond_type:
169
+ return False
170
+ if self.stereo != other.stereo:
171
+ return False
172
+ return True
173
+
174
+ def __hash__(self):
175
+ return self.__str__().__hash__()
176
+
177
+ def __str__(self) -> str:
178
+ return f'is_aromatic={self.is_aromatic}, '\
179
+ f'bond_type={self.bond_type}, '\
180
+ f'stereo={self.stereo}, '
models/mhg_model/graph_grammar/graph_grammar/utils.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jun 4 2018"
20
+
21
+ from ..hypergraph import Hypergraph
22
+ from copy import deepcopy
23
+ from typing import List
24
+ import numpy as np
25
+
26
+
27
+ def common_node_list(hg1: Hypergraph, hg2: Hypergraph) -> List[str]:
28
+ """ return a list of common nodes
29
+
30
+ Parameters
31
+ ----------
32
+ hg1, hg2 : Hypergraph
33
+
34
+ Returns
35
+ -------
36
+ list of str
37
+ list of common nodes
38
+ """
39
+ if hg1 is None or hg2 is None:
40
+ return [], False
41
+ else:
42
+ node_set = hg1.nodes.intersection(hg2.nodes)
43
+ node_dict = {}
44
+ if 'order4hrg' in hg1.node_attr(list(hg1.nodes)[0]):
45
+ for each_node in node_set:
46
+ node_dict[each_node] = hg1.node_attr(each_node)['order4hrg']
47
+ else:
48
+ for each_node in node_set:
49
+ node_dict[each_node] = hg1.node_attr(each_node)['symbol'].__hash__()
50
+ node_list = []
51
+ for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
52
+ node_list.append(each_key)
53
+ edge_name = hg1.has_edge(node_list, ignore_order=True)
54
+ if edge_name:
55
+ if not hg1.edge_attr(edge_name).get('terminal', True):
56
+ node_list = hg1.nodes_in_edge(edge_name)
57
+ return node_list, True
58
+ else:
59
+ return node_list, False
60
+
61
+
62
+ def _node_match(node1, node2):
63
+ # if the nodes are hyperedges, `atom_attr` determines the match
64
+ if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
65
+ return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol']
66
+ elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
67
+ # bond_symbol
68
+ return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
69
+ else:
70
+ return False
71
+
72
+ def _easy_node_match(node1, node2):
73
+ # if the nodes are hyperedges, `atom_attr` determines the match
74
+ if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
75
+ return node1["attr_dict"].get('symbol', None) == node2["attr_dict"].get('symbol', None)
76
+ elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
77
+ # bond_symbol
78
+ return node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1)\
79
+ and node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
80
+ else:
81
+ return False
82
+
83
+
84
+ def _node_match_prod_rule(node1, node2, ignore_order=False):
85
+ # if the nodes are hyperedges, `atom_attr` determines the match
86
+ if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
87
+ return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol']
88
+ elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
89
+ # ext_id, order4hrg, bond_symbol
90
+ if ignore_order:
91
+ return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
92
+ else:
93
+ return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']\
94
+ and node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1)
95
+ else:
96
+ return False
97
+
98
+
99
+ def _edge_match(edge1, edge2, ignore_order=False):
100
+ #return True
101
+ if ignore_order:
102
+ return True
103
+ else:
104
+ return edge1["order"] == edge2["order"]
105
+
106
+ def masked_softmax(logit, mask):
107
+ ''' compute a probability distribution from logit
108
+
109
+ Parameters
110
+ ----------
111
+ logit : array-like, length D
112
+ each element indicates how each dimension is likely to be chosen
113
+ (the larger, the more likely)
114
+ mask : array-like, length D
115
+ each element is either 0 or 1.
116
+ if 0, the dimension is ignored
117
+ when computing the probability distribution.
118
+
119
+ Returns
120
+ -------
121
+ prob_dist : array, length D
122
+ probability distribution computed from logit.
123
+ if `mask[d] = 0`, `prob_dist[d] = 0`.
124
+ '''
125
+ if logit.shape != mask.shape:
126
+ raise ValueError('logit and mask must have the same shape')
127
+ c = np.max(logit)
128
+ exp_logit = np.exp(logit - c) * mask
129
+ sum_exp_logit = exp_logit @ mask
130
+ return exp_logit / sum_exp_logit
models/mhg_model/graph_grammar/hypergraph.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 31 2018"
20
+
21
+ from copy import deepcopy
22
+ from typing import List, Dict, Tuple
23
+ import networkx as nx
24
+ import numpy as np
25
+ import os
26
+
27
+
28
+ class Hypergraph(object):
29
+ '''
30
+ A class of a hypergraph.
31
+ Each hyperedge can be ordered. For the ordered case,
32
+ edges adjacent to the hyperedge node are labeled by their orders.
33
+
34
+ Attributes
35
+ ----------
36
+ hg : nx.Graph
37
+ a bipartite graph representation of a hypergraph
38
+ edge_idx : int
39
+ total number of hyperedges that exist so far
40
+ '''
41
+ def __init__(self):
42
+ self.hg = nx.Graph()
43
+ self.edge_idx = 0
44
+ self.nodes = set([])
45
+ self.num_nodes = 0
46
+ self.edges = set([])
47
+ self.num_edges = 0
48
+ self.nodes_in_edge_dict = {}
49
+
50
+ def add_node(self, node: str, attr_dict=None):
51
+ ''' add a node to hypergraph
52
+
53
+ Parameters
54
+ ----------
55
+ node : str
56
+ node name
57
+ attr_dict : dict
58
+ dictionary of node attributes
59
+ '''
60
+ self.hg.add_node(node, bipartite='node', attr_dict=attr_dict)
61
+ if node not in self.nodes:
62
+ self.num_nodes += 1
63
+ self.nodes.add(node)
64
+
65
+ def add_edge(self, node_list: List[str], attr_dict=None, edge_name=None):
66
+ ''' add an edge consisting of nodes `node_list`
67
+
68
+ Parameters
69
+ ----------
70
+ node_list : list
71
+ ordered list of nodes that consist the edge
72
+ attr_dict : dict
73
+ dictionary of edge attributes
74
+ '''
75
+ if edge_name is None:
76
+ edge = 'e{}'.format(self.edge_idx)
77
+ else:
78
+ assert edge_name not in self.edges
79
+ edge = edge_name
80
+ self.hg.add_node(edge, bipartite='edge', attr_dict=attr_dict)
81
+ if edge not in self.edges:
82
+ self.num_edges += 1
83
+ self.edges.add(edge)
84
+ self.nodes_in_edge_dict[edge] = node_list
85
+ if type(node_list) == list:
86
+ for node_idx, each_node in enumerate(node_list):
87
+ self.hg.add_edge(edge, each_node, order=node_idx)
88
+ if each_node not in self.nodes:
89
+ self.num_nodes += 1
90
+ self.nodes.add(each_node)
91
+
92
+ elif type(node_list) == set:
93
+ for each_node in node_list:
94
+ self.hg.add_edge(edge, each_node, order=-1)
95
+ if each_node not in self.nodes:
96
+ self.num_nodes += 1
97
+ self.nodes.add(each_node)
98
+ else:
99
+ raise ValueError
100
+ self.edge_idx += 1
101
+ return edge
102
+
103
+ def remove_node(self, node: str, remove_connected_edges=True):
104
+ ''' remove a node
105
+
106
+ Parameters
107
+ ----------
108
+ node : str
109
+ node name
110
+ remove_connected_edges : bool
111
+ if True, remove edges that are adjacent to the node
112
+ '''
113
+ if remove_connected_edges:
114
+ connected_edges = deepcopy(self.adj_edges(node))
115
+ for each_edge in connected_edges:
116
+ self.remove_edge(each_edge)
117
+ self.hg.remove_node(node)
118
+ self.num_nodes -= 1
119
+ self.nodes.remove(node)
120
+
121
+ def remove_nodes(self, node_iter, remove_connected_edges=True):
122
+ ''' remove a set of nodes
123
+
124
+ Parameters
125
+ ----------
126
+ node_iter : iterator of strings
127
+ nodes to be removed
128
+ remove_connected_edges : bool
129
+ if True, remove edges that are adjacent to the node
130
+ '''
131
+ for each_node in node_iter:
132
+ self.remove_node(each_node, remove_connected_edges)
133
+
134
+ def remove_edge(self, edge: str):
135
+ ''' remove an edge
136
+
137
+ Parameters
138
+ ----------
139
+ edge : str
140
+ edge to be removed
141
+ '''
142
+ self.hg.remove_node(edge)
143
+ self.edges.remove(edge)
144
+ self.num_edges -= 1
145
+ self.nodes_in_edge_dict.pop(edge)
146
+
147
+ def remove_edges(self, edge_iter):
148
+ ''' remove a set of edges
149
+
150
+ Parameters
151
+ ----------
152
+ edge_iter : iterator of strings
153
+ edges to be removed
154
+ '''
155
+ for each_edge in edge_iter:
156
+ self.remove_edge(each_edge)
157
+
158
+ def remove_edges_with_attr(self, edge_attr_dict):
159
+ remove_edge_list = []
160
+ for each_edge in self.edges:
161
+ satisfy = True
162
+ for each_key, each_val in edge_attr_dict.items():
163
+ if not satisfy:
164
+ break
165
+ try:
166
+ if self.edge_attr(each_edge)[each_key] != each_val:
167
+ satisfy = False
168
+ except KeyError:
169
+ satisfy = False
170
+ if satisfy:
171
+ remove_edge_list.append(each_edge)
172
+ self.remove_edges(remove_edge_list)
173
+
174
+ def remove_subhg(self, subhg):
175
+ ''' remove subhypergraph.
176
+ all of the hyperedges are removed.
177
+ each node of subhg is removed if its degree becomes 0 after removing hyperedges.
178
+
179
+ Parameters
180
+ ----------
181
+ subhg : Hypergraph
182
+ '''
183
+ for each_edge in subhg.edges:
184
+ self.remove_edge(each_edge)
185
+ for each_node in subhg.nodes:
186
+ if self.degree(each_node) == 0:
187
+ self.remove_node(each_node)
188
+
189
+ def nodes_in_edge(self, edge):
190
+ ''' return an ordered list of nodes in a given edge.
191
+
192
+ Parameters
193
+ ----------
194
+ edge : str
195
+ edge whose nodes are returned
196
+
197
+ Returns
198
+ -------
199
+ list or set
200
+ ordered list or set of nodes that belong to the edge
201
+ '''
202
+ if edge.startswith('e'):
203
+ return self.nodes_in_edge_dict[edge]
204
+ else:
205
+ adj_node_list = self.hg.adj[edge]
206
+ adj_node_order_list = []
207
+ adj_node_name_list = []
208
+ for each_node in adj_node_list:
209
+ adj_node_order_list.append(adj_node_list[each_node]['order'])
210
+ adj_node_name_list.append(each_node)
211
+ if adj_node_order_list == [-1] * len(adj_node_order_list):
212
+ return set(adj_node_name_list)
213
+ else:
214
+ return [adj_node_name_list[each_idx] for each_idx
215
+ in np.argsort(adj_node_order_list)]
216
+
217
+ def adj_edges(self, node):
218
+ ''' return a dict of adjacent hyperedges
219
+
220
+ Parameters
221
+ ----------
222
+ node : str
223
+
224
+ Returns
225
+ -------
226
+ set
227
+ set of edges that are adjacent to `node`
228
+ '''
229
+ return self.hg.adj[node]
230
+
231
+ def adj_nodes(self, node):
232
+ ''' return a set of adjacent nodes
233
+
234
+ Parameters
235
+ ----------
236
+ node : str
237
+
238
+ Returns
239
+ -------
240
+ set
241
+ set of nodes that are adjacent to `node`
242
+ '''
243
+ node_set = set([])
244
+ for each_adj_edge in self.adj_edges(node):
245
+ node_set.update(set(self.nodes_in_edge(each_adj_edge)))
246
+ node_set.discard(node)
247
+ return node_set
248
+
249
+ def has_edge(self, node_list, ignore_order=False):
250
+ for each_edge in self.edges:
251
+ if ignore_order:
252
+ if set(self.nodes_in_edge(each_edge)) == set(node_list):
253
+ return each_edge
254
+ else:
255
+ if self.nodes_in_edge(each_edge) == node_list:
256
+ return each_edge
257
+ return False
258
+
259
+ def degree(self, node):
260
+ return len(self.hg.adj[node])
261
+
262
+ def degrees(self):
263
+ return {each_node: self.degree(each_node) for each_node in self.nodes}
264
+
265
+ def edge_degree(self, edge):
266
+ return len(self.nodes_in_edge(edge))
267
+
268
+ def edge_degrees(self):
269
+ return {each_edge: self.edge_degree(each_edge) for each_edge in self.edges}
270
+
271
+ def is_adj(self, node1, node2):
272
+ return node1 in self.adj_nodes(node2)
273
+
274
+ def adj_subhg(self, node, ident_node_dict=None):
275
+ """ return a subhypergraph consisting of a set of nodes and hyperedges adjacent to `node`.
276
+ if an adjacent node has a self-loop hyperedge, it will be also added to the subhypergraph.
277
+
278
+ Parameters
279
+ ----------
280
+ node : str
281
+ ident_node_dict : dict
282
+ dict containing identical nodes. see `get_identical_node_dict` for more details
283
+
284
+ Returns
285
+ -------
286
+ subhg : Hypergraph
287
+ """
288
+ if ident_node_dict is None:
289
+ ident_node_dict = self.get_identical_node_dict()
290
+ adj_node_set = set(ident_node_dict[node])
291
+ adj_edge_set = set([])
292
+ for each_node in ident_node_dict[node]:
293
+ adj_edge_set.update(set(self.adj_edges(each_node)))
294
+ fixed_adj_edge_set = deepcopy(adj_edge_set)
295
+ for each_edge in fixed_adj_edge_set:
296
+ other_nodes = self.nodes_in_edge(each_edge)
297
+ adj_node_set.update(other_nodes)
298
+
299
+ # if the adjacent node has self-loop edge, it will be appended to adj_edge_list.
300
+ for each_node in other_nodes:
301
+ for other_edge in set(self.adj_edges(each_node)) - set([each_edge]):
302
+ if len(set(self.nodes_in_edge(other_edge)) \
303
+ - set(self.nodes_in_edge(each_edge))) == 0:
304
+ adj_edge_set.update(set([other_edge]))
305
+ subhg = Hypergraph()
306
+ for each_node in adj_node_set:
307
+ subhg.add_node(each_node, attr_dict=self.node_attr(each_node))
308
+ for each_edge in adj_edge_set:
309
+ subhg.add_edge(self.nodes_in_edge(each_edge),
310
+ attr_dict=self.edge_attr(each_edge),
311
+ edge_name=each_edge)
312
+ subhg.edge_idx = self.edge_idx
313
+ return subhg
314
+
315
+ def get_subhg(self, node_list, edge_list, ident_node_dict=None):
316
+ """ return a subhypergraph consisting of a set of nodes and hyperedges adjacent to `node`.
317
+ if an adjacent node has a self-loop hyperedge, it will be also added to the subhypergraph.
318
+
319
+ Parameters
320
+ ----------
321
+ node : str
322
+ ident_node_dict : dict
323
+ dict containing identical nodes. see `get_identical_node_dict` for more details
324
+
325
+ Returns
326
+ -------
327
+ subhg : Hypergraph
328
+ """
329
+ if ident_node_dict is None:
330
+ ident_node_dict = self.get_identical_node_dict()
331
+ adj_node_set = set([])
332
+ for each_node in node_list:
333
+ adj_node_set.update(set(ident_node_dict[each_node]))
334
+ adj_edge_set = set(edge_list)
335
+
336
+ subhg = Hypergraph()
337
+ for each_node in adj_node_set:
338
+ subhg.add_node(each_node,
339
+ attr_dict=deepcopy(self.node_attr(each_node)))
340
+ for each_edge in adj_edge_set:
341
+ subhg.add_edge(self.nodes_in_edge(each_edge),
342
+ attr_dict=deepcopy(self.edge_attr(each_edge)),
343
+ edge_name=each_edge)
344
+ subhg.edge_idx = self.edge_idx
345
+ return subhg
346
+
347
+ def copy(self):
348
+ ''' return a copy of the object
349
+
350
+ Returns
351
+ -------
352
+ Hypergraph
353
+ '''
354
+ return deepcopy(self)
355
+
356
+ def node_attr(self, node):
357
+ return self.hg.nodes[node]['attr_dict']
358
+
359
+ def edge_attr(self, edge):
360
+ return self.hg.nodes[edge]['attr_dict']
361
+
362
+ def set_node_attr(self, node, attr_dict):
363
+ for each_key, each_val in attr_dict.items():
364
+ self.hg.nodes[node]['attr_dict'][each_key] = each_val
365
+
366
+ def set_edge_attr(self, edge, attr_dict):
367
+ for each_key, each_val in attr_dict.items():
368
+ self.hg.nodes[edge]['attr_dict'][each_key] = each_val
369
+
370
+ def get_identical_node_dict(self):
371
+ ''' get identical nodes
372
+ nodes are identical if they share the same set of adjacent edges.
373
+
374
+ Returns
375
+ -------
376
+ ident_node_dict : dict
377
+ ident_node_dict[node] returns a list of nodes that are identical to `node`.
378
+ '''
379
+ ident_node_dict = {}
380
+ for each_node in self.nodes:
381
+ ident_node_list = []
382
+ for each_other_node in self.nodes:
383
+ if each_other_node == each_node:
384
+ ident_node_list.append(each_other_node)
385
+ elif self.adj_edges(each_node) == self.adj_edges(each_other_node) \
386
+ and len(self.adj_edges(each_node)) != 0:
387
+ ident_node_list.append(each_other_node)
388
+ ident_node_dict[each_node] = ident_node_list
389
+ return ident_node_dict
390
+ '''
391
+ ident_node_dict = {}
392
+ for each_node in self.nodes:
393
+ ident_node_dict[each_node] = [each_node]
394
+ return ident_node_dict
395
+ '''
396
+
397
+ def get_leaf_edge(self):
398
+ ''' get an edge that is incident only to one edge
399
+
400
+ Returns
401
+ -------
402
+ if exists, return a leaf edge. otherwise, return None.
403
+ '''
404
+ for each_edge in self.edges:
405
+ if len(self.adj_nodes(each_edge)) == 1:
406
+ if 'tmp' not in self.edge_attr(each_edge):
407
+ return each_edge
408
+ return None
409
+
410
+ def get_nontmp_edge(self):
411
+ for each_edge in self.edges:
412
+ if 'tmp' not in self.edge_attr(each_edge):
413
+ return each_edge
414
+ return None
415
+
416
+ def is_subhg(self, hg):
417
+ ''' return whether this hypergraph is a subhypergraph of `hg`
418
+
419
+ Returns
420
+ -------
421
+ True if self \in hg,
422
+ False otherwise.
423
+ '''
424
+ for each_node in self.nodes:
425
+ if each_node not in hg.nodes:
426
+ return False
427
+ for each_edge in self.edges:
428
+ if each_edge not in hg.edges:
429
+ return False
430
+ return True
431
+
432
+ def in_cycle(self, node, visited=None, parent='', root_node='') -> bool:
433
+ ''' if `node` is in a cycle, then return True. otherwise, False.
434
+
435
+ Parameters
436
+ ----------
437
+ node : str
438
+ node in a hypergraph
439
+ visited : list
440
+ list of visited nodes, used for recursion
441
+ parent : str
442
+ parent node, used to eliminate a cycle consisting of two nodes and one edge.
443
+
444
+ Returns
445
+ -------
446
+ bool
447
+ '''
448
+ if visited is None:
449
+ visited = []
450
+ if parent == '':
451
+ visited = []
452
+ if root_node == '':
453
+ root_node = node
454
+ visited.append(node)
455
+ for each_adj_node in self.adj_nodes(node):
456
+ if each_adj_node not in visited:
457
+ if self.in_cycle(each_adj_node, visited, node, root_node):
458
+ return True
459
+ elif each_adj_node != parent and each_adj_node == root_node:
460
+ return True
461
+ return False
462
+
463
+
464
+ def draw(self, file_path=None, with_node=False, with_edge_name=False):
465
+ ''' draw hypergraph
466
+ '''
467
+ import graphviz
468
+ G = graphviz.Graph(format='png')
469
+ for each_node in self.nodes:
470
+ if 'ext_id' in self.node_attr(each_node):
471
+ G.node(each_node, label='',
472
+ shape='circle', width='0.1', height='0.1', style='filled',
473
+ fillcolor='black')
474
+ else:
475
+ if with_node:
476
+ G.node(each_node, label='',
477
+ shape='circle', width='0.1', height='0.1', style='filled',
478
+ fillcolor='gray')
479
+ edge_list = []
480
+ for each_edge in self.edges:
481
+ if self.edge_attr(each_edge).get('terminal', False):
482
+ G.node(each_edge,
483
+ label=self.edge_attr(each_edge)['symbol'].symbol if not with_edge_name \
484
+ else self.edge_attr(each_edge)['symbol'].symbol + ', ' + each_edge,
485
+ fontcolor='black', shape='square')
486
+ elif self.edge_attr(each_edge).get('tmp', False):
487
+ G.node(each_edge, label='tmp' if not with_edge_name else 'tmp, ' + each_edge,
488
+ fontcolor='black', shape='square')
489
+ else:
490
+ G.node(each_edge,
491
+ label=self.edge_attr(each_edge)['symbol'].symbol if not with_edge_name \
492
+ else self.edge_attr(each_edge)['symbol'].symbol + ', ' + each_edge,
493
+ fontcolor='black', shape='square', style='filled')
494
+ if with_node:
495
+ for each_node in self.nodes_in_edge(each_edge):
496
+ G.edge(each_edge, each_node)
497
+ else:
498
+ for each_node in self.nodes_in_edge(each_edge):
499
+ if 'ext_id' in self.node_attr(each_node)\
500
+ and set([each_node, each_edge]) not in edge_list:
501
+ G.edge(each_edge, each_node)
502
+ edge_list.append(set([each_node, each_edge]))
503
+ for each_other_edge in self.adj_nodes(each_edge):
504
+ if set([each_edge, each_other_edge]) not in edge_list:
505
+ num_bond = 0
506
+ common_node_set = set(self.nodes_in_edge(each_edge))\
507
+ .intersection(set(self.nodes_in_edge(each_other_edge)))
508
+ for each_node in common_node_set:
509
+ if self.node_attr(each_node)['symbol'].bond_type in [1, 2, 3]:
510
+ num_bond += self.node_attr(each_node)['symbol'].bond_type
511
+ elif self.node_attr(each_node)['symbol'].bond_type in [12]:
512
+ num_bond += 1
513
+ else:
514
+ raise NotImplementedError('unsupported bond type')
515
+ for _ in range(num_bond):
516
+ G.edge(each_edge, each_other_edge)
517
+ edge_list.append(set([each_edge, each_other_edge]))
518
+ if file_path is not None:
519
+ G.render(file_path, cleanup=True)
520
+ #os.remove(file_path)
521
+ return G
522
+
523
+ def is_dividable(self, node):
524
+ _hg = deepcopy(self.hg)
525
+ _hg.remove_node(node)
526
+ return (not nx.is_connected(_hg))
527
+
528
+ def divide(self, node):
529
+ subhg_list = []
530
+
531
+ hg_wo_node = deepcopy(self)
532
+ hg_wo_node.remove_node(node, remove_connected_edges=False)
533
+ connected_components = nx.connected_components(hg_wo_node.hg)
534
+ for each_component in connected_components:
535
+ node_list = [node]
536
+ edge_list = []
537
+ node_list.extend([each_node for each_node in each_component
538
+ if each_node.startswith('bond_')])
539
+ edge_list.extend([each_edge for each_edge in each_component
540
+ if each_edge.startswith('e')])
541
+ subhg_list.append(self.get_subhg(node_list, edge_list))
542
+ #subhg_list[-1].set_node_attr(node, {'divided': True})
543
+ return subhg_list
544
+
models/mhg_model/graph_grammar/io/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 1 2018"
20
+
models/mhg_model/graph_grammar/io/smi.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 12 2018"
20
+
21
+ from copy import deepcopy
22
+ from rdkit import Chem
23
+ from rdkit import RDLogger
24
+ import networkx as nx
25
+ import numpy as np
26
+ from ..hypergraph import Hypergraph
27
+ from ..graph_grammar.symbols import TSymbol, BondSymbol
28
+
29
+ # supress warnings
30
+ lg = RDLogger.logger()
31
+ lg.setLevel(RDLogger.CRITICAL)
32
+
33
+
34
+ class HGGen(object):
35
+ """
36
+ load .smi file and yield a hypergraph.
37
+
38
+ Attributes
39
+ ----------
40
+ path_to_file : str
41
+ path to .smi file
42
+ kekulize : bool
43
+ kekulize or not
44
+ add_Hs : bool
45
+ add implicit hydrogens to the molecule or not.
46
+ all_single : bool
47
+ if True, all multiple bonds are summarized into a single bond with some attributes
48
+
49
+ Yields
50
+ ------
51
+ Hypergraph
52
+ """
53
+ def __init__(self, path_to_file, kekulize=True, add_Hs=False, all_single=True):
54
+ self.num_line = 1
55
+ self.mol_gen = Chem.SmilesMolSupplier(path_to_file, titleLine=False)
56
+ self.kekulize = kekulize
57
+ self.add_Hs = add_Hs
58
+ self.all_single = all_single
59
+
60
+ def __iter__(self):
61
+ return self
62
+
63
+ def __next__(self):
64
+ '''
65
+ each_mol = None
66
+ while each_mol is None:
67
+ each_mol = next(self.mol_gen)
68
+ '''
69
+ # not ignoring parse errors
70
+ each_mol = next(self.mol_gen)
71
+ if each_mol is None:
72
+ raise ValueError(f'incorrect smiles in line {self.num_line}')
73
+ else:
74
+ self.num_line += 1
75
+ return mol_to_hg(each_mol, self.kekulize, self.add_Hs)
76
+
77
+
78
+ def mol_to_bipartite(mol, kekulize):
79
+ """
80
+ get a bipartite representation of a molecule.
81
+
82
+ Parameters
83
+ ----------
84
+ mol : rdkit.Chem.rdchem.Mol
85
+ molecule object
86
+
87
+ Returns
88
+ -------
89
+ nx.Graph
90
+ a bipartite graph representing which bond is connected to which atoms.
91
+ """
92
+ try:
93
+ mol = standardize_stereo(mol)
94
+ except KeyError:
95
+ print(Chem.MolToSmiles(mol))
96
+ raise KeyError
97
+
98
+ if kekulize:
99
+ Chem.Kekulize(mol)
100
+
101
+ bipartite_g = nx.Graph()
102
+ for each_atom in mol.GetAtoms():
103
+ bipartite_g.add_node(f"atom_{each_atom.GetIdx()}",
104
+ atom_attr=atom_attr(each_atom, kekulize))
105
+
106
+ for each_bond in mol.GetBonds():
107
+ bond_idx = each_bond.GetIdx()
108
+ bipartite_g.add_node(
109
+ f"bond_{bond_idx}",
110
+ bond_attr=bond_attr(each_bond, kekulize))
111
+ bipartite_g.add_edge(
112
+ f"atom_{each_bond.GetBeginAtomIdx()}",
113
+ f"bond_{bond_idx}")
114
+ bipartite_g.add_edge(
115
+ f"atom_{each_bond.GetEndAtomIdx()}",
116
+ f"bond_{bond_idx}")
117
+ return bipartite_g
118
+
119
+
120
+ def mol_to_hg(mol, kekulize, add_Hs):
121
+ """
122
+ get a bipartite representation of a molecule.
123
+
124
+ Parameters
125
+ ----------
126
+ mol : rdkit.Chem.rdchem.Mol
127
+ molecule object
128
+ kekulize : bool
129
+ kekulize or not
130
+ add_Hs : bool
131
+ add implicit hydrogens to the molecule or not.
132
+
133
+ Returns
134
+ -------
135
+ Hypergraph
136
+ """
137
+ if add_Hs:
138
+ mol = Chem.AddHs(mol)
139
+
140
+ if kekulize:
141
+ Chem.Kekulize(mol)
142
+
143
+ bipartite_g = mol_to_bipartite(mol, kekulize)
144
+ hg = Hypergraph()
145
+ for each_atom in [each_node for each_node in bipartite_g.nodes()
146
+ if each_node.startswith('atom_')]:
147
+ node_set = set([])
148
+ for each_bond in bipartite_g.adj[each_atom]:
149
+ hg.add_node(each_bond,
150
+ attr_dict=bipartite_g.nodes[each_bond]['bond_attr'])
151
+ node_set.add(each_bond)
152
+ hg.add_edge(node_set,
153
+ attr_dict=bipartite_g.nodes[each_atom]['atom_attr'])
154
+ return hg
155
+
156
+
157
+ def hg_to_mol(hg, verbose=False):
158
+ """ convert a hypergraph into Mol object
159
+
160
+ Parameters
161
+ ----------
162
+ hg : Hypergraph
163
+
164
+ Returns
165
+ -------
166
+ mol : Chem.RWMol
167
+ """
168
+ mol = Chem.RWMol()
169
+ atom_dict = {}
170
+ bond_set = set([])
171
+ for each_edge in hg.edges:
172
+ atom = Chem.Atom(hg.edge_attr(each_edge)['symbol'].symbol)
173
+ atom.SetNumExplicitHs(hg.edge_attr(each_edge)['symbol'].num_explicit_Hs)
174
+ atom.SetFormalCharge(hg.edge_attr(each_edge)['symbol'].formal_charge)
175
+ atom.SetChiralTag(
176
+ Chem.rdchem.ChiralType.values[
177
+ hg.edge_attr(each_edge)['symbol'].chirality])
178
+ atom_idx = mol.AddAtom(atom)
179
+ atom_dict[each_edge] = atom_idx
180
+
181
+ for each_node in hg.nodes:
182
+ edge_1, edge_2 = hg.adj_edges(each_node)
183
+ if edge_1+edge_2 not in bond_set:
184
+ if hg.node_attr(each_node)['symbol'].bond_type <= 3:
185
+ num_bond = hg.node_attr(each_node)['symbol'].bond_type
186
+ elif hg.node_attr(each_node)['symbol'].bond_type == 12:
187
+ num_bond = 1
188
+ else:
189
+ raise ValueError(f'too many bonds; {hg.node_attr(each_node)["bond_symbol"].bond_type}')
190
+ _ = mol.AddBond(atom_dict[edge_1],
191
+ atom_dict[edge_2],
192
+ order=Chem.rdchem.BondType.values[num_bond])
193
+ bond_idx = mol.GetBondBetweenAtoms(atom_dict[edge_1], atom_dict[edge_2]).GetIdx()
194
+
195
+ # stereo
196
+ mol.GetBondWithIdx(bond_idx).SetStereo(
197
+ Chem.rdchem.BondStereo.values[hg.node_attr(each_node)['symbol'].stereo])
198
+ bond_set.update([edge_1+edge_2])
199
+ bond_set.update([edge_2+edge_1])
200
+ mol.UpdatePropertyCache()
201
+ mol = mol.GetMol()
202
+ not_stereo_mol = deepcopy(mol)
203
+ if Chem.MolFromSmiles(Chem.MolToSmiles(not_stereo_mol)) is None:
204
+ raise RuntimeError('no valid molecule was obtained.')
205
+ try:
206
+ mol = set_stereo(mol)
207
+ is_stereo = True
208
+ except:
209
+ import traceback
210
+ traceback.print_exc()
211
+ is_stereo = False
212
+ mol_tmp = deepcopy(mol)
213
+ Chem.SetAromaticity(mol_tmp)
214
+ if Chem.MolFromSmiles(Chem.MolToSmiles(mol_tmp)) is not None:
215
+ mol = mol_tmp
216
+ else:
217
+ if Chem.MolFromSmiles(Chem.MolToSmiles(mol)) is None:
218
+ mol = not_stereo_mol
219
+ mol.UpdatePropertyCache()
220
+ Chem.GetSymmSSSR(mol)
221
+ mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
222
+ if verbose:
223
+ return mol, is_stereo
224
+ else:
225
+ return mol
226
+
227
+ def hgs_to_mols(hg_list, ignore_error=False):
228
+ if ignore_error:
229
+ mol_list = []
230
+ for each_hg in hg_list:
231
+ try:
232
+ mol = hg_to_mol(each_hg)
233
+ except:
234
+ mol = None
235
+ mol_list.append(mol)
236
+ else:
237
+ mol_list = [hg_to_mol(each_hg) for each_hg in hg_list]
238
+ return mol_list
239
+
240
+ def hgs_to_smiles(hg_list, ignore_error=False):
241
+ mol_list = hgs_to_mols(hg_list, ignore_error)
242
+ smiles_list = []
243
+ for each_mol in mol_list:
244
+ try:
245
+ smiles_list.append(
246
+ Chem.MolToSmiles(
247
+ Chem.MolFromSmiles(
248
+ Chem.MolToSmiles(
249
+ each_mol))))
250
+ except:
251
+ smiles_list.append(None)
252
+ return smiles_list
253
+
254
+ def atom_attr(atom, kekulize):
255
+ """
256
+ get atom's attributes
257
+
258
+ Parameters
259
+ ----------
260
+ atom : rdkit.Chem.rdchem.Atom
261
+ kekulize : bool
262
+ kekulize or not
263
+
264
+ Returns
265
+ -------
266
+ atom_attr : dict
267
+ "is_aromatic" : bool
268
+ the atom is aromatic or not.
269
+ "smarts" : str
270
+ SMARTS representation of the atom.
271
+ """
272
+ if kekulize:
273
+ return {'terminal': True,
274
+ 'is_in_ring': atom.IsInRing(),
275
+ 'symbol': TSymbol(degree=0,
276
+ #degree=atom.GetTotalDegree(),
277
+ is_aromatic=False,
278
+ symbol=atom.GetSymbol(),
279
+ num_explicit_Hs=atom.GetNumExplicitHs(),
280
+ formal_charge=atom.GetFormalCharge(),
281
+ chirality=atom.GetChiralTag().real
282
+ )}
283
+ else:
284
+ return {'terminal': True,
285
+ 'is_in_ring': atom.IsInRing(),
286
+ 'symbol': TSymbol(degree=0,
287
+ #degree=atom.GetTotalDegree(),
288
+ is_aromatic=atom.GetIsAromatic(),
289
+ symbol=atom.GetSymbol(),
290
+ num_explicit_Hs=atom.GetNumExplicitHs(),
291
+ formal_charge=atom.GetFormalCharge(),
292
+ chirality=atom.GetChiralTag().real
293
+ )}
294
+
295
+ def bond_attr(bond, kekulize):
296
+ """
297
+ get atom's attributes
298
+
299
+ Parameters
300
+ ----------
301
+ bond : rdkit.Chem.rdchem.Bond
302
+ kekulize : bool
303
+ kekulize or not
304
+
305
+ Returns
306
+ -------
307
+ bond_attr : dict
308
+ "bond_type" : int
309
+ {0: rdkit.Chem.rdchem.BondType.UNSPECIFIED,
310
+ 1: rdkit.Chem.rdchem.BondType.SINGLE,
311
+ 2: rdkit.Chem.rdchem.BondType.DOUBLE,
312
+ 3: rdkit.Chem.rdchem.BondType.TRIPLE,
313
+ 4: rdkit.Chem.rdchem.BondType.QUADRUPLE,
314
+ 5: rdkit.Chem.rdchem.BondType.QUINTUPLE,
315
+ 6: rdkit.Chem.rdchem.BondType.HEXTUPLE,
316
+ 7: rdkit.Chem.rdchem.BondType.ONEANDAHALF,
317
+ 8: rdkit.Chem.rdchem.BondType.TWOANDAHALF,
318
+ 9: rdkit.Chem.rdchem.BondType.THREEANDAHALF,
319
+ 10: rdkit.Chem.rdchem.BondType.FOURANDAHALF,
320
+ 11: rdkit.Chem.rdchem.BondType.FIVEANDAHALF,
321
+ 12: rdkit.Chem.rdchem.BondType.AROMATIC,
322
+ 13: rdkit.Chem.rdchem.BondType.IONIC,
323
+ 14: rdkit.Chem.rdchem.BondType.HYDROGEN,
324
+ 15: rdkit.Chem.rdchem.BondType.THREECENTER,
325
+ 16: rdkit.Chem.rdchem.BondType.DATIVEONE,
326
+ 17: rdkit.Chem.rdchem.BondType.DATIVE,
327
+ 18: rdkit.Chem.rdchem.BondType.DATIVEL,
328
+ 19: rdkit.Chem.rdchem.BondType.DATIVER,
329
+ 20: rdkit.Chem.rdchem.BondType.OTHER,
330
+ 21: rdkit.Chem.rdchem.BondType.ZERO}
331
+ """
332
+ if kekulize:
333
+ is_aromatic = False
334
+ if bond.GetBondType().real == 12:
335
+ bond_type = 1
336
+ else:
337
+ bond_type = bond.GetBondType().real
338
+ else:
339
+ is_aromatic = bond.GetIsAromatic()
340
+ bond_type = bond.GetBondType().real
341
+ return {'symbol': BondSymbol(is_aromatic=is_aromatic,
342
+ bond_type=bond_type,
343
+ stereo=int(bond.GetStereo())),
344
+ 'is_in_ring': bond.IsInRing()}
345
+
346
+
347
+ def standardize_stereo(mol):
348
+ '''
349
+ 0: rdkit.Chem.rdchem.BondDir.NONE,
350
+ 1: rdkit.Chem.rdchem.BondDir.BEGINWEDGE,
351
+ 2: rdkit.Chem.rdchem.BondDir.BEGINDASH,
352
+ 3: rdkit.Chem.rdchem.BondDir.ENDDOWNRIGHT,
353
+ 4: rdkit.Chem.rdchem.BondDir.ENDUPRIGHT,
354
+
355
+ '''
356
+ # mol = Chem.AddHs(mol) # this removes CIPRank !!!
357
+ for each_bond in mol.GetBonds():
358
+ if int(each_bond.GetStereo()) in [2, 3]: #2=Z (same side), 3=E
359
+ begin_stereo_atom_idx = each_bond.GetBeginAtomIdx()
360
+ end_stereo_atom_idx = each_bond.GetEndAtomIdx()
361
+ atom_idx_1 = each_bond.GetStereoAtoms()[0]
362
+ atom_idx_2 = each_bond.GetStereoAtoms()[1]
363
+ if mol.GetBondBetweenAtoms(atom_idx_1, begin_stereo_atom_idx):
364
+ begin_atom_idx = atom_idx_1
365
+ end_atom_idx = atom_idx_2
366
+ else:
367
+ begin_atom_idx = atom_idx_2
368
+ end_atom_idx = atom_idx_1
369
+
370
+ begin_another_atom_idx = None
371
+ assert len(mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors()) <= 3
372
+ for each_neighbor in mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors():
373
+ each_neighbor_idx = each_neighbor.GetIdx()
374
+ if each_neighbor_idx not in [end_stereo_atom_idx, begin_atom_idx]:
375
+ begin_another_atom_idx = each_neighbor_idx
376
+
377
+ end_another_atom_idx = None
378
+ assert len(mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors()) <= 3
379
+ for each_neighbor in mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors():
380
+ each_neighbor_idx = each_neighbor.GetIdx()
381
+ if each_neighbor_idx not in [begin_stereo_atom_idx, end_atom_idx]:
382
+ end_another_atom_idx = each_neighbor_idx
383
+
384
+ '''
385
+ relationship between begin_atom_idx and end_atom_idx is encoded in GetStereo
386
+ '''
387
+ begin_atom_rank = int(mol.GetAtomWithIdx(begin_atom_idx).GetProp('_CIPRank'))
388
+ end_atom_rank = int(mol.GetAtomWithIdx(end_atom_idx).GetProp('_CIPRank'))
389
+ try:
390
+ begin_another_atom_rank = int(mol.GetAtomWithIdx(begin_another_atom_idx).GetProp('_CIPRank'))
391
+ except:
392
+ begin_another_atom_rank = np.inf
393
+ try:
394
+ end_another_atom_rank = int(mol.GetAtomWithIdx(end_another_atom_idx).GetProp('_CIPRank'))
395
+ except:
396
+ end_another_atom_rank = np.inf
397
+ if begin_atom_rank < begin_another_atom_rank\
398
+ and end_atom_rank < end_another_atom_rank:
399
+ pass
400
+ elif begin_atom_rank < begin_another_atom_rank\
401
+ and end_atom_rank > end_another_atom_rank:
402
+ # (begin_atom_idx +) end_another_atom_idx should be in StereoAtoms
403
+ if each_bond.GetStereo() == 2:
404
+ # set stereo
405
+ each_bond.SetStereo(Chem.rdchem.BondStereo.values[3])
406
+ # set bond dir
407
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
408
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 0)
409
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
410
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 3)
411
+ elif each_bond.GetStereo() == 3:
412
+ # set stereo
413
+ each_bond.SetStereo(Chem.rdchem.BondStereo.values[2])
414
+ # set bond dir
415
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
416
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 0)
417
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
418
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 4)
419
+ else:
420
+ raise ValueError
421
+ each_bond.SetStereoAtoms(begin_atom_idx, end_another_atom_idx)
422
+ elif begin_atom_rank > begin_another_atom_rank\
423
+ and end_atom_rank < end_another_atom_rank:
424
+ # (end_atom_idx +) begin_another_atom_idx should be in StereoAtoms
425
+ if each_bond.GetStereo() == 2:
426
+ # set stereo
427
+ each_bond.SetStereo(Chem.rdchem.BondStereo.values[3])
428
+ # set bond dir
429
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
430
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
431
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 4)
432
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 0)
433
+ elif each_bond.GetStereo() == 3:
434
+ # set stereo
435
+ each_bond.SetStereo(Chem.rdchem.BondStereo.values[2])
436
+ # set bond dir
437
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
438
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
439
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 3)
440
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 0)
441
+ else:
442
+ raise ValueError
443
+ each_bond.SetStereoAtoms(begin_another_atom_idx, end_atom_idx)
444
+ elif begin_atom_rank > begin_another_atom_rank\
445
+ and end_atom_rank > end_another_atom_rank:
446
+ # begin_another_atom_idx + end_another_atom_idx should be in StereoAtoms
447
+ if each_bond.GetStereo() == 2:
448
+ # set bond dir
449
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
450
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
451
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
452
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 3)
453
+ elif each_bond.GetStereo() == 3:
454
+ # set bond dir
455
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
456
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
457
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
458
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 4)
459
+ else:
460
+ raise ValueError
461
+ each_bond.SetStereoAtoms(begin_another_atom_idx, end_another_atom_idx)
462
+ else:
463
+ raise RuntimeError
464
+ return mol
465
+
466
+
467
+ def set_stereo(mol):
468
+ '''
469
+ 0: rdkit.Chem.rdchem.BondDir.NONE,
470
+ 1: rdkit.Chem.rdchem.BondDir.BEGINWEDGE,
471
+ 2: rdkit.Chem.rdchem.BondDir.BEGINDASH,
472
+ 3: rdkit.Chem.rdchem.BondDir.ENDDOWNRIGHT,
473
+ 4: rdkit.Chem.rdchem.BondDir.ENDUPRIGHT,
474
+ '''
475
+ _mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
476
+ Chem.Kekulize(_mol, True)
477
+ substruct_match = mol.GetSubstructMatch(_mol)
478
+ if not substruct_match:
479
+ ''' mol and _mol are kekulized.
480
+ sometimes, the order of '=' and '-' changes, which causes mol and _mol not matched.
481
+ '''
482
+ Chem.SetAromaticity(mol)
483
+ Chem.SetAromaticity(_mol)
484
+ substruct_match = mol.GetSubstructMatch(_mol)
485
+ try:
486
+ atom_match = {substruct_match[_mol_atom_idx]: _mol_atom_idx for _mol_atom_idx in range(_mol.GetNumAtoms())} # mol to _mol
487
+ except:
488
+ raise ValueError('two molecules obtained from the same data do not match.')
489
+
490
+ for each_bond in mol.GetBonds():
491
+ begin_atom_idx = each_bond.GetBeginAtomIdx()
492
+ end_atom_idx = each_bond.GetEndAtomIdx()
493
+ _bond = _mol.GetBondBetweenAtoms(atom_match[begin_atom_idx], atom_match[end_atom_idx])
494
+ _bond.SetStereo(each_bond.GetStereo())
495
+
496
+ mol = _mol
497
+ for each_bond in mol.GetBonds():
498
+ if int(each_bond.GetStereo()) in [2, 3]: #2=Z (same side), 3=E
499
+ begin_stereo_atom_idx = each_bond.GetBeginAtomIdx()
500
+ end_stereo_atom_idx = each_bond.GetEndAtomIdx()
501
+ begin_atom_idx_set = set([each_neighbor.GetIdx()
502
+ for each_neighbor
503
+ in mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors()
504
+ if each_neighbor.GetIdx() != end_stereo_atom_idx])
505
+ end_atom_idx_set = set([each_neighbor.GetIdx()
506
+ for each_neighbor
507
+ in mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors()
508
+ if each_neighbor.GetIdx() != begin_stereo_atom_idx])
509
+ if not begin_atom_idx_set:
510
+ each_bond.SetStereo(Chem.rdchem.BondStereo(0))
511
+ continue
512
+ if not end_atom_idx_set:
513
+ each_bond.SetStereo(Chem.rdchem.BondStereo(0))
514
+ continue
515
+ if len(begin_atom_idx_set) == 1:
516
+ begin_atom_idx = begin_atom_idx_set.pop()
517
+ begin_another_atom_idx = None
518
+ if len(end_atom_idx_set) == 1:
519
+ end_atom_idx = end_atom_idx_set.pop()
520
+ end_another_atom_idx = None
521
+ if len(begin_atom_idx_set) == 2:
522
+ atom_idx_1 = begin_atom_idx_set.pop()
523
+ atom_idx_2 = begin_atom_idx_set.pop()
524
+ if int(mol.GetAtomWithIdx(atom_idx_1).GetProp('_CIPRank')) < int(mol.GetAtomWithIdx(atom_idx_2).GetProp('_CIPRank')):
525
+ begin_atom_idx = atom_idx_1
526
+ begin_another_atom_idx = atom_idx_2
527
+ else:
528
+ begin_atom_idx = atom_idx_2
529
+ begin_another_atom_idx = atom_idx_1
530
+ if len(end_atom_idx_set) == 2:
531
+ atom_idx_1 = end_atom_idx_set.pop()
532
+ atom_idx_2 = end_atom_idx_set.pop()
533
+ if int(mol.GetAtomWithIdx(atom_idx_1).GetProp('_CIPRank')) < int(mol.GetAtomWithIdx(atom_idx_2).GetProp('_CIPRank')):
534
+ end_atom_idx = atom_idx_1
535
+ end_another_atom_idx = atom_idx_2
536
+ else:
537
+ end_atom_idx = atom_idx_2
538
+ end_another_atom_idx = atom_idx_1
539
+
540
+ if each_bond.GetStereo() == 2: # same side
541
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
542
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 4)
543
+ each_bond.SetStereoAtoms(begin_atom_idx, end_atom_idx)
544
+ elif each_bond.GetStereo() == 3: # opposite side
545
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
546
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 3)
547
+ each_bond.SetStereoAtoms(begin_atom_idx, end_atom_idx)
548
+ else:
549
+ raise ValueError
550
+ return mol
551
+
552
+
553
+ def safe_set_bond_dir(mol, atom_idx_1, atom_idx_2, bond_dir_val):
554
+ if atom_idx_1 is None or atom_idx_2 is None:
555
+ return mol
556
+ else:
557
+ mol.GetBondBetweenAtoms(atom_idx_1, atom_idx_2).SetBondDir(Chem.rdchem.BondDir.values[bond_dir_val])
558
+ return mol
559
+
models/mhg_model/graph_grammar/nn/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # Rhizome
3
+ # Version beta 0.0, August 2023
4
+ # Property of IBM Research, Accelerated Discovery
5
+ #
6
+
7
+ """
8
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
9
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
10
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
11
+ """
models/mhg_model/graph_grammar/nn/dataset.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Apr 18 2018"
20
+
21
+ from torch.utils.data import Dataset, DataLoader
22
+ import torch
23
+ import numpy as np
24
+
25
+
26
+ def left_padding(sentence_list, max_len, pad_idx=-1, inverse=False):
27
+ ''' pad left
28
+
29
+ Parameters
30
+ ----------
31
+ sentence_list : list of sequences of integers
32
+ max_len : int
33
+ maximum length of sentences.
34
+ if a sentence is shorter than `max_len`, its left part is padded.
35
+ pad_idx : int
36
+ integer for padding
37
+ inverse : bool
38
+ if True, the sequence is inversed.
39
+
40
+ Returns
41
+ -------
42
+ List of torch.LongTensor
43
+ each sentence is left-padded.
44
+ '''
45
+ max_in_list = max([len(each_sen) for each_sen in sentence_list])
46
+
47
+ if max_in_list > max_len:
48
+ raise ValueError('`max_len` should be larger than the maximum length of input sequences, {}.'.format(max_in_list))
49
+
50
+ if inverse:
51
+ return [torch.LongTensor([pad_idx] * (max_len - len(each_sen)) + each_sen[::-1]) for each_sen in sentence_list]
52
+ else:
53
+ return [torch.LongTensor([pad_idx] * (max_len - len(each_sen)) + each_sen) for each_sen in sentence_list]
54
+
55
+
56
+ def right_padding(sentence_list, max_len, pad_idx=-1):
57
+ ''' pad right
58
+
59
+ Parameters
60
+ ----------
61
+ sentence_list : list of sequences of integers
62
+ max_len : int
63
+ maximum length of sentences.
64
+ if a sentence is shorter than `max_len`, its right part is padded.
65
+ pad_idx : int
66
+ integer for padding
67
+
68
+ Returns
69
+ -------
70
+ List of torch.LongTensor
71
+ each sentence is right-padded.
72
+ '''
73
+ max_in_list = max([len(each_sen) for each_sen in sentence_list])
74
+ if max_in_list > max_len:
75
+ raise ValueError('`max_len` should be larger than the maximum length of input sequences, {}.'.format(max_in_list))
76
+
77
+ return [torch.LongTensor(each_sen + [pad_idx] * (max_len - len(each_sen))) for each_sen in sentence_list]
78
+
79
+
80
+ class HRGDataset(Dataset):
81
+
82
+ '''
83
+ A class of HRG data
84
+ '''
85
+
86
+ def __init__(self, hrg, prod_rule_seq_list, max_len, target_val_list=None, inversed_input=False):
87
+ self.hrg = hrg
88
+ self.left_prod_rule_seq_list = left_padding(prod_rule_seq_list,
89
+ max_len,
90
+ inverse=inversed_input)
91
+
92
+ self.right_prod_rule_seq_list = right_padding(prod_rule_seq_list, max_len)
93
+ self.inserved_input = inversed_input
94
+ self.target_val_list = target_val_list
95
+ if target_val_list is not None:
96
+ if len(prod_rule_seq_list) != len(target_val_list):
97
+ raise ValueError(f'prod_rule_seq_list and target_val_list have inconsistent lengths: {len(prod_rule_seq_list)}, {len(target_val_list)}')
98
+
99
+ def __len__(self):
100
+ return len(self.left_prod_rule_seq_list)
101
+
102
+ def __getitem__(self, idx):
103
+ if self.target_val_list is not None:
104
+ return self.left_prod_rule_seq_list[idx], self.right_prod_rule_seq_list[idx], np.float32(self.target_val_list[idx])
105
+ else:
106
+ return self.left_prod_rule_seq_list[idx], self.right_prod_rule_seq_list[idx]
107
+
108
+ @property
109
+ def vocab_size(self):
110
+ return self.hrg.num_prod_rule
111
+
112
+ def batch_padding(each_batch, batch_size, padding_idx):
113
+ num_pad = batch_size - len(each_batch[0])
114
+ if num_pad:
115
+ each_batch[0] = torch.cat([each_batch[0],
116
+ padding_idx * torch.ones((batch_size - len(each_batch[0]),
117
+ len(each_batch[0][0])), dtype=torch.int64)], dim=0)
118
+ each_batch[1] = torch.cat([each_batch[1],
119
+ padding_idx * torch.ones((batch_size - len(each_batch[1]),
120
+ len(each_batch[1][0])), dtype=torch.int64)], dim=0)
121
+ return each_batch, num_pad
models/mhg_model/graph_grammar/nn/decoder.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Aug 9 2018"
20
+
21
+
22
+ import abc
23
+ import numpy as np
24
+ import torch
25
+ from torch import nn
26
+
27
+
28
+ class DecoderBase(nn.Module):
29
+
30
+ def __init__(self):
31
+ super().__init__()
32
+ self.hidden_dict = {}
33
+
34
+ @abc.abstractmethod
35
+ def forward_one_step(self, tgt_emb_in):
36
+ ''' one-step forward model
37
+
38
+ Parameters
39
+ ----------
40
+ tgt_emb_in : Tensor, shape (batch_size, input_dim)
41
+
42
+ Returns
43
+ -------
44
+ Tensor, shape (batch_size, hidden_dim)
45
+ '''
46
+ tgt_emb_out = None
47
+ return tgt_emb_out
48
+
49
+ @abc.abstractmethod
50
+ def init_hidden(self):
51
+ ''' initialize the hidden states
52
+ '''
53
+ pass
54
+
55
+ @abc.abstractmethod
56
+ def feed_hidden(self, hidden_dict_0):
57
+ for each_hidden in self.hidden_dict.keys():
58
+ self.hidden_dict[each_hidden][0] = hidden_dict_0[each_hidden]
59
+
60
+
61
+ class GRUDecoder(DecoderBase):
62
+
63
+ def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
64
+ dropout: float, batch_size: int, use_gpu: bool,
65
+ no_dropout=False):
66
+ super().__init__()
67
+ self.input_dim = input_dim
68
+ self.hidden_dim = hidden_dim
69
+ self.num_layers = num_layers
70
+ self.dropout = dropout
71
+ self.batch_size = batch_size
72
+ self.use_gpu = use_gpu
73
+ self.model = nn.GRU(input_size=self.input_dim,
74
+ hidden_size=self.hidden_dim,
75
+ num_layers=self.num_layers,
76
+ batch_first=True,
77
+ bidirectional=False,
78
+ dropout=self.dropout if not no_dropout else 0
79
+ )
80
+ if self.use_gpu:
81
+ self.model.cuda()
82
+ self.init_hidden()
83
+
84
+ def init_hidden(self):
85
+ self.hidden_dict['h'] = torch.zeros((self.num_layers,
86
+ self.batch_size,
87
+ self.hidden_dim),
88
+ requires_grad=False)
89
+ if self.use_gpu:
90
+ self.hidden_dict['h'] = self.hidden_dict['h'].cuda()
91
+
92
+ def forward_one_step(self, tgt_emb_in):
93
+ ''' one-step forward model
94
+
95
+ Parameters
96
+ ----------
97
+ tgt_emb_in : Tensor, shape (batch_size, input_dim)
98
+
99
+ Returns
100
+ -------
101
+ Tensor, shape (batch_size, hidden_dim)
102
+ '''
103
+ tgt_emb_out, self.hidden_dict['h'] \
104
+ = self.model(tgt_emb_in.view(self.batch_size, 1, -1),
105
+ self.hidden_dict['h'])
106
+ return tgt_emb_out
107
+
108
+
109
+ class LSTMDecoder(DecoderBase):
110
+
111
+ def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
112
+ dropout: float, batch_size: int, use_gpu: bool,
113
+ no_dropout=False):
114
+ super().__init__()
115
+ self.input_dim = input_dim
116
+ self.hidden_dim = hidden_dim
117
+ self.num_layers = num_layers
118
+ self.dropout = dropout
119
+ self.batch_size = batch_size
120
+ self.use_gpu = use_gpu
121
+ self.model = nn.LSTM(input_size=self.input_dim,
122
+ hidden_size=self.hidden_dim,
123
+ num_layers=self.num_layers,
124
+ batch_first=True,
125
+ bidirectional=False,
126
+ dropout=self.dropout if not no_dropout else 0)
127
+ if self.use_gpu:
128
+ self.model.cuda()
129
+ self.init_hidden()
130
+
131
+ def init_hidden(self):
132
+ self.hidden_dict['h'] = torch.zeros((self.num_layers,
133
+ self.batch_size,
134
+ self.hidden_dim),
135
+ requires_grad=False)
136
+ self.hidden_dict['c'] = torch.zeros((self.num_layers,
137
+ self.batch_size,
138
+ self.hidden_dim),
139
+ requires_grad=False)
140
+ if self.use_gpu:
141
+ for each_hidden in self.hidden_dict.keys():
142
+ self.hidden_dict[each_hidden] = self.hidden_dict[each_hidden].cuda()
143
+
144
+ def forward_one_step(self, tgt_emb_in):
145
+ ''' one-step forward model
146
+
147
+ Parameters
148
+ ----------
149
+ tgt_emb_in : Tensor, shape (batch_size, input_dim)
150
+
151
+ Returns
152
+ -------
153
+ Tensor, shape (batch_size, hidden_dim)
154
+ '''
155
+ tgt_hidden_out, self.hidden_dict['h'], self.hidden_dict['c'] \
156
+ = self.model(tgt_emb_in.view(self.batch_size, 1, -1),
157
+ self.hidden_dict['h'], self.hidden_dict['c'])
158
+ return tgt_hidden_out
models/mhg_model/graph_grammar/nn/encoder.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Aug 9 2018"
20
+
21
+
22
+ import abc
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn.functional as F
26
+ from torch import nn
27
+ from typing import List
28
+
29
+
30
+ class EncoderBase(nn.Module):
31
+
32
+ def __init__(self):
33
+ super().__init__()
34
+
35
+ @abc.abstractmethod
36
+ def forward(self, in_seq):
37
+ ''' forward model
38
+
39
+ Parameters
40
+ ----------
41
+ in_seq_emb : Variable, shape (batch_size, max_len, input_dim)
42
+
43
+ Returns
44
+ -------
45
+ hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
46
+ '''
47
+ pass
48
+
49
+ @abc.abstractmethod
50
+ def init_hidden(self):
51
+ ''' initialize the hidden states
52
+ '''
53
+ pass
54
+
55
+
56
+ class GRUEncoder(EncoderBase):
57
+
58
+ def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
59
+ bidirectional: bool, dropout: float, batch_size: int, use_gpu: bool,
60
+ no_dropout=False):
61
+ super().__init__()
62
+ self.input_dim = input_dim
63
+ self.hidden_dim = hidden_dim
64
+ self.num_layers = num_layers
65
+ self.bidirectional = bidirectional
66
+ self.dropout = dropout
67
+ self.batch_size = batch_size
68
+ self.use_gpu = use_gpu
69
+ self.model = nn.GRU(input_size=self.input_dim,
70
+ hidden_size=self.hidden_dim,
71
+ num_layers=self.num_layers,
72
+ batch_first=True,
73
+ bidirectional=self.bidirectional,
74
+ dropout=self.dropout if not no_dropout else 0)
75
+ if self.use_gpu:
76
+ self.model.cuda()
77
+ self.init_hidden()
78
+
79
+
80
+ def init_hidden(self):
81
+ self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
82
+ self.batch_size,
83
+ self.hidden_dim),
84
+ requires_grad=False)
85
+ if self.use_gpu:
86
+ self.h0 = self.h0.cuda()
87
+
88
+ def forward(self, in_seq_emb):
89
+ ''' forward model
90
+
91
+ Parameters
92
+ ----------
93
+ in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
94
+
95
+ Returns
96
+ -------
97
+ hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
98
+ '''
99
+ max_len = in_seq_emb.size(1)
100
+ hidden_seq_emb, self.h0 = self.model(
101
+ in_seq_emb, self.h0)
102
+ hidden_seq_emb = hidden_seq_emb.view(self.batch_size,
103
+ max_len,
104
+ 1 + self.bidirectional,
105
+ self.hidden_dim)
106
+ return hidden_seq_emb
107
+
108
+
109
+ class LSTMEncoder(EncoderBase):
110
+
111
+ def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
112
+ bidirectional: bool, dropout: float, batch_size: int, use_gpu: bool,
113
+ no_dropout=False):
114
+ super().__init__()
115
+ self.input_dim = input_dim
116
+ self.hidden_dim = hidden_dim
117
+ self.num_layers = num_layers
118
+ self.bidirectional = bidirectional
119
+ self.dropout = dropout
120
+ self.batch_size = batch_size
121
+ self.use_gpu = use_gpu
122
+ self.model = nn.LSTM(input_size=self.input_dim,
123
+ hidden_size=self.hidden_dim,
124
+ num_layers=self.num_layers,
125
+ batch_first=True,
126
+ bidirectional=self.bidirectional,
127
+ dropout=self.dropout if not no_dropout else 0)
128
+ if self.use_gpu:
129
+ self.model.cuda()
130
+ self.init_hidden()
131
+
132
+ def init_hidden(self):
133
+ self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
134
+ self.batch_size,
135
+ self.hidden_dim),
136
+ requires_grad=False)
137
+ self.c0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
138
+ self.batch_size,
139
+ self.hidden_dim),
140
+ requires_grad=False)
141
+ if self.use_gpu:
142
+ self.h0 = self.h0.cuda()
143
+ self.c0 = self.c0.cuda()
144
+
145
+ def forward(self, in_seq_emb):
146
+ ''' forward model
147
+
148
+ Parameters
149
+ ----------
150
+ in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
151
+
152
+ Returns
153
+ -------
154
+ hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
155
+ '''
156
+ max_len = in_seq_emb.size(1)
157
+ hidden_seq_emb, (self.h0, self.c0) = self.model(
158
+ in_seq_emb, (self.h0, self.c0))
159
+ hidden_seq_emb = hidden_seq_emb.view(self.batch_size,
160
+ max_len,
161
+ 1 + self.bidirectional,
162
+ self.hidden_dim)
163
+ return hidden_seq_emb
164
+
165
+
166
+ class FullConnectedEncoder(EncoderBase):
167
+
168
+ def __init__(self, input_dim: int, hidden_dim: int, max_len: int, hidden_dim_list: List[int],
169
+ batch_size: int, use_gpu: bool):
170
+ super().__init__()
171
+ self.input_dim = input_dim
172
+ self.hidden_dim = hidden_dim
173
+ self.max_len = max_len
174
+ self.hidden_dim_list = hidden_dim_list
175
+ self.use_gpu = use_gpu
176
+ in_out_dim_list = [input_dim * max_len] + list(hidden_dim_list) + [hidden_dim]
177
+ self.linear_list = nn.ModuleList(
178
+ [nn.Linear(in_out_dim_list[each_idx], in_out_dim_list[each_idx + 1])\
179
+ for each_idx in range(len(in_out_dim_list) - 1)])
180
+
181
+ def forward(self, in_seq_emb):
182
+ ''' forward model
183
+
184
+ Parameters
185
+ ----------
186
+ in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
187
+
188
+ Returns
189
+ -------
190
+ hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
191
+ '''
192
+ batch_size = in_seq_emb.size(0)
193
+ x = in_seq_emb.view(batch_size, -1)
194
+ for each_linear in self.linear_list:
195
+ x = F.relu(each_linear(x))
196
+ return x.view(batch_size, 1, -1)
197
+
198
+ def init_hidden(self):
199
+ pass
models/mhg_model/graph_grammar/nn/graph.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 1 2018"
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from graph_grammar.graph_grammar.hrg import ProductionRuleCorpus
25
+ from torch import nn
26
+ from torch.autograd import Variable
27
+
28
+ class MolecularProdRuleEmbedding(nn.Module):
29
+
30
+ ''' molecular fingerprint layer
31
+ '''
32
+
33
+ def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation,
34
+ out_dim=32, element_embed_dim=32,
35
+ num_layers=3, padding_idx=None, use_gpu=False):
36
+ super().__init__()
37
+ if padding_idx is not None:
38
+ assert padding_idx == -1, 'padding_idx must be -1.'
39
+ self.prod_rule_corpus = prod_rule_corpus
40
+ self.layer2layer_activation = layer2layer_activation
41
+ self.layer2out_activation = layer2out_activation
42
+ self.out_dim = out_dim
43
+ self.element_embed_dim = element_embed_dim
44
+ self.num_layers = num_layers
45
+ self.padding_idx = padding_idx
46
+ self.use_gpu = use_gpu
47
+
48
+ self.layer2layer_list = []
49
+ self.layer2out_list = []
50
+
51
+ if self.use_gpu:
52
+ self.atom_embed = torch.randn(self.prod_rule_corpus.num_edge_symbol,
53
+ self.element_embed_dim, requires_grad=True).cuda()
54
+ self.bond_embed = torch.randn(self.prod_rule_corpus.num_node_symbol,
55
+ self.element_embed_dim, requires_grad=True).cuda()
56
+ self.ext_id_embed = torch.randn(self.prod_rule_corpus.num_ext_id,
57
+ self.element_embed_dim, requires_grad=True).cuda()
58
+ for _ in range(num_layers):
59
+ self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim).cuda())
60
+ self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim).cuda())
61
+ else:
62
+ self.atom_embed = torch.randn(self.prod_rule_corpus.num_edge_symbol,
63
+ self.element_embed_dim, requires_grad=True)
64
+ self.bond_embed = torch.randn(self.prod_rule_corpus.num_node_symbol,
65
+ self.element_embed_dim, requires_grad=True)
66
+ self.ext_id_embed = torch.randn(self.prod_rule_corpus.num_ext_id,
67
+ self.element_embed_dim, requires_grad=True)
68
+ for _ in range(num_layers):
69
+ self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim))
70
+ self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim))
71
+
72
+
73
+ def forward(self, prod_rule_idx_seq):
74
+ ''' forward model for mini-batch
75
+
76
+ Parameters
77
+ ----------
78
+ prod_rule_idx_seq : (batch_size, length)
79
+
80
+ Returns
81
+ -------
82
+ Variable, shape (batch_size, length, out_dim)
83
+ '''
84
+ batch_size, length = prod_rule_idx_seq.shape
85
+ if self.use_gpu:
86
+ out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda()
87
+ else:
88
+ out = Variable(torch.zeros((batch_size, length, self.out_dim)))
89
+ for each_batch_idx in range(batch_size):
90
+ for each_idx in range(length):
91
+ if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list):
92
+ continue
93
+ else:
94
+ each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])]
95
+ layer_wise_embed_dict = {each_edge: self.atom_embed[
96
+ each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']]
97
+ for each_edge in each_prod_rule.rhs.edges}
98
+ layer_wise_embed_dict.update({each_node: self.bond_embed[
99
+ each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]
100
+ for each_node in each_prod_rule.rhs.nodes})
101
+ for each_node in each_prod_rule.rhs.nodes:
102
+ if 'ext_id' in each_prod_rule.rhs.node_attr(each_node):
103
+ layer_wise_embed_dict[each_node] \
104
+ = layer_wise_embed_dict[each_node] \
105
+ + self.ext_id_embed[each_prod_rule.rhs.node_attr(each_node)['ext_id']]
106
+
107
+ for each_layer in range(self.num_layers):
108
+ next_layer_embed_dict = {}
109
+ for each_edge in each_prod_rule.rhs.edges:
110
+ v = layer_wise_embed_dict[each_edge]
111
+ for each_node in each_prod_rule.rhs.nodes_in_edge(each_edge):
112
+ v = v + layer_wise_embed_dict[each_node]
113
+ next_layer_embed_dict[each_edge] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
114
+ out[each_batch_idx, each_idx, :] \
115
+ = out[each_batch_idx, each_idx, :] + self.layer2out_activation(self.layer2out_list[each_layer](v))
116
+ for each_node in each_prod_rule.rhs.nodes:
117
+ v = layer_wise_embed_dict[each_node]
118
+ for each_edge in each_prod_rule.rhs.adj_edges(each_node):
119
+ v = v + layer_wise_embed_dict[each_edge]
120
+ next_layer_embed_dict[each_node] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
121
+ out[each_batch_idx, each_idx, :]\
122
+ = out[each_batch_idx, each_idx, :] + self.layer2out_activation(self.layer2out_list[each_layer](v))
123
+ layer_wise_embed_dict = next_layer_embed_dict
124
+
125
+ return out
126
+
127
+
128
+ class MolecularProdRuleEmbeddingLastLayer(nn.Module):
129
+
130
+ ''' molecular fingerprint layer
131
+ '''
132
+
133
+ def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation,
134
+ out_dim=32, element_embed_dim=32,
135
+ num_layers=3, padding_idx=None, use_gpu=False):
136
+ super().__init__()
137
+ if padding_idx is not None:
138
+ assert padding_idx == -1, 'padding_idx must be -1.'
139
+ self.prod_rule_corpus = prod_rule_corpus
140
+ self.layer2layer_activation = layer2layer_activation
141
+ self.layer2out_activation = layer2out_activation
142
+ self.out_dim = out_dim
143
+ self.element_embed_dim = element_embed_dim
144
+ self.num_layers = num_layers
145
+ self.padding_idx = padding_idx
146
+ self.use_gpu = use_gpu
147
+
148
+ self.layer2layer_list = []
149
+ self.layer2out_list = []
150
+
151
+ if self.use_gpu:
152
+ self.atom_embed = nn.Embedding(self.prod_rule_corpus.num_edge_symbol, self.element_embed_dim).cuda()
153
+ self.bond_embed = nn.Embedding(self.prod_rule_corpus.num_node_symbol, self.element_embed_dim).cuda()
154
+ for _ in range(num_layers+1):
155
+ self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim).cuda())
156
+ self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim).cuda())
157
+ else:
158
+ self.atom_embed = nn.Embedding(self.prod_rule_corpus.num_edge_symbol, self.element_embed_dim)
159
+ self.bond_embed = nn.Embedding(self.prod_rule_corpus.num_node_symbol, self.element_embed_dim)
160
+ for _ in range(num_layers+1):
161
+ self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim))
162
+ self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim))
163
+
164
+
165
+ def forward(self, prod_rule_idx_seq):
166
+ ''' forward model for mini-batch
167
+
168
+ Parameters
169
+ ----------
170
+ prod_rule_idx_seq : (batch_size, length)
171
+
172
+ Returns
173
+ -------
174
+ Variable, shape (batch_size, length, out_dim)
175
+ '''
176
+ batch_size, length = prod_rule_idx_seq.shape
177
+ if self.use_gpu:
178
+ out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda()
179
+ else:
180
+ out = Variable(torch.zeros((batch_size, length, self.out_dim)))
181
+ for each_batch_idx in range(batch_size):
182
+ for each_idx in range(length):
183
+ if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list):
184
+ continue
185
+ else:
186
+ each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])]
187
+
188
+ if self.use_gpu:
189
+ layer_wise_embed_dict = {each_edge: self.atom_embed(
190
+ Variable(torch.LongTensor(
191
+ [each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']]
192
+ ), requires_grad=False).cuda())
193
+ for each_edge in each_prod_rule.rhs.edges}
194
+ layer_wise_embed_dict.update({each_node: self.bond_embed(
195
+ Variable(
196
+ torch.LongTensor([
197
+ each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]),
198
+ requires_grad=False).cuda()
199
+ ) for each_node in each_prod_rule.rhs.nodes})
200
+ else:
201
+ layer_wise_embed_dict = {each_edge: self.atom_embed(
202
+ Variable(torch.LongTensor(
203
+ [each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']]
204
+ ), requires_grad=False))
205
+ for each_edge in each_prod_rule.rhs.edges}
206
+ layer_wise_embed_dict.update({each_node: self.bond_embed(
207
+ Variable(
208
+ torch.LongTensor([
209
+ each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]),
210
+ requires_grad=False)
211
+ ) for each_node in each_prod_rule.rhs.nodes})
212
+
213
+ for each_layer in range(self.num_layers):
214
+ next_layer_embed_dict = {}
215
+ for each_edge in each_prod_rule.rhs.edges:
216
+ v = layer_wise_embed_dict[each_edge]
217
+ for each_node in each_prod_rule.rhs.nodes_in_edge(each_edge):
218
+ v += layer_wise_embed_dict[each_node]
219
+ next_layer_embed_dict[each_edge] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
220
+ for each_node in each_prod_rule.rhs.nodes:
221
+ v = layer_wise_embed_dict[each_node]
222
+ for each_edge in each_prod_rule.rhs.adj_edges(each_node):
223
+ v += layer_wise_embed_dict[each_edge]
224
+ next_layer_embed_dict[each_node] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
225
+ layer_wise_embed_dict = next_layer_embed_dict
226
+ for each_edge in each_prod_rule.rhs.edges:
227
+ out[each_batch_idx, each_idx, :] = self.layer2out_activation(self.layer2out_list[self.num_layers](v))
228
+ for each_edge in each_prod_rule.rhs.edges:
229
+ out[each_batch_idx, each_idx, :] = self.layer2out_activation(self.layer2out_list[self.num_layers](v))
230
+
231
+ return out
232
+
233
+
234
+ class MolecularProdRuleEmbeddingUsingFeatures(nn.Module):
235
+
236
+ ''' molecular fingerprint layer
237
+ '''
238
+
239
+ def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation,
240
+ out_dim=32, num_layers=3, padding_idx=None, use_gpu=False):
241
+ super().__init__()
242
+ if padding_idx is not None:
243
+ assert padding_idx == -1, 'padding_idx must be -1.'
244
+ self.feature_dict, self.feature_dim = prod_rule_corpus.construct_feature_vectors()
245
+ self.prod_rule_corpus = prod_rule_corpus
246
+ self.layer2layer_activation = layer2layer_activation
247
+ self.layer2out_activation = layer2out_activation
248
+ self.out_dim = out_dim
249
+ self.num_layers = num_layers
250
+ self.padding_idx = padding_idx
251
+ self.use_gpu = use_gpu
252
+
253
+ self.layer2layer_list = []
254
+ self.layer2out_list = []
255
+
256
+ if self.use_gpu:
257
+ for each_key in self.feature_dict:
258
+ self.feature_dict[each_key] = self.feature_dict[each_key].to_dense().cuda()
259
+ for _ in range(num_layers):
260
+ self.layer2layer_list.append(nn.Linear(self.feature_dim, self.feature_dim).cuda())
261
+ self.layer2out_list.append(nn.Linear(self.feature_dim, self.out_dim).cuda())
262
+ else:
263
+ for _ in range(num_layers):
264
+ self.layer2layer_list.append(nn.Linear(self.feature_dim, self.feature_dim))
265
+ self.layer2out_list.append(nn.Linear(self.feature_dim, self.out_dim))
266
+
267
+
268
+ def forward(self, prod_rule_idx_seq):
269
+ ''' forward model for mini-batch
270
+
271
+ Parameters
272
+ ----------
273
+ prod_rule_idx_seq : (batch_size, length)
274
+
275
+ Returns
276
+ -------
277
+ Variable, shape (batch_size, length, out_dim)
278
+ '''
279
+ batch_size, length = prod_rule_idx_seq.shape
280
+ if self.use_gpu:
281
+ out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda()
282
+ else:
283
+ out = Variable(torch.zeros((batch_size, length, self.out_dim)))
284
+ for each_batch_idx in range(batch_size):
285
+ for each_idx in range(length):
286
+ if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list):
287
+ continue
288
+ else:
289
+ each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])]
290
+ edge_list = sorted(list(each_prod_rule.rhs.edges))
291
+ node_list = sorted(list(each_prod_rule.rhs.nodes))
292
+ adj_mat = torch.FloatTensor(each_prod_rule.rhs_adj_mat(edge_list + node_list).todense() + np.identity(len(edge_list)+len(node_list)))
293
+ if self.use_gpu:
294
+ adj_mat = adj_mat.cuda()
295
+ layer_wise_embed = [
296
+ self.feature_dict[each_prod_rule.rhs.edge_attr(each_edge)['symbol']]
297
+ for each_edge in edge_list]\
298
+ + [self.feature_dict[each_prod_rule.rhs.node_attr(each_node)['symbol']]
299
+ for each_node in node_list]
300
+ for each_node in each_prod_rule.ext_node.values():
301
+ layer_wise_embed[each_prod_rule.rhs.num_edges + node_list.index(each_node)] \
302
+ = layer_wise_embed[each_prod_rule.rhs.num_edges + node_list.index(each_node)] \
303
+ + self.feature_dict[('ext_id', each_prod_rule.rhs.node_attr(each_node)['ext_id'])]
304
+ layer_wise_embed = torch.stack(layer_wise_embed)
305
+
306
+ for each_layer in range(self.num_layers):
307
+ message = adj_mat @ layer_wise_embed
308
+ next_layer_embed = self.layer2layer_activation(self.layer2layer_list[each_layer](message))
309
+ out[each_batch_idx, each_idx, :] \
310
+ = out[each_batch_idx, each_idx, :] \
311
+ + self.layer2out_activation(self.layer2out_list[each_layer](message)).sum(dim=0)
312
+ layer_wise_embed = next_layer_embed
313
+ return out
models/mhg_model/images/mhg_example.png ADDED
models/mhg_model/images/mhg_example1.png ADDED
models/mhg_model/images/mhg_example2.png ADDED
models/mhg_model/load.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # Rhizome
3
+ # Version beta 0.0, August 2023
4
+ # Property of IBM Research, Accelerated Discovery
5
+ #
6
+
7
+ import os
8
+ import pickle
9
+ import sys
10
+
11
+ from rdkit import Chem
12
+ import torch
13
+ from torch_geometric.utils.smiles import from_smiles
14
+
15
+ from typing import Any, Dict, List, Optional, Union
16
+ from typing_extensions import Self
17
+
18
+ from .graph_grammar.io.smi import hg_to_mol
19
+ from .models.mhgvae import GrammarGINVAE
20
+ from huggingface_hub import hf_hub_download
21
+
22
+
23
+ class PretrainedModelWrapper:
24
+ model: GrammarGINVAE
25
+
26
+ def __init__(self, model_dict: Dict[str, Any]) -> None:
27
+ json_params = model_dict['gnn_params']
28
+ encoder_params = json_params['encoder_params']
29
+ encoder_params['node_feature_size'] = model_dict['num_features']
30
+ encoder_params['edge_feature_size'] = model_dict['num_edge_features']
31
+ self.model = GrammarGINVAE(model_dict['hrg'], rank=-1, encoder_params=encoder_params,
32
+ decoder_params=json_params['decoder_params'],
33
+ prod_rule_embed_params=json_params["prod_rule_embed_params"],
34
+ batch_size=512, max_len=model_dict['max_length'])
35
+ self.model.load_state_dict(model_dict['model_state_dict'])
36
+
37
+ self.model.eval()
38
+
39
+ def to(self, device: Union[str, int, torch.device]) -> Self:
40
+ dev_type = type(device)
41
+ if dev_type != torch.device:
42
+ if dev_type == str or torch.cuda.is_available():
43
+ device = torch.device(device)
44
+ else:
45
+ device = torch.device("mps", device)
46
+
47
+ self.model = self.model.to(device)
48
+ return self
49
+
50
+ def encode(self, data: List[str]) -> List[torch.tensor]:
51
+ # Need to encode them into a graph nn
52
+ output = []
53
+ for d in data:
54
+ params = next(self.model.parameters())
55
+ g = from_smiles(d)
56
+ if (g.cpu() and params != 'cpu') or (not g.cpu() and params == 'cpu'):
57
+ g.to(params.device)
58
+ ltvec = self.model.graph_embed(g.x, g.edge_index, g.edge_attr, g.batch)
59
+ output.append(ltvec[0])
60
+ return output
61
+
62
+ def decode(self, data: List[torch.tensor]) -> List[str]:
63
+ output = []
64
+ for d in data:
65
+ mu, logvar = self.model.get_mean_var(d.unsqueeze(0))
66
+ z = self.model.reparameterize(mu, logvar)
67
+ flags, _, hgs = self.model.decode(z)
68
+ if flags[0]:
69
+ reconstructed_mol, _ = hg_to_mol(hgs[0], True)
70
+ output.append(Chem.MolToSmiles(reconstructed_mol))
71
+ else:
72
+ output.append(None)
73
+ return output
74
+
75
+
76
+ def load(model_name: str = "models/mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle") -> Optional[
77
+ PretrainedModelWrapper]:
78
+ repo_id = "ibm/materials.mhg-ged"
79
+ filename = "mhggnn_pretrained_model_0724_2023.pickle"
80
+ file_path = hf_hub_download(repo_id=repo_id, filename=filename)
81
+ with open(file_path, "rb") as f:
82
+ model_dict = pickle.load(f)
83
+ return PretrainedModelWrapper(model_dict)
84
+ return None
models/mhg_model/mhg_gnn.egg-info/PKG-INFO ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: mhg-gnn
3
+ Version: 0.0
4
+ Summary: Package for mhg-gnn
5
+ Author: team
6
+ License: TBD
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: Programming Language :: Python :: 3.9
9
+ Description-Content-Type: text/markdown
10
+ Requires-Dist: networkx>=2.8
11
+ Requires-Dist: numpy<2.0.0,>=1.23.5
12
+ Requires-Dist: pandas>=1.5.3
13
+ Requires-Dist: rdkit-pypi<2023.9.6,>=2022.9.4
14
+ Requires-Dist: torch>=2.0.0
15
+ Requires-Dist: torchinfo>=1.8.0
16
+ Requires-Dist: torch-geometric>=2.3.1
17
+
18
+ # mhg-gnn
19
+
20
+ This repository provides PyTorch source code assosiated with our publication, "MHG-GNN: Combination of Molecular Hypergraph Grammar with Graph Neural Network"
21
+
22
+ **Paper:** [Arxiv Link](https://arxiv.org/pdf/2309.16374)
23
+
24
+ For more information contact: [email protected]
25
+
26
+ ![mhg-gnn](images/mhg_example1.png)
27
+
28
+ ## Introduction
29
+
30
+ We present MHG-GNN, an autoencoder architecture
31
+ that has an encoder based on GNN and a decoder based on a sequential model with MHG.
32
+ Since the encoder is a GNN variant, MHG-GNN can accept any molecule as input, and
33
+ demonstrate high predictive performance on molecular graph data.
34
+ In addition, the decoder inherits the theoretical guarantee of MHG on always generating a structurally valid molecule as output.
35
+
36
+ ## Table of Contents
37
+
38
+ 1. [Getting Started](#getting-started)
39
+ 1. [Pretrained Models and Training Logs](#pretrained-models-and-training-logs)
40
+ 2. [Replicating Conda Environment](#replicating-conda-environment)
41
+ 2. [Feature Extraction](#feature-extraction)
42
+
43
+ ## Getting Started
44
+
45
+ **This code and environment have been tested on Intel E5-2667 CPUs at 3.30GHz and NVIDIA A100 Tensor Core GPUs.**
46
+
47
+ ### Pretrained Models and Training Logs
48
+
49
+ We provide checkpoints of the MHG-GNN model pre-trained on a dataset of ~1.34M molecules curated from PubChem. (later) For model weights: [HuggingFace Link]()
50
+
51
+ Add the MHG-GNN `pre-trained weights.pt` to the `models/` directory according to your needs.
52
+
53
+ ### Replacicating Conda Environment
54
+
55
+ Follow these steps to replicate our Conda environment and install the necessary libraries:
56
+
57
+ ```
58
+ conda create --name mhg-gnn-env python=3.8.18
59
+ conda activate mhg-gnn-env
60
+ ```
61
+
62
+ #### Install Packages with Conda
63
+
64
+ ```
65
+ conda install -c conda-forge networkx=2.8
66
+ conda install numpy=1.23.5
67
+ # conda install -c conda-forge rdkit=2022.9.4
68
+ conda install pytorch=2.0.0 torchvision torchaudio -c pytorch
69
+ conda install -c conda-forge torchinfo=1.8.0
70
+ conda install pyg -c pyg
71
+ ```
72
+
73
+ #### Install Packages with pip
74
+ ```
75
+ pip install rdkit torch-nl==0.3 torch-scatter torch-sparse
76
+ ```
77
+
78
+ ## Feature Extraction
79
+
80
+ The example notebook [mhg-gnn_encoder_decoder_example.ipynb](notebooks/mhg-gnn_encoder_decoder_example.ipynb) contains code to load checkpoint files and use the pre-trained model for encoder and decoder tasks.
81
+
82
+ To load mhg-gnn, you can simply use:
83
+
84
+ ```python
85
+ import torch
86
+ import load
87
+
88
+ model = load.load()
89
+ ```
90
+
91
+ To encode SMILES into embeddings, you can use:
92
+
93
+ ```python
94
+ with torch.no_grad():
95
+ repr = model.encode(["CCO", "O=C=O", "OC(=O)c1ccccc1C(=O)O"])
96
+ ```
97
+
98
+ For decoder, you can use the function, so you can return from embeddings to SMILES strings:
99
+
100
+ ```python
101
+ orig = model.decode(repr)
102
+ ```
models/mhg_model/mhg_gnn.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ setup.cfg
3
+ setup.py
4
+ ./graph_grammar/__init__.py
5
+ ./graph_grammar/hypergraph.py
6
+ ./graph_grammar/algo/__init__.py
7
+ ./graph_grammar/algo/tree_decomposition.py
8
+ ./graph_grammar/graph_grammar/__init__.py
9
+ ./graph_grammar/graph_grammar/base.py
10
+ ./graph_grammar/graph_grammar/corpus.py
11
+ ./graph_grammar/graph_grammar/hrg.py
12
+ ./graph_grammar/graph_grammar/symbols.py
13
+ ./graph_grammar/graph_grammar/utils.py
14
+ ./graph_grammar/io/__init__.py
15
+ ./graph_grammar/io/smi.py
16
+ ./graph_grammar/nn/__init__.py
17
+ ./graph_grammar/nn/dataset.py
18
+ ./graph_grammar/nn/decoder.py
19
+ ./graph_grammar/nn/encoder.py
20
+ ./graph_grammar/nn/graph.py
21
+ ./models/__init__.py
22
+ ./models/mhgvae.py
23
+ graph_grammar/__init__.py
24
+ graph_grammar/hypergraph.py
25
+ graph_grammar/algo/__init__.py
26
+ graph_grammar/algo/tree_decomposition.py
27
+ graph_grammar/graph_grammar/__init__.py
28
+ graph_grammar/graph_grammar/base.py
29
+ graph_grammar/graph_grammar/corpus.py
30
+ graph_grammar/graph_grammar/hrg.py
31
+ graph_grammar/graph_grammar/symbols.py
32
+ graph_grammar/graph_grammar/utils.py
33
+ graph_grammar/io/__init__.py
34
+ graph_grammar/io/smi.py
35
+ graph_grammar/nn/__init__.py
36
+ graph_grammar/nn/dataset.py
37
+ graph_grammar/nn/decoder.py
38
+ graph_grammar/nn/encoder.py
39
+ graph_grammar/nn/graph.py
40
+ mhg_gnn.egg-info/PKG-INFO
41
+ mhg_gnn.egg-info/SOURCES.txt
42
+ mhg_gnn.egg-info/dependency_links.txt
43
+ mhg_gnn.egg-info/requires.txt
44
+ mhg_gnn.egg-info/top_level.txt
45
+ models/__init__.py
46
+ models/mhgvae.py
models/mhg_model/mhg_gnn.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
models/mhg_model/mhg_gnn.egg-info/requires.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ networkx>=2.8
2
+ numpy<2.0.0,>=1.23.5
3
+ pandas>=1.5.3
4
+ rdkit-pypi<2023.9.6,>=2022.9.4
5
+ torch>=2.0.0
6
+ torchinfo>=1.8.0
7
+ torch-geometric>=2.3.1
models/mhg_model/mhg_gnn.egg-info/top_level.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ graph_grammar
2
+ models
models/mhg_model/models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # Rhizome
3
+ # Version beta 0.0, August 2023
4
+ # Property of IBM Research, Accelerated Discovery
5
+ #