qninhdt commited on
Commit
8cc0674
1 Parent(s): c6851a9
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +4 -0
  2. Makefile +30 -0
  3. README.md +1297 -0
  4. configs/__init__.py +1 -0
  5. configs/callbacks/default.yaml +22 -0
  6. configs/callbacks/early_stopping.yaml +15 -0
  7. configs/callbacks/model_checkpoint.yaml +17 -0
  8. configs/callbacks/model_summary.yaml +5 -0
  9. configs/callbacks/none.yaml +0 -0
  10. configs/callbacks/rich_progress_bar.yaml +4 -0
  11. configs/data/swim.yaml +4 -0
  12. configs/debug/default.yaml +35 -0
  13. configs/debug/fdr.yaml +9 -0
  14. configs/debug/limit.yaml +12 -0
  15. configs/debug/overfit.yaml +13 -0
  16. configs/debug/profiler.yaml +12 -0
  17. configs/eval.yaml +18 -0
  18. configs/experiment/example.yaml +41 -0
  19. configs/extras/default.yaml +8 -0
  20. configs/hparams_search/mnist_optuna.yaml +52 -0
  21. configs/hydra/default.yaml +19 -0
  22. configs/local/.gitkeep +0 -0
  23. configs/logger/aim.yaml +28 -0
  24. configs/logger/comet.yaml +12 -0
  25. configs/logger/csv.yaml +7 -0
  26. configs/logger/many_loggers.yaml +9 -0
  27. configs/logger/mlflow.yaml +12 -0
  28. configs/logger/neptune.yaml +9 -0
  29. configs/logger/tensorboard.yaml +10 -0
  30. configs/logger/wandb.yaml +16 -0
  31. configs/model/autoencoder.yaml +11 -0
  32. configs/paths/default.yaml +18 -0
  33. configs/train.yaml +49 -0
  34. configs/trainer/cpu.yaml +5 -0
  35. configs/trainer/ddp.yaml +9 -0
  36. configs/trainer/ddp_sim.yaml +7 -0
  37. configs/trainer/default.yaml +19 -0
  38. configs/trainer/gpu.yaml +5 -0
  39. configs/trainer/mps.yaml +5 -0
  40. environment.yaml +45 -0
  41. notebooks/.gitkeep +0 -0
  42. pyproject.toml +25 -0
  43. requirements.txt +24 -0
  44. scripts/schedule.sh +7 -0
  45. setup.py +21 -0
  46. swim/__init__.py +0 -0
  47. swim/data/__init__.py +0 -0
  48. swim/data/swim_data.py +145 -0
  49. swim/eval.py +99 -0
  50. swim/models/__init__.py +0 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__
