Spaces:
Sleeping
Sleeping
Upload 36 files
Browse files- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/data/__init__.py +0 -0
- src/data/__pycache__/__init__.cpython-310.pyc +0 -0
- src/data/__pycache__/pinder_datamodule.cpython-310.pyc +0 -0
- src/data/components/__init__.py +0 -0
- src/data/components/__pycache__/__init__.cpython-310.pyc +0 -0
- src/data/components/__pycache__/pinder_dataset.cpython-310.pyc +0 -0
- src/data/components/__pycache__/prepare_data.cpython-310.pyc +0 -0
- src/data/components/pinder_dataset.py +64 -0
- src/data/components/prepare_data.py +175 -0
- src/data/pinder_datamodule.py +167 -0
- src/eval.py +99 -0
- src/models/__init__.py +0 -0
- src/models/__pycache__/__init__.cpython-310.pyc +0 -0
- src/models/__pycache__/pinder_module.cpython-310.pyc +0 -0
- src/models/components/__init__.py +0 -0
- src/models/components/__pycache__/__init__.cpython-310.pyc +0 -0
- src/models/components/__pycache__/equivariant_mpnn.cpython-310.pyc +0 -0
- src/models/components/__pycache__/utils.cpython-310.pyc +0 -0
- src/models/components/equivariant_mpnn.py +231 -0
- src/models/components/utils.py +100 -0
- src/models/pinder_module.py +297 -0
- src/train.py +133 -0
- src/utils/__init__.py +5 -0
- src/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- src/utils/__pycache__/instantiators.cpython-310.pyc +0 -0
- src/utils/__pycache__/logging_utils.cpython-310.pyc +0 -0
- src/utils/__pycache__/pylogger.cpython-310.pyc +0 -0
- src/utils/__pycache__/rich_utils.cpython-310.pyc +0 -0
- src/utils/__pycache__/utils.cpython-310.pyc +0 -0
- src/utils/instantiators.py +56 -0
- src/utils/logging_utils.py +57 -0
- src/utils/pylogger.py +51 -0
- src/utils/rich_utils.py +103 -0
- src/utils/utils.py +119 -0
src/__init__.py
ADDED
File without changes
|
src/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (138 Bytes). View file
|
|
src/data/__init__.py
ADDED
File without changes
|
src/data/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (143 Bytes). View file
|
|
src/data/__pycache__/pinder_datamodule.cpython-310.pyc
ADDED
Binary file (6.15 kB). View file
|
|
src/data/components/__init__.py
ADDED
File without changes
|
src/data/components/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (154 Bytes). View file
|
|
src/data/components/__pycache__/pinder_dataset.cpython-310.pyc
ADDED
Binary file (2.09 kB). View file
|
|
src/data/components/__pycache__/prepare_data.cpython-310.pyc
ADDED
Binary file (5.29 kB). View file
|
|
src/data/components/pinder_dataset.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import __main__
|
4 |
+
import rootutils
|
5 |
+
import torch
|
6 |
+
from torch_geometric.data import Dataset
|
7 |
+
|
8 |
+
# setup root dir and pythonpath
|
9 |
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
10 |
+
from src.data.components.prepare_data import CropPairedPDB
|
11 |
+
|
12 |
+
setattr(__main__, "CropPairedPDB", CropPairedPDB)
|
13 |
+
|
14 |
+
|
15 |
+
class PinderDataset(Dataset):
|
16 |
+
"""Pinder dataset.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
Dataset: PyTorch Geometric Dataset.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, file_paths: List[str]) -> None:
|
23 |
+
"""Initialize the PinderDataset.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
file_paths: List of file paths.
|
27 |
+
"""
|
28 |
+
super().__init__()
|
29 |
+
self.file_paths = file_paths
|
30 |
+
|
31 |
+
@property
|
32 |
+
def processed_file_names(self) -> List[str]:
|
33 |
+
"""Return the processed file names.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
List[str]: List of processed
|
37 |
+
"""
|
38 |
+
return self.file_paths
|
39 |
+
|
40 |
+
def len(self) -> int:
|
41 |
+
"""Return the length of the dataset.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
int: Length of the dataset
|
45 |
+
"""
|
46 |
+
return len(self.processed_file_names)
|
47 |
+
|
48 |
+
def get(self, idx) -> CropPairedPDB:
|
49 |
+
"""Get the data at the given index.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
idx: Index of the data.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
CropPairedPDB: CropPairedPDB object.
|
56 |
+
"""
|
57 |
+
data = torch.load(self.processed_file_names[idx], weights_only=False)
|
58 |
+
return data
|
59 |
+
|
60 |
+
|
61 |
+
if __name__ == "__main__":
|
62 |
+
file_paths = ["./data/processed/apo/test/1a19__A1_P11540--1a19__B1_P11540.pt"]
|
63 |
+
dataset = PinderDataset(file_paths=file_paths)
|
64 |
+
print(dataset[0])
|
src/data/components/prepare_data.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import multiprocessing
|
2 |
+
import os
|
3 |
+
from argparse import ArgumentParser
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import rootutils
|
8 |
+
import torch
|
9 |
+
from loguru import logger
|
10 |
+
from pinder.core import PinderSystem, get_index
|
11 |
+
from pinder.core.loader.geodata import PairedPDB, structure2tensor
|
12 |
+
from pinder.core.loader.structure import Structure
|
13 |
+
from tqdm.auto import tqdm
|
14 |
+
|
15 |
+
# setup root dir and pythonpath
|
16 |
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
17 |
+
|
18 |
+
try:
|
19 |
+
from torch_cluster import knn_graph
|
20 |
+
|
21 |
+
torch_cluster_installed = True
|
22 |
+
except ImportError:
|
23 |
+
logger.warning(
|
24 |
+
"torch-cluster is not installed!"
|
25 |
+
"Please install the appropriate library for your pytorch installation."
|
26 |
+
"See https://github.com/rusty1s/pytorch_cluster/issues/185 for background."
|
27 |
+
)
|
28 |
+
torch_cluster_installed = False
|
29 |
+
|
30 |
+
|
31 |
+
def create_lr_files(system_id: str, apo_complex_path: str, save_path: str):
|
32 |
+
apo_r_path = os.path.join(save_path, f"apo_r_{system_id}.pdb")
|
33 |
+
apo_l_path = os.path.join(save_path, f"apo_l_{system_id}.pdb")
|
34 |
+
native_path = apo_complex_path.with_name(apo_complex_path.stem + f"{system_id}.pdb")
|
35 |
+
with open(native_path) as infile, open(apo_r_path, "w") as output_r, open(
|
36 |
+
apo_l_path, "w"
|
37 |
+
) as output_l:
|
38 |
+
|
39 |
+
for line in infile:
|
40 |
+
# Check if the line is an ATOM or HETATM line and has a chain ID at position 21
|
41 |
+
if line.startswith("ATOM") or line.startswith("HETATM"):
|
42 |
+
chain_id = line[21]
|
43 |
+
if chain_id == "R":
|
44 |
+
output_r.write(line)
|
45 |
+
elif chain_id == "L":
|
46 |
+
output_l.write(line)
|
47 |
+
else:
|
48 |
+
# Write other lines (e.g., HEADER, REMARK) to both files
|
49 |
+
output_r.write(line)
|
50 |
+
output_l.write(line)
|
51 |
+
return apo_r_path, apo_l_path
|
52 |
+
|
53 |
+
|
54 |
+
class CropPairedPDB(PairedPDB):
|
55 |
+
@classmethod
|
56 |
+
def from_crop_system(
|
57 |
+
cls,
|
58 |
+
system_id: str,
|
59 |
+
root: str = "./data/",
|
60 |
+
k: int = 10,
|
61 |
+
add_edges: bool = True,
|
62 |
+
predicted_structures: bool = True,
|
63 |
+
split: str = "train",
|
64 |
+
) -> None:
|
65 |
+
system = PinderSystem(system_id)
|
66 |
+
# Create directories if they do not exist
|
67 |
+
for subdir in ["apo", "holo", "predicted"]:
|
68 |
+
os.makedirs(Path(root) / "raw" / subdir / split, exist_ok=True)
|
69 |
+
|
70 |
+
try:
|
71 |
+
holo_complex, apo_complex, pred_complex = system.create_masked_bound_unbound_complexes(
|
72 |
+
renumber_residues=True
|
73 |
+
)
|
74 |
+
for complex_type, complex_obj in zip(
|
75 |
+
["apo", "holo", "predicted"], [apo_complex, holo_complex, pred_complex]
|
76 |
+
):
|
77 |
+
complex_obj.to_pdb(
|
78 |
+
Path(root) / "raw" / complex_type / split / f"{system_id}_complex.pdb"
|
79 |
+
)
|
80 |
+
except Exception as e:
|
81 |
+
logger.error(f"Error in writing PDB files: {e}, {system_id}")
|
82 |
+
return None
|
83 |
+
|
84 |
+
if predicted_structures:
|
85 |
+
apo_complex = pred_complex
|
86 |
+
save_path = os.path.join(root, "processed", "predicted", split)
|
87 |
+
else:
|
88 |
+
save_path = os.path.join(root, "processed", "apo", split)
|
89 |
+
|
90 |
+
# create the directory if it does not exist
|
91 |
+
os.makedirs(save_path, exist_ok=True)
|
92 |
+
|
93 |
+
graph = cls.from_structure_pair(
|
94 |
+
holo_complex=holo_complex,
|
95 |
+
apo_complex=apo_complex,
|
96 |
+
add_edges=add_edges,
|
97 |
+
k=k,
|
98 |
+
)
|
99 |
+
torch.save(graph, os.path.join(save_path, f"{system_id}.pt"))
|
100 |
+
|
101 |
+
@classmethod
|
102 |
+
def from_structure_pair(
|
103 |
+
cls,
|
104 |
+
holo_complex: Structure,
|
105 |
+
apo_complex: Structure,
|
106 |
+
add_edges: bool = True,
|
107 |
+
k: int = 10,
|
108 |
+
) -> PairedPDB:
|
109 |
+
def get_structure_props(structure: Structure, start: int, end: Optional[int]):
|
110 |
+
calpha = structure.filter("atom_name", mask=["CA"])
|
111 |
+
return structure2tensor(
|
112 |
+
atom_coordinates=structure.coords[start:end],
|
113 |
+
atom_types=structure.atom_array.atom_name[start:end],
|
114 |
+
element_types=structure.atom_array.element[start:end],
|
115 |
+
residue_coordinates=calpha.coords[start:end],
|
116 |
+
residue_types=calpha.atom_array.res_name[start:end],
|
117 |
+
residue_ids=calpha.atom_array.res_id[start:end],
|
118 |
+
)
|
119 |
+
|
120 |
+
graph = cls()
|
121 |
+
r_h = (holo_complex.dataframe["chain_id"] == "R").sum()
|
122 |
+
r_a = (apo_complex.dataframe["chain_id"] == "R").sum()
|
123 |
+
|
124 |
+
holo_r_props = get_structure_props(holo_complex, 0, r_h)
|
125 |
+
holo_l_props = get_structure_props(holo_complex, r_h, None)
|
126 |
+
apo_r_props = get_structure_props(apo_complex, 0, r_a)
|
127 |
+
apo_l_props = get_structure_props(apo_complex, r_a, None)
|
128 |
+
|
129 |
+
graph["ligand"].x = apo_l_props["atom_types"]
|
130 |
+
graph["ligand"].pos = apo_l_props["atom_coordinates"]
|
131 |
+
graph["receptor"].x = apo_r_props["atom_types"]
|
132 |
+
graph["receptor"].pos = apo_r_props["atom_coordinates"]
|
133 |
+
graph["ligand"].y = holo_l_props["atom_coordinates"]
|
134 |
+
graph["receptor"].y = holo_r_props["atom_coordinates"]
|
135 |
+
|
136 |
+
if add_edges and torch_cluster_installed:
|
137 |
+
graph["ligand", "ligand"].edge_index = knn_graph(graph["ligand"].pos, k=k)
|
138 |
+
graph["receptor", "receptor"].edge_index = knn_graph(graph["receptor"].pos, k=k)
|
139 |
+
|
140 |
+
return graph
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
parser = ArgumentParser()
|
145 |
+
parser.add_argument("--n_jobs", type=int, default=20)
|
146 |
+
parser.add_argument("--k", type=int, default=10)
|
147 |
+
parser.add_argument("--predicted_structures", action="store_true")
|
148 |
+
parser.add_argument("--split", type=str, default="train")
|
149 |
+
args = parser.parse_args()
|
150 |
+
|
151 |
+
predicted_structures = args.predicted_structures
|
152 |
+
|
153 |
+
# get indices for train, validation, and test splits
|
154 |
+
indices = get_index()
|
155 |
+
|
156 |
+
if predicted_structures:
|
157 |
+
query = '(split == "{split}") and ((apo_R == False and apo_L == False) and (predicted_R==True and predicted_L==True))'
|
158 |
+
else:
|
159 |
+
query = '(split == "{split}") and (apo_R == True and apo_L == True)'
|
160 |
+
|
161 |
+
system_idx = indices.query(query.format(split=args.split)).reset_index(drop=True)
|
162 |
+
|
163 |
+
system_ids = system_idx.id.tolist()
|
164 |
+
|
165 |
+
def process_system_id(system_id: str):
|
166 |
+
graph = CropPairedPDB.from_crop_system(
|
167 |
+
system_id,
|
168 |
+
predicted_structures=predicted_structures,
|
169 |
+
k=args.k,
|
170 |
+
split=args.split,
|
171 |
+
)
|
172 |
+
return graph
|
173 |
+
|
174 |
+
with multiprocessing.Pool(args.n_jobs) as pool:
|
175 |
+
results = list(tqdm(pool.imap(process_system_id, system_ids), total=len(system_ids)))
|
src/data/pinder_datamodule.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Any, Dict, Optional
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
import rootutils
|
6 |
+
from lightning import LightningDataModule
|
7 |
+
from torch_geometric.data import Dataset
|
8 |
+
from torch_geometric.loader import DataLoader
|
9 |
+
|
10 |
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
11 |
+
|
12 |
+
from src.data.components.pinder_dataset import PinderDataset
|
13 |
+
|
14 |
+
|
15 |
+
class PINDERDataModule(LightningDataModule):
|
16 |
+
"""`LightningDataModule` for the PINDER dataset."""
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
data_dir: str = "data/processed",
|
21 |
+
predicted_structures: bool = False,
|
22 |
+
high_quality: bool = False,
|
23 |
+
batch_size: int = 1,
|
24 |
+
num_workers: int = 0,
|
25 |
+
pin_memory: bool = True,
|
26 |
+
) -> None:
|
27 |
+
"""Initialize the `PINDERDataModule`.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
data_dir: Data for pinder. Defaults to "data/processed".
|
31 |
+
predicted_structures: Whether to use predicted structures. Defaults to True.
|
32 |
+
batch_size: Batch size. Defaults to 64.
|
33 |
+
num_workers: Number of workers for parallel processing. Defaults to 0.
|
34 |
+
pin_memory: Whether to pin memory. Defaults to True.
|
35 |
+
"""
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
# this line allows to access init params with 'self.hparams' attribute
|
39 |
+
# also ensures init params will be stored in ckpt
|
40 |
+
self.save_hyperparameters(logger=False)
|
41 |
+
|
42 |
+
# get metadata
|
43 |
+
metadata = pd.read_csv(os.path.join(self.hparams.data_dir, "metadata.csv"))
|
44 |
+
|
45 |
+
def get_files(split: str, complex_types: list) -> list:
|
46 |
+
file_df = metadata[
|
47 |
+
(metadata["split"] == split) & (metadata["complex"].isin(complex_types))
|
48 |
+
]
|
49 |
+
file_df["file_paths"] = file_df.apply(
|
50 |
+
lambda row: os.path.join(
|
51 |
+
"./data/processed", row["complex"], row["split"], row["file_paths"]
|
52 |
+
),
|
53 |
+
axis=1,
|
54 |
+
)
|
55 |
+
return file_df["file_paths"].tolist()
|
56 |
+
|
57 |
+
complex_types = ["apo", "predicted"] if self.hparams.predicted_structures else ["apo"]
|
58 |
+
self.train_files = get_files("train", complex_types)
|
59 |
+
self.val_files = get_files("val", complex_types)
|
60 |
+
self.test_files = get_files("test", complex_types)
|
61 |
+
|
62 |
+
self.data_train: Optional[Dataset] = None
|
63 |
+
self.data_val: Optional[Dataset] = None
|
64 |
+
self.data_test: Optional[Dataset] = None
|
65 |
+
|
66 |
+
self.batch_size_per_device = batch_size
|
67 |
+
|
68 |
+
def setup(self, stage: Optional[str] = None) -> None:
|
69 |
+
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
|
70 |
+
|
71 |
+
This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
|
72 |
+
`trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
|
73 |
+
`self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
|
74 |
+
`self.setup()` once the data is prepared and available for use.
|
75 |
+
|
76 |
+
:param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
|
77 |
+
"""
|
78 |
+
# Divide batch size by the number of devices.
|
79 |
+
if self.trainer is not None:
|
80 |
+
if self.hparams.batch_size % self.trainer.world_size != 0:
|
81 |
+
raise RuntimeError(
|
82 |
+
f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
|
83 |
+
)
|
84 |
+
self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size
|
85 |
+
|
86 |
+
# load and split datasets only if not loaded already
|
87 |
+
if not self.data_train and not self.data_val and not self.data_test:
|
88 |
+
self.data_train = PinderDataset(self.train_files)
|
89 |
+
self.data_val = PinderDataset(self.val_files)
|
90 |
+
self.data_test = PinderDataset(self.test_files)
|
91 |
+
|
92 |
+
def train_dataloader(self) -> DataLoader:
|
93 |
+
"""Create and return the train dataloader.
|
94 |
+
|
95 |
+
:return: The train dataloader.
|
96 |
+
"""
|
97 |
+
return DataLoader(
|
98 |
+
dataset=self.data_train,
|
99 |
+
batch_size=self.batch_size_per_device,
|
100 |
+
num_workers=self.hparams.num_workers,
|
101 |
+
pin_memory=self.hparams.pin_memory,
|
102 |
+
shuffle=True,
|
103 |
+
drop_last=True,
|
104 |
+
)
|
105 |
+
|
106 |
+
def val_dataloader(self) -> DataLoader:
|
107 |
+
"""Create and return the validation dataloader.
|
108 |
+
|
109 |
+
:return: The validation dataloader.
|
110 |
+
"""
|
111 |
+
return DataLoader(
|
112 |
+
dataset=self.data_val,
|
113 |
+
batch_size=self.batch_size_per_device,
|
114 |
+
num_workers=self.hparams.num_workers,
|
115 |
+
pin_memory=self.hparams.pin_memory,
|
116 |
+
shuffle=False,
|
117 |
+
)
|
118 |
+
|
119 |
+
def test_dataloader(self) -> DataLoader:
|
120 |
+
"""Create and return the test dataloader.
|
121 |
+
|
122 |
+
:return: The test dataloader.
|
123 |
+
"""
|
124 |
+
return DataLoader(
|
125 |
+
dataset=self.data_test,
|
126 |
+
batch_size=self.batch_size_per_device,
|
127 |
+
num_workers=self.hparams.num_workers,
|
128 |
+
pin_memory=self.hparams.pin_memory,
|
129 |
+
shuffle=False,
|
130 |
+
)
|
131 |
+
|
132 |
+
def teardown(self, stage: Optional[str] = None) -> None:
|
133 |
+
"""Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`,
|
134 |
+
`trainer.test()`, and `trainer.predict()`.
|
135 |
+
|
136 |
+
:param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
|
137 |
+
Defaults to ``None``.
|
138 |
+
"""
|
139 |
+
pass
|
140 |
+
|
141 |
+
def state_dict(self) -> Dict[Any, Any]:
|
142 |
+
"""Called when saving a checkpoint. Implement to generate and save the datamodule state.
|
143 |
+
|
144 |
+
:return: A dictionary containing the datamodule state that you want to save.
|
145 |
+
"""
|
146 |
+
return {}
|
147 |
+
|
148 |
+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
149 |
+
"""Called when loading a checkpoint. Implement to reload datamodule state given datamodule
|
150 |
+
`state_dict()`.
|
151 |
+
|
152 |
+
:param state_dict: The datamodule state returned by `self.state_dict()`.
|
153 |
+
"""
|
154 |
+
pass
|
155 |
+
|
156 |
+
|
157 |
+
if __name__ == "__main__":
|
158 |
+
datamodule = PINDERDataModule()
|
159 |
+
datamodule.setup()
|
160 |
+
# print(datamodule.train_files[64])
|
161 |
+
train_loader = datamodule.train_dataloader()
|
162 |
+
val_loader = datamodule.val_dataloader()
|
163 |
+
test_loader = datamodule.test_dataloader()
|
164 |
+
print(f"Number of training batches: {len(train_loader)}")
|
165 |
+
print(f"Number of validation batches: {len(val_loader)}")
|
166 |
+
print(f"Number of test batches: {len(test_loader)}")
|
167 |
+
print(next(iter(train_loader)))
|
src/eval.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Tuple
|
2 |
+
|
3 |
+
import hydra
|
4 |
+
import rootutils
|
5 |
+
from lightning import LightningDataModule, LightningModule, Trainer
|
6 |
+
from lightning.pytorch.loggers import Logger
|
7 |
+
from omegaconf import DictConfig
|
8 |
+
|
9 |
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
10 |
+
# ------------------------------------------------------------------------------------ #
|
11 |
+
# the setup_root above is equivalent to:
|
12 |
+
# - adding project root dir to PYTHONPATH
|
13 |
+
# (so you don't need to force user to install project as a package)
|
14 |
+
# (necessary before importing any local modules e.g. `from src import utils`)
|
15 |
+
# - setting up PROJECT_ROOT environment variable
|
16 |
+
# (which is used as a base for paths in "configs/paths/default.yaml")
|
17 |
+
# (this way all filepaths are the same no matter where you run the code)
|
18 |
+
# - loading environment variables from ".env" in root dir
|
19 |
+
#
|
20 |
+
# you can remove it if you:
|
21 |
+
# 1. either install project as a package or move entry files to project root dir
|
22 |
+
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
|
23 |
+
#
|
24 |
+
# more info: https://github.com/ashleve/rootutils
|
25 |
+
# ------------------------------------------------------------------------------------ #
|
26 |
+
|
27 |
+
from src.utils import (
|
28 |
+
RankedLogger,
|
29 |
+
extras,
|
30 |
+
instantiate_loggers,
|
31 |
+
log_hyperparameters,
|
32 |
+
task_wrapper,
|
33 |
+
)
|
34 |
+
|
35 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
36 |
+
|
37 |
+
|
38 |
+
@task_wrapper
|
39 |
+
def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
40 |
+
"""Evaluates given checkpoint on a datamodule testset.
|
41 |
+
|
42 |
+
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
|
43 |
+
failure. Useful for multiruns, saving info about the crash, etc.
|
44 |
+
|
45 |
+
:param cfg: DictConfig configuration composed by Hydra.
|
46 |
+
:return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
|
47 |
+
"""
|
48 |
+
assert cfg.ckpt_path
|
49 |
+
|
50 |
+
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
|
51 |
+
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
|
52 |
+
|
53 |
+
log.info(f"Instantiating model <{cfg.model._target_}>")
|
54 |
+
model: LightningModule = hydra.utils.instantiate(cfg.model)
|
55 |
+
|
56 |
+
log.info("Instantiating loggers...")
|
57 |
+
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
|
58 |
+
|
59 |
+
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
60 |
+
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
|
61 |
+
|
62 |
+
object_dict = {
|
63 |
+
"cfg": cfg,
|
64 |
+
"datamodule": datamodule,
|
65 |
+
"model": model,
|
66 |
+
"logger": logger,
|
67 |
+
"trainer": trainer,
|
68 |
+
}
|
69 |
+
|
70 |
+
if logger:
|
71 |
+
log.info("Logging hyperparameters!")
|
72 |
+
log_hyperparameters(object_dict)
|
73 |
+
|
74 |
+
log.info("Starting testing!")
|
75 |
+
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
|
76 |
+
|
77 |
+
# for predictions use trainer.predict(...)
|
78 |
+
# predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)
|
79 |
+
|
80 |
+
metric_dict = trainer.callback_metrics
|
81 |
+
|
82 |
+
return metric_dict, object_dict
|
83 |
+
|
84 |
+
|
85 |
+
@hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
|
86 |
+
def main(cfg: DictConfig) -> None:
|
87 |
+
"""Main entry point for evaluation.
|
88 |
+
|
89 |
+
:param cfg: DictConfig configuration composed by Hydra.
|
90 |
+
"""
|
91 |
+
# apply extra utilities
|
92 |
+
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
|
93 |
+
extras(cfg)
|
94 |
+
|
95 |
+
evaluate(cfg)
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == "__main__":
|
99 |
+
main()
|
src/models/__init__.py
ADDED
File without changes
|
src/models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (145 Bytes). View file
|
|
src/models/__pycache__/pinder_module.cpython-310.pyc
ADDED
Binary file (8.44 kB). View file
|
|
src/models/components/__init__.py
ADDED
File without changes
|
src/models/components/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (156 Bytes). View file
|
|
src/models/components/__pycache__/equivariant_mpnn.cpython-310.pyc
ADDED
Binary file (6.84 kB). View file
|
|
src/models/components/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (2.74 kB). View file
|
|
src/models/components/equivariant_mpnn.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import rootutils
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import BatchNorm1d, Linear, Module, ReLU, Sequential
|
5 |
+
from torch_geometric.loader import DataLoader
|
6 |
+
from torch_geometric.nn import MessagePassing
|
7 |
+
from torch_scatter import scatter
|
8 |
+
|
9 |
+
# setup root dir and pythonpath
|
10 |
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
11 |
+
|
12 |
+
from src.data.components.pinder_dataset import PinderDataset
|
13 |
+
from src.models.components.utils import (
|
14 |
+
compute_euler_angles_from_rotation_matrices,
|
15 |
+
compute_rotation_matrix_from_ortho6d,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class EquivariantMPNNLayer(MessagePassing):
|
20 |
+
def __init__(self, emb_dim=64, out_dim=128, aggr="add"):
|
21 |
+
r"""Message Passing Neural Network Layer
|
22 |
+
|
23 |
+
This layer is equivariant to 3D rotations and translations.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
emb_dim: (int) - hidden dimension d
|
27 |
+
edge_dim: (int) - edge feature dimension d_e
|
28 |
+
aggr: (str) - aggregation function \oplus (sum/mean/max)
|
29 |
+
"""
|
30 |
+
# Set the aggregation function
|
31 |
+
super().__init__(aggr=aggr)
|
32 |
+
|
33 |
+
self.emb_dim = emb_dim
|
34 |
+
|
35 |
+
#
|
36 |
+
self.mlp_msg = Sequential(
|
37 |
+
Linear(2 * emb_dim + 1, emb_dim),
|
38 |
+
BatchNorm1d(emb_dim),
|
39 |
+
ReLU(),
|
40 |
+
Linear(emb_dim, emb_dim),
|
41 |
+
BatchNorm1d(emb_dim),
|
42 |
+
ReLU(),
|
43 |
+
)
|
44 |
+
|
45 |
+
self.mlp_pos = Sequential(
|
46 |
+
Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), Linear(emb_dim, 1)
|
47 |
+
) # MLP \psi
|
48 |
+
self.mlp_upd = Sequential(
|
49 |
+
Linear(2 * emb_dim, emb_dim),
|
50 |
+
BatchNorm1d(emb_dim),
|
51 |
+
ReLU(),
|
52 |
+
Linear(emb_dim, emb_dim),
|
53 |
+
BatchNorm1d(emb_dim),
|
54 |
+
ReLU(),
|
55 |
+
) # MLP \phi
|
56 |
+
# ===========================================
|
57 |
+
|
58 |
+
self.lin_out = Linear(emb_dim, out_dim)
|
59 |
+
|
60 |
+
def forward(self, data):
|
61 |
+
"""
|
62 |
+
The forward pass updates node features h via one round of message passing.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
h: (n, d) - initial node features
|
66 |
+
pos: (n, 3) - initial node coordinates
|
67 |
+
edge_index: (e, 2) - pairs of edges (i, j)
|
68 |
+
edge_attr: (e, d_e) - edge features
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
out: [(n, d),(n,3)] - updated node features
|
72 |
+
"""
|
73 |
+
|
74 |
+
#
|
75 |
+
h, pos, edge_index = data
|
76 |
+
h_out, pos_out = self.propagate(edge_index=edge_index, h=h, pos=pos)
|
77 |
+
h_out = self.lin_out(h_out)
|
78 |
+
return h_out, pos_out, edge_index
|
79 |
+
# ==========================================
|
80 |
+
|
81 |
+
#
|
82 |
+
def message(self, h_i, h_j, pos_i, pos_j):
|
83 |
+
# Compute distance between nodes i and j (Euclidean distance)
|
84 |
+
# distance_ij = torch.norm(pos_i - pos_j, dim=-1, keepdim=True) # (e, 1)
|
85 |
+
pos_diff = pos_i - pos_j
|
86 |
+
dists = torch.norm(pos_diff, dim=-1).unsqueeze(1)
|
87 |
+
|
88 |
+
# Concatenate node features, edge features, and distance
|
89 |
+
msg = torch.cat([h_i, h_j, dists], dim=-1)
|
90 |
+
msg = self.mlp_msg(msg)
|
91 |
+
pos_diff = pos_diff * self.mlp_pos(msg) # (e, 2d + d_e + 1)
|
92 |
+
|
93 |
+
# (e, d)
|
94 |
+
return msg, pos_diff
|
95 |
+
|
96 |
+
# ...
|
97 |
+
#
|
98 |
+
def aggregate(self, inputs, index):
|
99 |
+
"""The aggregate function aggregates the messages from neighboring nodes,
|
100 |
+
according to the chosen aggregation function ('sum' by default).
|
101 |
+
|
102 |
+
Args:
|
103 |
+
inputs: (e, d) - messages m_ij from destination to source nodes
|
104 |
+
index: (e, 1) - list of source nodes for each edge/message in input
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
aggr_out: (n, d) - aggregated messages m_i
|
108 |
+
"""
|
109 |
+
msgs, pos_diffs = inputs
|
110 |
+
|
111 |
+
msg_aggr = scatter(msgs, index, dim=self.node_dim, reduce=self.aggr)
|
112 |
+
|
113 |
+
pos_aggr = scatter(pos_diffs, index, dim=self.node_dim, reduce="mean")
|
114 |
+
|
115 |
+
return msg_aggr, pos_aggr
|
116 |
+
|
117 |
+
def update(self, aggr_out, h, pos):
|
118 |
+
msg_aggr, pos_aggr = aggr_out
|
119 |
+
|
120 |
+
upd_out = self.mlp_upd(torch.cat((h, msg_aggr), dim=-1))
|
121 |
+
|
122 |
+
upd_pos = pos + pos_aggr
|
123 |
+
|
124 |
+
return upd_out, upd_pos
|
125 |
+
|
126 |
+
def __repr__(self) -> str:
|
127 |
+
return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})"
|
128 |
+
|
129 |
+
|
130 |
+
class PinderMPNNModel(Module):
|
131 |
+
def __init__(self, input_dim=1, emb_dim=64, num_heads=5):
|
132 |
+
"""Message Passing Neural Network model for graph property prediction
|
133 |
+
|
134 |
+
This model uses both node features and coordinates as inputs, and
|
135 |
+
is invariant to 3D rotations and translations (the constituent MPNN layers
|
136 |
+
are equivariant to 3D rotations and translations).
|
137 |
+
|
138 |
+
Args:
|
139 |
+
emb_dim: (int) - hidden dimension d
|
140 |
+
input_dim: (int) - initial node feature dimension d_n
|
141 |
+
edge_dim: (int) - edge feature dimension d_e
|
142 |
+
out_dim: (int) - output dimension (fixed to 1)
|
143 |
+
"""
|
144 |
+
super().__init__()
|
145 |
+
|
146 |
+
# Linear projection for initial node features
|
147 |
+
self.lin_in_rec = Linear(input_dim, emb_dim)
|
148 |
+
self.lin_in_lig = Linear(input_dim, emb_dim)
|
149 |
+
|
150 |
+
# Stack of MPNN layers
|
151 |
+
self.receptor_mpnn = Sequential(
|
152 |
+
EquivariantMPNNLayer(emb_dim, 128, aggr="mean"),
|
153 |
+
EquivariantMPNNLayer(128, 256, aggr="mean"),
|
154 |
+
# EquivariantMPNNLayer(256, 512, aggr="mean"),
|
155 |
+
# EquivariantMPNNLayer(512, 512, aggr="mean"),
|
156 |
+
)
|
157 |
+
self.ligand_mpnn = Sequential(
|
158 |
+
EquivariantMPNNLayer(64, 128, aggr="mean"),
|
159 |
+
EquivariantMPNNLayer(128, 256, aggr="mean"),
|
160 |
+
# EquivariantMPNNLayer(256, 512, aggr="mean"),
|
161 |
+
# EquivariantMPNNLayer(512, 512, aggr="mean"),
|
162 |
+
)
|
163 |
+
|
164 |
+
# Cross-attention layer
|
165 |
+
self.rec_cross_attention = nn.MultiheadAttention(256, num_heads, batch_first=True)
|
166 |
+
self.lig_cross_attention = nn.MultiheadAttention(256, num_heads, batch_first=True)
|
167 |
+
|
168 |
+
# MLPs for translation prediction
|
169 |
+
self.fc_translation_rec = nn.Linear(256 + 3, 3)
|
170 |
+
self.fc_translation_lig = nn.Linear(256 + 3, 3)
|
171 |
+
|
172 |
+
def forward(self, batch):
|
173 |
+
"""
|
174 |
+
The main forward pass of the model.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
batch: Same as in forward_rot_trans.
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
transformed_ligands: List of tensors, each of shape (1, num_ligand_atoms, 3)
|
181 |
+
representing the transformed ligand coordinates after applying the predicted
|
182 |
+
rotation and translation.
|
183 |
+
"""
|
184 |
+
h_receptor = self.lin_in_rec(batch["receptor"].x)
|
185 |
+
h_ligand = self.lin_in_lig(batch["ligand"].x)
|
186 |
+
|
187 |
+
pos_receptor = batch["receptor"].pos
|
188 |
+
pos_ligand = batch["ligand"].pos
|
189 |
+
|
190 |
+
h_receptor, pos_receptor, _ = self.receptor_mpnn(
|
191 |
+
(h_receptor, pos_receptor, batch["receptor", "receptor"].edge_index)
|
192 |
+
)
|
193 |
+
|
194 |
+
h_ligand, pos_ligand, _ = self.ligand_mpnn(
|
195 |
+
(h_ligand, pos_ligand, batch["ligand", "ligand"].edge_index)
|
196 |
+
)
|
197 |
+
|
198 |
+
attn_output_rec, _ = self.rec_cross_attention(h_receptor, h_ligand, h_ligand)
|
199 |
+
|
200 |
+
attn_output_lig, _ = self.lig_cross_attention(h_ligand, h_receptor, h_receptor)
|
201 |
+
|
202 |
+
emb_features_receptor = torch.cat((attn_output_rec, pos_receptor), dim=-1)
|
203 |
+
emb_features_ligand = torch.cat((attn_output_lig, pos_ligand), dim=-1)
|
204 |
+
|
205 |
+
translation_vector_r = self.fc_translation_rec(emb_features_receptor)
|
206 |
+
translation_vector_l = self.fc_translation_lig(emb_features_ligand)
|
207 |
+
|
208 |
+
ortho_6d_rec = compute_rotation_matrix_from_ortho6d(attn_output_rec)
|
209 |
+
ortho_6d_lig = compute_rotation_matrix_from_ortho6d(attn_output_lig)
|
210 |
+
|
211 |
+
receptor_coords = (
|
212 |
+
compute_euler_angles_from_rotation_matrices(ortho_6d_rec) * 180 / torch.pi
|
213 |
+
)
|
214 |
+
ligand_coords = compute_euler_angles_from_rotation_matrices(ortho_6d_lig) * 180 / torch.pi
|
215 |
+
|
216 |
+
receptor_coords = receptor_coords + translation_vector_r
|
217 |
+
ligand_coords = ligand_coords + translation_vector_l
|
218 |
+
|
219 |
+
return receptor_coords, ligand_coords
|
220 |
+
|
221 |
+
|
222 |
+
if __name__ == "__main__":
|
223 |
+
file_paths = ["./data/processed/apo/test/1a19__A1_P11540--1a19__B1_P11540.pt"]
|
224 |
+
dataset = PinderDataset(file_paths=file_paths * 3)
|
225 |
+
loader = DataLoader(dataset, batch_size=3, shuffle=False)
|
226 |
+
batch = next(iter(loader))
|
227 |
+
model = PinderMPNNModel()
|
228 |
+
print("Number of parameters:", sum(p.numel() for p in model.parameters()))
|
229 |
+
receptor_coords, ligand_coords = model(batch)
|
230 |
+
print(receptor_coords.shape)
|
231 |
+
print(ligand_coords.shape)
|
src/models/components/utils.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
# batch*n
|
5 |
+
def normalize_vector(v):
|
6 |
+
batch = v.shape[0]
|
7 |
+
v_mag = torch.sqrt(v.pow(2).sum(1)) # batch
|
8 |
+
eps = torch.tensor(1e-8, device=v.device)
|
9 |
+
v_mag = torch.max(v_mag, eps)
|
10 |
+
v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1])
|
11 |
+
v = v / v_mag
|
12 |
+
return v
|
13 |
+
|
14 |
+
|
15 |
+
# u, v batch*n
|
16 |
+
def cross_product(u, v):
|
17 |
+
batch = u.shape[0]
|
18 |
+
# print (u.shape)
|
19 |
+
# print (v.shape)
|
20 |
+
i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1]
|
21 |
+
j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2]
|
22 |
+
k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0]
|
23 |
+
|
24 |
+
out = torch.cat((i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1) # batch*3
|
25 |
+
|
26 |
+
return out
|
27 |
+
|
28 |
+
|
29 |
+
# poses batch*6
|
30 |
+
# poses
|
31 |
+
def compute_rotation_matrix_from_ortho6d(poses):
|
32 |
+
x_raw = poses[:, 0:3] # batch*3
|
33 |
+
y_raw = poses[:, 3:6] # batch*3
|
34 |
+
|
35 |
+
x = normalize_vector(x_raw) # batch*3
|
36 |
+
z = cross_product(x, y_raw) # batch*3
|
37 |
+
z = normalize_vector(z) # batch*3
|
38 |
+
y = cross_product(z, x) # batch*3
|
39 |
+
|
40 |
+
x = x.view(-1, 3, 1)
|
41 |
+
y = y.view(-1, 3, 1)
|
42 |
+
z = z.view(-1, 3, 1)
|
43 |
+
matrix = torch.cat((x, y, z), 2) # batch*3*3
|
44 |
+
return matrix
|
45 |
+
|
46 |
+
|
47 |
+
# input batch*4*4 or batch*3*3
|
48 |
+
# output torch batch*3 x, y, z in radiant
|
49 |
+
# the rotation is in the sequence of x,y,z
|
50 |
+
def compute_euler_angles_from_rotation_matrices(rotation_matrices):
|
51 |
+
batch = rotation_matrices.shape[0]
|
52 |
+
R = rotation_matrices
|
53 |
+
sy = torch.sqrt(R[:, 0, 0] * R[:, 0, 0] + R[:, 1, 0] * R[:, 1, 0])
|
54 |
+
singular = sy < 1e-6
|
55 |
+
singular = singular.float()
|
56 |
+
|
57 |
+
x = torch.atan2(R[:, 2, 1], R[:, 2, 2])
|
58 |
+
y = torch.atan2(-R[:, 2, 0], sy)
|
59 |
+
z = torch.atan2(R[:, 1, 0], R[:, 0, 0])
|
60 |
+
|
61 |
+
xs = torch.atan2(-R[:, 1, 2], R[:, 1, 1])
|
62 |
+
ys = torch.atan2(-R[:, 2, 0], sy)
|
63 |
+
zs = R[:, 1, 0] * 0
|
64 |
+
|
65 |
+
out_euler = torch.zeros(batch, 3, device=rotation_matrices.device)
|
66 |
+
|
67 |
+
out_euler[:, 0] = x * (1 - singular) + xs * singular
|
68 |
+
out_euler[:, 1] = y * (1 - singular) + ys * singular
|
69 |
+
out_euler[:, 2] = z * (1 - singular) + zs * singular
|
70 |
+
|
71 |
+
return out_euler
|
72 |
+
|
73 |
+
|
74 |
+
def get_R(x, y, z):
|
75 |
+
"""Get rotation matrix from three rotation angles (radians). right-handed.
|
76 |
+
Args:
|
77 |
+
x: rotation angle around x-axis
|
78 |
+
y: rotation angle around y-axis
|
79 |
+
z: rotation angle around z-axis
|
80 |
+
Returns:
|
81 |
+
R: [3, 3]. rotation matrix.
|
82 |
+
"""
|
83 |
+
# x
|
84 |
+
Rx = torch.tensor(
|
85 |
+
[[1, 0, 0], [0, torch.cos(x), -torch.sin(x)], [0, torch.sin(x), torch.cos(x)]],
|
86 |
+
device=x.device,
|
87 |
+
)
|
88 |
+
# y
|
89 |
+
Ry = torch.tensor(
|
90 |
+
[[torch.cos(y), 0, torch.sin(y)], [0, 1, 0], [-torch.sin(y), 0, torch.cos(y)]],
|
91 |
+
device=y.device,
|
92 |
+
)
|
93 |
+
# z
|
94 |
+
Rz = torch.tensor(
|
95 |
+
[[torch.cos(z), -torch.sin(z), 0], [torch.sin(z), torch.cos(z), 0], [0, 0, 1]],
|
96 |
+
device=z.device,
|
97 |
+
)
|
98 |
+
|
99 |
+
R = torch.mm(Rz, torch.mm(Ry, Rx))
|
100 |
+
return R
|
src/models/pinder_module.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from lightning import LightningModule
|
5 |
+
from torchmetrics import MeanMetric, MinMetric
|
6 |
+
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError
|
7 |
+
|
8 |
+
|
9 |
+
class PinderLitModule(LightningModule):
|
10 |
+
"""Example of a `LightningModule` for MNIST classification.
|
11 |
+
|
12 |
+
A `LightningModule` implements 8 key methods:
|
13 |
+
|
14 |
+
```python
|
15 |
+
def __init__(self):
|
16 |
+
# Define initialization code here.
|
17 |
+
|
18 |
+
def setup(self, stage):
|
19 |
+
# Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
|
20 |
+
# This hook is called on every process when using DDP.
|
21 |
+
|
22 |
+
def training_step(self, batch, batch_idx):
|
23 |
+
# The complete training step.
|
24 |
+
|
25 |
+
def validation_step(self, batch, batch_idx):
|
26 |
+
# The complete validation step.
|
27 |
+
|
28 |
+
def test_step(self, batch, batch_idx):
|
29 |
+
# The complete test step.
|
30 |
+
|
31 |
+
def predict_step(self, batch, batch_idx):
|
32 |
+
# The complete predict step.
|
33 |
+
|
34 |
+
def configure_optimizers(self):
|
35 |
+
# Define and configure optimizers and LR schedulers.
|
36 |
+
```
|
37 |
+
|
38 |
+
Docs:
|
39 |
+
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
net: torch.nn.Module,
|
45 |
+
optimizer: torch.optim.Optimizer,
|
46 |
+
scheduler: torch.optim.lr_scheduler,
|
47 |
+
compile: bool,
|
48 |
+
) -> None:
|
49 |
+
"""Initialize a `MNISTLitModule`.
|
50 |
+
|
51 |
+
:param net: The model to train.
|
52 |
+
:param optimizer: The optimizer to use for training.
|
53 |
+
:param scheduler: The learning rate scheduler to use for training.
|
54 |
+
"""
|
55 |
+
super().__init__()
|
56 |
+
|
57 |
+
# this line allows to access init params with 'self.hparams' attribute
|
58 |
+
# also ensures init params will be stored in ckpt
|
59 |
+
self.save_hyperparameters(logger=False)
|
60 |
+
|
61 |
+
self.net = net
|
62 |
+
|
63 |
+
# loss function
|
64 |
+
self.criterion = torch.nn.MSELoss()
|
65 |
+
|
66 |
+
# metric objects for calculating and averaging accuracy across batches
|
67 |
+
self.train_mse_ligand = MeanSquaredError()
|
68 |
+
self.val_mse_ligand = MeanSquaredError()
|
69 |
+
self.test_mse_ligand = MeanSquaredError()
|
70 |
+
|
71 |
+
self.train_mse_receptor = MeanSquaredError()
|
72 |
+
self.val_mse_receptor = MeanSquaredError()
|
73 |
+
self.test_mse_receptor = MeanSquaredError()
|
74 |
+
|
75 |
+
self.train_mae_receptor = MeanAbsoluteError()
|
76 |
+
self.val_mae_receptor = MeanAbsoluteError()
|
77 |
+
self.test_mae_receptor = MeanAbsoluteError()
|
78 |
+
|
79 |
+
self.train_mae_ligand = MeanAbsoluteError()
|
80 |
+
self.val_mae_ligand = MeanAbsoluteError()
|
81 |
+
self.test_mae_ligand = MeanAbsoluteError()
|
82 |
+
|
83 |
+
# for averaging loss across batches
|
84 |
+
self.train_loss = MeanMetric()
|
85 |
+
self.val_loss = MeanMetric()
|
86 |
+
self.test_loss = MeanMetric()
|
87 |
+
|
88 |
+
# for tracking best so far validation mse
|
89 |
+
self.val_mse_best = MinMetric()
|
90 |
+
|
91 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
92 |
+
"""Perform a forward pass through the model `self.net`.
|
93 |
+
|
94 |
+
:param x: A tensor of images.
|
95 |
+
:return: A tensor of logits.
|
96 |
+
"""
|
97 |
+
return self.net(x)
|
98 |
+
|
99 |
+
def on_train_start(self) -> None:
|
100 |
+
"""Lightning hook that is called when training begins."""
|
101 |
+
# by default lightning executes validation step sanity checks before training starts,
|
102 |
+
# so it's worth to make sure validation metrics don't store results from these checks
|
103 |
+
self.val_loss.reset()
|
104 |
+
self.val_mse_ligand.reset()
|
105 |
+
self.val_mse_receptor.reset()
|
106 |
+
self.val_mae_receptor.reset()
|
107 |
+
self.val_mae_ligand.reset()
|
108 |
+
self.val_mse_best.reset()
|
109 |
+
|
110 |
+
def model_step(
|
111 |
+
self, batch: Tuple[torch.Tensor, torch.Tensor]
|
112 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
113 |
+
"""Perform a single model step on a batch of data.
|
114 |
+
|
115 |
+
:param batch: A batch of data (a tuple) containing the input tensor of images and target labels.
|
116 |
+
|
117 |
+
:return: A tuple containing (in order):
|
118 |
+
- A tensor of losses.
|
119 |
+
- A tensor of predictions.
|
120 |
+
- A tensor of target labels.
|
121 |
+
"""
|
122 |
+
|
123 |
+
receptor_coords, ligand_coords = self.forward(batch)
|
124 |
+
loss_receptor = self.criterion(receptor_coords, batch["receptor"].y)
|
125 |
+
loss_ligand = self.criterion(ligand_coords, batch["ligand"].y)
|
126 |
+
loss = loss_receptor + loss_ligand
|
127 |
+
return loss, receptor_coords, ligand_coords, batch["receptor"].y, batch["ligand"].y
|
128 |
+
|
129 |
+
def training_step(
|
130 |
+
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
|
131 |
+
) -> torch.Tensor:
|
132 |
+
"""Perform a single training step on a batch of data from the training set.
|
133 |
+
|
134 |
+
:param batch: A batch of data (a tuple) containing the input tensor of images and target
|
135 |
+
labels.
|
136 |
+
:param batch_idx: The index of the current batch.
|
137 |
+
:return: A tensor of losses between model predictions and targets.
|
138 |
+
"""
|
139 |
+
loss, receptor_coords, ligand_coords, receptor_targets, ligand_targets = self.model_step(
|
140 |
+
batch
|
141 |
+
)
|
142 |
+
|
143 |
+
# update and log metrics
|
144 |
+
self.train_loss(loss)
|
145 |
+
self.train_mse_ligand(ligand_coords, ligand_targets)
|
146 |
+
self.train_mse_receptor(receptor_coords, receptor_targets)
|
147 |
+
self.train_mae_ligand(ligand_coords, ligand_targets)
|
148 |
+
self.train_mae_receptor(receptor_coords, receptor_targets)
|
149 |
+
self.log("train/loss", self.train_loss, on_step=True, on_epoch=False, prog_bar=True)
|
150 |
+
self.log(
|
151 |
+
"train/mse_ligand", self.train_mse_ligand, on_step=True, on_epoch=False, prog_bar=True
|
152 |
+
)
|
153 |
+
self.log(
|
154 |
+
"train/mse_receptor",
|
155 |
+
self.train_mse_receptor,
|
156 |
+
on_step=True,
|
157 |
+
on_epoch=False,
|
158 |
+
prog_bar=True,
|
159 |
+
)
|
160 |
+
self.log(
|
161 |
+
"train/mae_ligand", self.train_mae_ligand, on_step=True, on_epoch=False, prog_bar=True
|
162 |
+
)
|
163 |
+
self.log(
|
164 |
+
"train/mae_receptor",
|
165 |
+
self.train_mae_receptor,
|
166 |
+
on_step=True,
|
167 |
+
on_epoch=False,
|
168 |
+
prog_bar=True,
|
169 |
+
)
|
170 |
+
|
171 |
+
# return loss or backpropagation will fail
|
172 |
+
return loss
|
173 |
+
|
174 |
+
def on_train_epoch_end(self) -> None:
|
175 |
+
"Lightning hook that is called when a training epoch ends."
|
176 |
+
pass
|
177 |
+
|
178 |
+
def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
|
179 |
+
"""Perform a single validation step on a batch of data from the validation set.
|
180 |
+
|
181 |
+
:param batch: A batch of data (a tuple) containing the input tensor of images and target
|
182 |
+
labels.
|
183 |
+
:param batch_idx: The index of the current batch.
|
184 |
+
"""
|
185 |
+
loss, receptor_coords, ligand_coords, receptor_targets, ligand_targets = self.model_step(
|
186 |
+
batch
|
187 |
+
)
|
188 |
+
|
189 |
+
# update and log metrics
|
190 |
+
self.val_loss(loss)
|
191 |
+
self.val_mse_ligand(ligand_coords, ligand_targets)
|
192 |
+
self.val_mse_receptor(receptor_coords, receptor_targets)
|
193 |
+
self.val_mae_ligand(ligand_coords, ligand_targets)
|
194 |
+
self.val_mae_receptor(receptor_coords, receptor_targets)
|
195 |
+
self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
|
196 |
+
self.log(
|
197 |
+
"val/mse_ligand", self.val_mse_ligand, on_step=False, on_epoch=True, prog_bar=True
|
198 |
+
)
|
199 |
+
self.log(
|
200 |
+
"val/mse_receptor", self.val_mse_receptor, on_step=False, on_epoch=True, prog_bar=True
|
201 |
+
)
|
202 |
+
self.log(
|
203 |
+
"val/mae_ligand", self.val_mae_ligand, on_step=False, on_epoch=True, prog_bar=True
|
204 |
+
)
|
205 |
+
self.log(
|
206 |
+
"val/mae_receptor", self.val_mae_receptor, on_step=False, on_epoch=True, prog_bar=True
|
207 |
+
)
|
208 |
+
|
209 |
+
def on_validation_epoch_end(self) -> None:
|
210 |
+
"Lightning hook that is called when a validation epoch ends."
|
211 |
+
acc = self.val_mse_ligand.compute() # get current val acc
|
212 |
+
self.val_mse_best(acc) # update best so far val acc
|
213 |
+
# log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
|
214 |
+
# otherwise metric would be reset by lightning after each epoch
|
215 |
+
self.log("val/acc_best", self.val_mse_best.compute(), sync_dist=True, prog_bar=True)
|
216 |
+
|
217 |
+
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
|
218 |
+
"""Perform a single test step on a batch of data from the test set.
|
219 |
+
|
220 |
+
:param batch: A batch of data (a tuple) containing the input tensor of images and target
|
221 |
+
labels.
|
222 |
+
:param batch_idx: The index of the current batch.
|
223 |
+
"""
|
224 |
+
loss, receptor_coords, ligand_coords, receptor_targets, ligand_targets = self.model_step(
|
225 |
+
batch
|
226 |
+
)
|
227 |
+
|
228 |
+
# update and log metrics
|
229 |
+
self.test_loss(loss)
|
230 |
+
self.test_mse_ligand(ligand_coords, ligand_targets)
|
231 |
+
self.test_mse_receptor(receptor_coords, receptor_targets)
|
232 |
+
self.test_mae_ligand(ligand_coords, ligand_targets)
|
233 |
+
self.test_mae_receptor(receptor_coords, receptor_targets)
|
234 |
+
self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
|
235 |
+
self.log(
|
236 |
+
"test/mse_ligand", self.test_mse_ligand, on_step=False, on_epoch=True, prog_bar=True
|
237 |
+
)
|
238 |
+
self.log(
|
239 |
+
"test/mse_receptor",
|
240 |
+
self.test_mse_receptor,
|
241 |
+
on_step=False,
|
242 |
+
on_epoch=True,
|
243 |
+
prog_bar=True,
|
244 |
+
)
|
245 |
+
self.log(
|
246 |
+
"test/mae_ligand", self.test_mae_ligand, on_step=False, on_epoch=True, prog_bar=True
|
247 |
+
)
|
248 |
+
self.log(
|
249 |
+
"test/mae_receptor",
|
250 |
+
self.test_mae_receptor,
|
251 |
+
on_step=False,
|
252 |
+
on_epoch=True,
|
253 |
+
prog_bar=True,
|
254 |
+
)
|
255 |
+
|
256 |
+
def on_test_epoch_end(self) -> None:
|
257 |
+
"""Lightning hook that is called when a test epoch ends."""
|
258 |
+
pass
|
259 |
+
|
260 |
+
def setup(self, stage: str) -> None:
|
261 |
+
"""Lightning hook that is called at the beginning of fit (train + validate), validate,
|
262 |
+
test, or predict.
|
263 |
+
|
264 |
+
This is a good hook when you need to build models dynamically or adjust something about
|
265 |
+
them. This hook is called on every process when using DDP.
|
266 |
+
|
267 |
+
:param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
|
268 |
+
"""
|
269 |
+
if self.hparams.compile and stage == "fit":
|
270 |
+
self.net = torch.compile(self.net)
|
271 |
+
|
272 |
+
def configure_optimizers(self) -> Dict[str, Any]:
|
273 |
+
"""Choose what optimizers and learning-rate schedulers to use in your optimization.
|
274 |
+
Normally you'd need one. But in the case of GANs or similar you might have multiple.
|
275 |
+
|
276 |
+
Examples:
|
277 |
+
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
|
278 |
+
|
279 |
+
:return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
|
280 |
+
"""
|
281 |
+
optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
|
282 |
+
if self.hparams.scheduler is not None:
|
283 |
+
scheduler = self.hparams.scheduler(optimizer=optimizer)
|
284 |
+
return {
|
285 |
+
"optimizer": optimizer,
|
286 |
+
"lr_scheduler": {
|
287 |
+
"scheduler": scheduler,
|
288 |
+
"monitor": "val/loss",
|
289 |
+
"interval": "epoch",
|
290 |
+
"frequency": 1,
|
291 |
+
},
|
292 |
+
}
|
293 |
+
return {"optimizer": optimizer}
|
294 |
+
|
295 |
+
|
296 |
+
if __name__ == "__main__":
|
297 |
+
_ = PinderLitModule(None, None, None, None)
|
src/train.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Optional, Tuple
|
2 |
+
|
3 |
+
import hydra
|
4 |
+
import lightning as L
|
5 |
+
import rootutils
|
6 |
+
import torch
|
7 |
+
from lightning import Callback, LightningDataModule, LightningModule, Trainer
|
8 |
+
from lightning.pytorch.loggers import Logger
|
9 |
+
from omegaconf import DictConfig
|
10 |
+
|
11 |
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
12 |
+
# ------------------------------------------------------------------------------------ #
|
13 |
+
# the setup_root above is equivalent to:
|
14 |
+
# - adding project root dir to PYTHONPATH
|
15 |
+
# (so you don't need to force user to install project as a package)
|
16 |
+
# (necessary before importing any local modules e.g. `from src import utils`)
|
17 |
+
# - setting up PROJECT_ROOT environment variable
|
18 |
+
# (which is used as a base for paths in "configs/paths/default.yaml")
|
19 |
+
# (this way all filepaths are the same no matter where you run the code)
|
20 |
+
# - loading environment variables from ".env" in root dir
|
21 |
+
#
|
22 |
+
# you can remove it if you:
|
23 |
+
# 1. either install project as a package or move entry files to project root dir
|
24 |
+
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
|
25 |
+
#
|
26 |
+
# more info: https://github.com/ashleve/rootutils
|
27 |
+
# ------------------------------------------------------------------------------------ #
|
28 |
+
|
29 |
+
from src.utils import (
|
30 |
+
RankedLogger,
|
31 |
+
extras,
|
32 |
+
get_metric_value,
|
33 |
+
instantiate_callbacks,
|
34 |
+
instantiate_loggers,
|
35 |
+
log_hyperparameters,
|
36 |
+
task_wrapper,
|
37 |
+
)
|
38 |
+
|
39 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
40 |
+
|
41 |
+
|
42 |
+
@task_wrapper
|
43 |
+
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
44 |
+
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
|
45 |
+
training.
|
46 |
+
|
47 |
+
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
|
48 |
+
failure. Useful for multiruns, saving info about the crash, etc.
|
49 |
+
|
50 |
+
:param cfg: A DictConfig configuration composed by Hydra.
|
51 |
+
:return: A tuple with metrics and dict with all instantiated objects.
|
52 |
+
"""
|
53 |
+
# set seed for random number generators in pytorch, numpy and python.random
|
54 |
+
if cfg.get("seed"):
|
55 |
+
L.seed_everything(cfg.seed, workers=True)
|
56 |
+
|
57 |
+
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
|
58 |
+
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
|
59 |
+
|
60 |
+
log.info(f"Instantiating model <{cfg.model._target_}>")
|
61 |
+
model: LightningModule = hydra.utils.instantiate(cfg.model)
|
62 |
+
|
63 |
+
log.info("Instantiating callbacks...")
|
64 |
+
callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))
|
65 |
+
|
66 |
+
log.info("Instantiating loggers...")
|
67 |
+
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
|
68 |
+
|
69 |
+
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
70 |
+
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
|
71 |
+
|
72 |
+
object_dict = {
|
73 |
+
"cfg": cfg,
|
74 |
+
"datamodule": datamodule,
|
75 |
+
"model": model,
|
76 |
+
"callbacks": callbacks,
|
77 |
+
"logger": logger,
|
78 |
+
"trainer": trainer,
|
79 |
+
}
|
80 |
+
|
81 |
+
if logger:
|
82 |
+
log.info("Logging hyperparameters!")
|
83 |
+
log_hyperparameters(object_dict)
|
84 |
+
|
85 |
+
if cfg.get("train"):
|
86 |
+
log.info("Starting training!")
|
87 |
+
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
|
88 |
+
|
89 |
+
train_metrics = trainer.callback_metrics
|
90 |
+
|
91 |
+
if cfg.get("test"):
|
92 |
+
log.info("Starting testing!")
|
93 |
+
ckpt_path = trainer.checkpoint_callback.best_model_path
|
94 |
+
if ckpt_path == "":
|
95 |
+
log.warning("Best ckpt not found! Using current weights for testing...")
|
96 |
+
ckpt_path = None
|
97 |
+
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
|
98 |
+
log.info(f"Best ckpt path: {ckpt_path}")
|
99 |
+
|
100 |
+
test_metrics = trainer.callback_metrics
|
101 |
+
|
102 |
+
# merge train and test metrics
|
103 |
+
metric_dict = {**train_metrics, **test_metrics}
|
104 |
+
|
105 |
+
return metric_dict, object_dict
|
106 |
+
|
107 |
+
|
108 |
+
@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
|
109 |
+
def main(cfg: DictConfig) -> Optional[float]:
|
110 |
+
"""Main entry point for training.
|
111 |
+
|
112 |
+
:param cfg: DictConfig configuration composed by Hydra.
|
113 |
+
:return: Optional[float] with optimized metric value.
|
114 |
+
"""
|
115 |
+
# apply extra utilities
|
116 |
+
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
|
117 |
+
extras(cfg)
|
118 |
+
|
119 |
+
# train the model
|
120 |
+
metric_dict, _ = train(cfg)
|
121 |
+
|
122 |
+
# safely retrieve metric value for hydra-based hyperparameter optimization
|
123 |
+
metric_value = get_metric_value(
|
124 |
+
metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
|
125 |
+
)
|
126 |
+
|
127 |
+
# return optimized metric
|
128 |
+
return metric_value
|
129 |
+
|
130 |
+
|
131 |
+
if __name__ == "__main__":
|
132 |
+
torch.set_float32_matmul_precision("high")
|
133 |
+
main()
|
src/utils/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.utils.instantiators import instantiate_callbacks, instantiate_loggers
|
2 |
+
from src.utils.logging_utils import log_hyperparameters
|
3 |
+
from src.utils.pylogger import RankedLogger
|
4 |
+
from src.utils.rich_utils import enforce_tags, print_config_tree
|
5 |
+
from src.utils.utils import extras, get_metric_value, task_wrapper
|
src/utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (546 Bytes). View file
|
|
src/utils/__pycache__/instantiators.cpython-310.pyc
ADDED
Binary file (1.57 kB). View file
|
|
src/utils/__pycache__/logging_utils.cpython-310.pyc
ADDED
Binary file (1.96 kB). View file
|
|
src/utils/__pycache__/pylogger.cpython-310.pyc
ADDED
Binary file (2.55 kB). View file
|
|
src/utils/__pycache__/rich_utils.cpython-310.pyc
ADDED
Binary file (3.21 kB). View file
|
|
src/utils/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (3.69 kB). View file
|
|
src/utils/instantiators.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import hydra
|
4 |
+
from lightning import Callback
|
5 |
+
from lightning.pytorch.loggers import Logger
|
6 |
+
from omegaconf import DictConfig
|
7 |
+
|
8 |
+
from src.utils import pylogger
|
9 |
+
|
10 |
+
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
11 |
+
|
12 |
+
|
13 |
+
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
|
14 |
+
"""Instantiates callbacks from config.
|
15 |
+
|
16 |
+
:param callbacks_cfg: A DictConfig object containing callback configurations.
|
17 |
+
:return: A list of instantiated callbacks.
|
18 |
+
"""
|
19 |
+
callbacks: List[Callback] = []
|
20 |
+
|
21 |
+
if not callbacks_cfg:
|
22 |
+
log.warning("No callback configs found! Skipping..")
|
23 |
+
return callbacks
|
24 |
+
|
25 |
+
if not isinstance(callbacks_cfg, DictConfig):
|
26 |
+
raise TypeError("Callbacks config must be a DictConfig!")
|
27 |
+
|
28 |
+
for _, cb_conf in callbacks_cfg.items():
|
29 |
+
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
|
30 |
+
log.info(f"Instantiating callback <{cb_conf._target_}>")
|
31 |
+
callbacks.append(hydra.utils.instantiate(cb_conf))
|
32 |
+
|
33 |
+
return callbacks
|
34 |
+
|
35 |
+
|
36 |
+
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
37 |
+
"""Instantiates loggers from config.
|
38 |
+
|
39 |
+
:param logger_cfg: A DictConfig object containing logger configurations.
|
40 |
+
:return: A list of instantiated loggers.
|
41 |
+
"""
|
42 |
+
logger: List[Logger] = []
|
43 |
+
|
44 |
+
if not logger_cfg:
|
45 |
+
log.warning("No logger configs found! Skipping...")
|
46 |
+
return logger
|
47 |
+
|
48 |
+
if not isinstance(logger_cfg, DictConfig):
|
49 |
+
raise TypeError("Logger config must be a DictConfig!")
|
50 |
+
|
51 |
+
for _, lg_conf in logger_cfg.items():
|
52 |
+
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
53 |
+
log.info(f"Instantiating logger <{lg_conf._target_}>")
|
54 |
+
logger.append(hydra.utils.instantiate(lg_conf))
|
55 |
+
|
56 |
+
return logger
|
src/utils/logging_utils.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict
|
2 |
+
|
3 |
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
|
6 |
+
from src.utils import pylogger
|
7 |
+
|
8 |
+
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
9 |
+
|
10 |
+
|
11 |
+
@rank_zero_only
|
12 |
+
def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
|
13 |
+
"""Controls which config parts are saved by Lightning loggers.
|
14 |
+
|
15 |
+
Additionally saves:
|
16 |
+
- Number of model parameters
|
17 |
+
|
18 |
+
:param object_dict: A dictionary containing the following objects:
|
19 |
+
- `"cfg"`: A DictConfig object containing the main config.
|
20 |
+
- `"model"`: The Lightning model.
|
21 |
+
- `"trainer"`: The Lightning trainer.
|
22 |
+
"""
|
23 |
+
hparams = {}
|
24 |
+
|
25 |
+
cfg = OmegaConf.to_container(object_dict["cfg"])
|
26 |
+
model = object_dict["model"]
|
27 |
+
trainer = object_dict["trainer"]
|
28 |
+
|
29 |
+
if not trainer.logger:
|
30 |
+
log.warning("Logger not found! Skipping hyperparameter logging...")
|
31 |
+
return
|
32 |
+
|
33 |
+
hparams["model"] = cfg["model"]
|
34 |
+
|
35 |
+
# save number of model parameters
|
36 |
+
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
|
37 |
+
hparams["model/params/trainable"] = sum(
|
38 |
+
p.numel() for p in model.parameters() if p.requires_grad
|
39 |
+
)
|
40 |
+
hparams["model/params/non_trainable"] = sum(
|
41 |
+
p.numel() for p in model.parameters() if not p.requires_grad
|
42 |
+
)
|
43 |
+
|
44 |
+
hparams["data"] = cfg["data"]
|
45 |
+
hparams["trainer"] = cfg["trainer"]
|
46 |
+
|
47 |
+
hparams["callbacks"] = cfg.get("callbacks")
|
48 |
+
hparams["extras"] = cfg.get("extras")
|
49 |
+
|
50 |
+
hparams["task_name"] = cfg.get("task_name")
|
51 |
+
hparams["tags"] = cfg.get("tags")
|
52 |
+
hparams["ckpt_path"] = cfg.get("ckpt_path")
|
53 |
+
hparams["seed"] = cfg.get("seed")
|
54 |
+
|
55 |
+
# send hparams to all loggers
|
56 |
+
for logger in trainer.loggers:
|
57 |
+
logger.log_hyperparams(hparams)
|
src/utils/pylogger.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Mapping, Optional
|
3 |
+
|
4 |
+
from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
|
5 |
+
|
6 |
+
|
7 |
+
class RankedLogger(logging.LoggerAdapter):
|
8 |
+
"""A multi-GPU-friendly python command line logger."""
|
9 |
+
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
name: str = __name__,
|
13 |
+
rank_zero_only: bool = False,
|
14 |
+
extra: Optional[Mapping[str, object]] = None,
|
15 |
+
) -> None:
|
16 |
+
"""Initializes a multi-GPU-friendly python command line logger that logs on all processes
|
17 |
+
with their rank prefixed in the log message.
|
18 |
+
|
19 |
+
:param name: The name of the logger. Default is ``__name__``.
|
20 |
+
:param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
|
21 |
+
:param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
|
22 |
+
"""
|
23 |
+
logger = logging.getLogger(name)
|
24 |
+
super().__init__(logger=logger, extra=extra)
|
25 |
+
self.rank_zero_only = rank_zero_only
|
26 |
+
|
27 |
+
def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None:
|
28 |
+
"""Delegate a log call to the underlying logger, after prefixing its message with the rank
|
29 |
+
of the process it's being logged from. If `'rank'` is provided, then the log will only
|
30 |
+
occur on that rank/process.
|
31 |
+
|
32 |
+
:param level: The level to log at. Look at `logging.__init__.py` for more information.
|
33 |
+
:param msg: The message to log.
|
34 |
+
:param rank: The rank to log at.
|
35 |
+
:param args: Additional args to pass to the underlying logging function.
|
36 |
+
:param kwargs: Any additional keyword args to pass to the underlying logging function.
|
37 |
+
"""
|
38 |
+
if self.isEnabledFor(level):
|
39 |
+
msg, kwargs = self.process(msg, kwargs)
|
40 |
+
current_rank = getattr(rank_zero_only, "rank", None)
|
41 |
+
if current_rank is None:
|
42 |
+
raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
|
43 |
+
msg = rank_prefixed_message(msg, current_rank)
|
44 |
+
if self.rank_zero_only:
|
45 |
+
if current_rank == 0:
|
46 |
+
self.logger.log(level, msg, *args, **kwargs)
|
47 |
+
else:
|
48 |
+
if rank is None:
|
49 |
+
self.logger.log(level, msg, *args, **kwargs)
|
50 |
+
elif current_rank == rank:
|
51 |
+
self.logger.log(level, msg, *args, **kwargs)
|
src/utils/rich_utils.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Sequence
|
3 |
+
|
4 |
+
import rich
|
5 |
+
import rich.syntax
|
6 |
+
import rich.tree
|
7 |
+
from hydra.core.hydra_config import HydraConfig
|
8 |
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
9 |
+
from omegaconf import DictConfig, OmegaConf, open_dict
|
10 |
+
from rich.prompt import Prompt
|
11 |
+
|
12 |
+
from src.utils import pylogger
|
13 |
+
|
14 |
+
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
15 |
+
|
16 |
+
|
17 |
+
@rank_zero_only
|
18 |
+
def print_config_tree(
|
19 |
+
cfg: DictConfig,
|
20 |
+
print_order: Sequence[str] = (
|
21 |
+
"data",
|
22 |
+
"model",
|
23 |
+
"callbacks",
|
24 |
+
"logger",
|
25 |
+
"trainer",
|
26 |
+
"paths",
|
27 |
+
"extras",
|
28 |
+
),
|
29 |
+
resolve: bool = False,
|
30 |
+
save_to_file: bool = False,
|
31 |
+
) -> None:
|
32 |
+
"""Prints the contents of a DictConfig as a tree structure using the Rich library.
|
33 |
+
|
34 |
+
:param cfg: A DictConfig composed by Hydra.
|
35 |
+
:param print_order: Determines in what order config components are printed. Default is ``("data", "model",
|
36 |
+
"callbacks", "logger", "trainer", "paths", "extras")``.
|
37 |
+
:param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
|
38 |
+
:param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
|
39 |
+
"""
|
40 |
+
style = "dim"
|
41 |
+
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
|
42 |
+
|
43 |
+
queue = []
|
44 |
+
|
45 |
+
# add fields from `print_order` to queue
|
46 |
+
for field in print_order:
|
47 |
+
(
|
48 |
+
queue.append(field)
|
49 |
+
if field in cfg
|
50 |
+
else log.warning(
|
51 |
+
f"Field '{field}' not found in config. Skipping '{field}' config printing..."
|
52 |
+
)
|
53 |
+
)
|
54 |
+
|
55 |
+
# add all the other fields to queue (not specified in `print_order`)
|
56 |
+
for field in cfg:
|
57 |
+
if field not in queue:
|
58 |
+
queue.append(field)
|
59 |
+
|
60 |
+
# generate config tree from queue
|
61 |
+
for field in queue:
|
62 |
+
branch = tree.add(field, style=style, guide_style=style)
|
63 |
+
|
64 |
+
config_group = cfg[field]
|
65 |
+
if isinstance(config_group, DictConfig):
|
66 |
+
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
|
67 |
+
else:
|
68 |
+
branch_content = str(config_group)
|
69 |
+
|
70 |
+
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
|
71 |
+
|
72 |
+
# print config tree
|
73 |
+
rich.print(tree)
|
74 |
+
|
75 |
+
# save config tree to file
|
76 |
+
if save_to_file:
|
77 |
+
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
|
78 |
+
rich.print(tree, file=file)
|
79 |
+
|
80 |
+
|
81 |
+
@rank_zero_only
|
82 |
+
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
|
83 |
+
"""Prompts user to input tags from command line if no tags are provided in config.
|
84 |
+
|
85 |
+
:param cfg: A DictConfig composed by Hydra.
|
86 |
+
:param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
|
87 |
+
"""
|
88 |
+
if not cfg.get("tags"):
|
89 |
+
if "id" in HydraConfig().cfg.hydra.job:
|
90 |
+
raise ValueError("Specify tags before launching a multirun!")
|
91 |
+
|
92 |
+
log.warning("No tags provided in config. Prompting user to input tags...")
|
93 |
+
tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
|
94 |
+
tags = [t.strip() for t in tags.split(",") if t != ""]
|
95 |
+
|
96 |
+
with open_dict(cfg):
|
97 |
+
cfg.tags = tags
|
98 |
+
|
99 |
+
log.info(f"Tags: {cfg.tags}")
|
100 |
+
|
101 |
+
if save_to_file:
|
102 |
+
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
|
103 |
+
rich.print(cfg.tags, file=file)
|
src/utils/utils.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from importlib.util import find_spec
|
3 |
+
from typing import Any, Callable, Dict, Optional, Tuple
|
4 |
+
|
5 |
+
from omegaconf import DictConfig
|
6 |
+
|
7 |
+
from src.utils import pylogger, rich_utils
|
8 |
+
|
9 |
+
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
10 |
+
|
11 |
+
|
12 |
+
def extras(cfg: DictConfig) -> None:
|
13 |
+
"""Applies optional utilities before the task is started.
|
14 |
+
|
15 |
+
Utilities:
|
16 |
+
- Ignoring python warnings
|
17 |
+
- Setting tags from command line
|
18 |
+
- Rich config printing
|
19 |
+
|
20 |
+
:param cfg: A DictConfig object containing the config tree.
|
21 |
+
"""
|
22 |
+
# return if no `extras` config
|
23 |
+
if not cfg.get("extras"):
|
24 |
+
log.warning("Extras config not found! <cfg.extras=null>")
|
25 |
+
return
|
26 |
+
|
27 |
+
# disable python warnings
|
28 |
+
if cfg.extras.get("ignore_warnings"):
|
29 |
+
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
|
30 |
+
warnings.filterwarnings("ignore")
|
31 |
+
|
32 |
+
# prompt user to input tags from command line if none are provided in the config
|
33 |
+
if cfg.extras.get("enforce_tags"):
|
34 |
+
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
|
35 |
+
rich_utils.enforce_tags(cfg, save_to_file=True)
|
36 |
+
|
37 |
+
# pretty print config tree using Rich library
|
38 |
+
if cfg.extras.get("print_config"):
|
39 |
+
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
|
40 |
+
rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
|
41 |
+
|
42 |
+
|
43 |
+
def task_wrapper(task_func: Callable) -> Callable:
|
44 |
+
"""Optional decorator that controls the failure behavior when executing the task function.
|
45 |
+
|
46 |
+
This wrapper can be used to:
|
47 |
+
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
|
48 |
+
- save the exception to a `.log` file
|
49 |
+
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
|
50 |
+
- etc. (adjust depending on your needs)
|
51 |
+
|
52 |
+
Example:
|
53 |
+
```
|
54 |
+
@utils.task_wrapper
|
55 |
+
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
56 |
+
...
|
57 |
+
return metric_dict, object_dict
|
58 |
+
```
|
59 |
+
|
60 |
+
:param task_func: The task function to be wrapped.
|
61 |
+
|
62 |
+
:return: The wrapped task function.
|
63 |
+
"""
|
64 |
+
|
65 |
+
def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
66 |
+
# execute the task
|
67 |
+
try:
|
68 |
+
metric_dict, object_dict = task_func(cfg=cfg)
|
69 |
+
|
70 |
+
# things to do if exception occurs
|
71 |
+
except Exception as ex:
|
72 |
+
# save exception to `.log` file
|
73 |
+
log.exception("")
|
74 |
+
|
75 |
+
# some hyperparameter combinations might be invalid or cause out-of-memory errors
|
76 |
+
# so when using hparam search plugins like Optuna, you might want to disable
|
77 |
+
# raising the below exception to avoid multirun failure
|
78 |
+
raise ex
|
79 |
+
|
80 |
+
# things to always do after either success or exception
|
81 |
+
finally:
|
82 |
+
# display output dir path in terminal
|
83 |
+
log.info(f"Output dir: {cfg.paths.output_dir}")
|
84 |
+
|
85 |
+
# always close wandb run (even if exception occurs so multirun won't fail)
|
86 |
+
if find_spec("wandb"): # check if wandb is installed
|
87 |
+
import wandb
|
88 |
+
|
89 |
+
if wandb.run:
|
90 |
+
log.info("Closing wandb!")
|
91 |
+
wandb.finish()
|
92 |
+
|
93 |
+
return metric_dict, object_dict
|
94 |
+
|
95 |
+
return wrap
|
96 |
+
|
97 |
+
|
98 |
+
def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]:
|
99 |
+
"""Safely retrieves value of the metric logged in LightningModule.
|
100 |
+
|
101 |
+
:param metric_dict: A dict containing metric values.
|
102 |
+
:param metric_name: If provided, the name of the metric to retrieve.
|
103 |
+
:return: If a metric name was provided, the value of the metric.
|
104 |
+
"""
|
105 |
+
if not metric_name:
|
106 |
+
log.info("Metric name is None! Skipping metric value retrieval...")
|
107 |
+
return None
|
108 |
+
|
109 |
+
if metric_name not in metric_dict:
|
110 |
+
raise Exception(
|
111 |
+
f"Metric value not found! <metric_name={metric_name}>\n"
|
112 |
+
"Make sure metric name logged in LightningModule is correct!\n"
|
113 |
+
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
|
114 |
+
)
|
115 |
+
|
116 |
+
metric_value = metric_dict[metric_name].item()
|
117 |
+
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
|
118 |
+
|
119 |
+
return metric_value
|