init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- .idea/.gitignore +3 -0
- .idea/inspectionProfiles/Project_Default.xml +29 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/materials.mhg-ged.iml +12 -0
- .idea/modules.xml +8 -0
- .idea/vcs.xml +6 -0
- README.md +78 -3
- __init__.py +5 -0
- __pycache__/__init__.cpython-310.pyc +0 -0
- __pycache__/load.cpython-310.pyc +0 -0
- graph_grammar/.DS_Store +0 -0
- graph_grammar/__init__.py +19 -0
- graph_grammar/__pycache__/__init__.cpython-310.pyc +0 -0
- graph_grammar/__pycache__/hypergraph.cpython-310.pyc +0 -0
- graph_grammar/algo/__init__.py +20 -0
- graph_grammar/algo/__pycache__/__init__.cpython-310.pyc +0 -0
- graph_grammar/algo/__pycache__/tree_decomposition.cpython-310.pyc +0 -0
- graph_grammar/algo/tree_decomposition.py +821 -0
- graph_grammar/graph_grammar/__init__.py +20 -0
- graph_grammar/graph_grammar/__pycache__/__init__.cpython-310.pyc +0 -0
- graph_grammar/graph_grammar/__pycache__/base.cpython-310.pyc +0 -0
- graph_grammar/graph_grammar/__pycache__/corpus.cpython-310.pyc +0 -0
- graph_grammar/graph_grammar/__pycache__/hrg.cpython-310.pyc +0 -0
- graph_grammar/graph_grammar/__pycache__/symbols.cpython-310.pyc +0 -0
- graph_grammar/graph_grammar/__pycache__/utils.cpython-310.pyc +0 -0
- graph_grammar/graph_grammar/base.py +30 -0
- graph_grammar/graph_grammar/corpus.py +152 -0
- graph_grammar/graph_grammar/hrg.py +1065 -0
- graph_grammar/graph_grammar/symbols.py +180 -0
- graph_grammar/graph_grammar/utils.py +130 -0
- graph_grammar/hypergraph.py +544 -0
- graph_grammar/io/__init__.py +20 -0
- graph_grammar/io/__pycache__/__init__.cpython-310.pyc +0 -0
- graph_grammar/io/__pycache__/smi.cpython-310.pyc +0 -0
- graph_grammar/io/smi.py +559 -0
- graph_grammar/nn/__init__.py +11 -0
- graph_grammar/nn/__pycache__/__init__.cpython-310.pyc +0 -0
- graph_grammar/nn/__pycache__/decoder.cpython-310.pyc +0 -0
- graph_grammar/nn/__pycache__/encoder.cpython-310.pyc +0 -0
- graph_grammar/nn/dataset.py +121 -0
- graph_grammar/nn/decoder.py +158 -0
- graph_grammar/nn/encoder.py +199 -0
- graph_grammar/nn/graph.py +313 -0
- images/mhg_example.png +0 -0
- images/mhg_example1.png +0 -0
- images/mhg_example2.png +0 -0
- load.py +83 -0
- mhg_gnn.egg-info/PKG-INFO +102 -0
- mhg_gnn.egg-info/SOURCES.txt +46 -0
.DS_Store
ADDED
Binary file (10.2 kB). View file
|
|
.idea/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<profile version="1.0">
|
3 |
+
<option name="myName" value="Project Default" />
|
4 |
+
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
5 |
+
<option name="ignoredPackages">
|
6 |
+
<value>
|
7 |
+
<list size="16">
|
8 |
+
<item index="0" class="java.lang.String" itemvalue="accelerate" />
|
9 |
+
<item index="1" class="java.lang.String" itemvalue="matplotlib" />
|
10 |
+
<item index="2" class="java.lang.String" itemvalue="torch-geometric" />
|
11 |
+
<item index="3" class="java.lang.String" itemvalue="torchinfo" />
|
12 |
+
<item index="4" class="java.lang.String" itemvalue="caikit" />
|
13 |
+
<item index="5" class="java.lang.String" itemvalue="pytorch-fast-transformers" />
|
14 |
+
<item index="6" class="java.lang.String" itemvalue="e3nn" />
|
15 |
+
<item index="7" class="java.lang.String" itemvalue="rdkit" />
|
16 |
+
<item index="8" class="java.lang.String" itemvalue="PyImpetus" />
|
17 |
+
<item index="9" class="java.lang.String" itemvalue="torch-scatter" />
|
18 |
+
<item index="10" class="java.lang.String" itemvalue="torch-nl" />
|
19 |
+
<item index="11" class="java.lang.String" itemvalue="torch-sparse" />
|
20 |
+
<item index="12" class="java.lang.String" itemvalue="mordred" />
|
21 |
+
<item index="13" class="java.lang.String" itemvalue="xgboost" />
|
22 |
+
<item index="14" class="java.lang.String" itemvalue="mamba-ssm" />
|
23 |
+
<item index="15" class="java.lang.String" itemvalue="evaluate" />
|
24 |
+
</list>
|
25 |
+
</value>
|
26 |
+
</option>
|
27 |
+
</inspection_tool>
|
28 |
+
</profile>
|
29 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<settings>
|
3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
4 |
+
<version value="1.0" />
|
5 |
+
</settings>
|
6 |
+
</component>
|
.idea/materials.mhg-ged.iml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="PYTHON_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager">
|
4 |
+
<content url="file://$MODULE_DIR$" />
|
5 |
+
<orderEntry type="inheritedJdk" />
|
6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
7 |
+
</component>
|
8 |
+
<component name="PyDocumentationSettings">
|
9 |
+
<option name="format" value="NUMPY" />
|
10 |
+
<option name="myDocStringFormat" value="NumPy" />
|
11 |
+
</component>
|
12 |
+
</module>
|
.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/materials.mhg-ged.iml" filepath="$PROJECT_DIR$/.idea/materials.mhg-ged.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
.idea/vcs.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="VcsDirectoryMappings">
|
4 |
+
<mapping directory="" vcs="Git" />
|
5 |
+
</component>
|
6 |
+
</project>
|
README.md
CHANGED
@@ -1,3 +1,78 @@
|
|
1 |
-
---
|
2 |
-
license: apache-2.0
|
3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
---
|
4 |
+
# mhg-gnn
|
5 |
+
|
6 |
+
This repository provides PyTorch source code assosiated with our publication, "MHG-GNN: Combination of Molecular Hypergraph Grammar with Graph Neural Network"
|
7 |
+
|
8 |
+
**Paper:** [Arxiv Link](https://arxiv.org/pdf/2309.16374)
|
9 |
+
|
10 |
+
![mhg-gnn](images/mhg_example1.png)
|
11 |
+
|
12 |
+
## Introduction
|
13 |
+
|
14 |
+
We present MHG-GNN, an autoencoder architecture
|
15 |
+
that has an encoder based on GNN and a decoder based on a sequential model with MHG.
|
16 |
+
Since the encoder is a GNN variant, MHG-GNN can accept any molecule as input, and
|
17 |
+
demonstrate high predictive performance on molecular graph data.
|
18 |
+
In addition, the decoder inherits the theoretical guarantee of MHG on always generating a structurally valid molecule as output.
|
19 |
+
|
20 |
+
## Table of Contents
|
21 |
+
|
22 |
+
1. [Getting Started](#getting-started)
|
23 |
+
1. [Pretrained Models and Training Logs](#pretrained-models-and-training-logs)
|
24 |
+
2. [Installation](#installation)
|
25 |
+
2. [Feature Extraction](#feature-extraction)
|
26 |
+
|
27 |
+
## Getting Started
|
28 |
+
|
29 |
+
**This code and environment have been tested on Intel E5-2667 CPUs at 3.30GHz and NVIDIA A100 Tensor Core GPUs.**
|
30 |
+
|
31 |
+
### Pretrained Models and Training Logs
|
32 |
+
|
33 |
+
We provide checkpoints of the MHG-GNN model pre-trained on a dataset of ~1.34M molecules curated from PubChem. (later) For model weights: [HuggingFace Link]()
|
34 |
+
|
35 |
+
Add the MHG-GNN `pre-trained weights.pt` to the `models/` directory according to your needs.
|
36 |
+
|
37 |
+
### Installation
|
38 |
+
|
39 |
+
We recommend to create a virtual environment. For example:
|
40 |
+
|
41 |
+
```
|
42 |
+
python3 -m venv .venv
|
43 |
+
. .venv/bin/activate
|
44 |
+
```
|
45 |
+
|
46 |
+
Type the following command once the virtual environment is activated:
|
47 |
+
|
48 |
+
```
|
49 |
+
git clone [email protected]:CMD-TRL/mhg-gnn.git
|
50 |
+
cd ./mhg-gnn
|
51 |
+
pip install .
|
52 |
+
```
|
53 |
+
|
54 |
+
## Feature Extraction
|
55 |
+
|
56 |
+
The example notebook [mhg-gnn_encoder_decoder_example.ipynb](notebooks/mhg-gnn_encoder_decoder_example.ipynb) contains code to load checkpoint files and use the pre-trained model for encoder and decoder tasks.
|
57 |
+
|
58 |
+
To load mhg-gnn, you can simply use:
|
59 |
+
|
60 |
+
```python
|
61 |
+
import torch
|
62 |
+
import load
|
63 |
+
|
64 |
+
model = load.load()
|
65 |
+
```
|
66 |
+
|
67 |
+
To encode SMILES into embeddings, you can use:
|
68 |
+
|
69 |
+
```python
|
70 |
+
with torch.no_grad():
|
71 |
+
repr = model.encode(["CCO", "O=C=O", "OC(=O)c1ccccc1C(=O)O"])
|
72 |
+
```
|
73 |
+
|
74 |
+
For decoder, you can use the function, so you can return from embeddings to SMILES strings:
|
75 |
+
|
76 |
+
```python
|
77 |
+
orig = model.decode(repr)
|
78 |
+
```
|
__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
# Rhizome
|
3 |
+
# Version beta 0.0, August 2023
|
4 |
+
# Property of IBM Research, Accelerated Discovery
|
5 |
+
#
|
__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (214 Bytes). View file
|
|
__pycache__/load.cpython-310.pyc
ADDED
Binary file (3.04 kB). View file
|
|
graph_grammar/.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
graph_grammar/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
"""
|
8 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
9 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
10 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
11 |
+
"""
|
12 |
+
|
13 |
+
""" Title """
|
14 |
+
|
15 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
16 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
17 |
+
__version__ = "0.1"
|
18 |
+
__date__ = "Jan 1 2018"
|
19 |
+
|
graph_grammar/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (666 Bytes). View file
|
|
graph_grammar/__pycache__/hypergraph.cpython-310.pyc
ADDED
Binary file (15.3 kB). View file
|
|
graph_grammar/algo/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jan 1 2018"
|
20 |
+
|
graph_grammar/algo/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (659 Bytes). View file
|
|
graph_grammar/algo/__pycache__/tree_decomposition.cpython-310.pyc
ADDED
Binary file (19.5 kB). View file
|
|
graph_grammar/algo/tree_decomposition.py
ADDED
@@ -0,0 +1,821 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2017"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Dec 11 2017"
|
20 |
+
|
21 |
+
from copy import deepcopy
|
22 |
+
from itertools import combinations
|
23 |
+
from ..hypergraph import Hypergraph
|
24 |
+
import networkx as nx
|
25 |
+
import numpy as np
|
26 |
+
|
27 |
+
|
28 |
+
class CliqueTree(nx.Graph):
|
29 |
+
''' clique tree object
|
30 |
+
|
31 |
+
Attributes
|
32 |
+
----------
|
33 |
+
hg : Hypergraph
|
34 |
+
This hypergraph will be decomposed.
|
35 |
+
root_hg : Hypergraph
|
36 |
+
Hypergraph on the root node.
|
37 |
+
ident_node_dict : dict
|
38 |
+
ident_node_dict[key_node] gives a list of nodes that are identical (i.e., the adjacent hyperedges are common)
|
39 |
+
'''
|
40 |
+
def __init__(self, hg=None, **kwargs):
|
41 |
+
self.hg = deepcopy(hg)
|
42 |
+
if self.hg is not None:
|
43 |
+
self.ident_node_dict = self.hg.get_identical_node_dict()
|
44 |
+
else:
|
45 |
+
self.ident_node_dict = {}
|
46 |
+
super().__init__(**kwargs)
|
47 |
+
|
48 |
+
@property
|
49 |
+
def root_hg(self):
|
50 |
+
''' return the hypergraph on the root node
|
51 |
+
'''
|
52 |
+
return self.nodes[0]['subhg']
|
53 |
+
|
54 |
+
@root_hg.setter
|
55 |
+
def root_hg(self, hypergraph):
|
56 |
+
''' set the hypergraph on the root node
|
57 |
+
'''
|
58 |
+
self.nodes[0]['subhg'] = hypergraph
|
59 |
+
|
60 |
+
def insert_subhg(self, subhypergraph: Hypergraph) -> None:
|
61 |
+
''' insert a subhypergraph, which is extracted from a root hypergraph, into the tree.
|
62 |
+
|
63 |
+
Parameters
|
64 |
+
----------
|
65 |
+
subhg : Hypergraph
|
66 |
+
'''
|
67 |
+
num_nodes = self.number_of_nodes()
|
68 |
+
self.add_node(num_nodes, subhg=subhypergraph)
|
69 |
+
self.add_edge(num_nodes, 0)
|
70 |
+
adj_nodes = deepcopy(list(self.adj[0].keys()))
|
71 |
+
for each_node in adj_nodes:
|
72 |
+
if len(self.nodes[each_node]["subhg"].nodes.intersection(
|
73 |
+
self.nodes[num_nodes]["subhg"].nodes)\
|
74 |
+
- self.root_hg.nodes) != 0 and each_node != num_nodes:
|
75 |
+
self.remove_edge(0, each_node)
|
76 |
+
self.add_edge(each_node, num_nodes)
|
77 |
+
|
78 |
+
def to_irredundant(self) -> None:
|
79 |
+
''' convert the clique tree to be irredundant
|
80 |
+
'''
|
81 |
+
for each_node in self.hg.nodes:
|
82 |
+
subtree = self.subgraph([
|
83 |
+
each_tree_node for each_tree_node in self.nodes()\
|
84 |
+
if each_node in self.nodes[each_tree_node]["subhg"].nodes]).copy()
|
85 |
+
leaf_node_list = [x for x in subtree.nodes() if subtree.degree(x)==1]
|
86 |
+
redundant_leaf_node_list = []
|
87 |
+
for each_leaf_node in leaf_node_list:
|
88 |
+
if len(self.nodes[each_leaf_node]["subhg"].adj_edges(each_node)) == 0:
|
89 |
+
redundant_leaf_node_list.append(each_leaf_node)
|
90 |
+
for each_red_leaf_node in redundant_leaf_node_list:
|
91 |
+
current_node = each_red_leaf_node
|
92 |
+
while subtree.degree(current_node) == 1 \
|
93 |
+
and len(subtree.nodes[current_node]["subhg"].adj_edges(each_node)) == 0:
|
94 |
+
self.nodes[current_node]["subhg"].remove_node(each_node)
|
95 |
+
remove_node = current_node
|
96 |
+
current_node = list(dict(subtree[remove_node]).keys())[0]
|
97 |
+
subtree.remove_node(remove_node)
|
98 |
+
|
99 |
+
fixed_node_set = deepcopy(self.nodes)
|
100 |
+
for each_node in fixed_node_set:
|
101 |
+
if self.nodes[each_node]["subhg"].num_edges == 0:
|
102 |
+
if len(self[each_node]) == 1:
|
103 |
+
self.remove_node(each_node)
|
104 |
+
elif len(self[each_node]) == 2:
|
105 |
+
self.add_edge(*self[each_node])
|
106 |
+
self.remove_node(each_node)
|
107 |
+
else:
|
108 |
+
pass
|
109 |
+
else:
|
110 |
+
pass
|
111 |
+
|
112 |
+
redundant = True
|
113 |
+
while redundant:
|
114 |
+
redundant = False
|
115 |
+
fixed_edge_set = deepcopy(self.edges)
|
116 |
+
remove_node_set = set()
|
117 |
+
for node_1, node_2 in fixed_edge_set:
|
118 |
+
if node_1 in remove_node_set or node_2 in remove_node_set:
|
119 |
+
pass
|
120 |
+
else:
|
121 |
+
if self.nodes[node_1]['subhg'].is_subhg(self.nodes[node_2]['subhg']):
|
122 |
+
redundant = True
|
123 |
+
adj_node_list = set(self.adj[node_1]) - {node_2}
|
124 |
+
self.remove_node(node_1)
|
125 |
+
remove_node_set.add(node_1)
|
126 |
+
for each_node in adj_node_list:
|
127 |
+
self.add_edge(node_2, each_node)
|
128 |
+
|
129 |
+
elif self.nodes[node_2]['subhg'].is_subhg(self.nodes[node_1]['subhg']):
|
130 |
+
redundant = True
|
131 |
+
adj_node_list = set(self.adj[node_2]) - {node_1}
|
132 |
+
self.remove_node(node_2)
|
133 |
+
remove_node_set.add(node_2)
|
134 |
+
for each_node in adj_node_list:
|
135 |
+
self.add_edge(node_1, each_node)
|
136 |
+
|
137 |
+
def node_update(self, key_node: str, subhg) -> None:
|
138 |
+
""" given a pair of a hypergraph, H, and its subhypergraph, sH, return a hypergraph H\sH.
|
139 |
+
|
140 |
+
Parameters
|
141 |
+
----------
|
142 |
+
key_node : str
|
143 |
+
key node that must be removed.
|
144 |
+
subhg : Hypegraph
|
145 |
+
"""
|
146 |
+
for each_edge in subhg.edges:
|
147 |
+
self.root_hg.remove_edge(each_edge)
|
148 |
+
self.root_hg.remove_nodes(self.ident_node_dict[key_node])
|
149 |
+
|
150 |
+
adj_node_list = list(subhg.nodes)
|
151 |
+
for each_node in subhg.nodes:
|
152 |
+
if each_node not in self.ident_node_dict[key_node]:
|
153 |
+
if set(self.root_hg.adj_edges(each_node)).issubset(subhg.edges):
|
154 |
+
self.root_hg.remove_node(each_node)
|
155 |
+
adj_node_list.remove(each_node)
|
156 |
+
else:
|
157 |
+
adj_node_list.remove(each_node)
|
158 |
+
|
159 |
+
for each_node_1, each_node_2 in combinations(adj_node_list, 2):
|
160 |
+
if not self.root_hg.is_adj(each_node_1, each_node_2):
|
161 |
+
self.root_hg.add_edge(set([each_node_1, each_node_2]), attr_dict=dict(tmp=True))
|
162 |
+
|
163 |
+
subhg.remove_edges_with_attr({'tmp' : True})
|
164 |
+
self.insert_subhg(subhg)
|
165 |
+
|
166 |
+
def update(self, subhg, remove_nodes=False):
|
167 |
+
""" given a pair of a hypergraph, H, and its subhypergraph, sH, return a hypergraph H\sH.
|
168 |
+
|
169 |
+
Parameters
|
170 |
+
----------
|
171 |
+
subhg : Hypegraph
|
172 |
+
"""
|
173 |
+
for each_edge in subhg.edges:
|
174 |
+
self.root_hg.remove_edge(each_edge)
|
175 |
+
if remove_nodes:
|
176 |
+
remove_edge_list = []
|
177 |
+
for each_edge in self.root_hg.edges:
|
178 |
+
if set(self.root_hg.nodes_in_edge(each_edge)).issubset(subhg.nodes)\
|
179 |
+
and self.root_hg.edge_attr(each_edge).get('tmp', False):
|
180 |
+
remove_edge_list.append(each_edge)
|
181 |
+
self.root_hg.remove_edges(remove_edge_list)
|
182 |
+
|
183 |
+
adj_node_list = list(subhg.nodes)
|
184 |
+
for each_node in subhg.nodes:
|
185 |
+
if self.root_hg.degree(each_node) == 0:
|
186 |
+
self.root_hg.remove_node(each_node)
|
187 |
+
adj_node_list.remove(each_node)
|
188 |
+
|
189 |
+
if len(adj_node_list) != 1 and not remove_nodes:
|
190 |
+
self.root_hg.add_edge(set(adj_node_list), attr_dict=dict(tmp=True))
|
191 |
+
'''
|
192 |
+
else:
|
193 |
+
for each_node_1, each_node_2 in combinations(adj_node_list, 2):
|
194 |
+
if not self.root_hg.is_adj(each_node_1, each_node_2):
|
195 |
+
self.root_hg.add_edge(
|
196 |
+
[each_node_1, each_node_2], attr_dict=dict(tmp=True))
|
197 |
+
'''
|
198 |
+
subhg.remove_edges_with_attr({'tmp':True})
|
199 |
+
self.insert_subhg(subhg)
|
200 |
+
|
201 |
+
|
202 |
+
def _get_min_deg_node(hg, ident_node_dict: dict, mode='mol'):
|
203 |
+
if mode == 'standard':
|
204 |
+
degree_dict = hg.degrees()
|
205 |
+
min_deg_node = min(degree_dict, key=degree_dict.get)
|
206 |
+
min_deg_subhg = hg.adj_subhg(min_deg_node, ident_node_dict)
|
207 |
+
return min_deg_node, min_deg_subhg
|
208 |
+
elif mode == 'mol':
|
209 |
+
degree_dict = hg.degrees()
|
210 |
+
min_deg = min(degree_dict.values())
|
211 |
+
min_deg_node_list = [each_node for each_node in hg.nodes if degree_dict[each_node]==min_deg]
|
212 |
+
min_deg_subhg_list = [hg.adj_subhg(each_min_deg_node, ident_node_dict)
|
213 |
+
for each_min_deg_node in min_deg_node_list]
|
214 |
+
best_score = np.inf
|
215 |
+
best_idx = -1
|
216 |
+
for each_idx in range(len(min_deg_subhg_list)):
|
217 |
+
if min_deg_subhg_list[each_idx].num_nodes < best_score:
|
218 |
+
best_idx = each_idx
|
219 |
+
return min_deg_node_list[each_idx], min_deg_subhg_list[each_idx]
|
220 |
+
else:
|
221 |
+
raise ValueError
|
222 |
+
|
223 |
+
|
224 |
+
def tree_decomposition(hg, irredundant=True):
|
225 |
+
""" compute a tree decomposition of the input hypergraph
|
226 |
+
|
227 |
+
Parameters
|
228 |
+
----------
|
229 |
+
hg : Hypergraph
|
230 |
+
hypergraph to be decomposed
|
231 |
+
irredundant : bool
|
232 |
+
if True, irredundant tree decomposition will be computed.
|
233 |
+
|
234 |
+
Returns
|
235 |
+
-------
|
236 |
+
clique_tree : nx.Graph
|
237 |
+
each node contains a subhypergraph of `hg`
|
238 |
+
"""
|
239 |
+
org_hg = hg.copy()
|
240 |
+
ident_node_dict = hg.get_identical_node_dict()
|
241 |
+
clique_tree = CliqueTree(org_hg)
|
242 |
+
clique_tree.add_node(0, subhg=org_hg)
|
243 |
+
while True:
|
244 |
+
degree_dict = org_hg.degrees()
|
245 |
+
min_deg_node = min(degree_dict, key=degree_dict.get)
|
246 |
+
min_deg_subhg = org_hg.adj_subhg(min_deg_node, ident_node_dict)
|
247 |
+
if org_hg.nodes == min_deg_subhg.nodes:
|
248 |
+
break
|
249 |
+
|
250 |
+
# org_hg and min_deg_subhg are divided
|
251 |
+
clique_tree.node_update(min_deg_node, min_deg_subhg)
|
252 |
+
|
253 |
+
clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
|
254 |
+
|
255 |
+
if irredundant:
|
256 |
+
clique_tree.to_irredundant()
|
257 |
+
return clique_tree
|
258 |
+
|
259 |
+
|
260 |
+
def tree_decomposition_with_hrg(hg, hrg, irredundant=True, return_root=False):
|
261 |
+
''' compute a tree decomposition given a hyperedge replacement grammar.
|
262 |
+
the resultant clique tree should induce a less compact HRG.
|
263 |
+
|
264 |
+
Parameters
|
265 |
+
----------
|
266 |
+
hg : Hypergraph
|
267 |
+
hypergraph to be decomposed
|
268 |
+
hrg : HyperedgeReplacementGrammar
|
269 |
+
current HRG
|
270 |
+
irredundant : bool
|
271 |
+
if True, irredundant tree decomposition will be computed.
|
272 |
+
|
273 |
+
Returns
|
274 |
+
-------
|
275 |
+
clique_tree : nx.Graph
|
276 |
+
each node contains a subhypergraph of `hg`
|
277 |
+
'''
|
278 |
+
org_hg = hg.copy()
|
279 |
+
ident_node_dict = hg.get_identical_node_dict()
|
280 |
+
clique_tree = CliqueTree(org_hg)
|
281 |
+
clique_tree.add_node(0, subhg=org_hg)
|
282 |
+
root_node = 0
|
283 |
+
|
284 |
+
# construct a clique tree using HRG
|
285 |
+
success_any = True
|
286 |
+
while success_any:
|
287 |
+
success_any = False
|
288 |
+
for each_prod_rule in hrg.prod_rule_list:
|
289 |
+
org_hg, success, subhg = each_prod_rule.revert(org_hg, True)
|
290 |
+
if success:
|
291 |
+
if each_prod_rule.is_start_rule: root_node = clique_tree.number_of_nodes()
|
292 |
+
success_any = True
|
293 |
+
subhg.remove_edges_with_attr({'terminal' : False})
|
294 |
+
clique_tree.root_hg = org_hg
|
295 |
+
clique_tree.insert_subhg(subhg)
|
296 |
+
|
297 |
+
clique_tree.root_hg = org_hg
|
298 |
+
|
299 |
+
for each_edge in deepcopy(org_hg.edges):
|
300 |
+
if not org_hg.edge_attr(each_edge)['terminal']:
|
301 |
+
node_list = org_hg.nodes_in_edge(each_edge)
|
302 |
+
org_hg.remove_edge(each_edge)
|
303 |
+
|
304 |
+
for each_node_1, each_node_2 in combinations(node_list, 2):
|
305 |
+
if not org_hg.is_adj(each_node_1, each_node_2):
|
306 |
+
org_hg.add_edge([each_node_1, each_node_2], attr_dict=dict(tmp=True))
|
307 |
+
|
308 |
+
# construct a clique tree using the existing algorithm
|
309 |
+
degree_dict = org_hg.degrees()
|
310 |
+
if degree_dict:
|
311 |
+
while True:
|
312 |
+
min_deg_node, min_deg_subhg = _get_min_deg_node(org_hg, ident_node_dict)
|
313 |
+
if org_hg.nodes == min_deg_subhg.nodes: break
|
314 |
+
|
315 |
+
# org_hg and min_deg_subhg are divided
|
316 |
+
clique_tree.node_update(min_deg_node, min_deg_subhg)
|
317 |
+
|
318 |
+
clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
|
319 |
+
if irredundant:
|
320 |
+
clique_tree.to_irredundant()
|
321 |
+
|
322 |
+
if return_root:
|
323 |
+
if root_node == 0 and 0 not in clique_tree.nodes:
|
324 |
+
root_node = clique_tree.number_of_nodes()
|
325 |
+
while root_node not in clique_tree.nodes:
|
326 |
+
root_node -= 1
|
327 |
+
elif root_node not in clique_tree.nodes:
|
328 |
+
while root_node not in clique_tree.nodes:
|
329 |
+
root_node -= 1
|
330 |
+
else:
|
331 |
+
pass
|
332 |
+
return clique_tree, root_node
|
333 |
+
else:
|
334 |
+
return clique_tree
|
335 |
+
|
336 |
+
|
337 |
+
def tree_decomposition_from_leaf(hg, irredundant=True):
|
338 |
+
""" compute a tree decomposition of the input hypergraph
|
339 |
+
|
340 |
+
Parameters
|
341 |
+
----------
|
342 |
+
hg : Hypergraph
|
343 |
+
hypergraph to be decomposed
|
344 |
+
irredundant : bool
|
345 |
+
if True, irredundant tree decomposition will be computed.
|
346 |
+
|
347 |
+
Returns
|
348 |
+
-------
|
349 |
+
clique_tree : nx.Graph
|
350 |
+
each node contains a subhypergraph of `hg`
|
351 |
+
"""
|
352 |
+
def apply_normal_decomposition(clique_tree):
|
353 |
+
degree_dict = clique_tree.root_hg.degrees()
|
354 |
+
min_deg_node = min(degree_dict, key=degree_dict.get)
|
355 |
+
min_deg_subhg = clique_tree.root_hg.adj_subhg(min_deg_node, clique_tree.ident_node_dict)
|
356 |
+
if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
|
357 |
+
return clique_tree, False
|
358 |
+
clique_tree.node_update(min_deg_node, min_deg_subhg)
|
359 |
+
return clique_tree, True
|
360 |
+
|
361 |
+
def apply_min_edge_deg_decomposition(clique_tree):
|
362 |
+
edge_degree_dict = clique_tree.root_hg.edge_degrees()
|
363 |
+
non_tmp_edge_list = [each_edge for each_edge in clique_tree.root_hg.edges \
|
364 |
+
if not clique_tree.root_hg.edge_attr(each_edge).get('tmp')]
|
365 |
+
if not non_tmp_edge_list:
|
366 |
+
return clique_tree, False
|
367 |
+
min_deg_edge = None
|
368 |
+
min_deg = np.inf
|
369 |
+
for each_edge in non_tmp_edge_list:
|
370 |
+
if min_deg > edge_degree_dict[each_edge]:
|
371 |
+
min_deg_edge = each_edge
|
372 |
+
min_deg = edge_degree_dict[each_edge]
|
373 |
+
node_list = clique_tree.root_hg.nodes_in_edge(min_deg_edge)
|
374 |
+
min_deg_subhg = clique_tree.root_hg.get_subhg(
|
375 |
+
node_list, [min_deg_edge], clique_tree.ident_node_dict)
|
376 |
+
if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
|
377 |
+
return clique_tree, False
|
378 |
+
clique_tree.update(min_deg_subhg)
|
379 |
+
return clique_tree, True
|
380 |
+
|
381 |
+
org_hg = hg.copy()
|
382 |
+
clique_tree = CliqueTree(org_hg)
|
383 |
+
clique_tree.add_node(0, subhg=org_hg)
|
384 |
+
|
385 |
+
success = True
|
386 |
+
while success:
|
387 |
+
clique_tree, success = apply_min_edge_deg_decomposition(clique_tree)
|
388 |
+
if not success:
|
389 |
+
clique_tree, success = apply_normal_decomposition(clique_tree)
|
390 |
+
|
391 |
+
clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
|
392 |
+
if irredundant:
|
393 |
+
clique_tree.to_irredundant()
|
394 |
+
return clique_tree
|
395 |
+
|
396 |
+
def topological_tree_decomposition(
|
397 |
+
hg, irredundant=True, rip_labels=True, shrink_cycle=False, contract_cycles=False):
|
398 |
+
''' compute a tree decomposition of the input hypergraph
|
399 |
+
|
400 |
+
Parameters
|
401 |
+
----------
|
402 |
+
hg : Hypergraph
|
403 |
+
hypergraph to be decomposed
|
404 |
+
irredundant : bool
|
405 |
+
if True, irredundant tree decomposition will be computed.
|
406 |
+
|
407 |
+
Returns
|
408 |
+
-------
|
409 |
+
clique_tree : CliqueTree
|
410 |
+
each node contains a subhypergraph of `hg`
|
411 |
+
'''
|
412 |
+
def _contract_tree(clique_tree):
|
413 |
+
''' contract a single leaf
|
414 |
+
|
415 |
+
Parameters
|
416 |
+
----------
|
417 |
+
clique_tree : CliqueTree
|
418 |
+
|
419 |
+
Returns
|
420 |
+
-------
|
421 |
+
CliqueTree, bool
|
422 |
+
bool represents whether this operation succeeds or not.
|
423 |
+
'''
|
424 |
+
edge_degree_dict = clique_tree.root_hg.edge_degrees()
|
425 |
+
leaf_edge_list = [each_edge for each_edge in clique_tree.root_hg.edges \
|
426 |
+
if (not clique_tree.root_hg.edge_attr(each_edge).get('tmp'))\
|
427 |
+
and edge_degree_dict[each_edge] == 1]
|
428 |
+
if not leaf_edge_list:
|
429 |
+
return clique_tree, False
|
430 |
+
min_deg_edge = leaf_edge_list[0]
|
431 |
+
node_list = clique_tree.root_hg.nodes_in_edge(min_deg_edge)
|
432 |
+
min_deg_subhg = clique_tree.root_hg.get_subhg(
|
433 |
+
node_list, [min_deg_edge], clique_tree.ident_node_dict)
|
434 |
+
if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
|
435 |
+
return clique_tree, False
|
436 |
+
clique_tree.update(min_deg_subhg)
|
437 |
+
return clique_tree, True
|
438 |
+
|
439 |
+
def _rip_labels_from_cycles(clique_tree, org_hg):
|
440 |
+
''' rip hyperedge-labels off
|
441 |
+
|
442 |
+
Parameters
|
443 |
+
----------
|
444 |
+
clique_tree : CliqueTree
|
445 |
+
org_hg : Hypergraph
|
446 |
+
|
447 |
+
Returns
|
448 |
+
-------
|
449 |
+
CliqueTree, bool
|
450 |
+
bool represents whether this operation succeeds or not.
|
451 |
+
'''
|
452 |
+
ident_node_dict = clique_tree.ident_node_dict #hg.get_identical_node_dict()
|
453 |
+
for each_edge in clique_tree.root_hg.edges:
|
454 |
+
if each_edge in org_hg.edges:
|
455 |
+
if org_hg.in_cycle(each_edge):
|
456 |
+
node_list = clique_tree.root_hg.nodes_in_edge(each_edge)
|
457 |
+
subhg = clique_tree.root_hg.get_subhg(
|
458 |
+
node_list, [each_edge], ident_node_dict)
|
459 |
+
if clique_tree.root_hg.nodes == subhg.nodes:
|
460 |
+
return clique_tree, False
|
461 |
+
clique_tree.update(subhg)
|
462 |
+
'''
|
463 |
+
in_cycle_dict = {each_node: org_hg.node_attr(each_node)['is_in_ring'] for each_node in node_list}
|
464 |
+
if not all(in_cycle_dict.values()):
|
465 |
+
node_not_in_cycle = [each_node for each_node in in_cycle_dict.keys() if not in_cycle_dict[each_node]][0]
|
466 |
+
node_list = [node_not_in_cycle]
|
467 |
+
node_list.extend(clique_tree.root_hg.adj_nodes(node_not_in_cycle))
|
468 |
+
edge_list = clique_tree.root_hg.adj_edges(node_not_in_cycle)
|
469 |
+
import pdb; pdb.set_trace()
|
470 |
+
subhg = clique_tree.root_hg.get_subhg(
|
471 |
+
node_list, edge_list, ident_node_dict)
|
472 |
+
|
473 |
+
clique_tree.update(subhg)
|
474 |
+
'''
|
475 |
+
return clique_tree, True
|
476 |
+
return clique_tree, False
|
477 |
+
|
478 |
+
def _shrink_cycle(clique_tree):
|
479 |
+
''' shrink a cycle
|
480 |
+
|
481 |
+
Parameters
|
482 |
+
----------
|
483 |
+
clique_tree : CliqueTree
|
484 |
+
|
485 |
+
Returns
|
486 |
+
-------
|
487 |
+
CliqueTree, bool
|
488 |
+
bool represents whether this operation succeeds or not.
|
489 |
+
'''
|
490 |
+
def filter_subhg(subhg, hg, key_node):
|
491 |
+
num_nodes_cycle = 0
|
492 |
+
nodes_in_cycle_list = []
|
493 |
+
for each_node in subhg.nodes:
|
494 |
+
if hg.in_cycle(each_node):
|
495 |
+
num_nodes_cycle += 1
|
496 |
+
if each_node != key_node:
|
497 |
+
nodes_in_cycle_list.append(each_node)
|
498 |
+
if num_nodes_cycle > 3:
|
499 |
+
break
|
500 |
+
if num_nodes_cycle != 3:
|
501 |
+
return False
|
502 |
+
else:
|
503 |
+
for each_edge in hg.edges:
|
504 |
+
if set(nodes_in_cycle_list).issubset(hg.nodes_in_edge(each_edge)):
|
505 |
+
return False
|
506 |
+
return True
|
507 |
+
|
508 |
+
#ident_node_dict = hg.get_identical_node_dict()
|
509 |
+
ident_node_dict = clique_tree.ident_node_dict
|
510 |
+
for each_node in clique_tree.root_hg.nodes:
|
511 |
+
if clique_tree.root_hg.in_cycle(each_node)\
|
512 |
+
and filter_subhg(clique_tree.root_hg.adj_subhg(each_node, ident_node_dict),
|
513 |
+
clique_tree.root_hg,
|
514 |
+
each_node):
|
515 |
+
target_node = each_node
|
516 |
+
target_subhg = clique_tree.root_hg.adj_subhg(target_node, ident_node_dict)
|
517 |
+
if clique_tree.root_hg.nodes == target_subhg.nodes:
|
518 |
+
return clique_tree, False
|
519 |
+
clique_tree.update(target_subhg)
|
520 |
+
return clique_tree, True
|
521 |
+
return clique_tree, False
|
522 |
+
|
523 |
+
def _contract_cycles(clique_tree):
|
524 |
+
'''
|
525 |
+
remove a subhypergraph that looks like a cycle on a leaf.
|
526 |
+
|
527 |
+
Parameters
|
528 |
+
----------
|
529 |
+
clique_tree : CliqueTree
|
530 |
+
|
531 |
+
Returns
|
532 |
+
-------
|
533 |
+
CliqueTree, bool
|
534 |
+
bool represents whether this operation succeeds or not.
|
535 |
+
'''
|
536 |
+
def _divide_hg(hg):
|
537 |
+
''' divide a hypergraph into subhypergraphs such that
|
538 |
+
each subhypergraph is connected to each other in a tree-like way.
|
539 |
+
|
540 |
+
Parameters
|
541 |
+
----------
|
542 |
+
hg : Hypergraph
|
543 |
+
|
544 |
+
Returns
|
545 |
+
-------
|
546 |
+
list of Hypergraphs
|
547 |
+
each element corresponds to a subhypergraph of `hg`
|
548 |
+
'''
|
549 |
+
for each_node in hg.nodes:
|
550 |
+
if hg.is_dividable(each_node):
|
551 |
+
adj_edges_dict = {each_edge: hg.in_cycle(each_edge) for each_edge in hg.adj_edges(each_node)}
|
552 |
+
'''
|
553 |
+
if any(adj_edges_dict.values()):
|
554 |
+
import pdb; pdb.set_trace()
|
555 |
+
edge_in_cycle = [each_key for each_key, each_val in adj_edges_dict.items() if each_val][0]
|
556 |
+
subhg1, subhg2, subhg3 = hg.divide(each_node, edge_in_cycle)
|
557 |
+
return _divide_hg(subhg1) + _divide_hg(subhg2) + _divide_hg(subhg3)
|
558 |
+
else:
|
559 |
+
'''
|
560 |
+
subhg1, subhg2 = hg.divide(each_node)
|
561 |
+
return _divide_hg(subhg1) + _divide_hg(subhg2)
|
562 |
+
return [hg]
|
563 |
+
|
564 |
+
def _is_leaf(hg, divided_subhg) -> bool:
|
565 |
+
''' judge whether subhg is a leaf-like in the original hypergraph
|
566 |
+
|
567 |
+
Parameters
|
568 |
+
----------
|
569 |
+
hg : Hypergraph
|
570 |
+
divided_subhg : Hypergraph
|
571 |
+
`divided_subhg` is a subhypergraph of `hg`
|
572 |
+
|
573 |
+
Returns
|
574 |
+
-------
|
575 |
+
bool
|
576 |
+
'''
|
577 |
+
'''
|
578 |
+
adj_edges_set = set([])
|
579 |
+
for each_node in divided_subhg.nodes:
|
580 |
+
adj_edges_set.update(set(hg.adj_edges(each_node)))
|
581 |
+
|
582 |
+
|
583 |
+
_hg = deepcopy(hg)
|
584 |
+
_hg.remove_subhg(divided_subhg)
|
585 |
+
if nx.is_connected(_hg.hg) != (len(adj_edges_set - divided_subhg.edges) == 1):
|
586 |
+
import pdb; pdb.set_trace()
|
587 |
+
return len(adj_edges_set - divided_subhg.edges) == 1
|
588 |
+
'''
|
589 |
+
_hg = deepcopy(hg)
|
590 |
+
_hg.remove_subhg(divided_subhg)
|
591 |
+
return nx.is_connected(_hg.hg)
|
592 |
+
|
593 |
+
subhg_list = _divide_hg(clique_tree.root_hg)
|
594 |
+
if len(subhg_list) == 1:
|
595 |
+
return clique_tree, False
|
596 |
+
else:
|
597 |
+
while len(subhg_list) > 1:
|
598 |
+
max_leaf_subhg = None
|
599 |
+
for each_subhg in subhg_list:
|
600 |
+
if _is_leaf(clique_tree.root_hg, each_subhg):
|
601 |
+
if max_leaf_subhg is None:
|
602 |
+
max_leaf_subhg = each_subhg
|
603 |
+
elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
|
604 |
+
max_leaf_subhg = each_subhg
|
605 |
+
clique_tree.update(max_leaf_subhg)
|
606 |
+
subhg_list.remove(max_leaf_subhg)
|
607 |
+
return clique_tree, True
|
608 |
+
|
609 |
+
org_hg = hg.copy()
|
610 |
+
clique_tree = CliqueTree(org_hg)
|
611 |
+
clique_tree.add_node(0, subhg=org_hg)
|
612 |
+
|
613 |
+
success = True
|
614 |
+
while success:
|
615 |
+
'''
|
616 |
+
clique_tree, success = _rip_labels_from_cycles(clique_tree, hg)
|
617 |
+
if not success:
|
618 |
+
clique_tree, success = _contract_cycles(clique_tree)
|
619 |
+
'''
|
620 |
+
clique_tree, success = _contract_tree(clique_tree)
|
621 |
+
if not success:
|
622 |
+
if rip_labels:
|
623 |
+
clique_tree, success = _rip_labels_from_cycles(clique_tree, hg)
|
624 |
+
if not success:
|
625 |
+
if shrink_cycle:
|
626 |
+
clique_tree, success = _shrink_cycle(clique_tree)
|
627 |
+
if not success:
|
628 |
+
if contract_cycles:
|
629 |
+
clique_tree, success = _contract_cycles(clique_tree)
|
630 |
+
clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
|
631 |
+
if irredundant:
|
632 |
+
clique_tree.to_irredundant()
|
633 |
+
return clique_tree
|
634 |
+
|
635 |
+
def molecular_tree_decomposition(hg, irredundant=True):
|
636 |
+
""" compute a tree decomposition of the input molecular hypergraph
|
637 |
+
|
638 |
+
Parameters
|
639 |
+
----------
|
640 |
+
hg : Hypergraph
|
641 |
+
molecular hypergraph to be decomposed
|
642 |
+
irredundant : bool
|
643 |
+
if True, irredundant tree decomposition will be computed.
|
644 |
+
|
645 |
+
Returns
|
646 |
+
-------
|
647 |
+
clique_tree : CliqueTree
|
648 |
+
each node contains a subhypergraph of `hg`
|
649 |
+
"""
|
650 |
+
def _divide_hg(hg):
|
651 |
+
''' divide a hypergraph into subhypergraphs such that
|
652 |
+
each subhypergraph is connected to each other in a tree-like way.
|
653 |
+
|
654 |
+
Parameters
|
655 |
+
----------
|
656 |
+
hg : Hypergraph
|
657 |
+
|
658 |
+
Returns
|
659 |
+
-------
|
660 |
+
list of Hypergraphs
|
661 |
+
each element corresponds to a subhypergraph of `hg`
|
662 |
+
'''
|
663 |
+
is_ring = False
|
664 |
+
for each_node in hg.nodes:
|
665 |
+
if hg.node_attr(each_node)['is_in_ring']:
|
666 |
+
is_ring = True
|
667 |
+
if not hg.node_attr(each_node)['is_in_ring'] \
|
668 |
+
and hg.degree(each_node) == 2:
|
669 |
+
subhg1, subhg2 = hg.divide(each_node)
|
670 |
+
return _divide_hg(subhg1) + _divide_hg(subhg2)
|
671 |
+
|
672 |
+
if is_ring:
|
673 |
+
subhg_list = []
|
674 |
+
remove_edge_list = []
|
675 |
+
remove_node_list = []
|
676 |
+
for each_edge in hg.edges:
|
677 |
+
node_list = hg.nodes_in_edge(each_edge)
|
678 |
+
subhg = hg.get_subhg(node_list, [each_edge], hg.get_identical_node_dict())
|
679 |
+
subhg_list.append(subhg)
|
680 |
+
remove_edge_list.append(each_edge)
|
681 |
+
for each_node in node_list:
|
682 |
+
if not hg.node_attr(each_node)['is_in_ring']:
|
683 |
+
remove_node_list.append(each_node)
|
684 |
+
hg.remove_edges(remove_edge_list)
|
685 |
+
hg.remove_nodes(remove_node_list, False)
|
686 |
+
return subhg_list + [hg]
|
687 |
+
else:
|
688 |
+
return [hg]
|
689 |
+
|
690 |
+
org_hg = hg.copy()
|
691 |
+
clique_tree = CliqueTree(org_hg)
|
692 |
+
clique_tree.add_node(0, subhg=org_hg)
|
693 |
+
|
694 |
+
subhg_list = _divide_hg(deepcopy(clique_tree.root_hg))
|
695 |
+
#_subhg_list = deepcopy(subhg_list)
|
696 |
+
if len(subhg_list) == 1:
|
697 |
+
pass
|
698 |
+
else:
|
699 |
+
while len(subhg_list) > 1:
|
700 |
+
max_leaf_subhg = None
|
701 |
+
for each_subhg in subhg_list:
|
702 |
+
if _is_leaf(clique_tree.root_hg, each_subhg) and not _is_ring(each_subhg):
|
703 |
+
if max_leaf_subhg is None:
|
704 |
+
max_leaf_subhg = each_subhg
|
705 |
+
elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
|
706 |
+
max_leaf_subhg = each_subhg
|
707 |
+
|
708 |
+
if max_leaf_subhg is None:
|
709 |
+
for each_subhg in subhg_list:
|
710 |
+
if _is_ring_label(clique_tree.root_hg, each_subhg):
|
711 |
+
if max_leaf_subhg is None:
|
712 |
+
max_leaf_subhg = each_subhg
|
713 |
+
elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
|
714 |
+
max_leaf_subhg = each_subhg
|
715 |
+
if max_leaf_subhg is not None:
|
716 |
+
clique_tree.update(max_leaf_subhg)
|
717 |
+
subhg_list.remove(max_leaf_subhg)
|
718 |
+
else:
|
719 |
+
for each_subhg in subhg_list:
|
720 |
+
if _is_leaf(clique_tree.root_hg, each_subhg):
|
721 |
+
if max_leaf_subhg is None:
|
722 |
+
max_leaf_subhg = each_subhg
|
723 |
+
elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
|
724 |
+
max_leaf_subhg = each_subhg
|
725 |
+
if max_leaf_subhg is not None:
|
726 |
+
clique_tree.update(max_leaf_subhg, True)
|
727 |
+
subhg_list.remove(max_leaf_subhg)
|
728 |
+
else:
|
729 |
+
break
|
730 |
+
if len(subhg_list) > 1:
|
731 |
+
'''
|
732 |
+
for each_idx, each_subhg in enumerate(subhg_list):
|
733 |
+
each_subhg.draw(f'{each_idx}', True)
|
734 |
+
clique_tree.root_hg.draw('root', True)
|
735 |
+
import pickle
|
736 |
+
with open('buggy_hg.pkl', 'wb') as f:
|
737 |
+
pickle.dump(hg, f)
|
738 |
+
return clique_tree, subhg_list, _subhg_list
|
739 |
+
'''
|
740 |
+
raise RuntimeError('bug in tree decomposition algorithm')
|
741 |
+
clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
|
742 |
+
|
743 |
+
'''
|
744 |
+
for each_tree_node in clique_tree.adj[0]:
|
745 |
+
subhg = clique_tree.nodes[each_tree_node]['subhg']
|
746 |
+
for each_edge in subhg.edges:
|
747 |
+
if set(subhg.nodes_in_edge(each_edge)).issubset(clique_tree.root_hg.nodes):
|
748 |
+
clique_tree.root_hg.add_edge(set(subhg.nodes_in_edge(each_edge)), attr_dict=dict(tmp=True))
|
749 |
+
'''
|
750 |
+
if irredundant:
|
751 |
+
clique_tree.to_irredundant()
|
752 |
+
return clique_tree #, _subhg_list
|
753 |
+
|
754 |
+
def _is_leaf(hg, subhg) -> bool:
|
755 |
+
''' judge whether subhg is a leaf-like in the original hypergraph
|
756 |
+
|
757 |
+
Parameters
|
758 |
+
----------
|
759 |
+
hg : Hypergraph
|
760 |
+
subhg : Hypergraph
|
761 |
+
`subhg` is a subhypergraph of `hg`
|
762 |
+
|
763 |
+
Returns
|
764 |
+
-------
|
765 |
+
bool
|
766 |
+
'''
|
767 |
+
if len(subhg.edges) == 0:
|
768 |
+
adj_edge_set = set([])
|
769 |
+
subhg_edge_set = set([])
|
770 |
+
for each_edge in hg.edges:
|
771 |
+
if set(hg.nodes_in_edge(each_edge)).issubset(subhg.nodes) and hg.edge_attr(each_edge).get('tmp', False):
|
772 |
+
subhg_edge_set.add(each_edge)
|
773 |
+
for each_node in subhg.nodes:
|
774 |
+
adj_edge_set.update(set(hg.adj_edges(each_node)))
|
775 |
+
if subhg_edge_set.issubset(adj_edge_set) and len(adj_edge_set.difference(subhg_edge_set)) == 1:
|
776 |
+
return True
|
777 |
+
else:
|
778 |
+
return False
|
779 |
+
elif len(subhg.edges) == 1:
|
780 |
+
adj_edge_set = set([])
|
781 |
+
subhg_edge_set = subhg.edges
|
782 |
+
for each_node in subhg.nodes:
|
783 |
+
for each_adj_edge in hg.adj_edges(each_node):
|
784 |
+
adj_edge_set.add(each_adj_edge)
|
785 |
+
if subhg_edge_set.issubset(adj_edge_set) and len(adj_edge_set.difference(subhg_edge_set)) == 1:
|
786 |
+
return True
|
787 |
+
else:
|
788 |
+
return False
|
789 |
+
else:
|
790 |
+
raise ValueError('subhg should be nodes only or one-edge hypergraph.')
|
791 |
+
|
792 |
+
def _is_ring_label(hg, subhg):
|
793 |
+
if len(subhg.edges) != 1:
|
794 |
+
return False
|
795 |
+
edge_name = list(subhg.edges)[0]
|
796 |
+
#assert edge_name in hg.edges, f'{edge_name}'
|
797 |
+
is_in_ring = False
|
798 |
+
for each_node in subhg.nodes:
|
799 |
+
if subhg.node_attr(each_node)['is_in_ring']:
|
800 |
+
is_in_ring = True
|
801 |
+
else:
|
802 |
+
adj_edge_list = list(hg.adj_edges(each_node))
|
803 |
+
adj_edge_list.remove(edge_name)
|
804 |
+
if len(adj_edge_list) == 1:
|
805 |
+
if not hg.edge_attr(adj_edge_list[0]).get('tmp', False):
|
806 |
+
return False
|
807 |
+
elif len(adj_edge_list) == 0:
|
808 |
+
pass
|
809 |
+
else:
|
810 |
+
raise ValueError
|
811 |
+
if is_in_ring:
|
812 |
+
return True
|
813 |
+
else:
|
814 |
+
return False
|
815 |
+
|
816 |
+
def _is_ring(hg):
|
817 |
+
for each_node in hg.nodes:
|
818 |
+
if not hg.node_attr(each_node)['is_in_ring']:
|
819 |
+
return False
|
820 |
+
return True
|
821 |
+
|
graph_grammar/graph_grammar/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jan 1 2018"
|
20 |
+
|
graph_grammar/graph_grammar/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (680 Bytes). View file
|
|
graph_grammar/graph_grammar/__pycache__/base.cpython-310.pyc
ADDED
Binary file (1.17 kB). View file
|
|
graph_grammar/graph_grammar/__pycache__/corpus.cpython-310.pyc
ADDED
Binary file (4.71 kB). View file
|
|
graph_grammar/graph_grammar/__pycache__/hrg.cpython-310.pyc
ADDED
Binary file (29.1 kB). View file
|
|
graph_grammar/graph_grammar/__pycache__/symbols.cpython-310.pyc
ADDED
Binary file (5.38 kB). View file
|
|
graph_grammar/graph_grammar/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (3.63 kB). View file
|
|
graph_grammar/graph_grammar/base.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2017"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Dec 11 2017"
|
20 |
+
|
21 |
+
from abc import ABCMeta, abstractmethod
|
22 |
+
|
23 |
+
class GraphGrammarBase(metaclass=ABCMeta):
|
24 |
+
@abstractmethod
|
25 |
+
def learn(self):
|
26 |
+
pass
|
27 |
+
|
28 |
+
@abstractmethod
|
29 |
+
def sample(self):
|
30 |
+
pass
|
graph_grammar/graph_grammar/corpus.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jun 4 2018"
|
20 |
+
|
21 |
+
from collections import Counter
|
22 |
+
from functools import partial
|
23 |
+
from .utils import _easy_node_match, _edge_match, _node_match, common_node_list, _node_match_prod_rule
|
24 |
+
from networkx.algorithms.isomorphism import GraphMatcher
|
25 |
+
import os
|
26 |
+
|
27 |
+
|
28 |
+
class CliqueTreeCorpus(object):
|
29 |
+
|
30 |
+
''' clique tree corpus
|
31 |
+
|
32 |
+
Attributes
|
33 |
+
----------
|
34 |
+
clique_tree_list : list of CliqueTree
|
35 |
+
subhg_list : list of Hypergraph
|
36 |
+
'''
|
37 |
+
|
38 |
+
def __init__(self):
|
39 |
+
self.clique_tree_list = []
|
40 |
+
self.subhg_list = []
|
41 |
+
|
42 |
+
@property
|
43 |
+
def size(self):
|
44 |
+
return len(self.subhg_list)
|
45 |
+
|
46 |
+
def add_clique_tree(self, clique_tree):
|
47 |
+
for each_node in clique_tree.nodes:
|
48 |
+
subhg = clique_tree.nodes[each_node]['subhg']
|
49 |
+
subhg_idx = self.add_subhg(subhg)
|
50 |
+
clique_tree.nodes[each_node]['subhg_idx'] = subhg_idx
|
51 |
+
self.clique_tree_list.append(clique_tree)
|
52 |
+
|
53 |
+
def add_to_subhg_list(self, clique_tree, root_node):
|
54 |
+
parent_node_dict = {}
|
55 |
+
current_node = None
|
56 |
+
parent_node_dict[root_node] = None
|
57 |
+
stack = [root_node]
|
58 |
+
while stack:
|
59 |
+
current_node = stack.pop()
|
60 |
+
current_subhg = clique_tree.nodes[current_node]['subhg']
|
61 |
+
for each_child in clique_tree.adj[current_node]:
|
62 |
+
if each_child != parent_node_dict[current_node]:
|
63 |
+
stack.append(each_child)
|
64 |
+
parent_node_dict[each_child] = current_node
|
65 |
+
if parent_node_dict[current_node] is not None:
|
66 |
+
parent_subhg = clique_tree.nodes[parent_node_dict[current_node]]['subhg']
|
67 |
+
common, _ = common_node_list(parent_subhg, current_subhg)
|
68 |
+
parent_subhg.add_edge(set(common), attr_dict={'tmp': True})
|
69 |
+
|
70 |
+
parent_node_dict = {}
|
71 |
+
current_node = None
|
72 |
+
parent_node_dict[root_node] = None
|
73 |
+
stack = [root_node]
|
74 |
+
while stack:
|
75 |
+
current_node = stack.pop()
|
76 |
+
current_subhg = clique_tree.nodes[current_node]['subhg']
|
77 |
+
for each_child in clique_tree.adj[current_node]:
|
78 |
+
if each_child != parent_node_dict[current_node]:
|
79 |
+
stack.append(each_child)
|
80 |
+
parent_node_dict[each_child] = current_node
|
81 |
+
if parent_node_dict[current_node] is not None:
|
82 |
+
parent_subhg = clique_tree.nodes[parent_node_dict[current_node]]['subhg']
|
83 |
+
common, _ = common_node_list(parent_subhg, current_subhg)
|
84 |
+
for each_idx, each_node in enumerate(common):
|
85 |
+
current_subhg.set_node_attr(each_node, {'ext_id': each_idx})
|
86 |
+
|
87 |
+
subhg_idx, is_new = self.add_subhg(current_subhg)
|
88 |
+
clique_tree.nodes[current_node]['subhg_idx'] = subhg_idx
|
89 |
+
return clique_tree
|
90 |
+
|
91 |
+
def add_subhg(self, subhg):
|
92 |
+
if len(self.subhg_list) == 0:
|
93 |
+
node_dict = {}
|
94 |
+
for each_node in subhg.nodes:
|
95 |
+
node_dict[each_node] = subhg.node_attr(each_node)['symbol'].__hash__()
|
96 |
+
node_list = []
|
97 |
+
for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
|
98 |
+
node_list.append(each_key)
|
99 |
+
for each_idx, each_node in enumerate(node_list):
|
100 |
+
subhg.node_attr(each_node)['order4hrg'] = each_idx
|
101 |
+
self.subhg_list.append(subhg)
|
102 |
+
return 0, True
|
103 |
+
else:
|
104 |
+
match = False
|
105 |
+
subhg_bond_symbol_counter \
|
106 |
+
= Counter([subhg.node_attr(each_node)['symbol'] \
|
107 |
+
for each_node in subhg.nodes])
|
108 |
+
subhg_atom_symbol_counter \
|
109 |
+
= Counter([subhg.edge_attr(each_edge).get('symbol', None) \
|
110 |
+
for each_edge in subhg.edges])
|
111 |
+
for each_idx, each_subhg in enumerate(self.subhg_list):
|
112 |
+
each_bond_symbol_counter \
|
113 |
+
= Counter([each_subhg.node_attr(each_node)['symbol'] \
|
114 |
+
for each_node in each_subhg.nodes])
|
115 |
+
each_atom_symbol_counter \
|
116 |
+
= Counter([each_subhg.edge_attr(each_edge).get('symbol', None) \
|
117 |
+
for each_edge in each_subhg.edges])
|
118 |
+
if not match \
|
119 |
+
and (subhg.num_nodes == each_subhg.num_nodes
|
120 |
+
and subhg.num_edges == each_subhg.num_edges
|
121 |
+
and subhg_bond_symbol_counter == each_bond_symbol_counter
|
122 |
+
and subhg_atom_symbol_counter == each_atom_symbol_counter):
|
123 |
+
gm = GraphMatcher(each_subhg.hg,
|
124 |
+
subhg.hg,
|
125 |
+
node_match=_easy_node_match,
|
126 |
+
edge_match=_edge_match)
|
127 |
+
try:
|
128 |
+
isomap = next(gm.isomorphisms_iter())
|
129 |
+
match = True
|
130 |
+
for each_node in each_subhg.nodes:
|
131 |
+
subhg.node_attr(isomap[each_node])['order4hrg'] \
|
132 |
+
= each_subhg.node_attr(each_node)['order4hrg']
|
133 |
+
if 'ext_id' in each_subhg.node_attr(each_node):
|
134 |
+
subhg.node_attr(isomap[each_node])['ext_id'] \
|
135 |
+
= each_subhg.node_attr(each_node)['ext_id']
|
136 |
+
return each_idx, False
|
137 |
+
except StopIteration:
|
138 |
+
match = False
|
139 |
+
if not match:
|
140 |
+
node_dict = {}
|
141 |
+
for each_node in subhg.nodes:
|
142 |
+
node_dict[each_node] = subhg.node_attr(each_node)['symbol'].__hash__()
|
143 |
+
node_list = []
|
144 |
+
for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
|
145 |
+
node_list.append(each_key)
|
146 |
+
for each_idx, each_node in enumerate(node_list):
|
147 |
+
subhg.node_attr(each_node)['order4hrg'] = each_idx
|
148 |
+
|
149 |
+
#for each_idx, each_node in enumerate(subhg.nodes):
|
150 |
+
# subhg.node_attr(each_node)['order4hrg'] = each_idx
|
151 |
+
self.subhg_list.append(subhg)
|
152 |
+
return len(self.subhg_list) - 1, True
|
graph_grammar/graph_grammar/hrg.py
ADDED
@@ -0,0 +1,1065 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2017"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Dec 11 2017"
|
20 |
+
|
21 |
+
from .corpus import CliqueTreeCorpus
|
22 |
+
from .base import GraphGrammarBase
|
23 |
+
from .symbols import TSymbol, NTSymbol, BondSymbol
|
24 |
+
from .utils import _node_match, _node_match_prod_rule, _edge_match, masked_softmax, common_node_list
|
25 |
+
from ..hypergraph import Hypergraph
|
26 |
+
from collections import Counter
|
27 |
+
from copy import deepcopy
|
28 |
+
from ..algo.tree_decomposition import (
|
29 |
+
tree_decomposition,
|
30 |
+
tree_decomposition_with_hrg,
|
31 |
+
tree_decomposition_from_leaf,
|
32 |
+
topological_tree_decomposition,
|
33 |
+
molecular_tree_decomposition)
|
34 |
+
from functools import partial
|
35 |
+
from networkx.algorithms.isomorphism import GraphMatcher
|
36 |
+
from typing import List, Dict, Tuple
|
37 |
+
import networkx as nx
|
38 |
+
import numpy as np
|
39 |
+
import torch
|
40 |
+
import os
|
41 |
+
import random
|
42 |
+
|
43 |
+
DEBUG = False
|
44 |
+
|
45 |
+
|
46 |
+
class ProductionRule(object):
|
47 |
+
""" A class of a production rule
|
48 |
+
|
49 |
+
Attributes
|
50 |
+
----------
|
51 |
+
lhs : Hypergraph or None
|
52 |
+
the left hand side of the production rule.
|
53 |
+
if None, the rule is a starting rule.
|
54 |
+
rhs : Hypergraph
|
55 |
+
the right hand side of the production rule.
|
56 |
+
"""
|
57 |
+
def __init__(self, lhs, rhs):
|
58 |
+
self.lhs = lhs
|
59 |
+
self.rhs = rhs
|
60 |
+
|
61 |
+
@property
|
62 |
+
def is_start_rule(self) -> bool:
|
63 |
+
return self.lhs.num_nodes == 0
|
64 |
+
|
65 |
+
@property
|
66 |
+
def ext_node(self) -> Dict[int, str]:
|
67 |
+
""" return a dict of external nodes
|
68 |
+
"""
|
69 |
+
if self.is_start_rule:
|
70 |
+
return {}
|
71 |
+
else:
|
72 |
+
ext_node_dict = {}
|
73 |
+
for each_node in self.lhs.nodes:
|
74 |
+
ext_node_dict[self.lhs.node_attr(each_node)["ext_id"]] = each_node
|
75 |
+
return ext_node_dict
|
76 |
+
|
77 |
+
@property
|
78 |
+
def lhs_nt_symbol(self) -> NTSymbol:
|
79 |
+
if self.is_start_rule:
|
80 |
+
return NTSymbol(degree=0, is_aromatic=False, bond_symbol_list=[])
|
81 |
+
else:
|
82 |
+
return self.lhs.edge_attr(list(self.lhs.edges)[0])['symbol']
|
83 |
+
|
84 |
+
def rhs_adj_mat(self, node_edge_list):
|
85 |
+
''' return the adjacency matrix of rhs of the production rule
|
86 |
+
'''
|
87 |
+
return nx.adjacency_matrix(self.rhs.hg, node_edge_list)
|
88 |
+
|
89 |
+
def draw(self, file_path=None):
|
90 |
+
return self.rhs.draw(file_path)
|
91 |
+
|
92 |
+
def is_same(self, prod_rule, ignore_order=False):
|
93 |
+
""" judge whether this production rule is
|
94 |
+
the same as the input one, `prod_rule`
|
95 |
+
|
96 |
+
Parameters
|
97 |
+
----------
|
98 |
+
prod_rule : ProductionRule
|
99 |
+
production rule to be compared
|
100 |
+
|
101 |
+
Returns
|
102 |
+
-------
|
103 |
+
is_same : bool
|
104 |
+
isomap : dict
|
105 |
+
isomorphism of nodes and hyperedges.
|
106 |
+
ex) {'bond_42': 'bond_37', 'bond_2': 'bond_1',
|
107 |
+
'e36': 'e11', 'e16': 'e12', 'e25': 'e18',
|
108 |
+
'bond_40': 'bond_38', 'e26': 'e21', 'bond_41': 'bond_39'}.
|
109 |
+
key comes from `prod_rule`, value comes from `self`.
|
110 |
+
"""
|
111 |
+
if self.is_start_rule:
|
112 |
+
if not prod_rule.is_start_rule:
|
113 |
+
return False, {}
|
114 |
+
else:
|
115 |
+
if prod_rule.is_start_rule:
|
116 |
+
return False, {}
|
117 |
+
else:
|
118 |
+
if prod_rule.lhs.num_nodes != self.lhs.num_nodes:
|
119 |
+
return False, {}
|
120 |
+
|
121 |
+
if prod_rule.rhs.num_nodes != self.rhs.num_nodes:
|
122 |
+
return False, {}
|
123 |
+
if prod_rule.rhs.num_edges != self.rhs.num_edges:
|
124 |
+
return False, {}
|
125 |
+
|
126 |
+
subhg_bond_symbol_counter \
|
127 |
+
= Counter([prod_rule.rhs.node_attr(each_node)['symbol'] \
|
128 |
+
for each_node in prod_rule.rhs.nodes])
|
129 |
+
each_bond_symbol_counter \
|
130 |
+
= Counter([self.rhs.node_attr(each_node)['symbol'] \
|
131 |
+
for each_node in self.rhs.nodes])
|
132 |
+
if subhg_bond_symbol_counter != each_bond_symbol_counter:
|
133 |
+
return False, {}
|
134 |
+
|
135 |
+
subhg_atom_symbol_counter \
|
136 |
+
= Counter([prod_rule.rhs.edge_attr(each_edge)['symbol'] \
|
137 |
+
for each_edge in prod_rule.rhs.edges])
|
138 |
+
each_atom_symbol_counter \
|
139 |
+
= Counter([self.rhs.edge_attr(each_edge)['symbol'] \
|
140 |
+
for each_edge in self.rhs.edges])
|
141 |
+
if subhg_atom_symbol_counter != each_atom_symbol_counter:
|
142 |
+
return False, {}
|
143 |
+
|
144 |
+
gm = GraphMatcher(prod_rule.rhs.hg,
|
145 |
+
self.rhs.hg,
|
146 |
+
partial(_node_match_prod_rule,
|
147 |
+
ignore_order=ignore_order),
|
148 |
+
partial(_edge_match,
|
149 |
+
ignore_order=ignore_order))
|
150 |
+
try:
|
151 |
+
return True, next(gm.isomorphisms_iter())
|
152 |
+
except StopIteration:
|
153 |
+
return False, {}
|
154 |
+
|
155 |
+
def applied_to(self,
|
156 |
+
hg: Hypergraph,
|
157 |
+
edge: str) -> Tuple[Hypergraph, List[str]]:
|
158 |
+
""" augment `hg` by replacing `edge` with `self.rhs`.
|
159 |
+
|
160 |
+
Parameters
|
161 |
+
----------
|
162 |
+
hg : Hypergraph
|
163 |
+
edge : str
|
164 |
+
`edge` must belong to `hg`
|
165 |
+
|
166 |
+
Returns
|
167 |
+
-------
|
168 |
+
hg : Hypergraph
|
169 |
+
resultant hypergraph
|
170 |
+
nt_edge_list : list
|
171 |
+
list of non-terminal edges
|
172 |
+
"""
|
173 |
+
nt_edge_dict = {}
|
174 |
+
if self.is_start_rule:
|
175 |
+
if (edge is not None) or (hg is not None):
|
176 |
+
ValueError("edge and hg must be None for this prod rule.")
|
177 |
+
hg = Hypergraph()
|
178 |
+
node_map_rhs = {} # node id in rhs -> node id in hg, where rhs is augmented.
|
179 |
+
for num_idx, each_node in enumerate(self.rhs.nodes):
|
180 |
+
hg.add_node(f"bond_{num_idx}",
|
181 |
+
#attr_dict=deepcopy(self.rhs.node_attr(each_node)))
|
182 |
+
attr_dict=self.rhs.node_attr(each_node))
|
183 |
+
node_map_rhs[each_node] = f"bond_{num_idx}"
|
184 |
+
for each_edge in self.rhs.edges:
|
185 |
+
node_list = []
|
186 |
+
for each_node in self.rhs.nodes_in_edge(each_edge):
|
187 |
+
node_list.append(node_map_rhs[each_node])
|
188 |
+
if isinstance(self.rhs.nodes_in_edge(each_edge), set):
|
189 |
+
node_list = set(node_list)
|
190 |
+
edge_id = hg.add_edge(
|
191 |
+
node_list,
|
192 |
+
#attr_dict=deepcopy(self.rhs.edge_attr(each_edge)))
|
193 |
+
attr_dict=self.rhs.edge_attr(each_edge))
|
194 |
+
if "nt_idx" in hg.edge_attr(edge_id):
|
195 |
+
nt_edge_dict[hg.edge_attr(edge_id)["nt_idx"]] = edge_id
|
196 |
+
nt_edge_list = [nt_edge_dict[key] for key in range(len(nt_edge_dict))]
|
197 |
+
return hg, nt_edge_list
|
198 |
+
else:
|
199 |
+
if edge not in hg.edges:
|
200 |
+
raise ValueError("the input hyperedge does not exist.")
|
201 |
+
if hg.edge_attr(edge)["terminal"]:
|
202 |
+
raise ValueError("the input hyperedge is terminal.")
|
203 |
+
if hg.edge_attr(edge)['symbol'] != self.lhs_nt_symbol:
|
204 |
+
print(hg.edge_attr(edge)['symbol'], self.lhs_nt_symbol)
|
205 |
+
raise ValueError("the input hyperedge and lhs have inconsistent number of nodes.")
|
206 |
+
if DEBUG:
|
207 |
+
for node_idx, each_node in enumerate(hg.nodes_in_edge(edge)):
|
208 |
+
other_node = self.lhs.nodes_in_edge(list(self.lhs.edges)[0])[node_idx]
|
209 |
+
attr = deepcopy(self.lhs.node_attr(other_node))
|
210 |
+
attr.pop('ext_id')
|
211 |
+
if hg.node_attr(each_node) != attr:
|
212 |
+
raise ValueError('node attributes are inconsistent.')
|
213 |
+
|
214 |
+
# order of nodes that belong to the non-terminal edge in hg
|
215 |
+
nt_order_dict = {} # hg_node -> order ("bond_17" : 1)
|
216 |
+
nt_order_dict_inv = {} # order -> hg_node
|
217 |
+
for each_idx, each_node in enumerate(hg.nodes_in_edge(edge)):
|
218 |
+
nt_order_dict[each_node] = each_idx
|
219 |
+
nt_order_dict_inv[each_idx] = each_node
|
220 |
+
|
221 |
+
# construct a node_map_rhs: rhs -> new hg
|
222 |
+
node_map_rhs = {} # node id in rhs -> node id in hg, where rhs is augmented.
|
223 |
+
node_idx = hg.num_nodes
|
224 |
+
for each_node in self.rhs.nodes:
|
225 |
+
if "ext_id" in self.rhs.node_attr(each_node):
|
226 |
+
node_map_rhs[each_node] \
|
227 |
+
= nt_order_dict_inv[
|
228 |
+
self.rhs.node_attr(each_node)["ext_id"]]
|
229 |
+
else:
|
230 |
+
node_map_rhs[each_node] = f"bond_{node_idx}"
|
231 |
+
node_idx += 1
|
232 |
+
|
233 |
+
# delete non-terminal
|
234 |
+
hg.remove_edge(edge)
|
235 |
+
|
236 |
+
# add nodes to hg
|
237 |
+
for each_node in self.rhs.nodes:
|
238 |
+
hg.add_node(node_map_rhs[each_node],
|
239 |
+
attr_dict=self.rhs.node_attr(each_node))
|
240 |
+
|
241 |
+
# add hyperedges to hg
|
242 |
+
for each_edge in self.rhs.edges:
|
243 |
+
node_list_hg = []
|
244 |
+
for each_node in self.rhs.nodes_in_edge(each_edge):
|
245 |
+
node_list_hg.append(node_map_rhs[each_node])
|
246 |
+
edge_id = hg.add_edge(
|
247 |
+
node_list_hg,
|
248 |
+
attr_dict=self.rhs.edge_attr(each_edge))#deepcopy(self.rhs.edge_attr(each_edge)))
|
249 |
+
if "nt_idx" in hg.edge_attr(edge_id):
|
250 |
+
nt_edge_dict[hg.edge_attr(edge_id)["nt_idx"]] = edge_id
|
251 |
+
nt_edge_list = [nt_edge_dict[key] for key in range(len(nt_edge_dict))]
|
252 |
+
return hg, nt_edge_list
|
253 |
+
|
254 |
+
def revert(self, hg: Hypergraph, return_subhg=False):
|
255 |
+
''' revert applying this production rule.
|
256 |
+
i.e., if there exists a subhypergraph that matches the r.h.s. of this production rule,
|
257 |
+
this method replaces the subhypergraph with a non-terminal hyperedge.
|
258 |
+
|
259 |
+
Parameters
|
260 |
+
----------
|
261 |
+
hg : Hypergraph
|
262 |
+
hypergraph to be reverted
|
263 |
+
return_subhg : bool
|
264 |
+
if True, the removed subhypergraph will be returned.
|
265 |
+
|
266 |
+
Returns
|
267 |
+
-------
|
268 |
+
hg : Hypergraph
|
269 |
+
the resultant hypergraph. if it cannot be reverted, the original one is returned without any replacement.
|
270 |
+
success : bool
|
271 |
+
this indicates whether reverting is successed or not.
|
272 |
+
'''
|
273 |
+
gm = GraphMatcher(hg.hg, self.rhs.hg, node_match=_node_match_prod_rule,
|
274 |
+
edge_match=_edge_match)
|
275 |
+
try:
|
276 |
+
# in case when the matched subhg is connected to the other part via external nodes and more.
|
277 |
+
not_iso = True
|
278 |
+
while not_iso:
|
279 |
+
isomap = next(gm.subgraph_isomorphisms_iter())
|
280 |
+
adj_node_set = set([]) # reachable nodes from the internal nodes
|
281 |
+
subhg_node_set = set(isomap.keys()) # nodes in subhg
|
282 |
+
for each_node in subhg_node_set:
|
283 |
+
adj_node_set.add(each_node)
|
284 |
+
if isomap[each_node] not in self.ext_node.values():
|
285 |
+
adj_node_set.update(hg.hg.adj[each_node])
|
286 |
+
if adj_node_set == subhg_node_set:
|
287 |
+
not_iso = False
|
288 |
+
else:
|
289 |
+
if return_subhg:
|
290 |
+
return hg, False, Hypergraph()
|
291 |
+
else:
|
292 |
+
return hg, False
|
293 |
+
inv_isomap = {v: k for k, v in isomap.items()}
|
294 |
+
'''
|
295 |
+
isomap = {'e35': 'e8', 'bond_13': 'bond_18', 'bond_14': 'bond_19',
|
296 |
+
'bond_15': 'bond_17', 'e29': 'e23', 'bond_12': 'bond_20'}
|
297 |
+
where keys come from `hg` and values come from `self.rhs`
|
298 |
+
'''
|
299 |
+
except StopIteration:
|
300 |
+
if return_subhg:
|
301 |
+
return hg, False, Hypergraph()
|
302 |
+
else:
|
303 |
+
return hg, False
|
304 |
+
|
305 |
+
if return_subhg:
|
306 |
+
subhg = Hypergraph()
|
307 |
+
for each_node in hg.nodes:
|
308 |
+
if each_node in isomap:
|
309 |
+
subhg.add_node(each_node, attr_dict=hg.node_attr(each_node))
|
310 |
+
for each_edge in hg.edges:
|
311 |
+
if each_edge in isomap:
|
312 |
+
subhg.add_edge(hg.nodes_in_edge(each_edge),
|
313 |
+
attr_dict=hg.edge_attr(each_edge),
|
314 |
+
edge_name=each_edge)
|
315 |
+
subhg.edge_idx = hg.edge_idx
|
316 |
+
|
317 |
+
# remove subhg except for the externael nodes
|
318 |
+
for each_key, each_val in isomap.items():
|
319 |
+
if each_key.startswith('e'):
|
320 |
+
hg.remove_edge(each_key)
|
321 |
+
for each_key, each_val in isomap.items():
|
322 |
+
if each_key.startswith('bond_'):
|
323 |
+
if each_val not in self.ext_node.values():
|
324 |
+
hg.remove_node(each_key)
|
325 |
+
|
326 |
+
# add non-terminal hyperedge
|
327 |
+
nt_node_list = []
|
328 |
+
for each_ext_id in self.ext_node.keys():
|
329 |
+
nt_node_list.append(inv_isomap[self.ext_node[each_ext_id]])
|
330 |
+
|
331 |
+
hg.add_edge(nt_node_list,
|
332 |
+
attr_dict=dict(
|
333 |
+
terminal=False,
|
334 |
+
symbol=self.lhs_nt_symbol))
|
335 |
+
if return_subhg:
|
336 |
+
return hg, True, subhg
|
337 |
+
else:
|
338 |
+
return hg, True
|
339 |
+
|
340 |
+
|
341 |
+
class ProductionRuleCorpus(object):
|
342 |
+
|
343 |
+
'''
|
344 |
+
A corpus of production rules.
|
345 |
+
This class maintains
|
346 |
+
(i) list of unique production rules,
|
347 |
+
(ii) list of unique edge symbols (both terminal and non-terminal), and
|
348 |
+
(iii) list of unique node symbols.
|
349 |
+
|
350 |
+
Attributes
|
351 |
+
----------
|
352 |
+
prod_rule_list : list
|
353 |
+
list of unique production rules
|
354 |
+
edge_symbol_list : list
|
355 |
+
list of unique symbols (including both terminal and non-terminal)
|
356 |
+
node_symbol_list : list
|
357 |
+
list of node symbols
|
358 |
+
nt_symbol_list : list
|
359 |
+
list of unique lhs symbols
|
360 |
+
ext_id_list : list
|
361 |
+
list of ext_ids
|
362 |
+
lhs_in_prod_rule : array
|
363 |
+
a matrix of lhs vs prod_rule (= lhs_in_prod_rule)
|
364 |
+
'''
|
365 |
+
|
366 |
+
def __init__(self):
|
367 |
+
self.prod_rule_list = []
|
368 |
+
self.edge_symbol_list = []
|
369 |
+
self.edge_symbol_dict = {}
|
370 |
+
self.node_symbol_list = []
|
371 |
+
self.node_symbol_dict = {}
|
372 |
+
self.nt_symbol_list = []
|
373 |
+
self.ext_id_list = []
|
374 |
+
self._lhs_in_prod_rule = None
|
375 |
+
self.lhs_in_prod_rule_row_list = []
|
376 |
+
self.lhs_in_prod_rule_col_list = []
|
377 |
+
|
378 |
+
@property
|
379 |
+
def lhs_in_prod_rule(self):
|
380 |
+
if self._lhs_in_prod_rule is None:
|
381 |
+
self._lhs_in_prod_rule = torch.sparse.FloatTensor(
|
382 |
+
torch.LongTensor(list(zip(self.lhs_in_prod_rule_row_list, self.lhs_in_prod_rule_col_list))).t(),
|
383 |
+
torch.FloatTensor([1.0]*len(self.lhs_in_prod_rule_col_list)),
|
384 |
+
torch.Size([len(self.nt_symbol_list), len(self.prod_rule_list)])
|
385 |
+
).to_dense()
|
386 |
+
return self._lhs_in_prod_rule
|
387 |
+
|
388 |
+
@property
|
389 |
+
def num_prod_rule(self):
|
390 |
+
''' return the number of production rules
|
391 |
+
|
392 |
+
Returns
|
393 |
+
-------
|
394 |
+
int : the number of unique production rules
|
395 |
+
'''
|
396 |
+
return len(self.prod_rule_list)
|
397 |
+
|
398 |
+
@property
|
399 |
+
def start_rule_list(self):
|
400 |
+
''' return a list of start rules
|
401 |
+
|
402 |
+
Returns
|
403 |
+
-------
|
404 |
+
list : list of start rules
|
405 |
+
'''
|
406 |
+
start_rule_list = []
|
407 |
+
for each_prod_rule in self.prod_rule_list:
|
408 |
+
if each_prod_rule.is_start_rule:
|
409 |
+
start_rule_list.append(each_prod_rule)
|
410 |
+
return start_rule_list
|
411 |
+
|
412 |
+
@property
|
413 |
+
def num_edge_symbol(self):
|
414 |
+
return len(self.edge_symbol_list)
|
415 |
+
|
416 |
+
@property
|
417 |
+
def num_node_symbol(self):
|
418 |
+
return len(self.node_symbol_list)
|
419 |
+
|
420 |
+
@property
|
421 |
+
def num_ext_id(self):
|
422 |
+
return len(self.ext_id_list)
|
423 |
+
|
424 |
+
def construct_feature_vectors(self):
|
425 |
+
''' this method constructs feature vectors for the production rules collected so far.
|
426 |
+
currently, NTSymbol and TSymbol are treated in the same manner.
|
427 |
+
'''
|
428 |
+
feature_id_dict = {}
|
429 |
+
feature_id_dict['TSymbol'] = 0
|
430 |
+
feature_id_dict['NTSymbol'] = 1
|
431 |
+
feature_id_dict['BondSymbol'] = 2
|
432 |
+
for each_edge_symbol in self.edge_symbol_list:
|
433 |
+
for each_attr in each_edge_symbol.__dict__.keys():
|
434 |
+
each_val = each_edge_symbol.__dict__[each_attr]
|
435 |
+
if isinstance(each_val, list):
|
436 |
+
each_val = tuple(each_val)
|
437 |
+
if (each_attr, each_val) not in feature_id_dict:
|
438 |
+
feature_id_dict[(each_attr, each_val)] = len(feature_id_dict)
|
439 |
+
|
440 |
+
for each_node_symbol in self.node_symbol_list:
|
441 |
+
for each_attr in each_node_symbol.__dict__.keys():
|
442 |
+
each_val = each_node_symbol.__dict__[each_attr]
|
443 |
+
if isinstance(each_val, list):
|
444 |
+
each_val = tuple(each_val)
|
445 |
+
if (each_attr, each_val) not in feature_id_dict:
|
446 |
+
feature_id_dict[(each_attr, each_val)] = len(feature_id_dict)
|
447 |
+
for each_ext_id in self.ext_id_list:
|
448 |
+
feature_id_dict[('ext_id', each_ext_id)] = len(feature_id_dict)
|
449 |
+
dim = len(feature_id_dict)
|
450 |
+
|
451 |
+
feature_dict = {}
|
452 |
+
for each_edge_symbol in self.edge_symbol_list:
|
453 |
+
idx_list = []
|
454 |
+
idx_list.append(feature_id_dict[each_edge_symbol.__class__.__name__])
|
455 |
+
for each_attr in each_edge_symbol.__dict__.keys():
|
456 |
+
each_val = each_edge_symbol.__dict__[each_attr]
|
457 |
+
if isinstance(each_val, list):
|
458 |
+
each_val = tuple(each_val)
|
459 |
+
idx_list.append(feature_id_dict[(each_attr, each_val)])
|
460 |
+
feature = torch.sparse.LongTensor(
|
461 |
+
torch.LongTensor([idx_list]),
|
462 |
+
torch.ones(len(idx_list)),
|
463 |
+
torch.Size([len(feature_id_dict)])
|
464 |
+
)
|
465 |
+
feature_dict[each_edge_symbol] = feature
|
466 |
+
|
467 |
+
for each_node_symbol in self.node_symbol_list:
|
468 |
+
idx_list = []
|
469 |
+
idx_list.append(feature_id_dict[each_node_symbol.__class__.__name__])
|
470 |
+
for each_attr in each_node_symbol.__dict__.keys():
|
471 |
+
each_val = each_node_symbol.__dict__[each_attr]
|
472 |
+
if isinstance(each_val, list):
|
473 |
+
each_val = tuple(each_val)
|
474 |
+
idx_list.append(feature_id_dict[(each_attr, each_val)])
|
475 |
+
feature = torch.sparse.LongTensor(
|
476 |
+
torch.LongTensor([idx_list]),
|
477 |
+
torch.ones(len(idx_list)),
|
478 |
+
torch.Size([len(feature_id_dict)])
|
479 |
+
)
|
480 |
+
feature_dict[each_node_symbol] = feature
|
481 |
+
for each_ext_id in self.ext_id_list:
|
482 |
+
idx_list = [feature_id_dict[('ext_id', each_ext_id)]]
|
483 |
+
feature_dict[('ext_id', each_ext_id)] \
|
484 |
+
= torch.sparse.LongTensor(
|
485 |
+
torch.LongTensor([idx_list]),
|
486 |
+
torch.ones(len(idx_list)),
|
487 |
+
torch.Size([len(feature_id_dict)])
|
488 |
+
)
|
489 |
+
return feature_dict, dim
|
490 |
+
|
491 |
+
def edge_symbol_idx(self, symbol):
|
492 |
+
return self.edge_symbol_dict[symbol]
|
493 |
+
|
494 |
+
def node_symbol_idx(self, symbol):
|
495 |
+
return self.node_symbol_dict[symbol]
|
496 |
+
|
497 |
+
def append(self, prod_rule: ProductionRule) -> Tuple[int, ProductionRule]:
|
498 |
+
""" return whether the input production rule is new or not, and its production rule id.
|
499 |
+
Production rules are regarded as the same if
|
500 |
+
i) there exists a one-to-one mapping of nodes and edges, and
|
501 |
+
ii) all the attributes associated with nodes and hyperedges are the same.
|
502 |
+
|
503 |
+
Parameters
|
504 |
+
----------
|
505 |
+
prod_rule : ProductionRule
|
506 |
+
|
507 |
+
Returns
|
508 |
+
-------
|
509 |
+
prod_rule_id : int
|
510 |
+
production rule index. if new, a new index will be assigned.
|
511 |
+
prod_rule : ProductionRule
|
512 |
+
"""
|
513 |
+
num_lhs = len(self.nt_symbol_list)
|
514 |
+
for each_idx, each_prod_rule in enumerate(self.prod_rule_list):
|
515 |
+
is_same, isomap = prod_rule.is_same(each_prod_rule)
|
516 |
+
if is_same:
|
517 |
+
# we do not care about edge and node names, but care about the order of non-terminal edges.
|
518 |
+
for key, val in isomap.items(): # key : edges & nodes in each_prod_rule.rhs , val : those in prod_rule.rhs
|
519 |
+
if key.startswith("bond_"):
|
520 |
+
continue
|
521 |
+
|
522 |
+
# rewrite `nt_idx` in `prod_rule` for further processing
|
523 |
+
if "nt_idx" in prod_rule.rhs.edge_attr(val).keys():
|
524 |
+
if "nt_idx" not in each_prod_rule.rhs.edge_attr(key).keys():
|
525 |
+
raise ValueError
|
526 |
+
prod_rule.rhs.set_edge_attr(
|
527 |
+
val,
|
528 |
+
{'nt_idx': each_prod_rule.rhs.edge_attr(key)["nt_idx"]})
|
529 |
+
return each_idx, prod_rule
|
530 |
+
self.prod_rule_list.append(prod_rule)
|
531 |
+
self._update_edge_symbol_list(prod_rule)
|
532 |
+
self._update_node_symbol_list(prod_rule)
|
533 |
+
self._update_ext_id_list(prod_rule)
|
534 |
+
|
535 |
+
lhs_idx = self.nt_symbol_list.index(prod_rule.lhs_nt_symbol)
|
536 |
+
self.lhs_in_prod_rule_row_list.append(lhs_idx)
|
537 |
+
self.lhs_in_prod_rule_col_list.append(len(self.prod_rule_list)-1)
|
538 |
+
self._lhs_in_prod_rule = None
|
539 |
+
return len(self.prod_rule_list)-1, prod_rule
|
540 |
+
|
541 |
+
def get_prod_rule(self, prod_rule_idx: int) -> ProductionRule:
|
542 |
+
return self.prod_rule_list[prod_rule_idx]
|
543 |
+
|
544 |
+
def sample(self, unmasked_logit_array, nt_symbol, deterministic=False):
|
545 |
+
''' sample a production rule whose lhs is `nt_symbol`, followihng `unmasked_logit_array`.
|
546 |
+
|
547 |
+
Parameters
|
548 |
+
----------
|
549 |
+
unmasked_logit_array : array-like, length `num_prod_rule`
|
550 |
+
nt_symbol : NTSymbol
|
551 |
+
'''
|
552 |
+
if not isinstance(unmasked_logit_array, np.ndarray):
|
553 |
+
unmasked_logit_array = unmasked_logit_array.numpy().astype(np.float64)
|
554 |
+
if deterministic:
|
555 |
+
prob = masked_softmax(unmasked_logit_array,
|
556 |
+
self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64))
|
557 |
+
return self.prod_rule_list[np.argmax(prob)]
|
558 |
+
else:
|
559 |
+
return np.random.choice(
|
560 |
+
self.prod_rule_list, 1,
|
561 |
+
p=masked_softmax(unmasked_logit_array,
|
562 |
+
self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64)))[0]
|
563 |
+
|
564 |
+
def masked_logprob(self, unmasked_logit_array, nt_symbol):
|
565 |
+
if not isinstance(unmasked_logit_array, np.ndarray):
|
566 |
+
unmasked_logit_array = unmasked_logit_array.numpy().astype(np.float64)
|
567 |
+
prob = masked_softmax(unmasked_logit_array,
|
568 |
+
self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64))
|
569 |
+
return np.log(prob)
|
570 |
+
|
571 |
+
def _update_edge_symbol_list(self, prod_rule: ProductionRule):
|
572 |
+
''' update edge symbol list
|
573 |
+
|
574 |
+
Parameters
|
575 |
+
----------
|
576 |
+
prod_rule : ProductionRule
|
577 |
+
'''
|
578 |
+
if prod_rule.lhs_nt_symbol not in self.nt_symbol_list:
|
579 |
+
self.nt_symbol_list.append(prod_rule.lhs_nt_symbol)
|
580 |
+
|
581 |
+
for each_edge in prod_rule.rhs.edges:
|
582 |
+
if prod_rule.rhs.edge_attr(each_edge)['symbol'] not in self.edge_symbol_dict:
|
583 |
+
edge_symbol_idx = len(self.edge_symbol_list)
|
584 |
+
self.edge_symbol_list.append(prod_rule.rhs.edge_attr(each_edge)['symbol'])
|
585 |
+
self.edge_symbol_dict[prod_rule.rhs.edge_attr(each_edge)['symbol']] = edge_symbol_idx
|
586 |
+
else:
|
587 |
+
edge_symbol_idx = self.edge_symbol_dict[prod_rule.rhs.edge_attr(each_edge)['symbol']]
|
588 |
+
prod_rule.rhs.edge_attr(each_edge)['symbol_idx'] = edge_symbol_idx
|
589 |
+
pass
|
590 |
+
|
591 |
+
def _update_node_symbol_list(self, prod_rule: ProductionRule):
|
592 |
+
''' update node symbol list
|
593 |
+
|
594 |
+
Parameters
|
595 |
+
----------
|
596 |
+
prod_rule : ProductionRule
|
597 |
+
'''
|
598 |
+
for each_node in prod_rule.rhs.nodes:
|
599 |
+
if prod_rule.rhs.node_attr(each_node)['symbol'] not in self.node_symbol_dict:
|
600 |
+
node_symbol_idx = len(self.node_symbol_list)
|
601 |
+
self.node_symbol_list.append(prod_rule.rhs.node_attr(each_node)['symbol'])
|
602 |
+
self.node_symbol_dict[prod_rule.rhs.node_attr(each_node)['symbol']] = node_symbol_idx
|
603 |
+
else:
|
604 |
+
node_symbol_idx = self.node_symbol_dict[prod_rule.rhs.node_attr(each_node)['symbol']]
|
605 |
+
prod_rule.rhs.node_attr(each_node)['symbol_idx'] = node_symbol_idx
|
606 |
+
|
607 |
+
def _update_ext_id_list(self, prod_rule: ProductionRule):
|
608 |
+
for each_node in prod_rule.rhs.nodes:
|
609 |
+
if 'ext_id' in prod_rule.rhs.node_attr(each_node):
|
610 |
+
if prod_rule.rhs.node_attr(each_node)['ext_id'] not in self.ext_id_list:
|
611 |
+
self.ext_id_list.append(prod_rule.rhs.node_attr(each_node)['ext_id'])
|
612 |
+
|
613 |
+
|
614 |
+
class HyperedgeReplacementGrammar(GraphGrammarBase):
|
615 |
+
"""
|
616 |
+
Learn a hyperedge replacement grammar from a set of hypergraphs.
|
617 |
+
|
618 |
+
Attributes
|
619 |
+
----------
|
620 |
+
prod_rule_list : list of ProductionRule
|
621 |
+
production rules learned from the input hypergraphs
|
622 |
+
"""
|
623 |
+
def __init__(self,
|
624 |
+
tree_decomposition=molecular_tree_decomposition,
|
625 |
+
ignore_order=False, **kwargs):
|
626 |
+
from functools import partial
|
627 |
+
self.prod_rule_corpus = ProductionRuleCorpus()
|
628 |
+
self.clique_tree_corpus = CliqueTreeCorpus()
|
629 |
+
self.ignore_order = ignore_order
|
630 |
+
self.tree_decomposition = partial(tree_decomposition, **kwargs)
|
631 |
+
|
632 |
+
@property
|
633 |
+
def num_prod_rule(self):
|
634 |
+
''' return the number of production rules
|
635 |
+
|
636 |
+
Returns
|
637 |
+
-------
|
638 |
+
int : the number of unique production rules
|
639 |
+
'''
|
640 |
+
return self.prod_rule_corpus.num_prod_rule
|
641 |
+
|
642 |
+
@property
|
643 |
+
def start_rule_list(self):
|
644 |
+
''' return a list of start rules
|
645 |
+
|
646 |
+
Returns
|
647 |
+
-------
|
648 |
+
list : list of start rules
|
649 |
+
'''
|
650 |
+
return self.prod_rule_corpus.start_rule_list
|
651 |
+
|
652 |
+
@property
|
653 |
+
def prod_rule_list(self):
|
654 |
+
return self.prod_rule_corpus.prod_rule_list
|
655 |
+
|
656 |
+
def learn(self, hg_list, logger=print, max_mol=np.inf, print_freq=500):
|
657 |
+
""" learn from a list of hypergraphs
|
658 |
+
|
659 |
+
Parameters
|
660 |
+
----------
|
661 |
+
hg_list : list of Hypergraph
|
662 |
+
|
663 |
+
Returns
|
664 |
+
-------
|
665 |
+
prod_rule_seq_list : list of integers
|
666 |
+
each element corresponds to a sequence of production rules to generate each hypergraph.
|
667 |
+
"""
|
668 |
+
prod_rule_seq_list = []
|
669 |
+
idx = 0
|
670 |
+
for each_idx, each_hg in enumerate(hg_list):
|
671 |
+
clique_tree = self.tree_decomposition(each_hg)
|
672 |
+
|
673 |
+
# get a pair of myself and children
|
674 |
+
root_node = _find_root(clique_tree)
|
675 |
+
clique_tree = self.clique_tree_corpus.add_to_subhg_list(clique_tree, root_node)
|
676 |
+
prod_rule_seq = []
|
677 |
+
stack = []
|
678 |
+
|
679 |
+
children = sorted(list(clique_tree[root_node].keys()))
|
680 |
+
|
681 |
+
# extract a temporary production rule
|
682 |
+
prod_rule = extract_prod_rule(
|
683 |
+
None,
|
684 |
+
clique_tree.nodes[root_node]["subhg"],
|
685 |
+
[clique_tree.nodes[each_child]["subhg"]
|
686 |
+
for each_child in children],
|
687 |
+
clique_tree.nodes[root_node].get('subhg_idx', None))
|
688 |
+
|
689 |
+
# update the production rule list
|
690 |
+
prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
|
691 |
+
children = reorder_children(root_node,
|
692 |
+
children,
|
693 |
+
prod_rule,
|
694 |
+
clique_tree)
|
695 |
+
stack.extend([(root_node, each_child) for each_child in children[::-1]])
|
696 |
+
prod_rule_seq.append(prod_rule_id)
|
697 |
+
|
698 |
+
while len(stack) != 0:
|
699 |
+
# get a triple of parent, myself, and children
|
700 |
+
parent, myself = stack.pop()
|
701 |
+
children = sorted(list(dict(clique_tree[myself]).keys()))
|
702 |
+
children.remove(parent)
|
703 |
+
|
704 |
+
# extract a temp prod rule
|
705 |
+
prod_rule = extract_prod_rule(
|
706 |
+
clique_tree.nodes[parent]["subhg"],
|
707 |
+
clique_tree.nodes[myself]["subhg"],
|
708 |
+
[clique_tree.nodes[each_child]["subhg"]
|
709 |
+
for each_child in children],
|
710 |
+
clique_tree.nodes[myself].get('subhg_idx', None))
|
711 |
+
|
712 |
+
# update the prod rule list
|
713 |
+
prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
|
714 |
+
children = reorder_children(myself,
|
715 |
+
children,
|
716 |
+
prod_rule,
|
717 |
+
clique_tree)
|
718 |
+
stack.extend([(myself, each_child)
|
719 |
+
for each_child in children[::-1]])
|
720 |
+
prod_rule_seq.append(prod_rule_id)
|
721 |
+
prod_rule_seq_list.append(prod_rule_seq)
|
722 |
+
if (each_idx+1) % print_freq == 0:
|
723 |
+
msg = f'#(molecules processed)={each_idx+1}\t'\
|
724 |
+
f'#(production rules)={self.prod_rule_corpus.num_prod_rule}\t#(subhg in corpus)={self.clique_tree_corpus.size}'
|
725 |
+
logger(msg)
|
726 |
+
if each_idx > max_mol:
|
727 |
+
break
|
728 |
+
|
729 |
+
print(f'corpus_size = {self.clique_tree_corpus.size}')
|
730 |
+
return prod_rule_seq_list
|
731 |
+
|
732 |
+
def sample(self, z, deterministic=False):
|
733 |
+
""" sample a new hypergraph from HRG.
|
734 |
+
|
735 |
+
Parameters
|
736 |
+
----------
|
737 |
+
z : array-like, shape (len, num_prod_rule)
|
738 |
+
logit
|
739 |
+
deterministic : bool
|
740 |
+
if True, deterministic sampling
|
741 |
+
|
742 |
+
Returns
|
743 |
+
-------
|
744 |
+
Hypergraph
|
745 |
+
"""
|
746 |
+
seq_idx = 0
|
747 |
+
stack = []
|
748 |
+
z = z[:, :-1]
|
749 |
+
init_prod_rule = self.prod_rule_corpus.sample(z[0], NTSymbol(degree=0,
|
750 |
+
is_aromatic=False,
|
751 |
+
bond_symbol_list=[]),
|
752 |
+
deterministic=deterministic)
|
753 |
+
hg, nt_edge_list = init_prod_rule.applied_to(None, None)
|
754 |
+
stack = deepcopy(nt_edge_list[::-1])
|
755 |
+
while len(stack) != 0 and seq_idx < z.shape[0]-1:
|
756 |
+
seq_idx += 1
|
757 |
+
nt_edge = stack.pop()
|
758 |
+
nt_symbol = hg.edge_attr(nt_edge)['symbol']
|
759 |
+
prod_rule = self.prod_rule_corpus.sample(z[seq_idx], nt_symbol, deterministic=deterministic)
|
760 |
+
hg, nt_edge_list = prod_rule.applied_to(hg, nt_edge)
|
761 |
+
stack.extend(nt_edge_list[::-1])
|
762 |
+
if len(stack) != 0:
|
763 |
+
raise RuntimeError(f'{len(stack)} non-terminals are left.')
|
764 |
+
return hg
|
765 |
+
|
766 |
+
def construct(self, prod_rule_seq):
|
767 |
+
""" construct a hypergraph following `prod_rule_seq`
|
768 |
+
|
769 |
+
Parameters
|
770 |
+
----------
|
771 |
+
prod_rule_seq : list of integers
|
772 |
+
a sequence of production rules.
|
773 |
+
|
774 |
+
Returns
|
775 |
+
-------
|
776 |
+
UndirectedHypergraph
|
777 |
+
"""
|
778 |
+
seq_idx = 0
|
779 |
+
init_prod_rule = self.prod_rule_corpus.get_prod_rule(prod_rule_seq[seq_idx])
|
780 |
+
hg, nt_edge_list = init_prod_rule.applied_to(None, None)
|
781 |
+
stack = deepcopy(nt_edge_list[::-1])
|
782 |
+
while len(stack) != 0:
|
783 |
+
seq_idx += 1
|
784 |
+
nt_edge = stack.pop()
|
785 |
+
hg, nt_edge_list = self.prod_rule_corpus.get_prod_rule(prod_rule_seq[seq_idx]).applied_to(hg, nt_edge)
|
786 |
+
stack.extend(nt_edge_list[::-1])
|
787 |
+
return hg
|
788 |
+
|
789 |
+
def update_prod_rule_list(self, prod_rule):
|
790 |
+
""" return whether the input production rule is new or not, and its production rule id.
|
791 |
+
Production rules are regarded as the same if
|
792 |
+
i) there exists a one-to-one mapping of nodes and edges, and
|
793 |
+
ii) all the attributes associated with nodes and hyperedges are the same.
|
794 |
+
|
795 |
+
Parameters
|
796 |
+
----------
|
797 |
+
prod_rule : ProductionRule
|
798 |
+
|
799 |
+
Returns
|
800 |
+
-------
|
801 |
+
is_new : bool
|
802 |
+
if True, this production rule is new
|
803 |
+
prod_rule_id : int
|
804 |
+
production rule index. if new, a new index will be assigned.
|
805 |
+
"""
|
806 |
+
return self.prod_rule_corpus.append(prod_rule)
|
807 |
+
|
808 |
+
|
809 |
+
class IncrementalHyperedgeReplacementGrammar(HyperedgeReplacementGrammar):
|
810 |
+
'''
|
811 |
+
This class learns HRG incrementally leveraging the previously obtained production rules.
|
812 |
+
'''
|
813 |
+
def __init__(self, tree_decomposition=tree_decomposition_with_hrg, ignore_order=False):
|
814 |
+
self.prod_rule_list = []
|
815 |
+
self.tree_decomposition = tree_decomposition
|
816 |
+
self.ignore_order = ignore_order
|
817 |
+
|
818 |
+
def learn(self, hg_list):
|
819 |
+
""" learn from a list of hypergraphs
|
820 |
+
|
821 |
+
Parameters
|
822 |
+
----------
|
823 |
+
hg_list : list of UndirectedHypergraph
|
824 |
+
|
825 |
+
Returns
|
826 |
+
-------
|
827 |
+
prod_rule_seq_list : list of integers
|
828 |
+
each element corresponds to a sequence of production rules to generate each hypergraph.
|
829 |
+
"""
|
830 |
+
prod_rule_seq_list = []
|
831 |
+
for each_hg in hg_list:
|
832 |
+
clique_tree, root_node = tree_decomposition_with_hrg(each_hg, self, return_root=True)
|
833 |
+
|
834 |
+
prod_rule_seq = []
|
835 |
+
stack = []
|
836 |
+
|
837 |
+
# get a pair of myself and children
|
838 |
+
children = sorted(list(clique_tree[root_node].keys()))
|
839 |
+
|
840 |
+
# extract a temporary production rule
|
841 |
+
prod_rule = extract_prod_rule(None, clique_tree.nodes[root_node]["subhg"],
|
842 |
+
[clique_tree.nodes[each_child]["subhg"] for each_child in children])
|
843 |
+
|
844 |
+
# update the production rule list
|
845 |
+
prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
|
846 |
+
children = reorder_children(root_node, children, prod_rule, clique_tree)
|
847 |
+
stack.extend([(root_node, each_child) for each_child in children[::-1]])
|
848 |
+
prod_rule_seq.append(prod_rule_id)
|
849 |
+
|
850 |
+
while len(stack) != 0:
|
851 |
+
# get a triple of parent, myself, and children
|
852 |
+
parent, myself = stack.pop()
|
853 |
+
children = sorted(list(dict(clique_tree[myself]).keys()))
|
854 |
+
children.remove(parent)
|
855 |
+
|
856 |
+
# extract a temp prod rule
|
857 |
+
prod_rule = extract_prod_rule(
|
858 |
+
clique_tree.nodes[parent]["subhg"], clique_tree.nodes[myself]["subhg"],
|
859 |
+
[clique_tree.nodes[each_child]["subhg"] for each_child in children])
|
860 |
+
|
861 |
+
# update the prod rule list
|
862 |
+
prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
|
863 |
+
children = reorder_children(myself, children, prod_rule, clique_tree)
|
864 |
+
stack.extend([(myself, each_child) for each_child in children[::-1]])
|
865 |
+
prod_rule_seq.append(prod_rule_id)
|
866 |
+
prod_rule_seq_list.append(prod_rule_seq)
|
867 |
+
self._compute_stats()
|
868 |
+
return prod_rule_seq_list
|
869 |
+
|
870 |
+
|
871 |
+
def reorder_children(myself, children, prod_rule, clique_tree):
|
872 |
+
""" reorder children so that they match the order in `prod_rule`.
|
873 |
+
|
874 |
+
Parameters
|
875 |
+
----------
|
876 |
+
myself : int
|
877 |
+
children : list of int
|
878 |
+
prod_rule : ProductionRule
|
879 |
+
clique_tree : nx.Graph
|
880 |
+
|
881 |
+
Returns
|
882 |
+
-------
|
883 |
+
new_children : list of str
|
884 |
+
reordered children
|
885 |
+
"""
|
886 |
+
perm = {} # key : `nt_idx`, val : child
|
887 |
+
for each_edge in prod_rule.rhs.edges:
|
888 |
+
if "nt_idx" in prod_rule.rhs.edge_attr(each_edge).keys():
|
889 |
+
for each_child in children:
|
890 |
+
common_node_set = set(
|
891 |
+
common_node_list(clique_tree.nodes[myself]["subhg"],
|
892 |
+
clique_tree.nodes[each_child]["subhg"])[0])
|
893 |
+
if set(prod_rule.rhs.nodes_in_edge(each_edge)) == common_node_set:
|
894 |
+
assert prod_rule.rhs.edge_attr(each_edge)["nt_idx"] not in perm
|
895 |
+
perm[prod_rule.rhs.edge_attr(each_edge)["nt_idx"]] = each_child
|
896 |
+
new_children = []
|
897 |
+
assert len(perm) == len(children)
|
898 |
+
for i in range(len(perm)):
|
899 |
+
new_children.append(perm[i])
|
900 |
+
return new_children
|
901 |
+
|
902 |
+
|
903 |
+
def extract_prod_rule(parent_hg, myself_hg, children_hg_list, subhg_idx=None):
|
904 |
+
""" extract a production rule from a triple of `parent_hg`, `myself_hg`, and `children_hg_list`.
|
905 |
+
|
906 |
+
Parameters
|
907 |
+
----------
|
908 |
+
parent_hg : Hypergraph
|
909 |
+
myself_hg : Hypergraph
|
910 |
+
children_hg_list : list of Hypergraph
|
911 |
+
|
912 |
+
Returns
|
913 |
+
-------
|
914 |
+
ProductionRule, consisting of
|
915 |
+
lhs : Hypergraph or None
|
916 |
+
rhs : Hypergraph
|
917 |
+
"""
|
918 |
+
def _add_ext_node(hg, ext_nodes):
|
919 |
+
""" mark nodes to be external (ordered ids are assigned)
|
920 |
+
|
921 |
+
Parameters
|
922 |
+
----------
|
923 |
+
hg : UndirectedHypergraph
|
924 |
+
ext_nodes : list of str
|
925 |
+
list of external nodes
|
926 |
+
|
927 |
+
Returns
|
928 |
+
-------
|
929 |
+
hg : Hypergraph
|
930 |
+
nodes in `ext_nodes` are marked to be external
|
931 |
+
"""
|
932 |
+
ext_id = 0
|
933 |
+
ext_id_exists = []
|
934 |
+
for each_node in ext_nodes:
|
935 |
+
ext_id_exists.append('ext_id' in hg.node_attr(each_node))
|
936 |
+
if ext_id_exists and any(ext_id_exists) != all(ext_id_exists):
|
937 |
+
raise ValueError
|
938 |
+
if not all(ext_id_exists):
|
939 |
+
for each_node in ext_nodes:
|
940 |
+
hg.node_attr(each_node)['ext_id'] = ext_id
|
941 |
+
ext_id += 1
|
942 |
+
return hg
|
943 |
+
|
944 |
+
def _check_aromatic(hg, node_list):
|
945 |
+
is_aromatic = False
|
946 |
+
node_aromatic_list = []
|
947 |
+
for each_node in node_list:
|
948 |
+
if hg.node_attr(each_node)['symbol'].is_aromatic:
|
949 |
+
is_aromatic = True
|
950 |
+
node_aromatic_list.append(True)
|
951 |
+
else:
|
952 |
+
node_aromatic_list.append(False)
|
953 |
+
return is_aromatic, node_aromatic_list
|
954 |
+
|
955 |
+
def _check_ring(hg):
|
956 |
+
for each_edge in hg.edges:
|
957 |
+
if not ('tmp' in hg.edge_attr(each_edge) or (not hg.edge_attr(each_edge)['terminal'])):
|
958 |
+
return False
|
959 |
+
return True
|
960 |
+
|
961 |
+
if parent_hg is None:
|
962 |
+
lhs = Hypergraph()
|
963 |
+
node_list = []
|
964 |
+
else:
|
965 |
+
lhs = Hypergraph()
|
966 |
+
node_list, edge_exists = common_node_list(parent_hg, myself_hg)
|
967 |
+
for each_node in node_list:
|
968 |
+
lhs.add_node(each_node,
|
969 |
+
deepcopy(myself_hg.node_attr(each_node)))
|
970 |
+
is_aromatic, _ = _check_aromatic(parent_hg, node_list)
|
971 |
+
for_ring = _check_ring(myself_hg)
|
972 |
+
bond_symbol_list = []
|
973 |
+
for each_node in node_list:
|
974 |
+
bond_symbol_list.append(parent_hg.node_attr(each_node)['symbol'])
|
975 |
+
lhs.add_edge(
|
976 |
+
node_list,
|
977 |
+
attr_dict=dict(
|
978 |
+
terminal=False,
|
979 |
+
edge_exists=edge_exists,
|
980 |
+
symbol=NTSymbol(
|
981 |
+
degree=len(node_list),
|
982 |
+
is_aromatic=is_aromatic,
|
983 |
+
bond_symbol_list=bond_symbol_list,
|
984 |
+
for_ring=for_ring)))
|
985 |
+
try:
|
986 |
+
lhs = _add_ext_node(lhs, node_list)
|
987 |
+
except ValueError:
|
988 |
+
import pdb; pdb.set_trace()
|
989 |
+
|
990 |
+
rhs = remove_tmp_edge(deepcopy(myself_hg))
|
991 |
+
#rhs = remove_ext_node(rhs)
|
992 |
+
#rhs = remove_nt_edge(rhs)
|
993 |
+
try:
|
994 |
+
rhs = _add_ext_node(rhs, node_list)
|
995 |
+
except ValueError:
|
996 |
+
import pdb; pdb.set_trace()
|
997 |
+
|
998 |
+
nt_idx = 0
|
999 |
+
if children_hg_list is not None:
|
1000 |
+
for each_child_hg in children_hg_list:
|
1001 |
+
node_list, edge_exists = common_node_list(myself_hg, each_child_hg)
|
1002 |
+
is_aromatic, _ = _check_aromatic(myself_hg, node_list)
|
1003 |
+
for_ring = _check_ring(each_child_hg)
|
1004 |
+
bond_symbol_list = []
|
1005 |
+
for each_node in node_list:
|
1006 |
+
bond_symbol_list.append(myself_hg.node_attr(each_node)['symbol'])
|
1007 |
+
rhs.add_edge(
|
1008 |
+
node_list,
|
1009 |
+
attr_dict=dict(
|
1010 |
+
terminal=False,
|
1011 |
+
nt_idx=nt_idx,
|
1012 |
+
edge_exists=edge_exists,
|
1013 |
+
symbol=NTSymbol(degree=len(node_list),
|
1014 |
+
is_aromatic=is_aromatic,
|
1015 |
+
bond_symbol_list=bond_symbol_list,
|
1016 |
+
for_ring=for_ring)))
|
1017 |
+
nt_idx += 1
|
1018 |
+
prod_rule = ProductionRule(lhs, rhs)
|
1019 |
+
prod_rule.subhg_idx = subhg_idx
|
1020 |
+
if DEBUG:
|
1021 |
+
if sorted(list(prod_rule.ext_node.keys())) \
|
1022 |
+
!= list(np.arange(len(prod_rule.ext_node))):
|
1023 |
+
raise RuntimeError('ext_id is not continuous')
|
1024 |
+
return prod_rule
|
1025 |
+
|
1026 |
+
|
1027 |
+
def _find_root(clique_tree):
|
1028 |
+
max_node = None
|
1029 |
+
num_nodes_max = -np.inf
|
1030 |
+
for each_node in clique_tree.nodes:
|
1031 |
+
if clique_tree.nodes[each_node]['subhg'].num_nodes > num_nodes_max:
|
1032 |
+
max_node = each_node
|
1033 |
+
num_nodes_max = clique_tree.nodes[each_node]['subhg'].num_nodes
|
1034 |
+
'''
|
1035 |
+
children = sorted(list(clique_tree[each_node].keys()))
|
1036 |
+
prod_rule = extract_prod_rule(None,
|
1037 |
+
clique_tree.nodes[each_node]["subhg"],
|
1038 |
+
[clique_tree.nodes[each_child]["subhg"]
|
1039 |
+
for each_child in children])
|
1040 |
+
for each_start_rule in start_rule_list:
|
1041 |
+
if prod_rule.is_same(each_start_rule):
|
1042 |
+
return each_node
|
1043 |
+
'''
|
1044 |
+
return max_node
|
1045 |
+
|
1046 |
+
def remove_ext_node(hg):
|
1047 |
+
for each_node in hg.nodes:
|
1048 |
+
hg.node_attr(each_node).pop('ext_id', None)
|
1049 |
+
return hg
|
1050 |
+
|
1051 |
+
def remove_nt_edge(hg):
|
1052 |
+
remove_edge_list = []
|
1053 |
+
for each_edge in hg.edges:
|
1054 |
+
if not hg.edge_attr(each_edge)['terminal']:
|
1055 |
+
remove_edge_list.append(each_edge)
|
1056 |
+
hg.remove_edges(remove_edge_list)
|
1057 |
+
return hg
|
1058 |
+
|
1059 |
+
def remove_tmp_edge(hg):
|
1060 |
+
remove_edge_list = []
|
1061 |
+
for each_edge in hg.edges:
|
1062 |
+
if hg.edge_attr(each_edge).get('tmp', False):
|
1063 |
+
remove_edge_list.append(each_edge)
|
1064 |
+
hg.remove_edges(remove_edge_list)
|
1065 |
+
return hg
|
graph_grammar/graph_grammar/symbols.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
|
15 |
+
""" Title """
|
16 |
+
|
17 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
18 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
19 |
+
__version__ = "0.1"
|
20 |
+
__date__ = "Jan 1 2018"
|
21 |
+
|
22 |
+
from typing import List
|
23 |
+
|
24 |
+
class TSymbol(object):
|
25 |
+
|
26 |
+
''' terminal symbol
|
27 |
+
|
28 |
+
Attributes
|
29 |
+
----------
|
30 |
+
degree : int
|
31 |
+
the number of nodes in a hyperedge
|
32 |
+
is_aromatic : bool
|
33 |
+
whether or not the hyperedge is in an aromatic ring
|
34 |
+
symbol : str
|
35 |
+
atomic symbol
|
36 |
+
num_explicit_Hs : int
|
37 |
+
the number of hydrogens associated to this hyperedge
|
38 |
+
formal_charge : int
|
39 |
+
charge
|
40 |
+
chirality : int
|
41 |
+
chirality
|
42 |
+
'''
|
43 |
+
|
44 |
+
def __init__(self, degree, is_aromatic,
|
45 |
+
symbol, num_explicit_Hs, formal_charge, chirality):
|
46 |
+
self.degree = degree
|
47 |
+
self.is_aromatic = is_aromatic
|
48 |
+
self.symbol = symbol
|
49 |
+
self.num_explicit_Hs = num_explicit_Hs
|
50 |
+
self.formal_charge = formal_charge
|
51 |
+
self.chirality = chirality
|
52 |
+
|
53 |
+
@property
|
54 |
+
def terminal(self):
|
55 |
+
return True
|
56 |
+
|
57 |
+
def __eq__(self, other):
|
58 |
+
if not isinstance(other, TSymbol):
|
59 |
+
return False
|
60 |
+
if self.degree != other.degree:
|
61 |
+
return False
|
62 |
+
if self.is_aromatic != other.is_aromatic:
|
63 |
+
return False
|
64 |
+
if self.symbol != other.symbol:
|
65 |
+
return False
|
66 |
+
if self.num_explicit_Hs != other.num_explicit_Hs:
|
67 |
+
return False
|
68 |
+
if self.formal_charge != other.formal_charge:
|
69 |
+
return False
|
70 |
+
if self.chirality != other.chirality:
|
71 |
+
return False
|
72 |
+
return True
|
73 |
+
|
74 |
+
def __hash__(self):
|
75 |
+
return self.__str__().__hash__()
|
76 |
+
|
77 |
+
def __str__(self):
|
78 |
+
return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\
|
79 |
+
f'symbol={self.symbol}, '\
|
80 |
+
f'num_explicit_Hs={self.num_explicit_Hs}, '\
|
81 |
+
f'formal_charge={self.formal_charge}, chirality={self.chirality}'
|
82 |
+
|
83 |
+
|
84 |
+
class NTSymbol(object):
|
85 |
+
|
86 |
+
''' non-terminal symbol
|
87 |
+
|
88 |
+
Attributes
|
89 |
+
----------
|
90 |
+
degree : int
|
91 |
+
degree of the hyperedge
|
92 |
+
is_aromatic : bool
|
93 |
+
if True, at least one of the associated bonds must be aromatic.
|
94 |
+
node_aromatic_list : list of bool
|
95 |
+
indicate whether each of the nodes is aromatic or not.
|
96 |
+
bond_type_list : list of int
|
97 |
+
bond type of each node"
|
98 |
+
'''
|
99 |
+
|
100 |
+
def __init__(self, degree: int, is_aromatic: bool,
|
101 |
+
bond_symbol_list: list,
|
102 |
+
for_ring=False):
|
103 |
+
self.degree = degree
|
104 |
+
self.is_aromatic = is_aromatic
|
105 |
+
self.for_ring = for_ring
|
106 |
+
self.bond_symbol_list = bond_symbol_list
|
107 |
+
|
108 |
+
@property
|
109 |
+
def terminal(self) -> bool:
|
110 |
+
return False
|
111 |
+
|
112 |
+
@property
|
113 |
+
def symbol(self):
|
114 |
+
return f'NT{self.degree}'
|
115 |
+
|
116 |
+
def __eq__(self, other) -> bool:
|
117 |
+
if not isinstance(other, NTSymbol):
|
118 |
+
return False
|
119 |
+
|
120 |
+
if self.degree != other.degree:
|
121 |
+
return False
|
122 |
+
if self.is_aromatic != other.is_aromatic:
|
123 |
+
return False
|
124 |
+
if self.for_ring != other.for_ring:
|
125 |
+
return False
|
126 |
+
if len(self.bond_symbol_list) != len(other.bond_symbol_list):
|
127 |
+
return False
|
128 |
+
for each_idx in range(len(self.bond_symbol_list)):
|
129 |
+
if self.bond_symbol_list[each_idx] != other.bond_symbol_list[each_idx]:
|
130 |
+
return False
|
131 |
+
return True
|
132 |
+
|
133 |
+
def __hash__(self):
|
134 |
+
return self.__str__().__hash__()
|
135 |
+
|
136 |
+
def __str__(self) -> str:
|
137 |
+
return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\
|
138 |
+
f'bond_symbol_list={[str(each_symbol) for each_symbol in self.bond_symbol_list]}'\
|
139 |
+
f'for_ring={self.for_ring}'
|
140 |
+
|
141 |
+
|
142 |
+
class BondSymbol(object):
|
143 |
+
|
144 |
+
|
145 |
+
''' Bond symbol
|
146 |
+
|
147 |
+
Attributes
|
148 |
+
----------
|
149 |
+
is_aromatic : bool
|
150 |
+
if True, at least one of the associated bonds must be aromatic.
|
151 |
+
bond_type : int
|
152 |
+
bond type of each node"
|
153 |
+
'''
|
154 |
+
|
155 |
+
def __init__(self, is_aromatic: bool,
|
156 |
+
bond_type: int,
|
157 |
+
stereo: int):
|
158 |
+
self.is_aromatic = is_aromatic
|
159 |
+
self.bond_type = bond_type
|
160 |
+
self.stereo = stereo
|
161 |
+
|
162 |
+
def __eq__(self, other) -> bool:
|
163 |
+
if not isinstance(other, BondSymbol):
|
164 |
+
return False
|
165 |
+
|
166 |
+
if self.is_aromatic != other.is_aromatic:
|
167 |
+
return False
|
168 |
+
if self.bond_type != other.bond_type:
|
169 |
+
return False
|
170 |
+
if self.stereo != other.stereo:
|
171 |
+
return False
|
172 |
+
return True
|
173 |
+
|
174 |
+
def __hash__(self):
|
175 |
+
return self.__str__().__hash__()
|
176 |
+
|
177 |
+
def __str__(self) -> str:
|
178 |
+
return f'is_aromatic={self.is_aromatic}, '\
|
179 |
+
f'bond_type={self.bond_type}, '\
|
180 |
+
f'stereo={self.stereo}, '
|
graph_grammar/graph_grammar/utils.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jun 4 2018"
|
20 |
+
|
21 |
+
from ..hypergraph import Hypergraph
|
22 |
+
from copy import deepcopy
|
23 |
+
from typing import List
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
|
27 |
+
def common_node_list(hg1: Hypergraph, hg2: Hypergraph) -> List[str]:
|
28 |
+
""" return a list of common nodes
|
29 |
+
|
30 |
+
Parameters
|
31 |
+
----------
|
32 |
+
hg1, hg2 : Hypergraph
|
33 |
+
|
34 |
+
Returns
|
35 |
+
-------
|
36 |
+
list of str
|
37 |
+
list of common nodes
|
38 |
+
"""
|
39 |
+
if hg1 is None or hg2 is None:
|
40 |
+
return [], False
|
41 |
+
else:
|
42 |
+
node_set = hg1.nodes.intersection(hg2.nodes)
|
43 |
+
node_dict = {}
|
44 |
+
if 'order4hrg' in hg1.node_attr(list(hg1.nodes)[0]):
|
45 |
+
for each_node in node_set:
|
46 |
+
node_dict[each_node] = hg1.node_attr(each_node)['order4hrg']
|
47 |
+
else:
|
48 |
+
for each_node in node_set:
|
49 |
+
node_dict[each_node] = hg1.node_attr(each_node)['symbol'].__hash__()
|
50 |
+
node_list = []
|
51 |
+
for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
|
52 |
+
node_list.append(each_key)
|
53 |
+
edge_name = hg1.has_edge(node_list, ignore_order=True)
|
54 |
+
if edge_name:
|
55 |
+
if not hg1.edge_attr(edge_name).get('terminal', True):
|
56 |
+
node_list = hg1.nodes_in_edge(edge_name)
|
57 |
+
return node_list, True
|
58 |
+
else:
|
59 |
+
return node_list, False
|
60 |
+
|
61 |
+
|
62 |
+
def _node_match(node1, node2):
|
63 |
+
# if the nodes are hyperedges, `atom_attr` determines the match
|
64 |
+
if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
|
65 |
+
return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol']
|
66 |
+
elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
|
67 |
+
# bond_symbol
|
68 |
+
return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
|
69 |
+
else:
|
70 |
+
return False
|
71 |
+
|
72 |
+
def _easy_node_match(node1, node2):
|
73 |
+
# if the nodes are hyperedges, `atom_attr` determines the match
|
74 |
+
if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
|
75 |
+
return node1["attr_dict"].get('symbol', None) == node2["attr_dict"].get('symbol', None)
|
76 |
+
elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
|
77 |
+
# bond_symbol
|
78 |
+
return node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1)\
|
79 |
+
and node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
|
80 |
+
else:
|
81 |
+
return False
|
82 |
+
|
83 |
+
|
84 |
+
def _node_match_prod_rule(node1, node2, ignore_order=False):
|
85 |
+
# if the nodes are hyperedges, `atom_attr` determines the match
|
86 |
+
if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
|
87 |
+
return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol']
|
88 |
+
elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
|
89 |
+
# ext_id, order4hrg, bond_symbol
|
90 |
+
if ignore_order:
|
91 |
+
return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
|
92 |
+
else:
|
93 |
+
return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']\
|
94 |
+
and node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1)
|
95 |
+
else:
|
96 |
+
return False
|
97 |
+
|
98 |
+
|
99 |
+
def _edge_match(edge1, edge2, ignore_order=False):
|
100 |
+
#return True
|
101 |
+
if ignore_order:
|
102 |
+
return True
|
103 |
+
else:
|
104 |
+
return edge1["order"] == edge2["order"]
|
105 |
+
|
106 |
+
def masked_softmax(logit, mask):
|
107 |
+
''' compute a probability distribution from logit
|
108 |
+
|
109 |
+
Parameters
|
110 |
+
----------
|
111 |
+
logit : array-like, length D
|
112 |
+
each element indicates how each dimension is likely to be chosen
|
113 |
+
(the larger, the more likely)
|
114 |
+
mask : array-like, length D
|
115 |
+
each element is either 0 or 1.
|
116 |
+
if 0, the dimension is ignored
|
117 |
+
when computing the probability distribution.
|
118 |
+
|
119 |
+
Returns
|
120 |
+
-------
|
121 |
+
prob_dist : array, length D
|
122 |
+
probability distribution computed from logit.
|
123 |
+
if `mask[d] = 0`, `prob_dist[d] = 0`.
|
124 |
+
'''
|
125 |
+
if logit.shape != mask.shape:
|
126 |
+
raise ValueError('logit and mask must have the same shape')
|
127 |
+
c = np.max(logit)
|
128 |
+
exp_logit = np.exp(logit - c) * mask
|
129 |
+
sum_exp_logit = exp_logit @ mask
|
130 |
+
return exp_logit / sum_exp_logit
|
graph_grammar/hypergraph.py
ADDED
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jan 31 2018"
|
20 |
+
|
21 |
+
from copy import deepcopy
|
22 |
+
from typing import List, Dict, Tuple
|
23 |
+
import networkx as nx
|
24 |
+
import numpy as np
|
25 |
+
import os
|
26 |
+
|
27 |
+
|
28 |
+
class Hypergraph(object):
|
29 |
+
'''
|
30 |
+
A class of a hypergraph.
|
31 |
+
Each hyperedge can be ordered. For the ordered case,
|
32 |
+
edges adjacent to the hyperedge node are labeled by their orders.
|
33 |
+
|
34 |
+
Attributes
|
35 |
+
----------
|
36 |
+
hg : nx.Graph
|
37 |
+
a bipartite graph representation of a hypergraph
|
38 |
+
edge_idx : int
|
39 |
+
total number of hyperedges that exist so far
|
40 |
+
'''
|
41 |
+
def __init__(self):
|
42 |
+
self.hg = nx.Graph()
|
43 |
+
self.edge_idx = 0
|
44 |
+
self.nodes = set([])
|
45 |
+
self.num_nodes = 0
|
46 |
+
self.edges = set([])
|
47 |
+
self.num_edges = 0
|
48 |
+
self.nodes_in_edge_dict = {}
|
49 |
+
|
50 |
+
def add_node(self, node: str, attr_dict=None):
|
51 |
+
''' add a node to hypergraph
|
52 |
+
|
53 |
+
Parameters
|
54 |
+
----------
|
55 |
+
node : str
|
56 |
+
node name
|
57 |
+
attr_dict : dict
|
58 |
+
dictionary of node attributes
|
59 |
+
'''
|
60 |
+
self.hg.add_node(node, bipartite='node', attr_dict=attr_dict)
|
61 |
+
if node not in self.nodes:
|
62 |
+
self.num_nodes += 1
|
63 |
+
self.nodes.add(node)
|
64 |
+
|
65 |
+
def add_edge(self, node_list: List[str], attr_dict=None, edge_name=None):
|
66 |
+
''' add an edge consisting of nodes `node_list`
|
67 |
+
|
68 |
+
Parameters
|
69 |
+
----------
|
70 |
+
node_list : list
|
71 |
+
ordered list of nodes that consist the edge
|
72 |
+
attr_dict : dict
|
73 |
+
dictionary of edge attributes
|
74 |
+
'''
|
75 |
+
if edge_name is None:
|
76 |
+
edge = 'e{}'.format(self.edge_idx)
|
77 |
+
else:
|
78 |
+
assert edge_name not in self.edges
|
79 |
+
edge = edge_name
|
80 |
+
self.hg.add_node(edge, bipartite='edge', attr_dict=attr_dict)
|
81 |
+
if edge not in self.edges:
|
82 |
+
self.num_edges += 1
|
83 |
+
self.edges.add(edge)
|
84 |
+
self.nodes_in_edge_dict[edge] = node_list
|
85 |
+
if type(node_list) == list:
|
86 |
+
for node_idx, each_node in enumerate(node_list):
|
87 |
+
self.hg.add_edge(edge, each_node, order=node_idx)
|
88 |
+
if each_node not in self.nodes:
|
89 |
+
self.num_nodes += 1
|
90 |
+
self.nodes.add(each_node)
|
91 |
+
|
92 |
+
elif type(node_list) == set:
|
93 |
+
for each_node in node_list:
|
94 |
+
self.hg.add_edge(edge, each_node, order=-1)
|
95 |
+
if each_node not in self.nodes:
|
96 |
+
self.num_nodes += 1
|
97 |
+
self.nodes.add(each_node)
|
98 |
+
else:
|
99 |
+
raise ValueError
|
100 |
+
self.edge_idx += 1
|
101 |
+
return edge
|
102 |
+
|
103 |
+
def remove_node(self, node: str, remove_connected_edges=True):
|
104 |
+
''' remove a node
|
105 |
+
|
106 |
+
Parameters
|
107 |
+
----------
|
108 |
+
node : str
|
109 |
+
node name
|
110 |
+
remove_connected_edges : bool
|
111 |
+
if True, remove edges that are adjacent to the node
|
112 |
+
'''
|
113 |
+
if remove_connected_edges:
|
114 |
+
connected_edges = deepcopy(self.adj_edges(node))
|
115 |
+
for each_edge in connected_edges:
|
116 |
+
self.remove_edge(each_edge)
|
117 |
+
self.hg.remove_node(node)
|
118 |
+
self.num_nodes -= 1
|
119 |
+
self.nodes.remove(node)
|
120 |
+
|
121 |
+
def remove_nodes(self, node_iter, remove_connected_edges=True):
|
122 |
+
''' remove a set of nodes
|
123 |
+
|
124 |
+
Parameters
|
125 |
+
----------
|
126 |
+
node_iter : iterator of strings
|
127 |
+
nodes to be removed
|
128 |
+
remove_connected_edges : bool
|
129 |
+
if True, remove edges that are adjacent to the node
|
130 |
+
'''
|
131 |
+
for each_node in node_iter:
|
132 |
+
self.remove_node(each_node, remove_connected_edges)
|
133 |
+
|
134 |
+
def remove_edge(self, edge: str):
|
135 |
+
''' remove an edge
|
136 |
+
|
137 |
+
Parameters
|
138 |
+
----------
|
139 |
+
edge : str
|
140 |
+
edge to be removed
|
141 |
+
'''
|
142 |
+
self.hg.remove_node(edge)
|
143 |
+
self.edges.remove(edge)
|
144 |
+
self.num_edges -= 1
|
145 |
+
self.nodes_in_edge_dict.pop(edge)
|
146 |
+
|
147 |
+
def remove_edges(self, edge_iter):
|
148 |
+
''' remove a set of edges
|
149 |
+
|
150 |
+
Parameters
|
151 |
+
----------
|
152 |
+
edge_iter : iterator of strings
|
153 |
+
edges to be removed
|
154 |
+
'''
|
155 |
+
for each_edge in edge_iter:
|
156 |
+
self.remove_edge(each_edge)
|
157 |
+
|
158 |
+
def remove_edges_with_attr(self, edge_attr_dict):
|
159 |
+
remove_edge_list = []
|
160 |
+
for each_edge in self.edges:
|
161 |
+
satisfy = True
|
162 |
+
for each_key, each_val in edge_attr_dict.items():
|
163 |
+
if not satisfy:
|
164 |
+
break
|
165 |
+
try:
|
166 |
+
if self.edge_attr(each_edge)[each_key] != each_val:
|
167 |
+
satisfy = False
|
168 |
+
except KeyError:
|
169 |
+
satisfy = False
|
170 |
+
if satisfy:
|
171 |
+
remove_edge_list.append(each_edge)
|
172 |
+
self.remove_edges(remove_edge_list)
|
173 |
+
|
174 |
+
def remove_subhg(self, subhg):
|
175 |
+
''' remove subhypergraph.
|
176 |
+
all of the hyperedges are removed.
|
177 |
+
each node of subhg is removed if its degree becomes 0 after removing hyperedges.
|
178 |
+
|
179 |
+
Parameters
|
180 |
+
----------
|
181 |
+
subhg : Hypergraph
|
182 |
+
'''
|
183 |
+
for each_edge in subhg.edges:
|
184 |
+
self.remove_edge(each_edge)
|
185 |
+
for each_node in subhg.nodes:
|
186 |
+
if self.degree(each_node) == 0:
|
187 |
+
self.remove_node(each_node)
|
188 |
+
|
189 |
+
def nodes_in_edge(self, edge):
|
190 |
+
''' return an ordered list of nodes in a given edge.
|
191 |
+
|
192 |
+
Parameters
|
193 |
+
----------
|
194 |
+
edge : str
|
195 |
+
edge whose nodes are returned
|
196 |
+
|
197 |
+
Returns
|
198 |
+
-------
|
199 |
+
list or set
|
200 |
+
ordered list or set of nodes that belong to the edge
|
201 |
+
'''
|
202 |
+
if edge.startswith('e'):
|
203 |
+
return self.nodes_in_edge_dict[edge]
|
204 |
+
else:
|
205 |
+
adj_node_list = self.hg.adj[edge]
|
206 |
+
adj_node_order_list = []
|
207 |
+
adj_node_name_list = []
|
208 |
+
for each_node in adj_node_list:
|
209 |
+
adj_node_order_list.append(adj_node_list[each_node]['order'])
|
210 |
+
adj_node_name_list.append(each_node)
|
211 |
+
if adj_node_order_list == [-1] * len(adj_node_order_list):
|
212 |
+
return set(adj_node_name_list)
|
213 |
+
else:
|
214 |
+
return [adj_node_name_list[each_idx] for each_idx
|
215 |
+
in np.argsort(adj_node_order_list)]
|
216 |
+
|
217 |
+
def adj_edges(self, node):
|
218 |
+
''' return a dict of adjacent hyperedges
|
219 |
+
|
220 |
+
Parameters
|
221 |
+
----------
|
222 |
+
node : str
|
223 |
+
|
224 |
+
Returns
|
225 |
+
-------
|
226 |
+
set
|
227 |
+
set of edges that are adjacent to `node`
|
228 |
+
'''
|
229 |
+
return self.hg.adj[node]
|
230 |
+
|
231 |
+
def adj_nodes(self, node):
|
232 |
+
''' return a set of adjacent nodes
|
233 |
+
|
234 |
+
Parameters
|
235 |
+
----------
|
236 |
+
node : str
|
237 |
+
|
238 |
+
Returns
|
239 |
+
-------
|
240 |
+
set
|
241 |
+
set of nodes that are adjacent to `node`
|
242 |
+
'''
|
243 |
+
node_set = set([])
|
244 |
+
for each_adj_edge in self.adj_edges(node):
|
245 |
+
node_set.update(set(self.nodes_in_edge(each_adj_edge)))
|
246 |
+
node_set.discard(node)
|
247 |
+
return node_set
|
248 |
+
|
249 |
+
def has_edge(self, node_list, ignore_order=False):
|
250 |
+
for each_edge in self.edges:
|
251 |
+
if ignore_order:
|
252 |
+
if set(self.nodes_in_edge(each_edge)) == set(node_list):
|
253 |
+
return each_edge
|
254 |
+
else:
|
255 |
+
if self.nodes_in_edge(each_edge) == node_list:
|
256 |
+
return each_edge
|
257 |
+
return False
|
258 |
+
|
259 |
+
def degree(self, node):
|
260 |
+
return len(self.hg.adj[node])
|
261 |
+
|
262 |
+
def degrees(self):
|
263 |
+
return {each_node: self.degree(each_node) for each_node in self.nodes}
|
264 |
+
|
265 |
+
def edge_degree(self, edge):
|
266 |
+
return len(self.nodes_in_edge(edge))
|
267 |
+
|
268 |
+
def edge_degrees(self):
|
269 |
+
return {each_edge: self.edge_degree(each_edge) for each_edge in self.edges}
|
270 |
+
|
271 |
+
def is_adj(self, node1, node2):
|
272 |
+
return node1 in self.adj_nodes(node2)
|
273 |
+
|
274 |
+
def adj_subhg(self, node, ident_node_dict=None):
|
275 |
+
""" return a subhypergraph consisting of a set of nodes and hyperedges adjacent to `node`.
|
276 |
+
if an adjacent node has a self-loop hyperedge, it will be also added to the subhypergraph.
|
277 |
+
|
278 |
+
Parameters
|
279 |
+
----------
|
280 |
+
node : str
|
281 |
+
ident_node_dict : dict
|
282 |
+
dict containing identical nodes. see `get_identical_node_dict` for more details
|
283 |
+
|
284 |
+
Returns
|
285 |
+
-------
|
286 |
+
subhg : Hypergraph
|
287 |
+
"""
|
288 |
+
if ident_node_dict is None:
|
289 |
+
ident_node_dict = self.get_identical_node_dict()
|
290 |
+
adj_node_set = set(ident_node_dict[node])
|
291 |
+
adj_edge_set = set([])
|
292 |
+
for each_node in ident_node_dict[node]:
|
293 |
+
adj_edge_set.update(set(self.adj_edges(each_node)))
|
294 |
+
fixed_adj_edge_set = deepcopy(adj_edge_set)
|
295 |
+
for each_edge in fixed_adj_edge_set:
|
296 |
+
other_nodes = self.nodes_in_edge(each_edge)
|
297 |
+
adj_node_set.update(other_nodes)
|
298 |
+
|
299 |
+
# if the adjacent node has self-loop edge, it will be appended to adj_edge_list.
|
300 |
+
for each_node in other_nodes:
|
301 |
+
for other_edge in set(self.adj_edges(each_node)) - set([each_edge]):
|
302 |
+
if len(set(self.nodes_in_edge(other_edge)) \
|
303 |
+
- set(self.nodes_in_edge(each_edge))) == 0:
|
304 |
+
adj_edge_set.update(set([other_edge]))
|
305 |
+
subhg = Hypergraph()
|
306 |
+
for each_node in adj_node_set:
|
307 |
+
subhg.add_node(each_node, attr_dict=self.node_attr(each_node))
|
308 |
+
for each_edge in adj_edge_set:
|
309 |
+
subhg.add_edge(self.nodes_in_edge(each_edge),
|
310 |
+
attr_dict=self.edge_attr(each_edge),
|
311 |
+
edge_name=each_edge)
|
312 |
+
subhg.edge_idx = self.edge_idx
|
313 |
+
return subhg
|
314 |
+
|
315 |
+
def get_subhg(self, node_list, edge_list, ident_node_dict=None):
|
316 |
+
""" return a subhypergraph consisting of a set of nodes and hyperedges adjacent to `node`.
|
317 |
+
if an adjacent node has a self-loop hyperedge, it will be also added to the subhypergraph.
|
318 |
+
|
319 |
+
Parameters
|
320 |
+
----------
|
321 |
+
node : str
|
322 |
+
ident_node_dict : dict
|
323 |
+
dict containing identical nodes. see `get_identical_node_dict` for more details
|
324 |
+
|
325 |
+
Returns
|
326 |
+
-------
|
327 |
+
subhg : Hypergraph
|
328 |
+
"""
|
329 |
+
if ident_node_dict is None:
|
330 |
+
ident_node_dict = self.get_identical_node_dict()
|
331 |
+
adj_node_set = set([])
|
332 |
+
for each_node in node_list:
|
333 |
+
adj_node_set.update(set(ident_node_dict[each_node]))
|
334 |
+
adj_edge_set = set(edge_list)
|
335 |
+
|
336 |
+
subhg = Hypergraph()
|
337 |
+
for each_node in adj_node_set:
|
338 |
+
subhg.add_node(each_node,
|
339 |
+
attr_dict=deepcopy(self.node_attr(each_node)))
|
340 |
+
for each_edge in adj_edge_set:
|
341 |
+
subhg.add_edge(self.nodes_in_edge(each_edge),
|
342 |
+
attr_dict=deepcopy(self.edge_attr(each_edge)),
|
343 |
+
edge_name=each_edge)
|
344 |
+
subhg.edge_idx = self.edge_idx
|
345 |
+
return subhg
|
346 |
+
|
347 |
+
def copy(self):
|
348 |
+
''' return a copy of the object
|
349 |
+
|
350 |
+
Returns
|
351 |
+
-------
|
352 |
+
Hypergraph
|
353 |
+
'''
|
354 |
+
return deepcopy(self)
|
355 |
+
|
356 |
+
def node_attr(self, node):
|
357 |
+
return self.hg.nodes[node]['attr_dict']
|
358 |
+
|
359 |
+
def edge_attr(self, edge):
|
360 |
+
return self.hg.nodes[edge]['attr_dict']
|
361 |
+
|
362 |
+
def set_node_attr(self, node, attr_dict):
|
363 |
+
for each_key, each_val in attr_dict.items():
|
364 |
+
self.hg.nodes[node]['attr_dict'][each_key] = each_val
|
365 |
+
|
366 |
+
def set_edge_attr(self, edge, attr_dict):
|
367 |
+
for each_key, each_val in attr_dict.items():
|
368 |
+
self.hg.nodes[edge]['attr_dict'][each_key] = each_val
|
369 |
+
|
370 |
+
def get_identical_node_dict(self):
|
371 |
+
''' get identical nodes
|
372 |
+
nodes are identical if they share the same set of adjacent edges.
|
373 |
+
|
374 |
+
Returns
|
375 |
+
-------
|
376 |
+
ident_node_dict : dict
|
377 |
+
ident_node_dict[node] returns a list of nodes that are identical to `node`.
|
378 |
+
'''
|
379 |
+
ident_node_dict = {}
|
380 |
+
for each_node in self.nodes:
|
381 |
+
ident_node_list = []
|
382 |
+
for each_other_node in self.nodes:
|
383 |
+
if each_other_node == each_node:
|
384 |
+
ident_node_list.append(each_other_node)
|
385 |
+
elif self.adj_edges(each_node) == self.adj_edges(each_other_node) \
|
386 |
+
and len(self.adj_edges(each_node)) != 0:
|
387 |
+
ident_node_list.append(each_other_node)
|
388 |
+
ident_node_dict[each_node] = ident_node_list
|
389 |
+
return ident_node_dict
|
390 |
+
'''
|
391 |
+
ident_node_dict = {}
|
392 |
+
for each_node in self.nodes:
|
393 |
+
ident_node_dict[each_node] = [each_node]
|
394 |
+
return ident_node_dict
|
395 |
+
'''
|
396 |
+
|
397 |
+
def get_leaf_edge(self):
|
398 |
+
''' get an edge that is incident only to one edge
|
399 |
+
|
400 |
+
Returns
|
401 |
+
-------
|
402 |
+
if exists, return a leaf edge. otherwise, return None.
|
403 |
+
'''
|
404 |
+
for each_edge in self.edges:
|
405 |
+
if len(self.adj_nodes(each_edge)) == 1:
|
406 |
+
if 'tmp' not in self.edge_attr(each_edge):
|
407 |
+
return each_edge
|
408 |
+
return None
|
409 |
+
|
410 |
+
def get_nontmp_edge(self):
|
411 |
+
for each_edge in self.edges:
|
412 |
+
if 'tmp' not in self.edge_attr(each_edge):
|
413 |
+
return each_edge
|
414 |
+
return None
|
415 |
+
|
416 |
+
def is_subhg(self, hg):
|
417 |
+
''' return whether this hypergraph is a subhypergraph of `hg`
|
418 |
+
|
419 |
+
Returns
|
420 |
+
-------
|
421 |
+
True if self \in hg,
|
422 |
+
False otherwise.
|
423 |
+
'''
|
424 |
+
for each_node in self.nodes:
|
425 |
+
if each_node not in hg.nodes:
|
426 |
+
return False
|
427 |
+
for each_edge in self.edges:
|
428 |
+
if each_edge not in hg.edges:
|
429 |
+
return False
|
430 |
+
return True
|
431 |
+
|
432 |
+
def in_cycle(self, node, visited=None, parent='', root_node='') -> bool:
|
433 |
+
''' if `node` is in a cycle, then return True. otherwise, False.
|
434 |
+
|
435 |
+
Parameters
|
436 |
+
----------
|
437 |
+
node : str
|
438 |
+
node in a hypergraph
|
439 |
+
visited : list
|
440 |
+
list of visited nodes, used for recursion
|
441 |
+
parent : str
|
442 |
+
parent node, used to eliminate a cycle consisting of two nodes and one edge.
|
443 |
+
|
444 |
+
Returns
|
445 |
+
-------
|
446 |
+
bool
|
447 |
+
'''
|
448 |
+
if visited is None:
|
449 |
+
visited = []
|
450 |
+
if parent == '':
|
451 |
+
visited = []
|
452 |
+
if root_node == '':
|
453 |
+
root_node = node
|
454 |
+
visited.append(node)
|
455 |
+
for each_adj_node in self.adj_nodes(node):
|
456 |
+
if each_adj_node not in visited:
|
457 |
+
if self.in_cycle(each_adj_node, visited, node, root_node):
|
458 |
+
return True
|
459 |
+
elif each_adj_node != parent and each_adj_node == root_node:
|
460 |
+
return True
|
461 |
+
return False
|
462 |
+
|
463 |
+
|
464 |
+
def draw(self, file_path=None, with_node=False, with_edge_name=False):
|
465 |
+
''' draw hypergraph
|
466 |
+
'''
|
467 |
+
import graphviz
|
468 |
+
G = graphviz.Graph(format='png')
|
469 |
+
for each_node in self.nodes:
|
470 |
+
if 'ext_id' in self.node_attr(each_node):
|
471 |
+
G.node(each_node, label='',
|
472 |
+
shape='circle', width='0.1', height='0.1', style='filled',
|
473 |
+
fillcolor='black')
|
474 |
+
else:
|
475 |
+
if with_node:
|
476 |
+
G.node(each_node, label='',
|
477 |
+
shape='circle', width='0.1', height='0.1', style='filled',
|
478 |
+
fillcolor='gray')
|
479 |
+
edge_list = []
|
480 |
+
for each_edge in self.edges:
|
481 |
+
if self.edge_attr(each_edge).get('terminal', False):
|
482 |
+
G.node(each_edge,
|
483 |
+
label=self.edge_attr(each_edge)['symbol'].symbol if not with_edge_name \
|
484 |
+
else self.edge_attr(each_edge)['symbol'].symbol + ', ' + each_edge,
|
485 |
+
fontcolor='black', shape='square')
|
486 |
+
elif self.edge_attr(each_edge).get('tmp', False):
|
487 |
+
G.node(each_edge, label='tmp' if not with_edge_name else 'tmp, ' + each_edge,
|
488 |
+
fontcolor='black', shape='square')
|
489 |
+
else:
|
490 |
+
G.node(each_edge,
|
491 |
+
label=self.edge_attr(each_edge)['symbol'].symbol if not with_edge_name \
|
492 |
+
else self.edge_attr(each_edge)['symbol'].symbol + ', ' + each_edge,
|
493 |
+
fontcolor='black', shape='square', style='filled')
|
494 |
+
if with_node:
|
495 |
+
for each_node in self.nodes_in_edge(each_edge):
|
496 |
+
G.edge(each_edge, each_node)
|
497 |
+
else:
|
498 |
+
for each_node in self.nodes_in_edge(each_edge):
|
499 |
+
if 'ext_id' in self.node_attr(each_node)\
|
500 |
+
and set([each_node, each_edge]) not in edge_list:
|
501 |
+
G.edge(each_edge, each_node)
|
502 |
+
edge_list.append(set([each_node, each_edge]))
|
503 |
+
for each_other_edge in self.adj_nodes(each_edge):
|
504 |
+
if set([each_edge, each_other_edge]) not in edge_list:
|
505 |
+
num_bond = 0
|
506 |
+
common_node_set = set(self.nodes_in_edge(each_edge))\
|
507 |
+
.intersection(set(self.nodes_in_edge(each_other_edge)))
|
508 |
+
for each_node in common_node_set:
|
509 |
+
if self.node_attr(each_node)['symbol'].bond_type in [1, 2, 3]:
|
510 |
+
num_bond += self.node_attr(each_node)['symbol'].bond_type
|
511 |
+
elif self.node_attr(each_node)['symbol'].bond_type in [12]:
|
512 |
+
num_bond += 1
|
513 |
+
else:
|
514 |
+
raise NotImplementedError('unsupported bond type')
|
515 |
+
for _ in range(num_bond):
|
516 |
+
G.edge(each_edge, each_other_edge)
|
517 |
+
edge_list.append(set([each_edge, each_other_edge]))
|
518 |
+
if file_path is not None:
|
519 |
+
G.render(file_path, cleanup=True)
|
520 |
+
#os.remove(file_path)
|
521 |
+
return G
|
522 |
+
|
523 |
+
def is_dividable(self, node):
|
524 |
+
_hg = deepcopy(self.hg)
|
525 |
+
_hg.remove_node(node)
|
526 |
+
return (not nx.is_connected(_hg))
|
527 |
+
|
528 |
+
def divide(self, node):
|
529 |
+
subhg_list = []
|
530 |
+
|
531 |
+
hg_wo_node = deepcopy(self)
|
532 |
+
hg_wo_node.remove_node(node, remove_connected_edges=False)
|
533 |
+
connected_components = nx.connected_components(hg_wo_node.hg)
|
534 |
+
for each_component in connected_components:
|
535 |
+
node_list = [node]
|
536 |
+
edge_list = []
|
537 |
+
node_list.extend([each_node for each_node in each_component
|
538 |
+
if each_node.startswith('bond_')])
|
539 |
+
edge_list.extend([each_edge for each_edge in each_component
|
540 |
+
if each_edge.startswith('e')])
|
541 |
+
subhg_list.append(self.get_subhg(node_list, edge_list))
|
542 |
+
#subhg_list[-1].set_node_attr(node, {'divided': True})
|
543 |
+
return subhg_list
|
544 |
+
|
graph_grammar/io/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jan 1 2018"
|
20 |
+
|
graph_grammar/io/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (669 Bytes). View file
|
|
graph_grammar/io/__pycache__/smi.cpython-310.pyc
ADDED
Binary file (12.9 kB). View file
|
|
graph_grammar/io/smi.py
ADDED
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jan 12 2018"
|
20 |
+
|
21 |
+
from copy import deepcopy
|
22 |
+
from rdkit import Chem
|
23 |
+
from rdkit import RDLogger
|
24 |
+
import networkx as nx
|
25 |
+
import numpy as np
|
26 |
+
from ..hypergraph import Hypergraph
|
27 |
+
from ..graph_grammar.symbols import TSymbol, BondSymbol
|
28 |
+
|
29 |
+
# supress warnings
|
30 |
+
lg = RDLogger.logger()
|
31 |
+
lg.setLevel(RDLogger.CRITICAL)
|
32 |
+
|
33 |
+
|
34 |
+
class HGGen(object):
|
35 |
+
"""
|
36 |
+
load .smi file and yield a hypergraph.
|
37 |
+
|
38 |
+
Attributes
|
39 |
+
----------
|
40 |
+
path_to_file : str
|
41 |
+
path to .smi file
|
42 |
+
kekulize : bool
|
43 |
+
kekulize or not
|
44 |
+
add_Hs : bool
|
45 |
+
add implicit hydrogens to the molecule or not.
|
46 |
+
all_single : bool
|
47 |
+
if True, all multiple bonds are summarized into a single bond with some attributes
|
48 |
+
|
49 |
+
Yields
|
50 |
+
------
|
51 |
+
Hypergraph
|
52 |
+
"""
|
53 |
+
def __init__(self, path_to_file, kekulize=True, add_Hs=False, all_single=True):
|
54 |
+
self.num_line = 1
|
55 |
+
self.mol_gen = Chem.SmilesMolSupplier(path_to_file, titleLine=False)
|
56 |
+
self.kekulize = kekulize
|
57 |
+
self.add_Hs = add_Hs
|
58 |
+
self.all_single = all_single
|
59 |
+
|
60 |
+
def __iter__(self):
|
61 |
+
return self
|
62 |
+
|
63 |
+
def __next__(self):
|
64 |
+
'''
|
65 |
+
each_mol = None
|
66 |
+
while each_mol is None:
|
67 |
+
each_mol = next(self.mol_gen)
|
68 |
+
'''
|
69 |
+
# not ignoring parse errors
|
70 |
+
each_mol = next(self.mol_gen)
|
71 |
+
if each_mol is None:
|
72 |
+
raise ValueError(f'incorrect smiles in line {self.num_line}')
|
73 |
+
else:
|
74 |
+
self.num_line += 1
|
75 |
+
return mol_to_hg(each_mol, self.kekulize, self.add_Hs)
|
76 |
+
|
77 |
+
|
78 |
+
def mol_to_bipartite(mol, kekulize):
|
79 |
+
"""
|
80 |
+
get a bipartite representation of a molecule.
|
81 |
+
|
82 |
+
Parameters
|
83 |
+
----------
|
84 |
+
mol : rdkit.Chem.rdchem.Mol
|
85 |
+
molecule object
|
86 |
+
|
87 |
+
Returns
|
88 |
+
-------
|
89 |
+
nx.Graph
|
90 |
+
a bipartite graph representing which bond is connected to which atoms.
|
91 |
+
"""
|
92 |
+
try:
|
93 |
+
mol = standardize_stereo(mol)
|
94 |
+
except KeyError:
|
95 |
+
print(Chem.MolToSmiles(mol))
|
96 |
+
raise KeyError
|
97 |
+
|
98 |
+
if kekulize:
|
99 |
+
Chem.Kekulize(mol)
|
100 |
+
|
101 |
+
bipartite_g = nx.Graph()
|
102 |
+
for each_atom in mol.GetAtoms():
|
103 |
+
bipartite_g.add_node(f"atom_{each_atom.GetIdx()}",
|
104 |
+
atom_attr=atom_attr(each_atom, kekulize))
|
105 |
+
|
106 |
+
for each_bond in mol.GetBonds():
|
107 |
+
bond_idx = each_bond.GetIdx()
|
108 |
+
bipartite_g.add_node(
|
109 |
+
f"bond_{bond_idx}",
|
110 |
+
bond_attr=bond_attr(each_bond, kekulize))
|
111 |
+
bipartite_g.add_edge(
|
112 |
+
f"atom_{each_bond.GetBeginAtomIdx()}",
|
113 |
+
f"bond_{bond_idx}")
|
114 |
+
bipartite_g.add_edge(
|
115 |
+
f"atom_{each_bond.GetEndAtomIdx()}",
|
116 |
+
f"bond_{bond_idx}")
|
117 |
+
return bipartite_g
|
118 |
+
|
119 |
+
|
120 |
+
def mol_to_hg(mol, kekulize, add_Hs):
|
121 |
+
"""
|
122 |
+
get a bipartite representation of a molecule.
|
123 |
+
|
124 |
+
Parameters
|
125 |
+
----------
|
126 |
+
mol : rdkit.Chem.rdchem.Mol
|
127 |
+
molecule object
|
128 |
+
kekulize : bool
|
129 |
+
kekulize or not
|
130 |
+
add_Hs : bool
|
131 |
+
add implicit hydrogens to the molecule or not.
|
132 |
+
|
133 |
+
Returns
|
134 |
+
-------
|
135 |
+
Hypergraph
|
136 |
+
"""
|
137 |
+
if add_Hs:
|
138 |
+
mol = Chem.AddHs(mol)
|
139 |
+
|
140 |
+
if kekulize:
|
141 |
+
Chem.Kekulize(mol)
|
142 |
+
|
143 |
+
bipartite_g = mol_to_bipartite(mol, kekulize)
|
144 |
+
hg = Hypergraph()
|
145 |
+
for each_atom in [each_node for each_node in bipartite_g.nodes()
|
146 |
+
if each_node.startswith('atom_')]:
|
147 |
+
node_set = set([])
|
148 |
+
for each_bond in bipartite_g.adj[each_atom]:
|
149 |
+
hg.add_node(each_bond,
|
150 |
+
attr_dict=bipartite_g.nodes[each_bond]['bond_attr'])
|
151 |
+
node_set.add(each_bond)
|
152 |
+
hg.add_edge(node_set,
|
153 |
+
attr_dict=bipartite_g.nodes[each_atom]['atom_attr'])
|
154 |
+
return hg
|
155 |
+
|
156 |
+
|
157 |
+
def hg_to_mol(hg, verbose=False):
|
158 |
+
""" convert a hypergraph into Mol object
|
159 |
+
|
160 |
+
Parameters
|
161 |
+
----------
|
162 |
+
hg : Hypergraph
|
163 |
+
|
164 |
+
Returns
|
165 |
+
-------
|
166 |
+
mol : Chem.RWMol
|
167 |
+
"""
|
168 |
+
mol = Chem.RWMol()
|
169 |
+
atom_dict = {}
|
170 |
+
bond_set = set([])
|
171 |
+
for each_edge in hg.edges:
|
172 |
+
atom = Chem.Atom(hg.edge_attr(each_edge)['symbol'].symbol)
|
173 |
+
atom.SetNumExplicitHs(hg.edge_attr(each_edge)['symbol'].num_explicit_Hs)
|
174 |
+
atom.SetFormalCharge(hg.edge_attr(each_edge)['symbol'].formal_charge)
|
175 |
+
atom.SetChiralTag(
|
176 |
+
Chem.rdchem.ChiralType.values[
|
177 |
+
hg.edge_attr(each_edge)['symbol'].chirality])
|
178 |
+
atom_idx = mol.AddAtom(atom)
|
179 |
+
atom_dict[each_edge] = atom_idx
|
180 |
+
|
181 |
+
for each_node in hg.nodes:
|
182 |
+
edge_1, edge_2 = hg.adj_edges(each_node)
|
183 |
+
if edge_1+edge_2 not in bond_set:
|
184 |
+
if hg.node_attr(each_node)['symbol'].bond_type <= 3:
|
185 |
+
num_bond = hg.node_attr(each_node)['symbol'].bond_type
|
186 |
+
elif hg.node_attr(each_node)['symbol'].bond_type == 12:
|
187 |
+
num_bond = 1
|
188 |
+
else:
|
189 |
+
raise ValueError(f'too many bonds; {hg.node_attr(each_node)["bond_symbol"].bond_type}')
|
190 |
+
_ = mol.AddBond(atom_dict[edge_1],
|
191 |
+
atom_dict[edge_2],
|
192 |
+
order=Chem.rdchem.BondType.values[num_bond])
|
193 |
+
bond_idx = mol.GetBondBetweenAtoms(atom_dict[edge_1], atom_dict[edge_2]).GetIdx()
|
194 |
+
|
195 |
+
# stereo
|
196 |
+
mol.GetBondWithIdx(bond_idx).SetStereo(
|
197 |
+
Chem.rdchem.BondStereo.values[hg.node_attr(each_node)['symbol'].stereo])
|
198 |
+
bond_set.update([edge_1+edge_2])
|
199 |
+
bond_set.update([edge_2+edge_1])
|
200 |
+
mol.UpdatePropertyCache()
|
201 |
+
mol = mol.GetMol()
|
202 |
+
not_stereo_mol = deepcopy(mol)
|
203 |
+
if Chem.MolFromSmiles(Chem.MolToSmiles(not_stereo_mol)) is None:
|
204 |
+
raise RuntimeError('no valid molecule was obtained.')
|
205 |
+
try:
|
206 |
+
mol = set_stereo(mol)
|
207 |
+
is_stereo = True
|
208 |
+
except:
|
209 |
+
import traceback
|
210 |
+
traceback.print_exc()
|
211 |
+
is_stereo = False
|
212 |
+
mol_tmp = deepcopy(mol)
|
213 |
+
Chem.SetAromaticity(mol_tmp)
|
214 |
+
if Chem.MolFromSmiles(Chem.MolToSmiles(mol_tmp)) is not None:
|
215 |
+
mol = mol_tmp
|
216 |
+
else:
|
217 |
+
if Chem.MolFromSmiles(Chem.MolToSmiles(mol)) is None:
|
218 |
+
mol = not_stereo_mol
|
219 |
+
mol.UpdatePropertyCache()
|
220 |
+
Chem.GetSymmSSSR(mol)
|
221 |
+
mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
|
222 |
+
if verbose:
|
223 |
+
return mol, is_stereo
|
224 |
+
else:
|
225 |
+
return mol
|
226 |
+
|
227 |
+
def hgs_to_mols(hg_list, ignore_error=False):
|
228 |
+
if ignore_error:
|
229 |
+
mol_list = []
|
230 |
+
for each_hg in hg_list:
|
231 |
+
try:
|
232 |
+
mol = hg_to_mol(each_hg)
|
233 |
+
except:
|
234 |
+
mol = None
|
235 |
+
mol_list.append(mol)
|
236 |
+
else:
|
237 |
+
mol_list = [hg_to_mol(each_hg) for each_hg in hg_list]
|
238 |
+
return mol_list
|
239 |
+
|
240 |
+
def hgs_to_smiles(hg_list, ignore_error=False):
|
241 |
+
mol_list = hgs_to_mols(hg_list, ignore_error)
|
242 |
+
smiles_list = []
|
243 |
+
for each_mol in mol_list:
|
244 |
+
try:
|
245 |
+
smiles_list.append(
|
246 |
+
Chem.MolToSmiles(
|
247 |
+
Chem.MolFromSmiles(
|
248 |
+
Chem.MolToSmiles(
|
249 |
+
each_mol))))
|
250 |
+
except:
|
251 |
+
smiles_list.append(None)
|
252 |
+
return smiles_list
|
253 |
+
|
254 |
+
def atom_attr(atom, kekulize):
|
255 |
+
"""
|
256 |
+
get atom's attributes
|
257 |
+
|
258 |
+
Parameters
|
259 |
+
----------
|
260 |
+
atom : rdkit.Chem.rdchem.Atom
|
261 |
+
kekulize : bool
|
262 |
+
kekulize or not
|
263 |
+
|
264 |
+
Returns
|
265 |
+
-------
|
266 |
+
atom_attr : dict
|
267 |
+
"is_aromatic" : bool
|
268 |
+
the atom is aromatic or not.
|
269 |
+
"smarts" : str
|
270 |
+
SMARTS representation of the atom.
|
271 |
+
"""
|
272 |
+
if kekulize:
|
273 |
+
return {'terminal': True,
|
274 |
+
'is_in_ring': atom.IsInRing(),
|
275 |
+
'symbol': TSymbol(degree=0,
|
276 |
+
#degree=atom.GetTotalDegree(),
|
277 |
+
is_aromatic=False,
|
278 |
+
symbol=atom.GetSymbol(),
|
279 |
+
num_explicit_Hs=atom.GetNumExplicitHs(),
|
280 |
+
formal_charge=atom.GetFormalCharge(),
|
281 |
+
chirality=atom.GetChiralTag().real
|
282 |
+
)}
|
283 |
+
else:
|
284 |
+
return {'terminal': True,
|
285 |
+
'is_in_ring': atom.IsInRing(),
|
286 |
+
'symbol': TSymbol(degree=0,
|
287 |
+
#degree=atom.GetTotalDegree(),
|
288 |
+
is_aromatic=atom.GetIsAromatic(),
|
289 |
+
symbol=atom.GetSymbol(),
|
290 |
+
num_explicit_Hs=atom.GetNumExplicitHs(),
|
291 |
+
formal_charge=atom.GetFormalCharge(),
|
292 |
+
chirality=atom.GetChiralTag().real
|
293 |
+
)}
|
294 |
+
|
295 |
+
def bond_attr(bond, kekulize):
|
296 |
+
"""
|
297 |
+
get atom's attributes
|
298 |
+
|
299 |
+
Parameters
|
300 |
+
----------
|
301 |
+
bond : rdkit.Chem.rdchem.Bond
|
302 |
+
kekulize : bool
|
303 |
+
kekulize or not
|
304 |
+
|
305 |
+
Returns
|
306 |
+
-------
|
307 |
+
bond_attr : dict
|
308 |
+
"bond_type" : int
|
309 |
+
{0: rdkit.Chem.rdchem.BondType.UNSPECIFIED,
|
310 |
+
1: rdkit.Chem.rdchem.BondType.SINGLE,
|
311 |
+
2: rdkit.Chem.rdchem.BondType.DOUBLE,
|
312 |
+
3: rdkit.Chem.rdchem.BondType.TRIPLE,
|
313 |
+
4: rdkit.Chem.rdchem.BondType.QUADRUPLE,
|
314 |
+
5: rdkit.Chem.rdchem.BondType.QUINTUPLE,
|
315 |
+
6: rdkit.Chem.rdchem.BondType.HEXTUPLE,
|
316 |
+
7: rdkit.Chem.rdchem.BondType.ONEANDAHALF,
|
317 |
+
8: rdkit.Chem.rdchem.BondType.TWOANDAHALF,
|
318 |
+
9: rdkit.Chem.rdchem.BondType.THREEANDAHALF,
|
319 |
+
10: rdkit.Chem.rdchem.BondType.FOURANDAHALF,
|
320 |
+
11: rdkit.Chem.rdchem.BondType.FIVEANDAHALF,
|
321 |
+
12: rdkit.Chem.rdchem.BondType.AROMATIC,
|
322 |
+
13: rdkit.Chem.rdchem.BondType.IONIC,
|
323 |
+
14: rdkit.Chem.rdchem.BondType.HYDROGEN,
|
324 |
+
15: rdkit.Chem.rdchem.BondType.THREECENTER,
|
325 |
+
16: rdkit.Chem.rdchem.BondType.DATIVEONE,
|
326 |
+
17: rdkit.Chem.rdchem.BondType.DATIVE,
|
327 |
+
18: rdkit.Chem.rdchem.BondType.DATIVEL,
|
328 |
+
19: rdkit.Chem.rdchem.BondType.DATIVER,
|
329 |
+
20: rdkit.Chem.rdchem.BondType.OTHER,
|
330 |
+
21: rdkit.Chem.rdchem.BondType.ZERO}
|
331 |
+
"""
|
332 |
+
if kekulize:
|
333 |
+
is_aromatic = False
|
334 |
+
if bond.GetBondType().real == 12:
|
335 |
+
bond_type = 1
|
336 |
+
else:
|
337 |
+
bond_type = bond.GetBondType().real
|
338 |
+
else:
|
339 |
+
is_aromatic = bond.GetIsAromatic()
|
340 |
+
bond_type = bond.GetBondType().real
|
341 |
+
return {'symbol': BondSymbol(is_aromatic=is_aromatic,
|
342 |
+
bond_type=bond_type,
|
343 |
+
stereo=int(bond.GetStereo())),
|
344 |
+
'is_in_ring': bond.IsInRing()}
|
345 |
+
|
346 |
+
|
347 |
+
def standardize_stereo(mol):
|
348 |
+
'''
|
349 |
+
0: rdkit.Chem.rdchem.BondDir.NONE,
|
350 |
+
1: rdkit.Chem.rdchem.BondDir.BEGINWEDGE,
|
351 |
+
2: rdkit.Chem.rdchem.BondDir.BEGINDASH,
|
352 |
+
3: rdkit.Chem.rdchem.BondDir.ENDDOWNRIGHT,
|
353 |
+
4: rdkit.Chem.rdchem.BondDir.ENDUPRIGHT,
|
354 |
+
|
355 |
+
'''
|
356 |
+
# mol = Chem.AddHs(mol) # this removes CIPRank !!!
|
357 |
+
for each_bond in mol.GetBonds():
|
358 |
+
if int(each_bond.GetStereo()) in [2, 3]: #2=Z (same side), 3=E
|
359 |
+
begin_stereo_atom_idx = each_bond.GetBeginAtomIdx()
|
360 |
+
end_stereo_atom_idx = each_bond.GetEndAtomIdx()
|
361 |
+
atom_idx_1 = each_bond.GetStereoAtoms()[0]
|
362 |
+
atom_idx_2 = each_bond.GetStereoAtoms()[1]
|
363 |
+
if mol.GetBondBetweenAtoms(atom_idx_1, begin_stereo_atom_idx):
|
364 |
+
begin_atom_idx = atom_idx_1
|
365 |
+
end_atom_idx = atom_idx_2
|
366 |
+
else:
|
367 |
+
begin_atom_idx = atom_idx_2
|
368 |
+
end_atom_idx = atom_idx_1
|
369 |
+
|
370 |
+
begin_another_atom_idx = None
|
371 |
+
assert len(mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors()) <= 3
|
372 |
+
for each_neighbor in mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors():
|
373 |
+
each_neighbor_idx = each_neighbor.GetIdx()
|
374 |
+
if each_neighbor_idx not in [end_stereo_atom_idx, begin_atom_idx]:
|
375 |
+
begin_another_atom_idx = each_neighbor_idx
|
376 |
+
|
377 |
+
end_another_atom_idx = None
|
378 |
+
assert len(mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors()) <= 3
|
379 |
+
for each_neighbor in mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors():
|
380 |
+
each_neighbor_idx = each_neighbor.GetIdx()
|
381 |
+
if each_neighbor_idx not in [begin_stereo_atom_idx, end_atom_idx]:
|
382 |
+
end_another_atom_idx = each_neighbor_idx
|
383 |
+
|
384 |
+
'''
|
385 |
+
relationship between begin_atom_idx and end_atom_idx is encoded in GetStereo
|
386 |
+
'''
|
387 |
+
begin_atom_rank = int(mol.GetAtomWithIdx(begin_atom_idx).GetProp('_CIPRank'))
|
388 |
+
end_atom_rank = int(mol.GetAtomWithIdx(end_atom_idx).GetProp('_CIPRank'))
|
389 |
+
try:
|
390 |
+
begin_another_atom_rank = int(mol.GetAtomWithIdx(begin_another_atom_idx).GetProp('_CIPRank'))
|
391 |
+
except:
|
392 |
+
begin_another_atom_rank = np.inf
|
393 |
+
try:
|
394 |
+
end_another_atom_rank = int(mol.GetAtomWithIdx(end_another_atom_idx).GetProp('_CIPRank'))
|
395 |
+
except:
|
396 |
+
end_another_atom_rank = np.inf
|
397 |
+
if begin_atom_rank < begin_another_atom_rank\
|
398 |
+
and end_atom_rank < end_another_atom_rank:
|
399 |
+
pass
|
400 |
+
elif begin_atom_rank < begin_another_atom_rank\
|
401 |
+
and end_atom_rank > end_another_atom_rank:
|
402 |
+
# (begin_atom_idx +) end_another_atom_idx should be in StereoAtoms
|
403 |
+
if each_bond.GetStereo() == 2:
|
404 |
+
# set stereo
|
405 |
+
each_bond.SetStereo(Chem.rdchem.BondStereo.values[3])
|
406 |
+
# set bond dir
|
407 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
|
408 |
+
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 0)
|
409 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
|
410 |
+
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 3)
|
411 |
+
elif each_bond.GetStereo() == 3:
|
412 |
+
# set stereo
|
413 |
+
each_bond.SetStereo(Chem.rdchem.BondStereo.values[2])
|
414 |
+
# set bond dir
|
415 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
|
416 |
+
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 0)
|
417 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
|
418 |
+
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 4)
|
419 |
+
else:
|
420 |
+
raise ValueError
|
421 |
+
each_bond.SetStereoAtoms(begin_atom_idx, end_another_atom_idx)
|
422 |
+
elif begin_atom_rank > begin_another_atom_rank\
|
423 |
+
and end_atom_rank < end_another_atom_rank:
|
424 |
+
# (end_atom_idx +) begin_another_atom_idx should be in StereoAtoms
|
425 |
+
if each_bond.GetStereo() == 2:
|
426 |
+
# set stereo
|
427 |
+
each_bond.SetStereo(Chem.rdchem.BondStereo.values[3])
|
428 |
+
# set bond dir
|
429 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
|
430 |
+
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
|
431 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 4)
|
432 |
+
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 0)
|
433 |
+
elif each_bond.GetStereo() == 3:
|
434 |
+
# set stereo
|
435 |
+
each_bond.SetStereo(Chem.rdchem.BondStereo.values[2])
|
436 |
+
# set bond dir
|
437 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
|
438 |
+
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
|
439 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 3)
|
440 |
+
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 0)
|
441 |
+
else:
|
442 |
+
raise ValueError
|
443 |
+
each_bond.SetStereoAtoms(begin_another_atom_idx, end_atom_idx)
|
444 |
+
elif begin_atom_rank > begin_another_atom_rank\
|
445 |
+
and end_atom_rank > end_another_atom_rank:
|
446 |
+
# begin_another_atom_idx + end_another_atom_idx should be in StereoAtoms
|
447 |
+
if each_bond.GetStereo() == 2:
|
448 |
+
# set bond dir
|
449 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
|
450 |
+
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
|
451 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
|
452 |
+
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 3)
|
453 |
+
elif each_bond.GetStereo() == 3:
|
454 |
+
# set bond dir
|
455 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
|
456 |
+
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
|
457 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
|
458 |
+
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 4)
|
459 |
+
else:
|
460 |
+
raise ValueError
|
461 |
+
each_bond.SetStereoAtoms(begin_another_atom_idx, end_another_atom_idx)
|
462 |
+
else:
|
463 |
+
raise RuntimeError
|
464 |
+
return mol
|
465 |
+
|
466 |
+
|
467 |
+
def set_stereo(mol):
|
468 |
+
'''
|
469 |
+
0: rdkit.Chem.rdchem.BondDir.NONE,
|
470 |
+
1: rdkit.Chem.rdchem.BondDir.BEGINWEDGE,
|
471 |
+
2: rdkit.Chem.rdchem.BondDir.BEGINDASH,
|
472 |
+
3: rdkit.Chem.rdchem.BondDir.ENDDOWNRIGHT,
|
473 |
+
4: rdkit.Chem.rdchem.BondDir.ENDUPRIGHT,
|
474 |
+
'''
|
475 |
+
_mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
|
476 |
+
Chem.Kekulize(_mol, True)
|
477 |
+
substruct_match = mol.GetSubstructMatch(_mol)
|
478 |
+
if not substruct_match:
|
479 |
+
''' mol and _mol are kekulized.
|
480 |
+
sometimes, the order of '=' and '-' changes, which causes mol and _mol not matched.
|
481 |
+
'''
|
482 |
+
Chem.SetAromaticity(mol)
|
483 |
+
Chem.SetAromaticity(_mol)
|
484 |
+
substruct_match = mol.GetSubstructMatch(_mol)
|
485 |
+
try:
|
486 |
+
atom_match = {substruct_match[_mol_atom_idx]: _mol_atom_idx for _mol_atom_idx in range(_mol.GetNumAtoms())} # mol to _mol
|
487 |
+
except:
|
488 |
+
raise ValueError('two molecules obtained from the same data do not match.')
|
489 |
+
|
490 |
+
for each_bond in mol.GetBonds():
|
491 |
+
begin_atom_idx = each_bond.GetBeginAtomIdx()
|
492 |
+
end_atom_idx = each_bond.GetEndAtomIdx()
|
493 |
+
_bond = _mol.GetBondBetweenAtoms(atom_match[begin_atom_idx], atom_match[end_atom_idx])
|
494 |
+
_bond.SetStereo(each_bond.GetStereo())
|
495 |
+
|
496 |
+
mol = _mol
|
497 |
+
for each_bond in mol.GetBonds():
|
498 |
+
if int(each_bond.GetStereo()) in [2, 3]: #2=Z (same side), 3=E
|
499 |
+
begin_stereo_atom_idx = each_bond.GetBeginAtomIdx()
|
500 |
+
end_stereo_atom_idx = each_bond.GetEndAtomIdx()
|
501 |
+
begin_atom_idx_set = set([each_neighbor.GetIdx()
|
502 |
+
for each_neighbor
|
503 |
+
in mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors()
|
504 |
+
if each_neighbor.GetIdx() != end_stereo_atom_idx])
|
505 |
+
end_atom_idx_set = set([each_neighbor.GetIdx()
|
506 |
+
for each_neighbor
|
507 |
+
in mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors()
|
508 |
+
if each_neighbor.GetIdx() != begin_stereo_atom_idx])
|
509 |
+
if not begin_atom_idx_set:
|
510 |
+
each_bond.SetStereo(Chem.rdchem.BondStereo(0))
|
511 |
+
continue
|
512 |
+
if not end_atom_idx_set:
|
513 |
+
each_bond.SetStereo(Chem.rdchem.BondStereo(0))
|
514 |
+
continue
|
515 |
+
if len(begin_atom_idx_set) == 1:
|
516 |
+
begin_atom_idx = begin_atom_idx_set.pop()
|
517 |
+
begin_another_atom_idx = None
|
518 |
+
if len(end_atom_idx_set) == 1:
|
519 |
+
end_atom_idx = end_atom_idx_set.pop()
|
520 |
+
end_another_atom_idx = None
|
521 |
+
if len(begin_atom_idx_set) == 2:
|
522 |
+
atom_idx_1 = begin_atom_idx_set.pop()
|
523 |
+
atom_idx_2 = begin_atom_idx_set.pop()
|
524 |
+
if int(mol.GetAtomWithIdx(atom_idx_1).GetProp('_CIPRank')) < int(mol.GetAtomWithIdx(atom_idx_2).GetProp('_CIPRank')):
|
525 |
+
begin_atom_idx = atom_idx_1
|
526 |
+
begin_another_atom_idx = atom_idx_2
|
527 |
+
else:
|
528 |
+
begin_atom_idx = atom_idx_2
|
529 |
+
begin_another_atom_idx = atom_idx_1
|
530 |
+
if len(end_atom_idx_set) == 2:
|
531 |
+
atom_idx_1 = end_atom_idx_set.pop()
|
532 |
+
atom_idx_2 = end_atom_idx_set.pop()
|
533 |
+
if int(mol.GetAtomWithIdx(atom_idx_1).GetProp('_CIPRank')) < int(mol.GetAtomWithIdx(atom_idx_2).GetProp('_CIPRank')):
|
534 |
+
end_atom_idx = atom_idx_1
|
535 |
+
end_another_atom_idx = atom_idx_2
|
536 |
+
else:
|
537 |
+
end_atom_idx = atom_idx_2
|
538 |
+
end_another_atom_idx = atom_idx_1
|
539 |
+
|
540 |
+
if each_bond.GetStereo() == 2: # same side
|
541 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
|
542 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 4)
|
543 |
+
each_bond.SetStereoAtoms(begin_atom_idx, end_atom_idx)
|
544 |
+
elif each_bond.GetStereo() == 3: # opposite side
|
545 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
|
546 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 3)
|
547 |
+
each_bond.SetStereoAtoms(begin_atom_idx, end_atom_idx)
|
548 |
+
else:
|
549 |
+
raise ValueError
|
550 |
+
return mol
|
551 |
+
|
552 |
+
|
553 |
+
def safe_set_bond_dir(mol, atom_idx_1, atom_idx_2, bond_dir_val):
|
554 |
+
if atom_idx_1 is None or atom_idx_2 is None:
|
555 |
+
return mol
|
556 |
+
else:
|
557 |
+
mol.GetBondBetweenAtoms(atom_idx_1, atom_idx_2).SetBondDir(Chem.rdchem.BondDir.values[bond_dir_val])
|
558 |
+
return mol
|
559 |
+
|
graph_grammar/nn/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
# Rhizome
|
3 |
+
# Version beta 0.0, August 2023
|
4 |
+
# Property of IBM Research, Accelerated Discovery
|
5 |
+
#
|
6 |
+
|
7 |
+
"""
|
8 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
9 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
10 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
11 |
+
"""
|
graph_grammar/nn/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (508 Bytes). View file
|
|
graph_grammar/nn/__pycache__/decoder.cpython-310.pyc
ADDED
Binary file (3.98 kB). View file
|
|
graph_grammar/nn/__pycache__/encoder.cpython-310.pyc
ADDED
Binary file (5.38 kB). View file
|
|
graph_grammar/nn/dataset.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Apr 18 2018"
|
20 |
+
|
21 |
+
from torch.utils.data import Dataset, DataLoader
|
22 |
+
import torch
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
|
26 |
+
def left_padding(sentence_list, max_len, pad_idx=-1, inverse=False):
|
27 |
+
''' pad left
|
28 |
+
|
29 |
+
Parameters
|
30 |
+
----------
|
31 |
+
sentence_list : list of sequences of integers
|
32 |
+
max_len : int
|
33 |
+
maximum length of sentences.
|
34 |
+
if a sentence is shorter than `max_len`, its left part is padded.
|
35 |
+
pad_idx : int
|
36 |
+
integer for padding
|
37 |
+
inverse : bool
|
38 |
+
if True, the sequence is inversed.
|
39 |
+
|
40 |
+
Returns
|
41 |
+
-------
|
42 |
+
List of torch.LongTensor
|
43 |
+
each sentence is left-padded.
|
44 |
+
'''
|
45 |
+
max_in_list = max([len(each_sen) for each_sen in sentence_list])
|
46 |
+
|
47 |
+
if max_in_list > max_len:
|
48 |
+
raise ValueError('`max_len` should be larger than the maximum length of input sequences, {}.'.format(max_in_list))
|
49 |
+
|
50 |
+
if inverse:
|
51 |
+
return [torch.LongTensor([pad_idx] * (max_len - len(each_sen)) + each_sen[::-1]) for each_sen in sentence_list]
|
52 |
+
else:
|
53 |
+
return [torch.LongTensor([pad_idx] * (max_len - len(each_sen)) + each_sen) for each_sen in sentence_list]
|
54 |
+
|
55 |
+
|
56 |
+
def right_padding(sentence_list, max_len, pad_idx=-1):
|
57 |
+
''' pad right
|
58 |
+
|
59 |
+
Parameters
|
60 |
+
----------
|
61 |
+
sentence_list : list of sequences of integers
|
62 |
+
max_len : int
|
63 |
+
maximum length of sentences.
|
64 |
+
if a sentence is shorter than `max_len`, its right part is padded.
|
65 |
+
pad_idx : int
|
66 |
+
integer for padding
|
67 |
+
|
68 |
+
Returns
|
69 |
+
-------
|
70 |
+
List of torch.LongTensor
|
71 |
+
each sentence is right-padded.
|
72 |
+
'''
|
73 |
+
max_in_list = max([len(each_sen) for each_sen in sentence_list])
|
74 |
+
if max_in_list > max_len:
|
75 |
+
raise ValueError('`max_len` should be larger than the maximum length of input sequences, {}.'.format(max_in_list))
|
76 |
+
|
77 |
+
return [torch.LongTensor(each_sen + [pad_idx] * (max_len - len(each_sen))) for each_sen in sentence_list]
|
78 |
+
|
79 |
+
|
80 |
+
class HRGDataset(Dataset):
|
81 |
+
|
82 |
+
'''
|
83 |
+
A class of HRG data
|
84 |
+
'''
|
85 |
+
|
86 |
+
def __init__(self, hrg, prod_rule_seq_list, max_len, target_val_list=None, inversed_input=False):
|
87 |
+
self.hrg = hrg
|
88 |
+
self.left_prod_rule_seq_list = left_padding(prod_rule_seq_list,
|
89 |
+
max_len,
|
90 |
+
inverse=inversed_input)
|
91 |
+
|
92 |
+
self.right_prod_rule_seq_list = right_padding(prod_rule_seq_list, max_len)
|
93 |
+
self.inserved_input = inversed_input
|
94 |
+
self.target_val_list = target_val_list
|
95 |
+
if target_val_list is not None:
|
96 |
+
if len(prod_rule_seq_list) != len(target_val_list):
|
97 |
+
raise ValueError(f'prod_rule_seq_list and target_val_list have inconsistent lengths: {len(prod_rule_seq_list)}, {len(target_val_list)}')
|
98 |
+
|
99 |
+
def __len__(self):
|
100 |
+
return len(self.left_prod_rule_seq_list)
|
101 |
+
|
102 |
+
def __getitem__(self, idx):
|
103 |
+
if self.target_val_list is not None:
|
104 |
+
return self.left_prod_rule_seq_list[idx], self.right_prod_rule_seq_list[idx], np.float32(self.target_val_list[idx])
|
105 |
+
else:
|
106 |
+
return self.left_prod_rule_seq_list[idx], self.right_prod_rule_seq_list[idx]
|
107 |
+
|
108 |
+
@property
|
109 |
+
def vocab_size(self):
|
110 |
+
return self.hrg.num_prod_rule
|
111 |
+
|
112 |
+
def batch_padding(each_batch, batch_size, padding_idx):
|
113 |
+
num_pad = batch_size - len(each_batch[0])
|
114 |
+
if num_pad:
|
115 |
+
each_batch[0] = torch.cat([each_batch[0],
|
116 |
+
padding_idx * torch.ones((batch_size - len(each_batch[0]),
|
117 |
+
len(each_batch[0][0])), dtype=torch.int64)], dim=0)
|
118 |
+
each_batch[1] = torch.cat([each_batch[1],
|
119 |
+
padding_idx * torch.ones((batch_size - len(each_batch[1]),
|
120 |
+
len(each_batch[1][0])), dtype=torch.int64)], dim=0)
|
121 |
+
return each_batch, num_pad
|
graph_grammar/nn/decoder.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Aug 9 2018"
|
20 |
+
|
21 |
+
|
22 |
+
import abc
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
from torch import nn
|
26 |
+
|
27 |
+
|
28 |
+
class DecoderBase(nn.Module):
|
29 |
+
|
30 |
+
def __init__(self):
|
31 |
+
super().__init__()
|
32 |
+
self.hidden_dict = {}
|
33 |
+
|
34 |
+
@abc.abstractmethod
|
35 |
+
def forward_one_step(self, tgt_emb_in):
|
36 |
+
''' one-step forward model
|
37 |
+
|
38 |
+
Parameters
|
39 |
+
----------
|
40 |
+
tgt_emb_in : Tensor, shape (batch_size, input_dim)
|
41 |
+
|
42 |
+
Returns
|
43 |
+
-------
|
44 |
+
Tensor, shape (batch_size, hidden_dim)
|
45 |
+
'''
|
46 |
+
tgt_emb_out = None
|
47 |
+
return tgt_emb_out
|
48 |
+
|
49 |
+
@abc.abstractmethod
|
50 |
+
def init_hidden(self):
|
51 |
+
''' initialize the hidden states
|
52 |
+
'''
|
53 |
+
pass
|
54 |
+
|
55 |
+
@abc.abstractmethod
|
56 |
+
def feed_hidden(self, hidden_dict_0):
|
57 |
+
for each_hidden in self.hidden_dict.keys():
|
58 |
+
self.hidden_dict[each_hidden][0] = hidden_dict_0[each_hidden]
|
59 |
+
|
60 |
+
|
61 |
+
class GRUDecoder(DecoderBase):
|
62 |
+
|
63 |
+
def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
|
64 |
+
dropout: float, batch_size: int, use_gpu: bool,
|
65 |
+
no_dropout=False):
|
66 |
+
super().__init__()
|
67 |
+
self.input_dim = input_dim
|
68 |
+
self.hidden_dim = hidden_dim
|
69 |
+
self.num_layers = num_layers
|
70 |
+
self.dropout = dropout
|
71 |
+
self.batch_size = batch_size
|
72 |
+
self.use_gpu = use_gpu
|
73 |
+
self.model = nn.GRU(input_size=self.input_dim,
|
74 |
+
hidden_size=self.hidden_dim,
|
75 |
+
num_layers=self.num_layers,
|
76 |
+
batch_first=True,
|
77 |
+
bidirectional=False,
|
78 |
+
dropout=self.dropout if not no_dropout else 0
|
79 |
+
)
|
80 |
+
if self.use_gpu:
|
81 |
+
self.model.cuda()
|
82 |
+
self.init_hidden()
|
83 |
+
|
84 |
+
def init_hidden(self):
|
85 |
+
self.hidden_dict['h'] = torch.zeros((self.num_layers,
|
86 |
+
self.batch_size,
|
87 |
+
self.hidden_dim),
|
88 |
+
requires_grad=False)
|
89 |
+
if self.use_gpu:
|
90 |
+
self.hidden_dict['h'] = self.hidden_dict['h'].cuda()
|
91 |
+
|
92 |
+
def forward_one_step(self, tgt_emb_in):
|
93 |
+
''' one-step forward model
|
94 |
+
|
95 |
+
Parameters
|
96 |
+
----------
|
97 |
+
tgt_emb_in : Tensor, shape (batch_size, input_dim)
|
98 |
+
|
99 |
+
Returns
|
100 |
+
-------
|
101 |
+
Tensor, shape (batch_size, hidden_dim)
|
102 |
+
'''
|
103 |
+
tgt_emb_out, self.hidden_dict['h'] \
|
104 |
+
= self.model(tgt_emb_in.view(self.batch_size, 1, -1),
|
105 |
+
self.hidden_dict['h'])
|
106 |
+
return tgt_emb_out
|
107 |
+
|
108 |
+
|
109 |
+
class LSTMDecoder(DecoderBase):
|
110 |
+
|
111 |
+
def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
|
112 |
+
dropout: float, batch_size: int, use_gpu: bool,
|
113 |
+
no_dropout=False):
|
114 |
+
super().__init__()
|
115 |
+
self.input_dim = input_dim
|
116 |
+
self.hidden_dim = hidden_dim
|
117 |
+
self.num_layers = num_layers
|
118 |
+
self.dropout = dropout
|
119 |
+
self.batch_size = batch_size
|
120 |
+
self.use_gpu = use_gpu
|
121 |
+
self.model = nn.LSTM(input_size=self.input_dim,
|
122 |
+
hidden_size=self.hidden_dim,
|
123 |
+
num_layers=self.num_layers,
|
124 |
+
batch_first=True,
|
125 |
+
bidirectional=False,
|
126 |
+
dropout=self.dropout if not no_dropout else 0)
|
127 |
+
if self.use_gpu:
|
128 |
+
self.model.cuda()
|
129 |
+
self.init_hidden()
|
130 |
+
|
131 |
+
def init_hidden(self):
|
132 |
+
self.hidden_dict['h'] = torch.zeros((self.num_layers,
|
133 |
+
self.batch_size,
|
134 |
+
self.hidden_dim),
|
135 |
+
requires_grad=False)
|
136 |
+
self.hidden_dict['c'] = torch.zeros((self.num_layers,
|
137 |
+
self.batch_size,
|
138 |
+
self.hidden_dim),
|
139 |
+
requires_grad=False)
|
140 |
+
if self.use_gpu:
|
141 |
+
for each_hidden in self.hidden_dict.keys():
|
142 |
+
self.hidden_dict[each_hidden] = self.hidden_dict[each_hidden].cuda()
|
143 |
+
|
144 |
+
def forward_one_step(self, tgt_emb_in):
|
145 |
+
''' one-step forward model
|
146 |
+
|
147 |
+
Parameters
|
148 |
+
----------
|
149 |
+
tgt_emb_in : Tensor, shape (batch_size, input_dim)
|
150 |
+
|
151 |
+
Returns
|
152 |
+
-------
|
153 |
+
Tensor, shape (batch_size, hidden_dim)
|
154 |
+
'''
|
155 |
+
tgt_hidden_out, self.hidden_dict['h'], self.hidden_dict['c'] \
|
156 |
+
= self.model(tgt_emb_in.view(self.batch_size, 1, -1),
|
157 |
+
self.hidden_dict['h'], self.hidden_dict['c'])
|
158 |
+
return tgt_hidden_out
|
graph_grammar/nn/encoder.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Aug 9 2018"
|
20 |
+
|
21 |
+
|
22 |
+
import abc
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
import torch.nn.functional as F
|
26 |
+
from torch import nn
|
27 |
+
from typing import List
|
28 |
+
|
29 |
+
|
30 |
+
class EncoderBase(nn.Module):
|
31 |
+
|
32 |
+
def __init__(self):
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
@abc.abstractmethod
|
36 |
+
def forward(self, in_seq):
|
37 |
+
''' forward model
|
38 |
+
|
39 |
+
Parameters
|
40 |
+
----------
|
41 |
+
in_seq_emb : Variable, shape (batch_size, max_len, input_dim)
|
42 |
+
|
43 |
+
Returns
|
44 |
+
-------
|
45 |
+
hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
|
46 |
+
'''
|
47 |
+
pass
|
48 |
+
|
49 |
+
@abc.abstractmethod
|
50 |
+
def init_hidden(self):
|
51 |
+
''' initialize the hidden states
|
52 |
+
'''
|
53 |
+
pass
|
54 |
+
|
55 |
+
|
56 |
+
class GRUEncoder(EncoderBase):
|
57 |
+
|
58 |
+
def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
|
59 |
+
bidirectional: bool, dropout: float, batch_size: int, use_gpu: bool,
|
60 |
+
no_dropout=False):
|
61 |
+
super().__init__()
|
62 |
+
self.input_dim = input_dim
|
63 |
+
self.hidden_dim = hidden_dim
|
64 |
+
self.num_layers = num_layers
|
65 |
+
self.bidirectional = bidirectional
|
66 |
+
self.dropout = dropout
|
67 |
+
self.batch_size = batch_size
|
68 |
+
self.use_gpu = use_gpu
|
69 |
+
self.model = nn.GRU(input_size=self.input_dim,
|
70 |
+
hidden_size=self.hidden_dim,
|
71 |
+
num_layers=self.num_layers,
|
72 |
+
batch_first=True,
|
73 |
+
bidirectional=self.bidirectional,
|
74 |
+
dropout=self.dropout if not no_dropout else 0)
|
75 |
+
if self.use_gpu:
|
76 |
+
self.model.cuda()
|
77 |
+
self.init_hidden()
|
78 |
+
|
79 |
+
|
80 |
+
def init_hidden(self):
|
81 |
+
self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
|
82 |
+
self.batch_size,
|
83 |
+
self.hidden_dim),
|
84 |
+
requires_grad=False)
|
85 |
+
if self.use_gpu:
|
86 |
+
self.h0 = self.h0.cuda()
|
87 |
+
|
88 |
+
def forward(self, in_seq_emb):
|
89 |
+
''' forward model
|
90 |
+
|
91 |
+
Parameters
|
92 |
+
----------
|
93 |
+
in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
|
94 |
+
|
95 |
+
Returns
|
96 |
+
-------
|
97 |
+
hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
|
98 |
+
'''
|
99 |
+
max_len = in_seq_emb.size(1)
|
100 |
+
hidden_seq_emb, self.h0 = self.model(
|
101 |
+
in_seq_emb, self.h0)
|
102 |
+
hidden_seq_emb = hidden_seq_emb.view(self.batch_size,
|
103 |
+
max_len,
|
104 |
+
1 + self.bidirectional,
|
105 |
+
self.hidden_dim)
|
106 |
+
return hidden_seq_emb
|
107 |
+
|
108 |
+
|
109 |
+
class LSTMEncoder(EncoderBase):
|
110 |
+
|
111 |
+
def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
|
112 |
+
bidirectional: bool, dropout: float, batch_size: int, use_gpu: bool,
|
113 |
+
no_dropout=False):
|
114 |
+
super().__init__()
|
115 |
+
self.input_dim = input_dim
|
116 |
+
self.hidden_dim = hidden_dim
|
117 |
+
self.num_layers = num_layers
|
118 |
+
self.bidirectional = bidirectional
|
119 |
+
self.dropout = dropout
|
120 |
+
self.batch_size = batch_size
|
121 |
+
self.use_gpu = use_gpu
|
122 |
+
self.model = nn.LSTM(input_size=self.input_dim,
|
123 |
+
hidden_size=self.hidden_dim,
|
124 |
+
num_layers=self.num_layers,
|
125 |
+
batch_first=True,
|
126 |
+
bidirectional=self.bidirectional,
|
127 |
+
dropout=self.dropout if not no_dropout else 0)
|
128 |
+
if self.use_gpu:
|
129 |
+
self.model.cuda()
|
130 |
+
self.init_hidden()
|
131 |
+
|
132 |
+
def init_hidden(self):
|
133 |
+
self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
|
134 |
+
self.batch_size,
|
135 |
+
self.hidden_dim),
|
136 |
+
requires_grad=False)
|
137 |
+
self.c0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
|
138 |
+
self.batch_size,
|
139 |
+
self.hidden_dim),
|
140 |
+
requires_grad=False)
|
141 |
+
if self.use_gpu:
|
142 |
+
self.h0 = self.h0.cuda()
|
143 |
+
self.c0 = self.c0.cuda()
|
144 |
+
|
145 |
+
def forward(self, in_seq_emb):
|
146 |
+
''' forward model
|
147 |
+
|
148 |
+
Parameters
|
149 |
+
----------
|
150 |
+
in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
|
151 |
+
|
152 |
+
Returns
|
153 |
+
-------
|
154 |
+
hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
|
155 |
+
'''
|
156 |
+
max_len = in_seq_emb.size(1)
|
157 |
+
hidden_seq_emb, (self.h0, self.c0) = self.model(
|
158 |
+
in_seq_emb, (self.h0, self.c0))
|
159 |
+
hidden_seq_emb = hidden_seq_emb.view(self.batch_size,
|
160 |
+
max_len,
|
161 |
+
1 + self.bidirectional,
|
162 |
+
self.hidden_dim)
|
163 |
+
return hidden_seq_emb
|
164 |
+
|
165 |
+
|
166 |
+
class FullConnectedEncoder(EncoderBase):
|
167 |
+
|
168 |
+
def __init__(self, input_dim: int, hidden_dim: int, max_len: int, hidden_dim_list: List[int],
|
169 |
+
batch_size: int, use_gpu: bool):
|
170 |
+
super().__init__()
|
171 |
+
self.input_dim = input_dim
|
172 |
+
self.hidden_dim = hidden_dim
|
173 |
+
self.max_len = max_len
|
174 |
+
self.hidden_dim_list = hidden_dim_list
|
175 |
+
self.use_gpu = use_gpu
|
176 |
+
in_out_dim_list = [input_dim * max_len] + list(hidden_dim_list) + [hidden_dim]
|
177 |
+
self.linear_list = nn.ModuleList(
|
178 |
+
[nn.Linear(in_out_dim_list[each_idx], in_out_dim_list[each_idx + 1])\
|
179 |
+
for each_idx in range(len(in_out_dim_list) - 1)])
|
180 |
+
|
181 |
+
def forward(self, in_seq_emb):
|
182 |
+
''' forward model
|
183 |
+
|
184 |
+
Parameters
|
185 |
+
----------
|
186 |
+
in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
|
187 |
+
|
188 |
+
Returns
|
189 |
+
-------
|
190 |
+
hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
|
191 |
+
'''
|
192 |
+
batch_size = in_seq_emb.size(0)
|
193 |
+
x = in_seq_emb.view(batch_size, -1)
|
194 |
+
for each_linear in self.linear_list:
|
195 |
+
x = F.relu(each_linear(x))
|
196 |
+
return x.view(batch_size, 1, -1)
|
197 |
+
|
198 |
+
def init_hidden(self):
|
199 |
+
pass
|
graph_grammar/nn/graph.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jan 1 2018"
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
from graph_grammar.graph_grammar.hrg import ProductionRuleCorpus
|
25 |
+
from torch import nn
|
26 |
+
from torch.autograd import Variable
|
27 |
+
|
28 |
+
class MolecularProdRuleEmbedding(nn.Module):
|
29 |
+
|
30 |
+
''' molecular fingerprint layer
|
31 |
+
'''
|
32 |
+
|
33 |
+
def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation,
|
34 |
+
out_dim=32, element_embed_dim=32,
|
35 |
+
num_layers=3, padding_idx=None, use_gpu=False):
|
36 |
+
super().__init__()
|
37 |
+
if padding_idx is not None:
|
38 |
+
assert padding_idx == -1, 'padding_idx must be -1.'
|
39 |
+
self.prod_rule_corpus = prod_rule_corpus
|
40 |
+
self.layer2layer_activation = layer2layer_activation
|
41 |
+
self.layer2out_activation = layer2out_activation
|
42 |
+
self.out_dim = out_dim
|
43 |
+
self.element_embed_dim = element_embed_dim
|
44 |
+
self.num_layers = num_layers
|
45 |
+
self.padding_idx = padding_idx
|
46 |
+
self.use_gpu = use_gpu
|
47 |
+
|
48 |
+
self.layer2layer_list = []
|
49 |
+
self.layer2out_list = []
|
50 |
+
|
51 |
+
if self.use_gpu:
|
52 |
+
self.atom_embed = torch.randn(self.prod_rule_corpus.num_edge_symbol,
|
53 |
+
self.element_embed_dim, requires_grad=True).cuda()
|
54 |
+
self.bond_embed = torch.randn(self.prod_rule_corpus.num_node_symbol,
|
55 |
+
self.element_embed_dim, requires_grad=True).cuda()
|
56 |
+
self.ext_id_embed = torch.randn(self.prod_rule_corpus.num_ext_id,
|
57 |
+
self.element_embed_dim, requires_grad=True).cuda()
|
58 |
+
for _ in range(num_layers):
|
59 |
+
self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim).cuda())
|
60 |
+
self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim).cuda())
|
61 |
+
else:
|
62 |
+
self.atom_embed = torch.randn(self.prod_rule_corpus.num_edge_symbol,
|
63 |
+
self.element_embed_dim, requires_grad=True)
|
64 |
+
self.bond_embed = torch.randn(self.prod_rule_corpus.num_node_symbol,
|
65 |
+
self.element_embed_dim, requires_grad=True)
|
66 |
+
self.ext_id_embed = torch.randn(self.prod_rule_corpus.num_ext_id,
|
67 |
+
self.element_embed_dim, requires_grad=True)
|
68 |
+
for _ in range(num_layers):
|
69 |
+
self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim))
|
70 |
+
self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim))
|
71 |
+
|
72 |
+
|
73 |
+
def forward(self, prod_rule_idx_seq):
|
74 |
+
''' forward model for mini-batch
|
75 |
+
|
76 |
+
Parameters
|
77 |
+
----------
|
78 |
+
prod_rule_idx_seq : (batch_size, length)
|
79 |
+
|
80 |
+
Returns
|
81 |
+
-------
|
82 |
+
Variable, shape (batch_size, length, out_dim)
|
83 |
+
'''
|
84 |
+
batch_size, length = prod_rule_idx_seq.shape
|
85 |
+
if self.use_gpu:
|
86 |
+
out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda()
|
87 |
+
else:
|
88 |
+
out = Variable(torch.zeros((batch_size, length, self.out_dim)))
|
89 |
+
for each_batch_idx in range(batch_size):
|
90 |
+
for each_idx in range(length):
|
91 |
+
if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list):
|
92 |
+
continue
|
93 |
+
else:
|
94 |
+
each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])]
|
95 |
+
layer_wise_embed_dict = {each_edge: self.atom_embed[
|
96 |
+
each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']]
|
97 |
+
for each_edge in each_prod_rule.rhs.edges}
|
98 |
+
layer_wise_embed_dict.update({each_node: self.bond_embed[
|
99 |
+
each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]
|
100 |
+
for each_node in each_prod_rule.rhs.nodes})
|
101 |
+
for each_node in each_prod_rule.rhs.nodes:
|
102 |
+
if 'ext_id' in each_prod_rule.rhs.node_attr(each_node):
|
103 |
+
layer_wise_embed_dict[each_node] \
|
104 |
+
= layer_wise_embed_dict[each_node] \
|
105 |
+
+ self.ext_id_embed[each_prod_rule.rhs.node_attr(each_node)['ext_id']]
|
106 |
+
|
107 |
+
for each_layer in range(self.num_layers):
|
108 |
+
next_layer_embed_dict = {}
|
109 |
+
for each_edge in each_prod_rule.rhs.edges:
|
110 |
+
v = layer_wise_embed_dict[each_edge]
|
111 |
+
for each_node in each_prod_rule.rhs.nodes_in_edge(each_edge):
|
112 |
+
v = v + layer_wise_embed_dict[each_node]
|
113 |
+
next_layer_embed_dict[each_edge] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
|
114 |
+
out[each_batch_idx, each_idx, :] \
|
115 |
+
= out[each_batch_idx, each_idx, :] + self.layer2out_activation(self.layer2out_list[each_layer](v))
|
116 |
+
for each_node in each_prod_rule.rhs.nodes:
|
117 |
+
v = layer_wise_embed_dict[each_node]
|
118 |
+
for each_edge in each_prod_rule.rhs.adj_edges(each_node):
|
119 |
+
v = v + layer_wise_embed_dict[each_edge]
|
120 |
+
next_layer_embed_dict[each_node] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
|
121 |
+
out[each_batch_idx, each_idx, :]\
|
122 |
+
= out[each_batch_idx, each_idx, :] + self.layer2out_activation(self.layer2out_list[each_layer](v))
|
123 |
+
layer_wise_embed_dict = next_layer_embed_dict
|
124 |
+
|
125 |
+
return out
|
126 |
+
|
127 |
+
|
128 |
+
class MolecularProdRuleEmbeddingLastLayer(nn.Module):
|
129 |
+
|
130 |
+
''' molecular fingerprint layer
|
131 |
+
'''
|
132 |
+
|
133 |
+
def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation,
|
134 |
+
out_dim=32, element_embed_dim=32,
|
135 |
+
num_layers=3, padding_idx=None, use_gpu=False):
|
136 |
+
super().__init__()
|
137 |
+
if padding_idx is not None:
|
138 |
+
assert padding_idx == -1, 'padding_idx must be -1.'
|
139 |
+
self.prod_rule_corpus = prod_rule_corpus
|
140 |
+
self.layer2layer_activation = layer2layer_activation
|
141 |
+
self.layer2out_activation = layer2out_activation
|
142 |
+
self.out_dim = out_dim
|
143 |
+
self.element_embed_dim = element_embed_dim
|
144 |
+
self.num_layers = num_layers
|
145 |
+
self.padding_idx = padding_idx
|
146 |
+
self.use_gpu = use_gpu
|
147 |
+
|
148 |
+
self.layer2layer_list = []
|
149 |
+
self.layer2out_list = []
|
150 |
+
|
151 |
+
if self.use_gpu:
|
152 |
+
self.atom_embed = nn.Embedding(self.prod_rule_corpus.num_edge_symbol, self.element_embed_dim).cuda()
|
153 |
+
self.bond_embed = nn.Embedding(self.prod_rule_corpus.num_node_symbol, self.element_embed_dim).cuda()
|
154 |
+
for _ in range(num_layers+1):
|
155 |
+
self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim).cuda())
|
156 |
+
self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim).cuda())
|
157 |
+
else:
|
158 |
+
self.atom_embed = nn.Embedding(self.prod_rule_corpus.num_edge_symbol, self.element_embed_dim)
|
159 |
+
self.bond_embed = nn.Embedding(self.prod_rule_corpus.num_node_symbol, self.element_embed_dim)
|
160 |
+
for _ in range(num_layers+1):
|
161 |
+
self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim))
|
162 |
+
self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim))
|
163 |
+
|
164 |
+
|
165 |
+
def forward(self, prod_rule_idx_seq):
|
166 |
+
''' forward model for mini-batch
|
167 |
+
|
168 |
+
Parameters
|
169 |
+
----------
|
170 |
+
prod_rule_idx_seq : (batch_size, length)
|
171 |
+
|
172 |
+
Returns
|
173 |
+
-------
|
174 |
+
Variable, shape (batch_size, length, out_dim)
|
175 |
+
'''
|
176 |
+
batch_size, length = prod_rule_idx_seq.shape
|
177 |
+
if self.use_gpu:
|
178 |
+
out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda()
|
179 |
+
else:
|
180 |
+
out = Variable(torch.zeros((batch_size, length, self.out_dim)))
|
181 |
+
for each_batch_idx in range(batch_size):
|
182 |
+
for each_idx in range(length):
|
183 |
+
if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list):
|
184 |
+
continue
|
185 |
+
else:
|
186 |
+
each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])]
|
187 |
+
|
188 |
+
if self.use_gpu:
|
189 |
+
layer_wise_embed_dict = {each_edge: self.atom_embed(
|
190 |
+
Variable(torch.LongTensor(
|
191 |
+
[each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']]
|
192 |
+
), requires_grad=False).cuda())
|
193 |
+
for each_edge in each_prod_rule.rhs.edges}
|
194 |
+
layer_wise_embed_dict.update({each_node: self.bond_embed(
|
195 |
+
Variable(
|
196 |
+
torch.LongTensor([
|
197 |
+
each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]),
|
198 |
+
requires_grad=False).cuda()
|
199 |
+
) for each_node in each_prod_rule.rhs.nodes})
|
200 |
+
else:
|
201 |
+
layer_wise_embed_dict = {each_edge: self.atom_embed(
|
202 |
+
Variable(torch.LongTensor(
|
203 |
+
[each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']]
|
204 |
+
), requires_grad=False))
|
205 |
+
for each_edge in each_prod_rule.rhs.edges}
|
206 |
+
layer_wise_embed_dict.update({each_node: self.bond_embed(
|
207 |
+
Variable(
|
208 |
+
torch.LongTensor([
|
209 |
+
each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]),
|
210 |
+
requires_grad=False)
|
211 |
+
) for each_node in each_prod_rule.rhs.nodes})
|
212 |
+
|
213 |
+
for each_layer in range(self.num_layers):
|
214 |
+
next_layer_embed_dict = {}
|
215 |
+
for each_edge in each_prod_rule.rhs.edges:
|
216 |
+
v = layer_wise_embed_dict[each_edge]
|
217 |
+
for each_node in each_prod_rule.rhs.nodes_in_edge(each_edge):
|
218 |
+
v += layer_wise_embed_dict[each_node]
|
219 |
+
next_layer_embed_dict[each_edge] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
|
220 |
+
for each_node in each_prod_rule.rhs.nodes:
|
221 |
+
v = layer_wise_embed_dict[each_node]
|
222 |
+
for each_edge in each_prod_rule.rhs.adj_edges(each_node):
|
223 |
+
v += layer_wise_embed_dict[each_edge]
|
224 |
+
next_layer_embed_dict[each_node] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
|
225 |
+
layer_wise_embed_dict = next_layer_embed_dict
|
226 |
+
for each_edge in each_prod_rule.rhs.edges:
|
227 |
+
out[each_batch_idx, each_idx, :] = self.layer2out_activation(self.layer2out_list[self.num_layers](v))
|
228 |
+
for each_edge in each_prod_rule.rhs.edges:
|
229 |
+
out[each_batch_idx, each_idx, :] = self.layer2out_activation(self.layer2out_list[self.num_layers](v))
|
230 |
+
|
231 |
+
return out
|
232 |
+
|
233 |
+
|
234 |
+
class MolecularProdRuleEmbeddingUsingFeatures(nn.Module):
|
235 |
+
|
236 |
+
''' molecular fingerprint layer
|
237 |
+
'''
|
238 |
+
|
239 |
+
def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation,
|
240 |
+
out_dim=32, num_layers=3, padding_idx=None, use_gpu=False):
|
241 |
+
super().__init__()
|
242 |
+
if padding_idx is not None:
|
243 |
+
assert padding_idx == -1, 'padding_idx must be -1.'
|
244 |
+
self.feature_dict, self.feature_dim = prod_rule_corpus.construct_feature_vectors()
|
245 |
+
self.prod_rule_corpus = prod_rule_corpus
|
246 |
+
self.layer2layer_activation = layer2layer_activation
|
247 |
+
self.layer2out_activation = layer2out_activation
|
248 |
+
self.out_dim = out_dim
|
249 |
+
self.num_layers = num_layers
|
250 |
+
self.padding_idx = padding_idx
|
251 |
+
self.use_gpu = use_gpu
|
252 |
+
|
253 |
+
self.layer2layer_list = []
|
254 |
+
self.layer2out_list = []
|
255 |
+
|
256 |
+
if self.use_gpu:
|
257 |
+
for each_key in self.feature_dict:
|
258 |
+
self.feature_dict[each_key] = self.feature_dict[each_key].to_dense().cuda()
|
259 |
+
for _ in range(num_layers):
|
260 |
+
self.layer2layer_list.append(nn.Linear(self.feature_dim, self.feature_dim).cuda())
|
261 |
+
self.layer2out_list.append(nn.Linear(self.feature_dim, self.out_dim).cuda())
|
262 |
+
else:
|
263 |
+
for _ in range(num_layers):
|
264 |
+
self.layer2layer_list.append(nn.Linear(self.feature_dim, self.feature_dim))
|
265 |
+
self.layer2out_list.append(nn.Linear(self.feature_dim, self.out_dim))
|
266 |
+
|
267 |
+
|
268 |
+
def forward(self, prod_rule_idx_seq):
|
269 |
+
''' forward model for mini-batch
|
270 |
+
|
271 |
+
Parameters
|
272 |
+
----------
|
273 |
+
prod_rule_idx_seq : (batch_size, length)
|
274 |
+
|
275 |
+
Returns
|
276 |
+
-------
|
277 |
+
Variable, shape (batch_size, length, out_dim)
|
278 |
+
'''
|
279 |
+
batch_size, length = prod_rule_idx_seq.shape
|
280 |
+
if self.use_gpu:
|
281 |
+
out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda()
|
282 |
+
else:
|
283 |
+
out = Variable(torch.zeros((batch_size, length, self.out_dim)))
|
284 |
+
for each_batch_idx in range(batch_size):
|
285 |
+
for each_idx in range(length):
|
286 |
+
if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list):
|
287 |
+
continue
|
288 |
+
else:
|
289 |
+
each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])]
|
290 |
+
edge_list = sorted(list(each_prod_rule.rhs.edges))
|
291 |
+
node_list = sorted(list(each_prod_rule.rhs.nodes))
|
292 |
+
adj_mat = torch.FloatTensor(each_prod_rule.rhs_adj_mat(edge_list + node_list).todense() + np.identity(len(edge_list)+len(node_list)))
|
293 |
+
if self.use_gpu:
|
294 |
+
adj_mat = adj_mat.cuda()
|
295 |
+
layer_wise_embed = [
|
296 |
+
self.feature_dict[each_prod_rule.rhs.edge_attr(each_edge)['symbol']]
|
297 |
+
for each_edge in edge_list]\
|
298 |
+
+ [self.feature_dict[each_prod_rule.rhs.node_attr(each_node)['symbol']]
|
299 |
+
for each_node in node_list]
|
300 |
+
for each_node in each_prod_rule.ext_node.values():
|
301 |
+
layer_wise_embed[each_prod_rule.rhs.num_edges + node_list.index(each_node)] \
|
302 |
+
= layer_wise_embed[each_prod_rule.rhs.num_edges + node_list.index(each_node)] \
|
303 |
+
+ self.feature_dict[('ext_id', each_prod_rule.rhs.node_attr(each_node)['ext_id'])]
|
304 |
+
layer_wise_embed = torch.stack(layer_wise_embed)
|
305 |
+
|
306 |
+
for each_layer in range(self.num_layers):
|
307 |
+
message = adj_mat @ layer_wise_embed
|
308 |
+
next_layer_embed = self.layer2layer_activation(self.layer2layer_list[each_layer](message))
|
309 |
+
out[each_batch_idx, each_idx, :] \
|
310 |
+
= out[each_batch_idx, each_idx, :] \
|
311 |
+
+ self.layer2out_activation(self.layer2out_list[each_layer](message)).sum(dim=0)
|
312 |
+
layer_wise_embed = next_layer_embed
|
313 |
+
return out
|
images/mhg_example.png
ADDED
images/mhg_example1.png
ADDED
images/mhg_example2.png
ADDED
load.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
# Rhizome
|
3 |
+
# Version beta 0.0, August 2023
|
4 |
+
# Property of IBM Research, Accelerated Discovery
|
5 |
+
#
|
6 |
+
|
7 |
+
import os
|
8 |
+
import pickle
|
9 |
+
import sys
|
10 |
+
|
11 |
+
from rdkit import Chem
|
12 |
+
import torch
|
13 |
+
from torch_geometric.utils.smiles import from_smiles
|
14 |
+
|
15 |
+
from typing import Any, Dict, List, Optional, Union
|
16 |
+
from typing_extensions import Self
|
17 |
+
|
18 |
+
from .graph_grammar.io.smi import hg_to_mol
|
19 |
+
from .models.mhgvae import GrammarGINVAE
|
20 |
+
|
21 |
+
|
22 |
+
class PretrainedModelWrapper:
|
23 |
+
model: GrammarGINVAE
|
24 |
+
|
25 |
+
def __init__(self, model_dict: Dict[str, Any]) -> None:
|
26 |
+
json_params = model_dict['gnn_params']
|
27 |
+
encoder_params = json_params['encoder_params']
|
28 |
+
encoder_params['node_feature_size'] = model_dict['num_features']
|
29 |
+
encoder_params['edge_feature_size'] = model_dict['num_edge_features']
|
30 |
+
self.model = GrammarGINVAE(model_dict['hrg'], rank=-1, encoder_params=encoder_params,
|
31 |
+
decoder_params=json_params['decoder_params'],
|
32 |
+
prod_rule_embed_params=json_params["prod_rule_embed_params"],
|
33 |
+
batch_size=512, max_len=model_dict['max_length'])
|
34 |
+
self.model.load_state_dict(model_dict['model_state_dict'])
|
35 |
+
|
36 |
+
self.model.eval()
|
37 |
+
|
38 |
+
def to(self, device: Union[str, int, torch.device]) -> Self:
|
39 |
+
dev_type = type(device)
|
40 |
+
if dev_type != torch.device:
|
41 |
+
if dev_type == str or torch.cuda.is_available():
|
42 |
+
device = torch.device(device)
|
43 |
+
else:
|
44 |
+
device = torch.device("mps", device)
|
45 |
+
|
46 |
+
self.model = self.model.to(device)
|
47 |
+
return self
|
48 |
+
|
49 |
+
def encode(self, data: List[str]) -> List[torch.tensor]:
|
50 |
+
# Need to encode them into a graph nn
|
51 |
+
output = []
|
52 |
+
for d in data:
|
53 |
+
params = next(self.model.parameters())
|
54 |
+
g = from_smiles(d)
|
55 |
+
if (g.cpu() and params != 'cpu') or (not g.cpu() and params == 'cpu'):
|
56 |
+
g.to(params.device)
|
57 |
+
ltvec = self.model.graph_embed(g.x, g.edge_index, g.edge_attr, g.batch)
|
58 |
+
output.append(ltvec[0])
|
59 |
+
return output
|
60 |
+
|
61 |
+
def decode(self, data: List[torch.tensor]) -> List[str]:
|
62 |
+
output = []
|
63 |
+
for d in data:
|
64 |
+
mu, logvar = self.model.get_mean_var(d.unsqueeze(0))
|
65 |
+
z = self.model.reparameterize(mu, logvar)
|
66 |
+
flags, _, hgs = self.model.decode(z)
|
67 |
+
if flags[0]:
|
68 |
+
reconstructed_mol, _ = hg_to_mol(hgs[0], True)
|
69 |
+
output.append(Chem.MolToSmiles(reconstructed_mol))
|
70 |
+
else:
|
71 |
+
output.append(None)
|
72 |
+
return output
|
73 |
+
|
74 |
+
|
75 |
+
def load(model_name: str = "models/mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle") -> Optional[
|
76 |
+
PretrainedModelWrapper]:
|
77 |
+
for p in sys.path:
|
78 |
+
file = p + "/" + model_name
|
79 |
+
if os.path.isfile(file):
|
80 |
+
with open(file, "rb") as f:
|
81 |
+
model_dict = pickle.load(f)
|
82 |
+
return PretrainedModelWrapper(model_dict)
|
83 |
+
return None
|
mhg_gnn.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: mhg-gnn
|
3 |
+
Version: 0.0
|
4 |
+
Summary: Package for mhg-gnn
|
5 |
+
Author: team
|
6 |
+
License: TBD
|
7 |
+
Classifier: Programming Language :: Python :: 3
|
8 |
+
Classifier: Programming Language :: Python :: 3.9
|
9 |
+
Description-Content-Type: text/markdown
|
10 |
+
Requires-Dist: networkx>=2.8
|
11 |
+
Requires-Dist: numpy<2.0.0,>=1.23.5
|
12 |
+
Requires-Dist: pandas>=1.5.3
|
13 |
+
Requires-Dist: rdkit-pypi<2023.9.6,>=2022.9.4
|
14 |
+
Requires-Dist: torch>=2.0.0
|
15 |
+
Requires-Dist: torchinfo>=1.8.0
|
16 |
+
Requires-Dist: torch-geometric>=2.3.1
|
17 |
+
|
18 |
+
# mhg-gnn
|
19 |
+
|
20 |
+
This repository provides PyTorch source code assosiated with our publication, "MHG-GNN: Combination of Molecular Hypergraph Grammar with Graph Neural Network"
|
21 |
+
|
22 |
+
**Paper:** [Arxiv Link](https://arxiv.org/pdf/2309.16374)
|
23 |
+
|
24 |
+
For more information contact: [email protected]
|
25 |
+
|
26 |
+
![mhg-gnn](images/mhg_example1.png)
|
27 |
+
|
28 |
+
## Introduction
|
29 |
+
|
30 |
+
We present MHG-GNN, an autoencoder architecture
|
31 |
+
that has an encoder based on GNN and a decoder based on a sequential model with MHG.
|
32 |
+
Since the encoder is a GNN variant, MHG-GNN can accept any molecule as input, and
|
33 |
+
demonstrate high predictive performance on molecular graph data.
|
34 |
+
In addition, the decoder inherits the theoretical guarantee of MHG on always generating a structurally valid molecule as output.
|
35 |
+
|
36 |
+
## Table of Contents
|
37 |
+
|
38 |
+
1. [Getting Started](#getting-started)
|
39 |
+
1. [Pretrained Models and Training Logs](#pretrained-models-and-training-logs)
|
40 |
+
2. [Replicating Conda Environment](#replicating-conda-environment)
|
41 |
+
2. [Feature Extraction](#feature-extraction)
|
42 |
+
|
43 |
+
## Getting Started
|
44 |
+
|
45 |
+
**This code and environment have been tested on Intel E5-2667 CPUs at 3.30GHz and NVIDIA A100 Tensor Core GPUs.**
|
46 |
+
|
47 |
+
### Pretrained Models and Training Logs
|
48 |
+
|
49 |
+
We provide checkpoints of the MHG-GNN model pre-trained on a dataset of ~1.34M molecules curated from PubChem. (later) For model weights: [HuggingFace Link]()
|
50 |
+
|
51 |
+
Add the MHG-GNN `pre-trained weights.pt` to the `models/` directory according to your needs.
|
52 |
+
|
53 |
+
### Replacicating Conda Environment
|
54 |
+
|
55 |
+
Follow these steps to replicate our Conda environment and install the necessary libraries:
|
56 |
+
|
57 |
+
```
|
58 |
+
conda create --name mhg-gnn-env python=3.8.18
|
59 |
+
conda activate mhg-gnn-env
|
60 |
+
```
|
61 |
+
|
62 |
+
#### Install Packages with Conda
|
63 |
+
|
64 |
+
```
|
65 |
+
conda install -c conda-forge networkx=2.8
|
66 |
+
conda install numpy=1.23.5
|
67 |
+
# conda install -c conda-forge rdkit=2022.9.4
|
68 |
+
conda install pytorch=2.0.0 torchvision torchaudio -c pytorch
|
69 |
+
conda install -c conda-forge torchinfo=1.8.0
|
70 |
+
conda install pyg -c pyg
|
71 |
+
```
|
72 |
+
|
73 |
+
#### Install Packages with pip
|
74 |
+
```
|
75 |
+
pip install rdkit torch-nl==0.3 torch-scatter torch-sparse
|
76 |
+
```
|
77 |
+
|
78 |
+
## Feature Extraction
|
79 |
+
|
80 |
+
The example notebook [mhg-gnn_encoder_decoder_example.ipynb](notebooks/mhg-gnn_encoder_decoder_example.ipynb) contains code to load checkpoint files and use the pre-trained model for encoder and decoder tasks.
|
81 |
+
|
82 |
+
To load mhg-gnn, you can simply use:
|
83 |
+
|
84 |
+
```python
|
85 |
+
import torch
|
86 |
+
import load
|
87 |
+
|
88 |
+
model = load.load()
|
89 |
+
```
|
90 |
+
|
91 |
+
To encode SMILES into embeddings, you can use:
|
92 |
+
|
93 |
+
```python
|
94 |
+
with torch.no_grad():
|
95 |
+
repr = model.encode(["CCO", "O=C=O", "OC(=O)c1ccccc1C(=O)O"])
|
96 |
+
```
|
97 |
+
|
98 |
+
For decoder, you can use the function, so you can return from embeddings to SMILES strings:
|
99 |
+
|
100 |
+
```python
|
101 |
+
orig = model.decode(repr)
|
102 |
+
```
|
mhg_gnn.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
README.md
|
2 |
+
setup.cfg
|
3 |
+
setup.py
|
4 |
+
./graph_grammar/__init__.py
|
5 |
+
./graph_grammar/hypergraph.py
|
6 |
+
./graph_grammar/algo/__init__.py
|
7 |
+
./graph_grammar/algo/tree_decomposition.py
|
8 |
+
./graph_grammar/graph_grammar/__init__.py
|
9 |
+
./graph_grammar/graph_grammar/base.py
|
10 |
+
./graph_grammar/graph_grammar/corpus.py
|
11 |
+
./graph_grammar/graph_grammar/hrg.py
|
12 |
+
./graph_grammar/graph_grammar/symbols.py
|
13 |
+
./graph_grammar/graph_grammar/utils.py
|
14 |
+
./graph_grammar/io/__init__.py
|
15 |
+
./graph_grammar/io/smi.py
|
16 |
+
./graph_grammar/nn/__init__.py
|
17 |
+
./graph_grammar/nn/dataset.py
|
18 |
+
./graph_grammar/nn/decoder.py
|
19 |
+
./graph_grammar/nn/encoder.py
|
20 |
+
./graph_grammar/nn/graph.py
|
21 |
+
./models/__init__.py
|
22 |
+
./models/mhgvae.py
|
23 |
+
graph_grammar/__init__.py
|
24 |
+
graph_grammar/hypergraph.py
|
25 |
+
graph_grammar/algo/__init__.py
|
26 |
+
graph_grammar/algo/tree_decomposition.py
|
27 |
+
graph_grammar/graph_grammar/__init__.py
|
28 |
+
graph_grammar/graph_grammar/base.py
|
29 |
+
graph_grammar/graph_grammar/corpus.py
|
30 |
+
graph_grammar/graph_grammar/hrg.py
|
31 |
+
graph_grammar/graph_grammar/symbols.py
|
32 |
+
graph_grammar/graph_grammar/utils.py
|
33 |
+
graph_grammar/io/__init__.py
|
34 |
+
graph_grammar/io/smi.py
|
35 |
+
graph_grammar/nn/__init__.py
|
36 |
+
graph_grammar/nn/dataset.py
|
37 |
+
graph_grammar/nn/decoder.py
|
38 |
+
graph_grammar/nn/encoder.py
|
39 |
+
graph_grammar/nn/graph.py
|
40 |
+
mhg_gnn.egg-info/PKG-INFO
|
41 |
+
mhg_gnn.egg-info/SOURCES.txt
|
42 |
+
mhg_gnn.egg-info/dependency_links.txt
|
43 |
+
mhg_gnn.egg-info/requires.txt
|
44 |
+
mhg_gnn.egg-info/top_level.txt
|
45 |
+
models/__init__.py
|
46 |
+
models/mhgvae.py
|