2
+ *.zip
3
+ wandb
4
+ logs
Makefile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ help: ## Show help
3
+ @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
4
+
5
+ clean: ## Clean autogenerated files
6
+ rm -rf dist
7
+ find . -type f -name "*.DS_Store" -ls -delete
8
+ find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
9
+ find . | grep -E ".pytest_cache" | xargs rm -rf
10
+ find . | grep -E ".ipynb_checkpoints" | xargs rm -rf
11
+ rm -f .coverage
12
+
13
+ clean-logs: ## Clean logs
14
+ rm -rf logs/**
15
+
16
+ format: ## Run pre-commit hooks
17
+ pre-commit run -a
18
+
19
+ sync: ## Merge changes from main branch to your current branch
20
+ git pull
21
+ git pull origin main
22
+
23
+ test: ## Run not slow tests
24
+ pytest -k "not slow"
25
+
26
+ test-full: ## Run all tests
27
+ pytest
28
+
29
+ train: ## Train the model
30
+ python src/train.py
README.md ADDED
@@ -0,0 +1,1297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # Lightning-Hydra-Template
4
+
5
+ [![python](https://img.shields.io/badge/-Python_3.8_%7C_3.9_%7C_3.10-blue?logo=python&logoColor=white)](https://github.com/pre-commit/pre-commit)
6
+ [![pytorch](https://img.shields.io/badge/PyTorch_2.0+-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/)
7
+ [![lightning](https://img.shields.io/badge/-Lightning_2.0+-792ee5?logo=pytorchlightning&logoColor=white)](https://pytorchlightning.ai/)
8
+ [![hydra](https://img.shields.io/badge/Config-Hydra_1.3-89b8cd)](https://hydra.cc/)
9
+ [![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/)
10
+ [![isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) <br>
11
+ [![tests](https://github.com/ashleve/lightning-hydra-template/actions/workflows/test.yml/badge.svg)](https://github.com/ashleve/lightning-hydra-template/actions/workflows/test.yml)
12
+ [![code-quality](https://github.com/ashleve/lightning-hydra-template/actions/workflows/code-quality-main.yaml/badge.svg)](https://github.com/ashleve/lightning-hydra-template/actions/workflows/code-quality-main.yaml)
13
+ [![codecov](https://codecov.io/gh/ashleve/lightning-hydra-template/branch/main/graph/badge.svg)](https://codecov.io/gh/ashleve/lightning-hydra-template) <br>
14
+ [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/ashleve/lightning-hydra-template#license)
15
+ [![PRs](https://img.shields.io/badge/PRs-welcome-brightgreen.svg)](https://github.com/ashleve/lightning-hydra-template/pulls)
16
+ [![contributors](https://img.shields.io/github/contributors/ashleve/lightning-hydra-template.svg)](https://github.com/ashleve/lightning-hydra-template/graphs/contributors)
17
+
18
+ A clean template to kickstart your deep learning project 🚀⚡🔥<br>
19
+ Click on [<kbd>Use this template</kbd>](https://github.com/ashleve/lightning-hydra-template/generate) to initialize new repository.
20
+
21
+ _Suggestions are always welcome!_
22
+
23
+ </div>
24
+
25
+ <br>
26
+
27
+ ## 📌  Introduction
28
+
29
+ **Why you might want to use it:**
30
+
31
+ ✅ Save on boilerplate <br>
32
+ Easily add new models, datasets, tasks, experiments, and train on different accelerators, like multi-GPU, TPU or SLURM clusters.
33
+
34
+ ✅ Education <br>
35
+ Thoroughly commented. You can use this repo as a learning resource.
36
+
37
+ ✅ Reusability <br>
38
+ Collection of useful MLOps tools, configs, and code snippets. You can use this repo as a reference for various utilities.
39
+
40
+ **Why you might not want to use it:**
41
+
42
+ ❌ Things break from time to time <br>
43
+ Lightning and Hydra are still evolving and integrate many libraries, which means sometimes things break. For the list of currently known problems visit [this page](https://github.com/ashleve/lightning-hydra-template/labels/bug).
44
+
45
+ ❌ Not adjusted for data engineering <br>
46
+ Template is not really adjusted for building data pipelines that depend on each other. It's more efficient to use it for model prototyping on ready-to-use data.
47
+
48
+ ❌ Overfitted to simple use case <br>
49
+ The configuration setup is built with simple lightning training in mind. You might need to put some effort to adjust it for different use cases, e.g. lightning fabric.
50
+
51
+ ❌ Might not support your workflow <br>
52
+ For example, you can't resume hydra-based multirun or hyperparameter search.
53
+
54
+ > **Note**: _Keep in mind this is unofficial community project._
55
+
56
+ <br>
57
+
58
+ ## Main Technologies
59
+
60
+ [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) - a lightweight PyTorch wrapper for high-performance AI research. Think of it as a framework for organizing your PyTorch code.
61
+
62
+ [Hydra](https://github.com/facebookresearch/hydra) - a framework for elegantly configuring complex applications. The key feature is the ability to dynamically create a hierarchical configuration by composition and override it through config files and the command line.
63
+
64
+ <br>
65
+
66
+ ## Main Ideas
67
+
68
+ - [**Rapid Experimentation**](#your-superpowers): thanks to hydra command line superpowers
69
+ - [**Minimal Boilerplate**](#how-it-works): thanks to automating pipelines with config instantiation
70
+ - [**Main Configs**](#main-config): allow you to specify default training configuration
71
+ - [**Experiment Configs**](#experiment-config): allow you to override chosen hyperparameters and version control experiments
72
+ - [**Workflow**](#workflow): comes down to 4 simple steps
73
+ - [**Experiment Tracking**](#experiment-tracking): Tensorboard, W&B, Neptune, Comet, MLFlow and CSVLogger
74
+ - [**Logs**](#logs): all logs (checkpoints, configs, etc.) are stored in a dynamically generated folder structure
75
+ - [**Hyperparameter Search**](#hyperparameter-search): simple search is effortless with Hydra plugins like Optuna Sweeper
76
+ - [**Tests**](#tests): generic, easy-to-adapt smoke tests for speeding up the development
77
+ - [**Continuous Integration**](#continuous-integration): automatically test and lint your repo with Github Actions
78
+ - [**Best Practices**](#best-practices): a couple of recommended tools, practices and standards
79
+
80
+ <br>
81
+
82
+ ## Project Structure
83
+
84
+ The directory structure of new project looks like this:
85
+
86
+ ```
87
+ ├── .github <- Github Actions workflows
88
+
89
+ ├── configs <- Hydra configs
90
+ │ ├── callbacks <- Callbacks configs
91
+ │ ├── data <- Data configs
92
+ │ ├── debug <- Debugging configs
93
+ │ ├── experiment <- Experiment configs
94
+ │ ├── extras <- Extra utilities configs
95
+ │ ├── hparams_search <- Hyperparameter search configs
96
+ │ ├── hydra <- Hydra configs
97
+ │ ├── local <- Local configs
98
+ │ ├── logger <- Logger configs
99
+ │ ├── model <- Model configs
100
+ │ ├── paths <- Project paths configs
101
+ │ ├── trainer <- Trainer configs
102
+ │ │
103
+ │ ├── eval.yaml <- Main config for evaluation
104
+ │ └── train.yaml <- Main config for training
105
+
106
+ ├── data <- Project data
107
+
108
+ ├── logs <- Logs generated by hydra and lightning loggers
109
+
110
+ ├── notebooks <- Jupyter notebooks. Naming convention is a number (for ordering),
111
+ │ the creator's initials, and a short `-` delimited description,
112
+ │ e.g. `1.0-jqp-initial-data-exploration.ipynb`.
113
+
114
+ ├── scripts <- Shell scripts
115
+
116
+ ├── src <- Source code
117
+ │ ├── data <- Data scripts
118
+ │ ├── models <- Model scripts
119
+ │ ├── utils <- Utility scripts
120
+ │ │
121
+ │ ├── eval.py <- Run evaluation
122
+ │ └── train.py <- Run training
123
+
124
+ ├── tests <- Tests of any kind
125
+
126
+ ├── .env.example <- Example of file for storing private environment variables
127
+ ├── .gitignore <- List of files ignored by git
128
+ ├── .pre-commit-config.yaml <- Configuration of pre-commit hooks for code formatting
129
+ ├── .project-root <- File for inferring the position of project root directory
130
+ ├── environment.yaml <- File for installing conda environment
131
+ ├── Makefile <- Makefile with commands like `make train` or `make test`
132
+ ├── pyproject.toml <- Configuration options for testing and linting
133
+ ├── requirements.txt <- File for installing python dependencies
134
+ ├── setup.py <- File for installing project as a package
135
+ └── README.md
136
+ ```
137
+
138
+ <br>
139
+
140
+ ## 🚀  Quickstart
141
+
142
+ ```bash
143
+ # clone project
144
+ git clone https://github.com/ashleve/lightning-hydra-template
145
+ cd lightning-hydra-template
146
+
147
+ # [OPTIONAL] create conda environment
148
+ conda create -n myenv python=3.9
149
+ conda activate myenv
150
+
151
+ # install pytorch according to instructions
152
+ # https://pytorch.org/get-started/
153
+
154
+ # install requirements
155
+ pip install -r requirements.txt
156
+ ```
157
+
158
+ Template contains example with MNIST classification.<br>
159
+ When running `python src/train.py` you should see something like this:
160
+
161
+ <div align="center">
162
+
163
+ ![](https://github.com/ashleve/lightning-hydra-template/blob/resources/terminal.png)
164
+
165
+ </div>
166
+
167
+ ## ⚡  Your Superpowers
168
+
169
+ <details>
170
+ <summary><b>Override any config parameter from command line</b></summary>
171
+
172
+ ```bash
173
+ python train.py trainer.max_epochs=20 model.optimizer.lr=1e-4
174
+ ```
175
+
176
+ > **Note**: You can also add new parameters with `+` sign.
177
+
178
+ ```bash
179
+ python train.py +model.new_param="owo"
180
+ ```
181
+
182
+ </details>
183
+
184
+ <details>
185
+ <summary><b>Train on CPU, GPU, multi-GPU and TPU</b></summary>
186
+
187
+ ```bash
188
+ # train on CPU
189
+ python train.py trainer=cpu
190
+
191
+ # train on 1 GPU
192
+ python train.py trainer=gpu
193
+
194
+ # train on TPU
195
+ python train.py +trainer.tpu_cores=8
196
+
197
+ # train with DDP (Distributed Data Parallel) (4 GPUs)
198
+ python train.py trainer=ddp trainer.devices=4
199
+
200
+ # train with DDP (Distributed Data Parallel) (8 GPUs, 2 nodes)
201
+ python train.py trainer=ddp trainer.devices=4 trainer.num_nodes=2
202
+
203
+ # simulate DDP on CPU processes
204
+ python train.py trainer=ddp_sim trainer.devices=2
205
+
206
+ # accelerate training on mac
207
+ python train.py trainer=mps
208
+ ```
209
+
210
+ > **Warning**: Currently there are problems with DDP mode, read [this issue](https://github.com/ashleve/lightning-hydra-template/issues/393) to learn more.
211
+
212
+ </details>
213
+
214
+ <details>
215
+ <summary><b>Train with mixed precision</b></summary>
216
+
217
+ ```bash
218
+ # train with pytorch native automatic mixed precision (AMP)
219
+ python train.py trainer=gpu +trainer.precision=16
220
+ ```
221
+
222
+ </details>
223
+
224
+ <!-- deepspeed support still in beta
225
+ <details>
226
+ <summary><b>Optimize large scale models on multiple GPUs with Deepspeed</b></summary>
227
+
228
+ ```bash
229
+ python train.py +trainer.
230
+ ```
231
+
232
+ </details>
233
+ -->
234
+
235
+ <details>
236
+ <summary><b>Train model with any logger available in PyTorch Lightning, like W&B or Tensorboard</b></summary>
237
+
238
+ ```yaml
239
+ # set project and entity names in `configs/logger/wandb`
240
+ wandb:
241
+ project: "your_project_name"
242
+ entity: "your_wandb_team_name"
243
+ ```
244
+
245
+ ```bash
246
+ # train model with Weights&Biases (link to wandb dashboard should appear in the terminal)
247
+ python train.py logger=wandb
248
+ ```
249
+
250
+ > **Note**: Lightning provides convenient integrations with most popular logging frameworks. Learn more [here](#experiment-tracking).
251
+
252
+ > **Note**: Using wandb requires you to [setup account](https://www.wandb.com/) first. After that just complete the config as below.
253
+
254
+ > **Note**: Click [here](https://wandb.ai/hobglob/template-dashboard/) to see example wandb dashboard generated with this template.
255
+
256
+ </details>
257
+
258
+ <details>
259
+ <summary><b>Train model with chosen experiment config</b></summary>
260
+
261
+ ```bash
262
+ python train.py experiment=example
263
+ ```
264
+
265
+ > **Note**: Experiment configs are placed in [configs/experiment/](configs/experiment/).
266
+
267
+ </details>
268
+
269
+ <details>
270
+ <summary><b>Attach some callbacks to run</b></summary>
271
+
272
+ ```bash
273
+ python train.py callbacks=default
274
+ ```
275
+
276
+ > **Note**: Callbacks can be used for things such as as model checkpointing, early stopping and [many more](https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html#built-in-callbacks).
277
+
278
+ > **Note**: Callbacks configs are placed in [configs/callbacks/](configs/callbacks/).
279
+
280
+ </details>
281
+
282
+ <details>
283
+ <summary><b>Use different tricks available in Pytorch Lightning</b></summary>
284
+
285
+ ```yaml
286
+ # gradient clipping may be enabled to avoid exploding gradients
287
+ python train.py +trainer.gradient_clip_val=0.5
288
+
289
+ # run validation loop 4 times during a training epoch
290
+ python train.py +trainer.val_check_interval=0.25
291
+
292
+ # accumulate gradients
293
+ python train.py +trainer.accumulate_grad_batches=10
294
+
295
+ # terminate training after 12 hours
296
+ python train.py +trainer.max_time="00:12:00:00"
297
+ ```
298
+
299
+ > **Note**: PyTorch Lightning provides about [40+ useful trainer flags](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags).
300
+
301
+ </details>
302
+
303
+ <details>
304
+ <summary><b>Easily debug</b></summary>
305
+
306
+ ```bash
307
+ # runs 1 epoch in default debugging mode
308
+ # changes logging directory to `logs/debugs/...`
309
+ # sets level of all command line loggers to 'DEBUG'
310
+ # enforces debug-friendly configuration
311
+ python train.py debug=default
312
+
313
+ # run 1 train, val and test loop, using only 1 batch
314
+ python train.py debug=fdr
315
+
316
+ # print execution time profiling
317
+ python train.py debug=profiler
318
+
319
+ # try overfitting to 1 batch
320
+ python train.py debug=overfit
321
+
322
+ # raise exception if there are any numerical anomalies in tensors, like NaN or +/-inf
323
+ python train.py +trainer.detect_anomaly=true
324
+
325
+ # use only 20% of the data
326
+ python train.py +trainer.limit_train_batches=0.2 \
327
+ +trainer.limit_val_batches=0.2 +trainer.limit_test_batches=0.2
328
+ ```
329
+
330
+ > **Note**: Visit [configs/debug/](configs/debug/) for different debugging configs.
331
+
332
+ </details>
333
+
334
+ <details>
335
+ <summary><b>Resume training from checkpoint</b></summary>
336
+
337
+ ```yaml
338
+ python train.py ckpt_path="/path/to/ckpt/name.ckpt"
339
+ ```
340
+
341
+ > **Note**: Checkpoint can be either path or URL.
342
+
343
+ > **Note**: Currently loading ckpt doesn't resume logger experiment, but it will be supported in future Lightning release.
344
+
345
+ </details>
346
+
347
+ <details>
348
+ <summary><b>Evaluate checkpoint on test dataset</b></summary>
349
+
350
+ ```yaml
351
+ python eval.py ckpt_path="/path/to/ckpt/name.ckpt"
352
+ ```
353
+
354
+ > **Note**: Checkpoint can be either path or URL.
355
+
356
+ </details>
357
+
358
+ <details>
359
+ <summary><b>Create a sweep over hyperparameters</b></summary>
360
+
361
+ ```bash
362
+ # this will run 6 experiments one after the other,
363
+ # each with different combination of batch_size and learning rate
364
+ python train.py -m data.batch_size=32,64,128 model.lr=0.001,0.0005
365
+ ```
366
+
367
+ > **Note**: Hydra composes configs lazily at job launch time. If you change code or configs after launching a job/sweep, the final composed configs might be impacted.
368
+
369
+ </details>
370
+
371
+ <details>
372
+ <summary><b>Create a sweep over hyperparameters with Optuna</b></summary>
373
+
374
+ ```bash
375
+ # this will run hyperparameter search defined in `configs/hparams_search/mnist_optuna.yaml`
376
+ # over chosen experiment config
377
+ python train.py -m hparams_search=mnist_optuna experiment=example
378
+ ```
379
+
380
+ > **Note**: Using [Optuna Sweeper](https://hydra.cc/docs/next/plugins/optuna_sweeper) doesn't require you to add any boilerplate to your code, everything is defined in a [single config file](configs/hparams_search/mnist_optuna.yaml).
381
+
382
+ > **Warning**: Optuna sweeps are not failure-resistant (if one job crashes then the whole sweep crashes).
383
+
384
+ </details>
385
+
386
+ <details>
387
+ <summary><b>Execute all experiments from folder</b></summary>
388
+
389
+ ```bash
390
+ python train.py -m 'experiment=glob(*)'
391
+ ```
392
+
393
+ > **Note**: Hydra provides special syntax for controlling behavior of multiruns. Learn more [here](https://hydra.cc/docs/next/tutorials/basic/running_your_app/multi-run). The command above executes all experiments from [configs/experiment/](configs/experiment/).
394
+
395
+ </details>
396
+
397
+ <details>
398
+ <summary><b>Execute run for multiple different seeds</b></summary>
399
+
400
+ ```bash
401
+ python train.py -m seed=1,2,3,4,5 trainer.deterministic=True logger=csv tags=["benchmark"]
402
+ ```
403
+
404
+ > **Note**: `trainer.deterministic=True` makes pytorch more deterministic but impacts the performance.
405
+
406
+ </details>
407
+
408
+ <details>
409
+ <summary><b>Execute sweep on a remote AWS cluster</b></summary>
410
+
411
+ > **Note**: This should be achievable with simple config using [Ray AWS launcher for Hydra](https://hydra.cc/docs/next/plugins/ray_launcher). Example is not implemented in this template.
412
+
413
+ </details>
414
+
415
+ <!-- <details>
416
+ <summary><b>Execute sweep on a SLURM cluster</b></summary>
417
+
418
+ > This should be achievable with either [the right lightning trainer flags](https://pytorch-lightning.readthedocs.io/en/latest/clouds/cluster.html?highlight=SLURM#slurm-managed-cluster) or simple config using [Submitit launcher for Hydra](https://hydra.cc/docs/plugins/submitit_launcher). Example is not yet implemented in this template.
419
+
420
+ </details> -->
421
+
422
+ <details>
423
+ <summary><b>Use Hydra tab completion</b></summary>
424
+
425
+ > **Note**: Hydra allows you to autocomplete config argument overrides in shell as you write them, by pressing `tab` key. Read the [docs](https://hydra.cc/docs/tutorials/basic/running_your_app/tab_completion).
426
+
427
+ </details>
428
+
429
+ <details>
430
+ <summary><b>Apply pre-commit hooks</b></summary>
431
+
432
+ ```bash
433
+ pre-commit run -a
434
+ ```
435
+
436
+ > **Note**: Apply pre-commit hooks to do things like auto-formatting code and configs, performing code analysis or removing output from jupyter notebooks. See [# Best Practices](#best-practices) for more.
437
+
438
+ Update pre-commit hook versions in `.pre-commit-config.yaml` with:
439
+
440
+ ```bash
441
+ pre-commit autoupdate
442
+ ```
443
+
444
+ </details>
445
+
446
+ <details>
447
+ <summary><b>Run tests</b></summary>
448
+
449
+ ```bash
450
+ # run all tests
451
+ pytest
452
+
453
+ # run tests from specific file
454
+ pytest tests/test_train.py
455
+
456
+ # run all tests except the ones marked as slow
457
+ pytest -k "not slow"
458
+ ```
459
+
460
+ </details>
461
+
462
+ <details>
463
+ <summary><b>Use tags</b></summary>
464
+
465
+ Each experiment should be tagged in order to easily filter them across files or in logger UI:
466
+
467
+ ```bash
468
+ python train.py tags=["mnist","experiment_X"]
469
+ ```
470
+
471
+ > **Note**: You might need to escape the bracket characters in your shell with `python train.py tags=\["mnist","experiment_X"\]`.
472
+
473
+ If no tags are provided, you will be asked to input them from command line:
474
+
475
+ ```bash
476
+ >>> python train.py tags=[]
477
+ [2022-07-11 15:40:09,358][src.utils.utils][INFO] - Enforcing tags! <cfg.extras.enforce_tags=True>
478
+ [2022-07-11 15:40:09,359][src.utils.rich_utils][WARNING] - No tags provided in config. Prompting user to input tags...
479
+ Enter a list of comma separated tags (dev):
480
+ ```
481
+
482
+ If no tags are provided for multirun, an error will be raised:
483
+
484
+ ```bash
485
+ >>> python train.py -m +x=1,2,3 tags=[]
486
+ ValueError: Specify tags before launching a multirun!
487
+ ```
488
+
489
+ > **Note**: Appending lists from command line is currently not supported in hydra :(
490
+
491
+ </details>
492
+
493
+ <br>
494
+
495
+ ## ❤️  Contributions
496
+
497
+ This project exists thanks to all the people who contribute.
498
+
499
+ ![Contributors](https://readme-contributors.now.sh/ashleve/lightning-hydra-template?extension=jpg&width=400&aspectRatio=1)
500
+
501
+ Have a question? Found a bug? Missing a specific feature? Feel free to file a new issue, discussion or PR with respective title and description.
502
+
503
+ Before making an issue, please verify that:
504
+
505
+ - The problem still exists on the current `main` branch.
506
+ - Your python dependencies are updated to recent versions.
507
+
508
+ Suggestions for improvements are always welcome!
509
+
510
+ <br>
511
+
512
+ ## How It Works
513
+
514
+ All PyTorch Lightning modules are dynamically instantiated from module paths specified in config. Example model config:
515
+
516
+ ```yaml
517
+ _target_: src.models.mnist_model.MNISTLitModule
518
+ lr: 0.001
519
+ net:
520
+ _target_: src.models.components.simple_dense_net.SimpleDenseNet
521
+ input_size: 784
522
+ lin1_size: 256
523
+ lin2_size: 256
524
+ lin3_size: 256
525
+ output_size: 10
526
+ ```
527
+
528
+ Using this config we can instantiate the object with the following line:
529
+
530
+ ```python
531
+ model = hydra.utils.instantiate(config.model)
532
+ ```
533
+
534
+ This allows you to easily iterate over new models! Every time you create a new one, just specify its module path and parameters in appropriate config file. <br>
535
+
536
+ Switch between models and datamodules with command line arguments:
537
+
538
+ ```bash
539
+ python train.py model=mnist
540
+ ```
541
+
542
+ Example pipeline managing the instantiation logic: [src/train.py](src/train.py).
543
+
544
+ <br>
545
+
546
+ ## Main Config
547
+
548
+ Location: [configs/train.yaml](configs/train.yaml) <br>
549
+ Main project config contains default training configuration.<br>
550
+ It determines how config is composed when simply executing command `python train.py`.<br>
551
+
552
+ <details>
553
+ <summary><b>Show main project config</b></summary>
554
+
555
+ ```yaml
556
+ # order of defaults determines the order in which configs override each other
557
+ defaults:
558
+ - _self_
559
+ - data: mnist.yaml
560
+ - model: mnist.yaml
561
+ - callbacks: default.yaml
562
+ - logger: null # set logger here or use command line (e.g. `python train.py logger=csv`)
563
+ - trainer: default.yaml
564
+ - paths: default.yaml
565
+ - extras: default.yaml
566
+ - hydra: default.yaml
567
+
568
+ # experiment configs allow for version control of specific hyperparameters
569
+ # e.g. best hyperparameters for given model and datamodule
570
+ - experiment: null
571
+
572
+ # config for hyperparameter optimization
573
+ - hparams_search: null
574
+
575
+ # optional local config for machine/user specific settings
576
+ # it's optional since it doesn't need to exist and is excluded from version control
577
+ - optional local: default.yaml
578
+
579
+ # debugging config (enable through command line, e.g. `python train.py debug=default)
580
+ - debug: null
581
+
582
+ # task name, determines output directory path
583
+ task_name: "train"
584
+
585
+ # tags to help you identify your experiments
586
+ # you can overwrite this in experiment configs
587
+ # overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
588
+ # appending lists from command line is currently not supported :(
589
+ # https://github.com/facebookresearch/hydra/issues/1547
590
+ tags: ["dev"]
591
+
592
+ # set False to skip model training
593
+ train: True
594
+
595
+ # evaluate on test set, using best model weights achieved during training
596
+ # lightning chooses best weights based on the metric specified in checkpoint callback
597
+ test: True
598
+
599
+ # simply provide checkpoint path to resume training
600
+ ckpt_path: null
601
+
602
+ # seed for random number generators in pytorch, numpy and python.random
603
+ seed: null
604
+ ```
605
+
606
+ </details>
607
+
608
+ <br>
609
+
610
+ ## Experiment Config
611
+
612
+ Location: [configs/experiment](configs/experiment)<br>
613
+ Experiment configs allow you to overwrite parameters from main config.<br>
614
+ For example, you can use them to version control best hyperparameters for each combination of model and dataset.
615
+
616
+ <details>
617
+ <summary><b>Show example experiment config</b></summary>
618
+
619
+ ```yaml
620
+ # @package _global_
621
+
622
+ # to execute this experiment run:
623
+ # python train.py experiment=example
624
+
625
+ defaults:
626
+ - override /data: mnist.yaml
627
+ - override /model: mnist.yaml
628
+ - override /callbacks: default.yaml
629
+ - override /trainer: default.yaml
630
+
631
+ # all parameters below will be merged with parameters from default configurations set above
632
+ # this allows you to overwrite only specified parameters
633
+
634
+ tags: ["mnist", "simple_dense_net"]
635
+
636
+ seed: 12345
637
+
638
+ trainer:
639
+ min_epochs: 10
640
+ max_epochs: 10
641
+ gradient_clip_val: 0.5
642
+
643
+ model:
644
+ optimizer:
645
+ lr: 0.002
646
+ net:
647
+ lin1_size: 128
648
+ lin2_size: 256
649
+ lin3_size: 64
650
+
651
+ data:
652
+ batch_size: 64
653
+
654
+ logger:
655
+ wandb:
656
+ tags: ${tags}
657
+ group: "mnist"
658
+ ```
659
+
660
+ </details>
661
+
662
+ <br>
663
+
664
+ ## Workflow
665
+
666
+ **Basic workflow**
667
+
668
+ 1. Write your PyTorch Lightning module (see [models/mnist_module.py](src/models/mnist_module.py) for example)
669
+ 2. Write your PyTorch Lightning datamodule (see [data/mnist_datamodule.py](src/data/mnist_datamodule.py) for example)
670
+ 3. Write your experiment config, containing paths to model and datamodule
671
+ 4. Run training with chosen experiment config:
672
+ ```bash
673
+ python src/train.py experiment=experiment_name.yaml
674
+ ```
675
+
676
+ **Experiment design**
677
+
678
+ _Say you want to execute many runs to plot how accuracy changes in respect to batch size._
679
+
680
+ 1. Execute the runs with some config parameter that allows you to identify them easily, like tags:
681
+
682
+ ```bash
683
+ python train.py -m logger=csv data.batch_size=16,32,64,128 tags=["batch_size_exp"]
684
+ ```
685
+
686
+ 2. Write a script or notebook that searches over the `logs/` folder and retrieves csv logs from runs containing given tags in config. Plot the results.
687
+
688
+ <br>
689
+
690
+ ## Logs
691
+
692
+ Hydra creates new output directory for every executed run.
693
+
694
+ Default logging structure:
695
+
696
+ ```
697
+ ├── logs
698
+ │ ├── task_name
699
+ │ │ ├── runs # Logs generated by single runs
700
+ │ │ │ ├── YYYY-MM-DD_HH-MM-SS # Datetime of the run
701
+ │ │ │ │ ├── .hydra # Hydra logs
702
+ │ │ │ │ ├── csv # Csv logs
703
+ │ │ │ │ ├── wandb # Weights&Biases logs
704
+ │ │ │ │ ├── checkpoints # Training checkpoints
705
+ │ │ │ │ └── ... # Any other thing saved during training
706
+ │ │ │ └── ...
707
+ │ │ │
708
+ │ │ └── multiruns # Logs generated by multiruns
709
+ │ │ ├── YYYY-MM-DD_HH-MM-SS # Datetime of the multirun
710
+ │ │ │ ├──1 # Multirun job number
711
+ │ │ │ ├──2
712
+ │ │ │ └── ...
713
+ │ │ └── ...
714
+ │ │
715
+ │ └── debugs # Logs generated when debugging config is attached
716
+ │ └── ...
717
+ ```
718
+
719
+ </details>
720
+
721
+ You can change this structure by modifying paths in [hydra configuration](configs/hydra).
722
+
723
+ <br>
724
+
725
+ ## Experiment Tracking
726
+
727
+ PyTorch Lightning supports many popular logging frameworks: [Weights&Biases](https://www.wandb.com/), [Neptune](https://neptune.ai/), [Comet](https://www.comet.ml/), [MLFlow](https://mlflow.org), [Tensorboard](https://www.tensorflow.org/tensorboard/).
728
+
729
+ These tools help you keep track of hyperparameters and output metrics and allow you to compare and visualize results. To use one of them simply complete its configuration in [configs/logger](configs/logger) and run:
730
+
731
+ ```bash
732
+ python train.py logger=logger_name
733
+ ```
734
+
735
+ You can use many of them at once (see [configs/logger/many_loggers.yaml](configs/logger/many_loggers.yaml) for example).
736
+
737
+ You can also write your own logger.
738
+
739
+ Lightning provides convenient method for logging custom metrics from inside LightningModule. Read the [docs](https://pytorch-lightning.readthedocs.io/en/latest/extensions/logging.html#automatic-logging) or take a look at [MNIST example](src/models/mnist_module.py).
740
+
741
+ <br>
742
+
743
+ ## Tests
744
+
745
+ Template comes with generic tests implemented with `pytest`.
746
+
747
+ ```bash
748
+ # run all tests
749
+ pytest
750
+
751
+ # run tests from specific file
752
+ pytest tests/test_train.py
753
+
754
+ # run all tests except the ones marked as slow
755
+ pytest -k "not slow"
756
+ ```
757
+
758
+ Most of the implemented tests don't check for any specific output - they exist to simply verify that executing some commands doesn't end up in throwing exceptions. You can execute them once in a while to speed up the development.
759
+
760
+ Currently, the tests cover cases like:
761
+
762
+ - running 1 train, val and test step
763
+ - running 1 epoch on 1% of data, saving ckpt and resuming for the second epoch
764
+ - running 2 epochs on 1% of data, with DDP simulated on CPU
765
+
766
+ And many others. You should be able to modify them easily for your use case.
767
+
768
+ There is also `@RunIf` decorator implemented, that allows you to run tests only if certain conditions are met, e.g. GPU is available or system is not windows. See the [examples](tests/test_train.py).
769
+
770
+ <br>
771
+
772
+ ## Hyperparameter Search
773
+
774
+ You can define hyperparameter search by adding new config file to [configs/hparams_search](configs/hparams_search).
775
+
776
+ <details>
777
+ <summary><b>Show example hyperparameter search config</b></summary>
778
+
779
+ ```yaml
780
+ # @package _global_
781
+
782
+ defaults:
783
+ - override /hydra/sweeper: optuna
784
+
785
+ # choose metric which will be optimized by Optuna
786
+ # make sure this is the correct name of some metric logged in lightning module!
787
+ optimized_metric: "val/acc_best"
788
+
789
+ # here we define Optuna hyperparameter search
790
+ # it optimizes for value returned from function with @hydra.main decorator
791
+ hydra:
792
+ sweeper:
793
+ _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
794
+
795
+ # 'minimize' or 'maximize' the objective
796
+ direction: maximize
797
+
798
+ # total number of runs that will be executed
799
+ n_trials: 20
800
+
801
+ # choose Optuna hyperparameter sampler
802
+ # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
803
+ sampler:
804
+ _target_: optuna.samplers.TPESampler
805
+ seed: 1234
806
+ n_startup_trials: 10 # number of random sampling runs before optimization starts
807
+
808
+ # define hyperparameter search space
809
+ params:
810
+ model.optimizer.lr: interval(0.0001, 0.1)
811
+ data.batch_size: choice(32, 64, 128, 256)
812
+ model.net.lin1_size: choice(64, 128, 256)
813
+ model.net.lin2_size: choice(64, 128, 256)
814
+ model.net.lin3_size: choice(32, 64, 128, 256)
815
+ ```
816
+
817
+ </details>
818
+
819
+ Next, execute it with: `python train.py -m hparams_search=mnist_optuna`
820
+
821
+ Using this approach doesn't require adding any boilerplate to code, everything is defined in a single config file. The only necessary thing is to return the optimized metric value from the launch file.
822
+
823
+ You can use different optimization frameworks integrated with Hydra, like [Optuna, Ax or Nevergrad](https://hydra.cc/docs/plugins/optuna_sweeper/).
824
+
825
+ The `optimization_results.yaml` will be available under `logs/task_name/multirun` folder.
826
+
827
+ This approach doesn't support resuming interrupted search and advanced techniques like prunning - for more sophisticated search and workflows, you should probably write a dedicated optimization task (without multirun feature).
828
+
829
+ <br>
830
+
831
+ ## Continuous Integration
832
+
833
+ Template comes with CI workflows implemented in Github Actions:
834
+
835
+ - `.github/workflows/test.yaml`: running all tests with pytest
836
+ - `.github/workflows/code-quality-main.yaml`: running pre-commits on main branch for all files
837
+ - `.github/workflows/code-quality-pr.yaml`: running pre-commits on pull requests for modified files only
838
+
839
+ <br>
840
+
841
+ ## Distributed Training
842
+
843
+ Lightning supports multiple ways of doing distributed training. The most common one is DDP, which spawns separate process for each GPU and averages gradients between them. To learn about other approaches read the [lightning docs](https://lightning.ai/docs/pytorch/latest/advanced/speed.html).
844
+
845
+ You can run DDP on mnist example with 4 GPUs like this:
846
+
847
+ ```bash
848
+ python train.py trainer=ddp
849
+ ```
850
+
851
+ > **Note**: When using DDP you have to be careful how you write your models - read the [docs](https://lightning.ai/docs/pytorch/latest/advanced/speed.html).
852
+
853
+ <br>
854
+
855
+ ## Accessing Datamodule Attributes In Model
856
+
857
+ The simplest way is to pass datamodule attribute directly to model on initialization:
858
+
859
+ ```python
860
+ # ./src/train.py
861
+ datamodule = hydra.utils.instantiate(config.data)
862
+ model = hydra.utils.instantiate(config.model, some_param=datamodule.some_param)
863
+ ```
864
+
865
+ > **Note**: Not a very robust solution, since it assumes all your datamodules have `some_param` attribute available.
866
+
867
+ Similarly, you can pass a whole datamodule config as an init parameter:
868
+
869
+ ```python
870
+ # ./src/train.py
871
+ model = hydra.utils.instantiate(config.model, dm_conf=config.data, _recursive_=False)
872
+ ```
873
+
874
+ You can also pass a datamodule config parameter to your model through variable interpolation:
875
+
876
+ ```yaml
877
+ # ./configs/model/my_model.yaml
878
+ _target_: src.models.my_module.MyLitModule
879
+ lr: 0.01
880
+ some_param: ${data.some_param}
881
+ ```
882
+
883
+ Another approach is to access datamodule in LightningModule directly through Trainer:
884
+
885
+ ```python
886
+ # ./src/models/mnist_module.py
887
+ def on_train_start(self):
888
+ self.some_param = self.trainer.datamodule.some_param
889
+ ```
890
+
891
+ > **Note**: This only works after the training starts since otherwise trainer won't be yet available in LightningModule.
892
+
893
+ <br>
894
+
895
+ ## Best Practices
896
+
897
+ <details>
898
+ <summary><b>Use Miniconda</b></summary>
899
+
900
+ It's usually unnecessary to install full anaconda environment, miniconda should be enough (weights around 80MB).
901
+
902
+ Big advantage of conda is that it allows for installing packages without requiring certain compilers or libraries to be available in the system (since it installs precompiled binaries), so it often makes it easier to install some dependencies e.g. cudatoolkit for GPU support.
903
+
904
+ It also allows you to access your environments globally which might be more convenient than creating new local environment for every project.
905
+
906
+ Example installation:
907
+
908
+ ```bash
909
+ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
910
+ bash Miniconda3-latest-Linux-x86_64.sh
911
+ ```
912
+
913
+ Update conda:
914
+
915
+ ```bash
916
+ conda update -n base -c defaults conda
917
+ ```
918
+
919
+ Create new conda environment:
920
+
921
+ ```bash
922
+ conda create -n myenv python=3.10
923
+ conda activate myenv
924
+ ```
925
+
926
+ </details>
927
+
928
+ <details>
929
+ <summary><b>Use automatic code formatting</b></summary>
930
+
931
+ Use pre-commit hooks to standardize code formatting of your project and save mental energy.<br>
932
+ Simply install pre-commit package with:
933
+
934
+ ```bash
935
+ pip install pre-commit
936
+ ```
937
+
938
+ Next, install hooks from [.pre-commit-config.yaml](.pre-commit-config.yaml):
939
+
940
+ ```bash
941
+ pre-commit install
942
+ ```
943
+
944
+ After that your code will be automatically reformatted on every new commit.
945
+
946
+ To reformat all files in the project use command:
947
+
948
+ ```bash
949
+ pre-commit run -a
950
+ ```
951
+
952
+ To update hook versions in [.pre-commit-config.yaml](.pre-commit-config.yaml) use:
953
+
954
+ ```bash
955
+ pre-commit autoupdate
956
+ ```
957
+
958
+ </details>
959
+
960
+ <details>
961
+ <summary><b>Set private environment variables in .env file</b></summary>
962
+
963
+ System specific variables (e.g. absolute paths to datasets) should not be under version control or it will result in conflict between different users. Your private keys also shouldn't be versioned since you don't want them to be leaked.<br>
964
+
965
+ Template contains `.env.example` file, which serves as an example. Create a new file called `.env` (this name is excluded from version control in .gitignore).
966
+ You should use it for storing environment variables like this:
967
+
968
+ ```
969
+ MY_VAR=/home/user/my_system_path
970
+ ```
971
+
972
+ All variables from `.env` are loaded in `train.py` automatically.
973
+
974
+ Hydra allows you to reference any env variable in `.yaml` configs like this:
975
+
976
+ ```yaml
977
+ path_to_data: ${oc.env:MY_VAR}
978
+ ```
979
+
980
+ </details>
981
+
982
+ <details>
983
+ <summary><b>Name metrics using '/' character</b></summary>
984
+
985
+ Depending on which logger you're using, it's often useful to define metric name with `/` character:
986
+
987
+ ```python
988
+ self.log("train/loss", loss)
989
+ ```
990
+
991
+ This way loggers will treat your metrics as belonging to different sections, which helps to get them organised in UI.
992
+
993
+ </details>
994
+
995
+ <details>
996
+ <summary><b>Use torchmetrics</b></summary>
997
+
998
+ Use official [torchmetrics](https://github.com/PytorchLightning/metrics) library to ensure proper calculation of metrics. This is especially important for multi-GPU training!
999
+
1000
+ For example, instead of calculating accuracy by yourself, you should use the provided `Accuracy` class like this:
1001
+
1002
+ ```python
1003
+ from torchmetrics.classification.accuracy import Accuracy
1004
+
1005
+
1006
+ class LitModel(LightningModule):
1007
+ def __init__(self)
1008
+ self.train_acc = Accuracy()
1009
+ self.val_acc = Accuracy()
1010
+
1011
+ def training_step(self, batch, batch_idx):
1012
+ ...
1013
+ acc = self.train_acc(predictions, targets)
1014
+ self.log("train/acc", acc)
1015
+ ...
1016
+
1017
+ def validation_step(self, batch, batch_idx):
1018
+ ...
1019
+ acc = self.val_acc(predictions, targets)
1020
+ self.log("val/acc", acc)
1021
+ ...
1022
+ ```
1023
+
1024
+ Make sure to use different metric instance for each step to ensure proper value reduction over all GPU processes.
1025
+
1026
+ Torchmetrics provides metrics for most use cases, like F1 score or confusion matrix. Read [documentation](https://torchmetrics.readthedocs.io/en/latest/#more-reading) for more.
1027
+
1028
+ </details>
1029
+
1030
+ <details>
1031
+ <summary><b>Follow PyTorch Lightning style guide</b></summary>
1032
+
1033
+ The style guide is available [here](https://pytorch-lightning.readthedocs.io/en/latest/starter/style_guide.html).<br>
1034
+
1035
+ 1. Be explicit in your init. Try to define all the relevant defaults so that the user doesn’t have to guess. Provide type hints. This way your module is reusable across projects!
1036
+
1037
+ ```python
1038
+ class LitModel(LightningModule):
1039
+ def __init__(self, layer_size: int = 256, lr: float = 0.001):
1040
+ ```
1041
+
1042
+ 2. Preserve the recommended method order.
1043
+
1044
+ ```python
1045
+ class LitModel(LightningModule):
1046
+
1047
+ def __init__():
1048
+ ...
1049
+
1050
+ def forward():
1051
+ ...
1052
+
1053
+ def training_step():
1054
+ ...
1055
+
1056
+ def training_step_end():
1057
+ ...
1058
+
1059
+ def on_train_epoch_end():
1060
+ ...
1061
+
1062
+ def validation_step():
1063
+ ...
1064
+
1065
+ def validation_step_end():
1066
+ ...
1067
+
1068
+ def on_validation_epoch_end():
1069
+ ...
1070
+
1071
+ def test_step():
1072
+ ...
1073
+
1074
+ def test_step_end():
1075
+ ...
1076
+
1077
+ def on_test_epoch_end():
1078
+ ...
1079
+
1080
+ def configure_optimizers():
1081
+ ...
1082
+
1083
+ def any_extra_hook():
1084
+ ...
1085
+ ```
1086
+
1087
+ </details>
1088
+
1089
+ <details>
1090
+ <summary><b>Version control your data and models with DVC</b></summary>
1091
+
1092
+ Use [DVC](https://dvc.org) to version control big files, like your data or trained ML models.<br>
1093
+ To initialize the dvc repository:
1094
+
1095
+ ```bash
1096
+ dvc init
1097
+ ```
1098
+
1099
+ To start tracking a file or directory, use `dvc add`:
1100
+
1101
+ ```bash
1102
+ dvc add data/MNIST
1103
+ ```
1104
+
1105
+ DVC stores information about the added file (or a directory) in a special .dvc file named data/MNIST.dvc, a small text file with a human-readable format. This file can be easily versioned like source code with Git, as a placeholder for the original data:
1106
+
1107
+ ```bash
1108
+ git add data/MNIST.dvc data/.gitignore
1109
+ git commit -m "Add raw data"
1110
+ ```
1111
+
1112
+ </details>
1113
+
1114
+ <details>
1115
+ <summary><b>Support installing project as a package</b></summary>
1116
+
1117
+ It allows other people to easily use your modules in their own projects.
1118
+ Change name of the `src` folder to your project name and complete the `setup.py` file.
1119
+
1120
+ Now your project can be installed from local files:
1121
+
1122
+ ```bash
1123
+ pip install -e .
1124
+ ```
1125
+
1126
+ Or directly from git repository:
1127
+
1128
+ ```bash
1129
+ pip install git+git://github.com/YourGithubName/your-repo-name.git --upgrade
1130
+ ```
1131
+
1132
+ So any file can be easily imported into any other file like so:
1133
+
1134
+ ```python
1135
+ from project_name.models.mnist_module import MNISTLitModule
1136
+ from project_name.data.mnist_datamodule import MNISTDataModule
1137
+ ```
1138
+
1139
+ </details>
1140
+
1141
+ <details>
1142
+ <summary><b>Keep local configs out of code versioning</b></summary>
1143
+
1144
+ Some configurations are user/machine/installation specific (e.g. configuration of local cluster, or harddrive paths on a specific machine). For such scenarios, a file [configs/local/default.yaml](configs/local/) can be created which is automatically loaded but not tracked by Git.
1145
+
1146
+ For example, you can use it for a SLURM cluster config:
1147
+
1148
+ ```yaml
1149
+ # @package _global_
1150
+
1151
+ defaults:
1152
+ - override /hydra/launcher@_here_: submitit_slurm
1153
+
1154
+ data_dir: /mnt/scratch/data/
1155
+
1156
+ hydra:
1157
+ launcher:
1158
+ timeout_min: 1440
1159
+ gpus_per_task: 1
1160
+ gres: gpu:1
1161
+ job:
1162
+ env_set:
1163
+ MY_VAR: /home/user/my/system/path
1164
+ MY_KEY: asdgjhawi8y23ihsghsueity23ihwd
1165
+ ```
1166
+
1167
+ </details>
1168
+
1169
+ <br>
1170
+
1171
+ ## Resources
1172
+
1173
+ This template was inspired by:
1174
+
1175
+ - [PyTorchLightning/deep-learning-project-template](https://github.com/PyTorchLightning/deep-learning-project-template)
1176
+ - [drivendata/cookiecutter-data-science](https://github.com/drivendata/cookiecutter-data-science)
1177
+ - [lucmos/nn-template](https://github.com/lucmos/nn-template)
1178
+
1179
+ Other useful repositories:
1180
+
1181
+ - [jxpress/lightning-hydra-template-vertex-ai](https://github.com/jxpress/lightning-hydra-template-vertex-ai) - lightning-hydra-template integration with Vertex AI hyperparameter tuning and custom training job
1182
+
1183
+ </details>
1184
+
1185
+ <br>
1186
+
1187
+ ## License
1188
+
1189
+ Lightning-Hydra-Template is licensed under the MIT License.
1190
+
1191
+ ```
1192
+ MIT License
1193
+
1194
+ Copyright (c) 2021 ashleve
1195
+
1196
+ Permission is hereby granted, free of charge, to any person obtaining a copy
1197
+ of this software and associated documentation files (the "Software"), to deal
1198
+ in the Software without restriction, including without limitation the rights
1199
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
1200
+ copies of the Software, and to permit persons to whom the Software is
1201
+ furnished to do so, subject to the following conditions:
1202
+
1203
+ The above copyright notice and this permission notice shall be included in all
1204
+ copies or substantial portions of the Software.
1205
+
1206
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
1207
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
1208
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
1209
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
1210
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
1211
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
1212
+ SOFTWARE.
1213
+ ```
1214
+
1215
+ <br>
1216
+ <br>
1217
+ <br>
1218
+ <br>
1219
+
1220
+ **DELETE EVERYTHING ABOVE FOR YOUR PROJECT**
1221
+
1222
+ ______________________________________________________________________
1223
+
1224
+ <div align="center">
1225
+
1226
+ # Your Project Name
1227
+
1228
+ <a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a>
1229
+ <a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning-792ee5?logo=pytorchlightning&logoColor=white"></a>
1230
+ <a href="https://hydra.cc/"><img alt="Config: Hydra" src="https://img.shields.io/badge/Config-Hydra-89b8cd"></a>
1231
+ <a href="https://github.com/ashleve/lightning-hydra-template"><img alt="Template" src="https://img.shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=gray"></a><br>
1232
+ [![Paper](http://img.shields.io/badge/paper-arxiv.1001.2234-B31B1B.svg)](https://www.nature.com/articles/nature14539)
1233
+ [![Conference](http://img.shields.io/badge/AnyConference-year-4b44ce.svg)](https://papers.nips.cc/paper/2020)
1234
+
1235
+ </div>
1236
+
1237
+ ## Description
1238
+
1239
+ What it does
1240
+
1241
+ ## Installation
1242
+
1243
+ #### Pip
1244
+
1245
+ ```bash
1246
+ # clone project
1247
+ git clone https://github.com/YourGithubName/your-repo-name
1248
+ cd your-repo-name
1249
+
1250
+ # [OPTIONAL] create conda environment
1251
+ conda create -n myenv python=3.9
1252
+ conda activate myenv
1253
+
1254
+ # install pytorch according to instructions
1255
+ # https://pytorch.org/get-started/
1256
+
1257
+ # install requirements
1258
+ pip install -r requirements.txt
1259
+ ```
1260
+
1261
+ #### Conda
1262
+
1263
+ ```bash
1264
+ # clone project
1265
+ git clone https://github.com/YourGithubName/your-repo-name
1266
+ cd your-repo-name
1267
+
1268
+ # create conda environment and install dependencies
1269
+ conda env create -f environment.yaml -n myenv
1270
+
1271
+ # activate conda environment
1272
+ conda activate myenv
1273
+ ```
1274
+
1275
+ ## How to run
1276
+
1277
+ Train model with default configuration
1278
+
1279
+ ```bash
1280
+ # train on CPU
1281
+ python src/train.py trainer=cpu
1282
+
1283
+ # train on GPU
1284
+ python src/train.py trainer=gpu
1285
+ ```
1286
+
1287
+ Train model with chosen experiment configuration from [configs/experiment/](configs/experiment/)
1288
+
1289
+ ```bash
1290
+ python src/train.py experiment=experiment_name.yaml
1291
+ ```
1292
+
1293
+ You can override any parameter from command line like this
1294
+
1295
+ ```bash
1296
+ python src/train.py trainer.max_epochs=20 data.batch_size=64
1297
+ ```
configs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # this file is needed here to include configs when building project as a package
configs/callbacks/default.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model_checkpoint
3
+ - early_stopping
4
+ - model_summary
5
+ - rich_progress_bar
6
+ - _self_
7
+
8
+ model_checkpoint:
9
+ dirpath: ${paths.root_dir}/checkpoints
10
+ filename: "epoch_{epoch:03d}"
11
+ monitor: "val/psnr"
12
+ mode: "max"
13
+ save_last: True
14
+ auto_insert_metric_name: False
15
+
16
+ early_stopping:
17
+ monitor: "val/psnr"
18
+ patience: 100
19
+ mode: "max"
20
+
21
+ model_summary:
22
+ max_depth: -1
configs/callbacks/early_stopping.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
2
+
3
+ early_stopping:
4
+ _target_: lightning.pytorch.callbacks.EarlyStopping
5
+ monitor: ??? # quantity to be monitored, must be specified !!!
6
+ min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
7
+ patience: 3 # number of checks with no improvement after which training will be stopped
8
+ verbose: False # verbosity mode
9
+ mode: "min" # "max" means higher metric value is better, can be also "min"
10
+ strict: True # whether to crash the training if monitor is not found in the validation metrics
11
+ check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
12
+ stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
13
+ divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
14
+ check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
15
+ # log_rank_zero_only: False # this keyword argument isn't available in stable version
configs/callbacks/model_checkpoint.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
2
+
3
+ model_checkpoint:
4
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
5
+ dirpath: null # directory to save the model file
6
+ filename: null # checkpoint filename
7
+ monitor: null # name of the logged metric which determines when model is improving
8
+ verbose: False # verbosity mode
9
+ save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
10
+ save_top_k: 1 # save k best models (determined by above metric)
11
+ mode: "min" # "max" means higher metric value is better, can be also "min"
12
+ auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
13
+ save_weights_only: False # if True, then only the model’s weights will be saved
14
+ every_n_train_steps: null # number of training steps between checkpoints
15
+ train_time_interval: null # checkpoints are monitored at the specified time interval
16
+ every_n_epochs: null # number of epochs between checkpoints
17
+ save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
configs/callbacks/model_summary.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
2
+
3
+ model_summary:
4
+ _target_: lightning.pytorch.callbacks.RichModelSummary
5
+ max_depth: 1 # the maximum depth of layer nesting that the summary will include
configs/callbacks/none.yaml ADDED
File without changes
configs/callbacks/rich_progress_bar.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
2
+
3
+ rich_progress_bar:
4
+ _target_: lightning.pytorch.callbacks.RichProgressBar
configs/data/swim.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ _target_: swim.data.swim_data.SwimDataModule
2
+ root_dir: /home/qninh/projects/swim_/datasets/swim_data
3
+ batch_size: 4
4
+ img_size: 64
configs/debug/default.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # default debugging setup, runs 1 full epoch
4
+ # other debugging configs can inherit from this one
5
+
6
+ # overwrite task name so debugging logs are stored in separate folder
7
+ task_name: "debug"
8
+
9
+ # disable callbacks and loggers during debugging
10
+ callbacks: null
11
+ logger: null
12
+
13
+ extras:
14
+ ignore_warnings: False
15
+ enforce_tags: False
16
+
17
+ # sets level of all command line loggers to 'DEBUG'
18
+ # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
19
+ hydra:
20
+ job_logging:
21
+ root:
22
+ level: DEBUG
23
+
24
+ # use this to also set hydra loggers to 'DEBUG'
25
+ # verbose: True
26
+
27
+ trainer:
28
+ max_epochs: 1
29
+ accelerator: cpu # debuggers don't like gpus
30
+ devices: 1 # debuggers don't like multiprocessing
31
+ detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
32
+
33
+ data:
34
+ num_workers: 0 # debuggers don't like multiprocessing
35
+ pin_memory: False # disable gpu memory pin
configs/debug/fdr.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # runs 1 train, 1 validation and 1 test step
4
+
5
+ defaults:
6
+ - default
7
+
8
+ trainer:
9
+ fast_dev_run: true
configs/debug/limit.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # uses only 1% of the training data and 5% of validation/test data
4
+
5
+ defaults:
6
+ - default
7
+
8
+ trainer:
9
+ max_epochs: 3
10
+ limit_train_batches: 0.01
11
+ limit_val_batches: 0.05
12
+ limit_test_batches: 0.05
configs/debug/overfit.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # overfits to 3 batches
4
+
5
+ defaults:
6
+ - default
7
+
8
+ trainer:
9
+ max_epochs: 20
10
+ overfit_batches: 3
11
+
12
+ # model ckpt and early stopping need to be disabled during overfitting
13
+ callbacks: null
configs/debug/profiler.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # runs with execution time profiling
4
+
5
+ defaults:
6
+ - default
7
+
8
+ trainer:
9
+ max_epochs: 1
10
+ profiler: "simple"
11
+ # profiler: "advanced"
12
+ # profiler: "pytorch"
configs/eval.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - _self_
5
+ - data: mnist # choose datamodule with `test_dataloader()` for evaluation
6
+ - model: mnist
7
+ - logger: null
8
+ - trainer: default
9
+ - paths: default
10
+ - extras: default
11
+ - hydra: default
12
+
13
+ task_name: "eval"
14
+
15
+ tags: ["dev"]
16
+
17
+ # passing checkpoint path is necessary for evaluation
18
+ ckpt_path: ???
configs/experiment/example.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=example
5
+
6
+ defaults:
7
+ - override /data: mnist
8
+ - override /model: mnist
9
+ - override /callbacks: default
10
+ - override /trainer: default
11
+
12
+ # all parameters below will be merged with parameters from default configurations set above
13
+ # this allows you to overwrite only specified parameters
14
+
15
+ tags: ["mnist", "simple_dense_net"]
16
+
17
+ seed: 12345
18
+
19
+ trainer:
20
+ min_epochs: 10
21
+ max_epochs: 10
22
+ gradient_clip_val: 0.5
23
+
24
+ model:
25
+ optimizer:
26
+ lr: 0.002
27
+ net:
28
+ lin1_size: 128
29
+ lin2_size: 256
30
+ lin3_size: 64
31
+ compile: false
32
+
33
+ data:
34
+ batch_size: 64
35
+
36
+ logger:
37
+ wandb:
38
+ tags: ${tags}
39
+ group: "mnist"
40
+ aim:
41
+ experiment: "mnist"
configs/extras/default.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # disable python warnings if they annoy you
2
+ ignore_warnings: False
3
+
4
+ # ask user for tags if none are provided in the config
5
+ enforce_tags: True
6
+
7
+ # pretty print config tree at the start of the run using Rich library
8
+ print_config: True
configs/hparams_search/mnist_optuna.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # example hyperparameter optimization of some experiment with Optuna:
4
+ # python train.py -m hparams_search=mnist_optuna experiment=example
5
+
6
+ defaults:
7
+ - override /hydra/sweeper: optuna
8
+
9
+ # choose metric which will be optimized by Optuna
10
+ # make sure this is the correct name of some metric logged in lightning module!
11
+ optimized_metric: "val/acc_best"
12
+
13
+ # here we define Optuna hyperparameter search
14
+ # it optimizes for value returned from function with @hydra.main decorator
15
+ # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
16
+ hydra:
17
+ mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
18
+
19
+ sweeper:
20
+ _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
21
+
22
+ # storage URL to persist optimization results
23
+ # for example, you can use SQLite if you set 'sqlite:///example.db'
24
+ storage: null
25
+
26
+ # name of the study to persist optimization results
27
+ study_name: null
28
+
29
+ # number of parallel workers
30
+ n_jobs: 1
31
+
32
+ # 'minimize' or 'maximize' the objective
33
+ direction: maximize
34
+
35
+ # total number of runs that will be executed
36
+ n_trials: 20
37
+
38
+ # choose Optuna hyperparameter sampler
39
+ # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others
40
+ # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
41
+ sampler:
42
+ _target_: optuna.samplers.TPESampler
43
+ seed: 1234
44
+ n_startup_trials: 10 # number of random sampling runs before optimization starts
45
+
46
+ # define hyperparameter search space
47
+ params:
48
+ model.optimizer.lr: interval(0.0001, 0.1)
49
+ data.batch_size: choice(32, 64, 128, 256)
50
+ model.net.lin1_size: choice(64, 128, 256)
51
+ model.net.lin2_size: choice(64, 128, 256)
52
+ model.net.lin3_size: choice(32, 64, 128, 256)
configs/hydra/default.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://hydra.cc/docs/configure_hydra/intro/
2
+
3
+ # enable color logging
4
+ defaults:
5
+ - override hydra_logging: colorlog
6
+ - override job_logging: colorlog
7
+
8
+ # output directory, generated dynamically on each run
9
+ run:
10
+ dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
11
+ sweep:
12
+ dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
13
+ subdir: ${hydra.job.num}
14
+
15
+ job_logging:
16
+ handlers:
17
+ file:
18
+ # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
19
+ filename: ${hydra.runtime.output_dir}/${task_name}.log
configs/local/.gitkeep ADDED
File without changes
configs/logger/aim.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://aimstack.io/
2
+
3
+ # example usage in lightning module:
4
+ # https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py
5
+
6
+ # open the Aim UI with the following command (run in the folder containing the `.aim` folder):
7
+ # `aim up`
8
+
9
+ aim:
10
+ _target_: aim.pytorch_lightning.AimLogger
11
+ repo: ${paths.root_dir} # .aim folder will be created here
12
+ # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html#
13
+
14
+ # aim allows to group runs under experiment name
15
+ experiment: null # any string, set to "default" if not specified
16
+
17
+ train_metric_prefix: "train/"
18
+ val_metric_prefix: "val/"
19
+ test_metric_prefix: "test/"
20
+
21
+ # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.)
22
+ system_tracking_interval: 10 # set to null to disable system metrics tracking
23
+
24
+ # enable/disable logging of system params such as installed packages, git info, env vars, etc.
25
+ log_system_params: true
26
+
27
+ # enable/disable tracking console logs (default value is true)
28
+ capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550
configs/logger/comet.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://www.comet.ml
2
+
3
+ comet:
4
+ _target_: lightning.pytorch.loggers.comet.CometLogger
5
+ api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
6
+ save_dir: "${paths.output_dir}"
7
+ project_name: "lightning-hydra-template"
8
+ rest_api_key: null
9
+ # experiment_name: ""
10
+ experiment_key: null # set to resume experiment
11
+ offline: False
12
+ prefix: ""
configs/logger/csv.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # csv logger built in lightning
2
+
3
+ csv:
4
+ _target_: lightning.pytorch.loggers.csv_logs.CSVLogger
5
+ save_dir: "${paths.output_dir}"
6
+ name: "csv/"
7
+ prefix: ""
configs/logger/many_loggers.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # train with many loggers at once
2
+
3
+ defaults:
4
+ # - comet
5
+ - csv
6
+ # - mlflow
7
+ # - neptune
8
+ - tensorboard
9
+ - wandb
configs/logger/mlflow.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://mlflow.org
2
+
3
+ mlflow:
4
+ _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
5
+ # experiment_name: ""
6
+ # run_name: ""
7
+ tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
8
+ tags: null
9
+ # save_dir: "./mlruns"
10
+ prefix: ""
11
+ artifact_location: null
12
+ # run_id: ""
configs/logger/neptune.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # https://neptune.ai
2
+
3
+ neptune:
4
+ _target_: lightning.pytorch.loggers.neptune.NeptuneLogger
5
+ api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
6
+ project: username/lightning-hydra-template
7
+ # name: ""
8
+ log_model_checkpoints: True
9
+ prefix: ""
configs/logger/tensorboard.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://www.tensorflow.org/tensorboard/
2
+
3
+ tensorboard:
4
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
5
+ save_dir: "${paths.output_dir}/tensorboard/"
6
+ name: null
7
+ log_graph: False
8
+ default_hp_metric: True
9
+ prefix: ""
10
+ # version: ""
configs/logger/wandb.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://wandb.ai
2
+
3
+ wandb:
4
+ _target_: lightning.pytorch.loggers.wandb.WandbLogger
5
+ # name: "" # name of the run (normally generated by wandb)
6
+ save_dir: "${paths.output_dir}"
7
+ offline: False
8
+ id: null # pass correct id to resume experiment!
9
+ anonymous: null # enable anonymous logging
10
+ project: "swim"
11
+ log_model: False # upload lightning ckpts
12
+ prefix: "" # a string to put at the beginning of metric keys
13
+ # entity: "" # set to name of your wandb team
14
+ group: ""
15
+ tags: []
16
+ job_type: ""
configs/model/autoencoder.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: swim.models.autoencoder.Autoencoder
2
+
3
+ learning_rate: 1e-4
4
+
5
+ channels: 128
6
+ channel_multipliers: [1, 2, 4]
7
+ n_resnet_blocks: 1
8
+ in_channels: 3
9
+ out_channels: 3
10
+ z_channels: 4
11
+ emb_channels: 4
configs/paths/default.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path to root directory
2
+ # this requires PROJECT_ROOT environment variable to exist
3
+ # you can replace it with "." if you want the root to be the current working directory
4
+ root_dir: ${oc.env:PROJECT_ROOT}
5
+
6
+ # path to data directory
7
+ data_dir: ${paths.root_dir}/data/
8
+
9
+ # path to logging directory
10
+ log_dir: ${paths.root_dir}/logs/
11
+
12
+ # path to output directory, created dynamically by hydra
13
+ # path generation pattern is specified in `configs/hydra/default.yaml`
14
+ # use it to store all files generated during the run, like ckpts and metrics
15
+ output_dir: ${hydra:runtime.output_dir}
16
+
17
+ # path to working directory
18
+ work_dir: ${hydra:runtime.cwd}
configs/train.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # specify here default configuration
4
+ # order of defaults determines the order in which configs override each other
5
+ defaults:
6
+ - _self_
7
+ - data: swim
8
+ - model: autoencoder
9
+ - callbacks: default
10
+ - logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
11
+ - trainer: gpu
12
+ - paths: default
13
+ - extras: default
14
+ - hydra: default
15
+
16
+ # experiment configs allow for version control of specific hyperparameters
17
+ # e.g. best hyperparameters for given model and datamodule
18
+ - experiment: null
19
+
20
+ # config for hyperparameter optimization
21
+ - hparams_search: null
22
+
23
+ # optional local config for machine/user specific settings
24
+ # it's optional since it doesn't need to exist and is excluded from version control
25
+ - optional local: default
26
+
27
+ # debugging config (enable through command line, e.g. `python train.py debug=default)
28
+ - debug: null
29
+
30
+ # task name, determines output directory path
31
+ task_name: "train"
32
+
33
+ # tags to help you identify your experiments
34
+ # you can overwrite this in experiment configs
35
+ # overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
36
+ tags: ["dev"]
37
+
38
+ # set False to skip model training
39
+ train: True
40
+
41
+ # evaluate on test set, using best model weights achieved during training
42
+ # lightning chooses best weights based on the metric specified in checkpoint callback
43
+ test: True
44
+
45
+ # simply provide checkpoint path to resume training
46
+ ckpt_path: null
47
+
48
+ # seed for random number generators in pytorch, numpy and python.random
49
+ seed: 42
configs/trainer/cpu.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ accelerator: cpu
5
+ devices: 1
configs/trainer/ddp.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ strategy: ddp
5
+
6
+ accelerator: gpu
7
+ devices: 4
8
+ num_nodes: 1
9
+ sync_batchnorm: True
configs/trainer/ddp_sim.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ # simulate DDP on CPU, useful for debugging
5
+ accelerator: cpu
6
+ devices: 2
7
+ strategy: ddp_spawn
configs/trainer/default.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: lightning.pytorch.trainer.Trainer
2
+
3
+ default_root_dir: ${paths.output_dir}
4
+
5
+ min_epochs: 1 # prevents early stopping
6
+ max_epochs: 10
7
+
8
+ accelerator: cpu
9
+ devices: 1
10
+
11
+ # mixed precision for extra speed-up
12
+ # precision: 16
13
+
14
+ # perform a validation loop every N training epochs
15
+ check_val_every_n_epoch: 1
16
+
17
+ # set True to to ensure deterministic results
18
+ # makes training slower but gives more reproducibility than just setting seeds
19
+ deterministic: False
configs/trainer/gpu.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ accelerator: gpu
5
+ devices: 1
configs/trainer/mps.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - default
3
+
4
+ accelerator: mps
5
+ devices: 1
environment.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # reasons you might want to use `environment.yaml` instead of `requirements.txt`:
2
+ # - pip installs packages in a loop, without ensuring dependencies across all packages
3
+ # are fulfilled simultaneously, but conda achieves proper dependency control across
4
+ # all packages
5
+ # - conda allows for installing packages without requiring certain compilers or
6
+ # libraries to be available in the system, since it installs precompiled binaries
7
+
8
+ name: myenv
9
+
10
+ channels:
11
+ - pytorch
12
+ - conda-forge
13
+ - defaults
14
+
15
+ # it is strongly recommended to specify versions of packages installed through conda
16
+ # to avoid situation when version-unspecified packages install their latest major
17
+ # versions which can sometimes break things
18
+
19
+ # current approach below keeps the dependencies in the same major versions across all
20
+ # users, but allows for different minor and patch versions of packages where backwards
21
+ # compatibility is usually guaranteed
22
+
23
+ dependencies:
24
+ - python=3.10
25
+ - pytorch=2.*
26
+ - torchvision=0.*
27
+ - lightning=2.*
28
+ - torchmetrics=0.*
29
+ - hydra-core=1.*
30
+ - rich=13.*
31
+ - pre-commit=3.*
32
+ - pytest=7.*
33
+
34
+ # --------- loggers --------- #
35
+ # - wandb
36
+ # - neptune-client
37
+ # - mlflow
38
+ # - comet-ml
39
+ # - aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550
40
+
41
+ - pip>=23
42
+ - pip:
43
+ - hydra-optuna-sweeper
44
+ - hydra-colorlog
45
+ - rootutils
notebooks/.gitkeep ADDED
File without changes
pyproject.toml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.pytest.ini_options]
2
+ addopts = [
3
+ "--color=yes",
4
+ "--durations=0",
5
+ "--strict-markers",
6
+ "--doctest-modules",
7
+ ]
8
+ filterwarnings = [
9
+ "ignore::DeprecationWarning",
10
+ "ignore::UserWarning",
11
+ ]
12
+ log_cli = "True"
13
+ markers = [
14
+ "slow: slow tests",
15
+ ]
16
+ minversion = "6.0"
17
+ testpaths = "tests/"
18
+
19
+ [tool.coverage.report]
20
+ exclude_lines = [
21
+ "pragma: nocover",
22
+ "raise NotImplementedError",
23
+ "raise NotImplementedError()",
24
+ "if __name__ == .__main__.:",
25
+ ]
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------- pytorch --------- #
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ lightning>=2.0.0
5
+ torchmetrics>=0.11.4
6
+
7
+ # --------- hydra --------- #
8
+ hydra-core==1.3.2
9
+ hydra-colorlog==1.2.0
10
+ hydra-optuna-sweeper==1.2.0
11
+
12
+ # --------- loggers --------- #
13
+ wandb
14
+ # neptune-client
15
+ # mlflow
16
+ # comet-ml
17
+ # aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550
18
+
19
+ # --------- others --------- #
20
+ rootutils # standardizing the project root setup
21
+ pre-commit # hooks for applying linters on commit
22
+ rich # beautiful text formatting in terminal
23
+ pytest # tests
24
+ # sh # for running bash commands in some tests (linux/macos only)
scripts/schedule.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Schedule execution of many runs
3
+ # Run from root folder with: bash scripts/schedule.sh
4
+
5
+ python src/train.py trainer.max_epochs=5 logger=csv
6
+
7
+ python src/train.py trainer.max_epochs=10 logger=csv
setup.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from setuptools import find_packages, setup
4
+
5
+ setup(
6
+ name="src",
7
+ version="0.0.1",
8
+ description="Describe Your Cool Project",
9
+ author="",
10
+ author_email="",
11
+ url="https://github.com/user/project",
12
+ install_requires=["lightning", "hydra-core"],
13
+ packages=find_packages(),
14
+ # use this to customize global commands available in the terminal after installing the package
15
+ entry_points={
16
+ "console_scripts": [
17
+ "train_command = src.train:main",
18
+ "eval_command = src.eval:main",
19
+ ]
20
+ },
21
+ )
swim/__init__.py ADDED
File without changes
swim/data/__init__.py ADDED
File without changes
swim/data/swim_data.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, List
2
+
3
+ import os
4
+ import json
5
+ import torch
6
+ import torchvision.transforms as T
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from PIL import Image
9
+ from lightning import LightningDataModule
10
+
11
+
12
+ class SwimDataset(Dataset):
13
+ def __init__(
14
+ self,
15
+ root_dir: str = "./datasets/swim_data",
16
+ split: Literal["train", "val"] = "train",
17
+ img_size: int = 512,
18
+ ):
19
+ super().__init__()
20
+ self.root_dir = root_dir
21
+ self.split_dir = os.path.join(root_dir, split)
22
+ self.img_size = img_size
23
+
24
+ if split == "train":
25
+ self.transform = T.Compose(
26
+ [
27
+ T.Resize(img_size), # smaller edge of image resized to img_size
28
+ T.RandomCrop(img_size), # get a random crop of img_size x img_size
29
+ T.RandomHorizontalFlip(),
30
+ T.ToTensor(),
31
+ T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
32
+ ]
33
+ )
34
+ elif split == "val":
35
+ self.transform = T.Compose(
36
+ [
37
+ T.Resize(img_size),
38
+ T.CenterCrop(img_size),
39
+ T.ToTensor(),
40
+ T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
41
+ ]
42
+ )
43
+
44
+ with open(os.path.join(self.split_dir, "labels.json"), "r") as f:
45
+ self.data = json.load(f)
46
+
47
+ # filter out images that are both at night and have adverse weather conditions
48
+ self.data = [
49
+ img
50
+ for img in self.data
51
+ if not (img["timeofday"] == "night" and img["weather"] != "clear")
52
+ ]
53
+
54
+ def __len__(self):
55
+ return len(self.data) // 100
56
+
57
+ def __getitem__(self, idx):
58
+ data = self.data[idx]
59
+
60
+ # load image
61
+ img_path = os.path.join(self.split_dir, "images", data["name"])
62
+ img = Image.open(img_path).convert("RGB")
63
+ img = self.transform(img)
64
+
65
+ # load style
66
+ if data["weather"] != "clear":
67
+ style_name = data["weather"]
68
+ elif data["timeofday"] == "night":
69
+ style_name = "night"
70
+ else:
71
+ style_name = "clear"
72
+
73
+ # true if image has any styles
74
+ style_flag = style_name != "clear"
75
+
76
+ # one-hot encode style
77
+ style = torch.zeros(4)
78
+
79
+ if style_flag:
80
+ style[self.get_stylenames().index(style_name)] = 1
81
+
82
+ return {
83
+ "image": img,
84
+ "style": style,
85
+ "style_flag": style_flag,
86
+ }
87
+
88
+ def get_stylenames(self) -> List[str]:
89
+ return ["rain", "snow", "fog", "night"]
90
+
91
+
92
+ class SwimDataModule(LightningDataModule):
93
+ def __init__(
94
+ self,
95
+ root_dir: str = "./datasets/swim_data",
96
+ batch_size: int = 1,
97
+ img_size: int = 512,
98
+ ):
99
+ super().__init__()
100
+ self.root_dir = root_dir
101
+ self.img_size = img_size
102
+ self.batch_size = batch_size
103
+
104
+ def setup(self, stage=None):
105
+ if stage == "fit" or stage is None:
106
+ self.train_dataset = SwimDataset(
107
+ root_dir=self.root_dir, split="train", img_size=self.img_size
108
+ )
109
+ self.val_dataset = SwimDataset(
110
+ root_dir=self.root_dir, split="val", img_size=self.img_size
111
+ )
112
+
113
+ def train_dataloader(self):
114
+ return DataLoader(
115
+ self.train_dataset,
116
+ batch_size=self.batch_size,
117
+ shuffle=True,
118
+ num_workers=4,
119
+ collate_fn=self.custom_collate_fn,
120
+ )
121
+
122
+ def val_dataloader(self):
123
+ return DataLoader(
124
+ self.val_dataset,
125
+ batch_size=self.batch_size,
126
+ shuffle=False,
127
+ num_workers=4,
128
+ collate_fn=self.custom_collate_fn,
129
+ )
130
+
131
+ def test_dataloader(self):
132
+ return DataLoader(
133
+ self.val_dataset,
134
+ batch_size=1,
135
+ shuffle=False,
136
+ num_workers=4,
137
+ collate_fn=self.custom_collate_fn,
138
+ )
139
+
140
+ @staticmethod
141
+ def custom_collate_fn(batch):
142
+ images = torch.stack([item["image"] for item in batch])
143
+ styles = torch.stack([item["style"] for item in batch])
144
+ style_flags = [item["style_flag"] for item in batch]
145
+ return {"images": images, "styles": styles, "style_flags": style_flags}
swim/eval.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Tuple
2
+
3
+ import hydra
4
+ import rootutils
5
+ from lightning import LightningDataModule, LightningModule, Trainer
6
+ from lightning.pytorch.loggers import Logger
7
+ from omegaconf import DictConfig
8
+
9
+ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
10
+ # ------------------------------------------------------------------------------------ #
11
+ # the setup_root above is equivalent to:
12
+ # - adding project root dir to PYTHONPATH
13
+ # (so you don't need to force user to install project as a package)
14
+ # (necessary before importing any local modules e.g. `from src import utils`)
15
+ # - setting up PROJECT_ROOT environment variable
16
+ # (which is used as a base for paths in "configs/paths/default.yaml")
17
+ # (this way all filepaths are the same no matter where you run the code)
18
+ # - loading environment variables from ".env" in root dir
19
+ #
20
+ # you can remove it if you:
21
+ # 1. either install project as a package or move entry files to project root dir
22
+ # 2. set `root_dir` to "." in "configs/paths/default.yaml"
23
+ #
24
+ # more info: https://github.com/ashleve/rootutils
25
+ # ------------------------------------------------------------------------------------ #
26
+
27
+ from swim.utils import (
28
+ RankedLogger,
29
+ extras,
30
+ instantiate_loggers,
31
+ log_hyperparameters,
32
+ task_wrapper,
33
+ )
34
+
35
+ log = RankedLogger(__name__, rank_zero_only=True)
36
+
37
+
38
+ @task_wrapper
39
+ def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
40
+ """Evaluates given checkpoint on a datamodule testset.
41
+
42
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
43
+ failure. Useful for multiruns, saving info about the crash, etc.
44
+
45
+ :param cfg: DictConfig configuration composed by Hydra.
46
+ :return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
47
+ """
48
+ assert cfg.ckpt_path
49
+
50
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>")
51
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
52
+
53
+ log.info(f"Instantiating model <{cfg.model._target_}>")
54
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
55
+
56
+ log.info("Instantiating loggers...")
57
+ logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
58
+
59
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
60
+ trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
61
+
62
+ object_dict = {
63
+ "cfg": cfg,
64
+ "datamodule": datamodule,
65
+ "model": model,
66
+ "logger": logger,
67
+ "trainer": trainer,
68
+ }
69
+
70
+ if logger:
71
+ log.info("Logging hyperparameters!")
72
+ log_hyperparameters(object_dict)
73
+
74
+ log.info("Starting testing!")
75
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
76
+
77
+ # for predictions use trainer.predict(...)
78
+ # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)
79
+
80
+ metric_dict = trainer.callback_metrics
81
+
82
+ return metric_dict, object_dict
83
+
84
+
85
+ @hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
86
+ def main(cfg: DictConfig) -> None:
87
+ """Main entry point for evaluation.
88
+
89
+ :param cfg: DictConfig configuration composed by Hydra.
90
+ """
91
+ # apply extra utilities
92
+ # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
93
+ extras(cfg)
94
+
95
+ evaluate(cfg)
96
+
97
+
98
+ if __name__ == "__main__":
99
+ main()
swim/models/__init__.py ADDED
File without changes