veichta commited on
Commit
205a7af
·
verified ·
1 Parent(s): 4ad7f0b

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. .gitignore +138 -0
  3. .pre-commit-config.yaml +35 -0
  4. LICENSE +190 -0
  5. README.md +582 -7
  6. assets/fisheye-dog-pool.jpg +0 -0
  7. assets/fisheye-skyline.jpg +3 -0
  8. assets/pinhole-church.jpg +0 -0
  9. assets/pinhole-garden.jpg +0 -0
  10. assets/teaser.gif +3 -0
  11. demo.ipynb +3 -0
  12. geocalib/__init__.py +17 -0
  13. geocalib/camera.py +774 -0
  14. geocalib/extractor.py +126 -0
  15. geocalib/geocalib.py +150 -0
  16. geocalib/gravity.py +131 -0
  17. geocalib/interactive_demo.py +450 -0
  18. geocalib/lm_optimizer.py +642 -0
  19. geocalib/misc.py +318 -0
  20. geocalib/modules.py +575 -0
  21. geocalib/perspective_fields.py +366 -0
  22. geocalib/utils.py +325 -0
  23. geocalib/viz2d.py +502 -0
  24. gradio_app.py +228 -0
  25. hubconf.py +14 -0
  26. pyproject.toml +49 -0
  27. requirements.txt +5 -0
  28. siclib/LICENSE +190 -0
  29. siclib/__init__.py +15 -0
  30. siclib/configs/deepcalib.yaml +12 -0
  31. siclib/configs/geocalib-radial.yaml +38 -0
  32. siclib/configs/geocalib.yaml +10 -0
  33. siclib/configs/model/deepcalib.yaml +7 -0
  34. siclib/configs/model/geocalib.yaml +31 -0
  35. siclib/configs/train/deepcalib.yaml +22 -0
  36. siclib/configs/train/geocalib.yaml +50 -0
  37. siclib/datasets/__init__.py +25 -0
  38. siclib/datasets/augmentations.py +359 -0
  39. siclib/datasets/base_dataset.py +218 -0
  40. siclib/datasets/configs/openpano-radial.yaml +41 -0
  41. siclib/datasets/configs/openpano.yaml +34 -0
  42. siclib/datasets/create_dataset_from_pano.py +350 -0
  43. siclib/datasets/simple_dataset.py +237 -0
  44. siclib/datasets/utils/__init__.py +0 -0
  45. siclib/datasets/utils/align_megadepth.py +41 -0
  46. siclib/datasets/utils/download_openpano.py +75 -0
  47. siclib/datasets/utils/tonemapping.py +316 -0
  48. siclib/eval/__init__.py +18 -0
  49. siclib/eval/configs/deepcalib.yaml +3 -0
  50. siclib/eval/configs/geocalib-pinhole.yaml +2 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/fisheye-skyline.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/teaser.gif filter=lfs diff=lfs merge=lfs -text
38
+ demo.ipynb filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # folders
2
+ data/**
3
+ outputs/**
4
+ weights/**
5
+ **.DS_Store
6
+ .vscode/**
7
+ wandb/**
8
+ third_party/**
9
+
10
+ # Byte-compiled / optimized / DLL files
11
+ __pycache__/
12
+ *.py[cod]
13
+ *$py.class
14
+
15
+ # C extensions
16
+ *.so
17
+
18
+ # Distribution / packaging
19
+ .Python
20
+ build/
21
+ develop-eggs/
22
+ dist/
23
+ downloads/
24
+ eggs/
25
+ .eggs/
26
+ lib/
27
+ lib64/
28
+ parts/
29
+ sdist/
30
+ var/
31
+ wheels/
32
+ pip-wheel-metadata/
33
+ share/python-wheels/
34
+ *.egg-info/
35
+ .installed.cfg
36
+ *.egg
37
+ MANIFEST
38
+
39
+ # PyInstaller
40
+ # Usually these files are written by a python script from a template
41
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
42
+ *.manifest
43
+ *.spec
44
+
45
+ # Installer logs
46
+ pip-log.txt
47
+ pip-delete-this-directory.txt
48
+
49
+ # Unit test / coverage reports
50
+ htmlcov/
51
+ .tox/
52
+ .nox/
53
+ .coverage
54
+ .coverage.*
55
+ .cache
56
+ nosetests.xml
57
+ coverage.xml
58
+ *.cover
59
+ *.py,cover
60
+ .hypothesis/
61
+ .pytest_cache/
62
+
63
+ # Translations
64
+ *.mo
65
+ *.pot
66
+
67
+ # Django stuff:
68
+ *.log
69
+ local_settings.py
70
+ db.sqlite3
71
+ db.sqlite3-journal
72
+
73
+ # Flask stuff:
74
+ instance/
75
+ .webassets-cache
76
+
77
+ # Scrapy stuff:
78
+ .scrapy
79
+
80
+ # Sphinx documentation
81
+ docs/_build/
82
+
83
+ # PyBuilder
84
+ target/
85
+
86
+ # Jupyter Notebook
87
+ .ipynb_checkpoints
88
+
89
+ # IPython
90
+ profile_default/
91
+ ipython_config.py
92
+
93
+ # pyenv
94
+ .python-version
95
+
96
+ # pipenv
97
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
99
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
100
+ # install all needed dependencies.
101
+ #Pipfile.lock
102
+
103
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
104
+ __pypackages__/
105
+
106
+ # Celery stuff
107
+ celerybeat-schedule
108
+ celerybeat.pid
109
+
110
+ # SageMath parsed files
111
+ *.sage.py
112
+
113
+ # Environments
114
+ .env
115
+ .venv
116
+ env/
117
+ venv/
118
+ ENV/
119
+ env.bak/
120
+ venv.bak/
121
+
122
+ # Spyder project settings
123
+ .spyderproject
124
+ .spyproject
125
+
126
+ # Rope project settings
127
+ .ropeproject
128
+
129
+ # mkdocs documentation
130
+ /site
131
+
132
+ # mypy
133
+ .mypy_cache/
134
+ .dmypy.json
135
+ dmypy.json
136
+
137
+ # Pyre type checker
138
+ .pyre/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_stages: [commit]
2
+ default_language_version:
3
+ python: python3.10
4
+ repos:
5
+ - repo: https://github.com/psf/black
6
+ rev: 23.9.1
7
+ hooks:
8
+ - id: black
9
+ args: [--line-length=100]
10
+ exclude: ^(venv/|docs/)
11
+ types: [python]
12
+ - repo: https://github.com/PyCQA/flake8
13
+ rev: 6.1.0
14
+ hooks:
15
+ - id: flake8
16
+ additional_dependencies: [flake8-docstrings]
17
+ args:
18
+ [
19
+ --max-line-length=100,
20
+ --docstring-convention=google,
21
+ --ignore=E203 W503 E402 E731,
22
+ ]
23
+ exclude: ^(venv/|docs/|.*__init__.py)
24
+ types: [python]
25
+
26
+ - repo: https://github.com/pycqa/isort
27
+ rev: 5.12.0
28
+ hooks:
29
+ - id: isort
30
+ args: [--line-length=100, --profile=black, --atomic]
31
+
32
+ - repo: https://github.com/pre-commit/mirrors-mypy
33
+ rev: v1.1.1
34
+ hooks:
35
+ - id: mypy
LICENSE ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ Copyright 2024 ETH Zurich
179
+
180
+ Licensed under the Apache License, Version 2.0 (the "License");
181
+ you may not use this file except in compliance with the License.
182
+ You may obtain a copy of the License at
183
+
184
+ http://www.apache.org/licenses/LICENSE-2.0
185
+
186
+ Unless required by applicable law or agreed to in writing, software
187
+ distributed under the License is distributed on an "AS IS" BASIS,
188
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
189
+ See the License for the specific language governing permissions and
190
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,587 @@
1
  ---
2
  title: GeoCalib
3
- emoji: 👀
4
- colorFrom: pink
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 4.43.0
8
- app_file: app.py
9
- pinned: false
10
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: GeoCalib
3
+ app_file: gradio_app.py
 
 
4
  sdk: gradio
5
+ sdk_version: 4.38.1
 
 
6
  ---
7
+ <p align="center">
8
+ <h1 align="center"><ins>GeoCalib</ins> 📸<br>Single-image Calibration with Geometric Optimization</h1>
9
+ <p align="center">
10
+ <a href="https://www.linkedin.com/in/alexander-veicht/">Alexander Veicht</a>
11
+ ·
12
+ <a href="https://psarlin.com/">Paul-Edouard&nbsp;Sarlin</a>
13
+ ·
14
+ <a href="https://www.linkedin.com/in/philipplindenberger/">Philipp Lindenberger</a>
15
+ ·
16
+ <a href="https://www.microsoft.com/en-us/research/people/mapoll/">Marc&nbsp;Pollefeys</a>
17
+ </p>
18
+ <h2 align="center">
19
+ <p>ECCV 2024</p>
20
+ <a href="" align="center">Paper</a> | <!--TODO: update link-->
21
+ <a href="https://colab.research.google.com/drive/1oMzgPGppAPAIQxe-s7SRd_q8r7dVfnqo#scrollTo=etdzQZQzoo-K" align="center">Colab</a> |
22
+ <a href="https://huggingface.co/spaces/veichta/GeoCalib" align="center">Demo 🤗</a>
23
+ </h2>
24
+
25
+ </p>
26
+ <p align="center">
27
+ <a href=""><img src="assets/teaser.gif" alt="example" width=80%></a> <!--TODO: update link-->
28
+ <br>
29
+ <em>
30
+ GeoCalib accurately estimates the camera intrinsics and gravity direction from a single image
31
+ <br>
32
+ by combining geometric optimization with deep learning.
33
+ </em>
34
+ </p>
35
+
36
+ ##
37
+
38
+ GeoCalib is a an algoritm for single-image calibration: it estimates the camera intrinsics and gravity direction from a single image only. By combining geometric optimization with deep learning, GeoCalib provides a more flexible and accurate calibration compared to previous approaches. This repository hosts the [inference](#setup-and-demo), [evaluation](#evaluation), and [training](#training) code for GeoCalib and instructions to download our training set [OpenPano](#openpano-dataset).
39
+
40
+
41
+ ## Setup and demo
42
+
43
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1oMzgPGppAPAIQxe-s7SRd_q8r7dVfnqo#scrollTo=etdzQZQzoo-K)
44
+ [![Hugging Face](https://img.shields.io/badge/Gradio-Demo-blue)](https://huggingface.co/spaces/veichta/GeoCalib)
45
+
46
+ We provide a small inference package [`geocalib`](geocalib) that requires only minimal dependencies and Python >= 3.9. First clone the repository and install the dependencies:
47
+
48
+ ```bash
49
+ git clone https://github.com/cvg/GeoCalib.git && cd GeoCalib
50
+ python -m pip install -e .
51
+ # OR
52
+ python -m pip install -e "git+https://github.com/cvg/GeoCalib#egg=geocalib"
53
+ ```
54
+
55
+ Here is a minimal usage example:
56
+
57
+ ```python
58
+ from geocalib import GeoCalib
59
+
60
+ device = "cuda" if torch.cuda.is_available() else "cpu"
61
+ model = GeoCalib().to(device)
62
+
63
+ # load image as tensor in range [0, 1] with shape [C, H, W]
64
+ img = model.load_image("path/to/image.jpg").to(device)
65
+ result = model.calibrate(img)
66
+
67
+ print("camera:", result["camera"])
68
+ print("gravity:", result["gravity"])
69
+ ```
70
+
71
+ When either the intrinsics or the gravity are already know, they can be provided:
72
+
73
+ ```python
74
+ # known intrinsics:
75
+ result = model.calibrate(img, priors={"focal": focal_length_tensor})
76
+
77
+ # known gravity:
78
+ result = model.calibrate(img, priors={"gravity": gravity_direction_tensor})
79
+ ```
80
+
81
+ The default model is optimized for pinhole images. To handle lens distortion, use the following:
82
+
83
+ ```python
84
+ model = GeoCalib(weights="distorted") # default is "pinhole"
85
+ result = model.calibrate(img, camera_model="simple_radial") # or pinhole, simple_divisional
86
+ ```
87
+
88
+ Check out our [demo notebook](demo.ipynb) for a full working example.
89
+
90
+ <details>
91
+ <summary><b>[Interactive demo for your webcam - click to expand]</b></summary>
92
+ Run the following command:
93
+
94
+ ```bash
95
+ python -m geocalib.interactive_demo --camera_id 0
96
+ ```
97
+
98
+ The demo will open a window showing the camera feed and the calibration results. If `--camera_id` is not provided, the demo will ask for the IP address of a [droidcam](https://droidcam.app) camera.
99
+
100
+ Controls:
101
+
102
+ >Toggle the different features using the following keys:
103
+ >
104
+ >- ```h```: Show the estimated horizon line
105
+ >- ```u```: Show the estimated up-vectors
106
+ >- ```l```: Show the estimated latitude heatmap
107
+ >- ```c```: Show the confidence heatmap for the up-vectors and latitudes
108
+ >- ```d```: Show undistorted image, will overwrite the other features
109
+ >- ```g```: Shows a virtual grid of points
110
+ >- ```b```: Shows a virtual box object
111
+ >
112
+ >Change the camera model using the following keys:
113
+ >
114
+ >- ```1```: Pinhole -> Simple and fast
115
+ >- ```2```: Simple Radial -> For small distortions
116
+ >- ```3```: Simple Divisional -> For large distortions
117
+ >
118
+ >Press ```q``` to quit the demo.
119
+
120
+ </details>
121
+
122
+
123
+ <details>
124
+ <summary><b>[Load GeoCalib with torch hub - click to expand]</b></summary>
125
+
126
+ ```python
127
+ model = torch.hub.load("cvg/GeoCalib", "GeoCalib", trust_repo=True)
128
+ ```
129
+
130
+ </details>
131
+
132
+ ## Evaluation
133
+
134
+ The full evaluation and training code is provided in the single-image calibration library [`siclib`](siclib), which can be installed as:
135
+ ```bash
136
+ python -m pip install -e siclib
137
+ ```
138
+
139
+ Running the evaluation commands will write the results to `outputs/results/`.
140
+
141
+ ### LaMAR
142
+
143
+ Running the evaluation commands will download the dataset to ```data/lamar2k``` which will take around 400 MB of disk space.
144
+
145
+ <details>
146
+ <summary>[Evaluate GeoCalib]</summary>
147
+
148
+ To evaluate GeoCalib trained on the OpenPano dataset, run:
149
+
150
+ ```bash
151
+ python -m siclib.eval.lamar2k --conf geocalib-pinhole --tag geocalib --overwrite
152
+ ```
153
+
154
+ </details>
155
+
156
+ <details>
157
+ <summary>[Evaluate DeepCalib]</summary>
158
+
159
+ To evaluate DeepCalib trained on the OpenPano dataset, run:
160
+
161
+ ```bash
162
+ python -m siclib.eval.lamar2k --conf deepcalib --tag deepcalib --overwrite
163
+ ```
164
+
165
+ </details>
166
+
167
+ <details>
168
+ <summary>[Evaluate Perspective Fields]</summary>
169
+
170
+ Coming soon!
171
+
172
+ </details>
173
+
174
+ <details>
175
+ <summary>[Evaluate UVP]</summary>
176
+
177
+ To evaluate UVP, install the [VP-Estimation-with-Prior-Gravity](https://github.com/cvg/VP-Estimation-with-Prior-Gravity) under ```third_party/VP-Estimation-with-Prior-Gravity```. Then run:
178
+
179
+ ```bash
180
+ python -m siclib.eval.lamar2k --conf uvp --tag uvp --overwrite data.preprocessing.edge_divisible_by=null
181
+ ```
182
+
183
+ </details>
184
+
185
+ <details>
186
+ <summary>[Evaluate your own model]</summary>
187
+
188
+ If you have trained your own model, you can evaluate it by running:
189
+
190
+ ```bash
191
+ python -m siclib.eval.lamar2k --checkpoint <experiment name> --tag <eval name> --overwrite
192
+ ```
193
+
194
+ </details>
195
+
196
+
197
+ <details>
198
+ <summary>[Results]</summary>
199
+
200
+ Here are the results for the Area Under the Curve (AUC) for the roll, pitch and field of view (FoV) errors at 1/5/10 degrees for the different methods:
201
+
202
+ | Approach | Roll | Pitch | FoV |
203
+ | --------- | ------------------ | ------------------ | ------------------ |
204
+ | DeepCalib | 44.1 / 73.9 / 84.8 | 10.8 / 28.3 / 49.8 | 0.7 / 13.0 / 24.0 |
205
+ | ParamNet | 51.7 / 77.0 / 86.0 | 27.0 / 52.7 / 70.2 | 02.8 / 06.8 / 14.3 |
206
+ | UVP | 72.7 / 81.8 / 85.7 | 42.3 / 59.9 / 69.4 | 15.6 / 30.6 / 43.5 |
207
+ | GeoCalib | 86.4 / 92.5 / 95.0 | 55.0 / 76.9 / 86.2 | 19.1 / 41.5 / 60.0 |
208
+ </details>
209
+
210
+ ### MegaDepth
211
+
212
+ Running the evaluation commands will download the dataset to ```data/megadepth2k``` or ```data/memegadepth2k-radial``` which will take around 2.1 GB and 1.47 GB of disk space respectively.
213
+
214
+ <details>
215
+ <summary>[Evaluate GeoCalib]</summary>
216
+
217
+ To evaluate GeoCalib trained on the OpenPano dataset, run:
218
+
219
+ ```bash
220
+ python -m siclib.eval.megadepth2k --conf geocalib-pinhole --tag geocalib --overwrite
221
+ ```
222
+
223
+ To run the eval on the radial distorted images, run:
224
+
225
+ ```bash
226
+ python -m siclib.eval.megadepth2k_radial --conf geocalib-pinhole --tag geocalib --overwrite model.camera_model=simple_radial
227
+ ```
228
+
229
+ </details>
230
+
231
+ <details>
232
+ <summary>[Evaluate DeepCalib]</summary>
233
+
234
+ To evaluate DeepCalib trained on the OpenPano dataset, run:
235
+
236
+ ```bash
237
+ python -m siclib.eval.megadepth2k --conf deepcalib --tag deepcalib --overwrite
238
+ ```
239
+
240
+ </details>
241
+
242
+ <details>
243
+ <summary>[Evaluate Perspective Fields]</summary>
244
+
245
+ Coming soon!
246
+
247
+ </details>
248
+
249
+ <details>
250
+ <summary>[Evaluate UVP]</summary>
251
+
252
+ To evaluate UVP, install the [VP-Estimation-with-Prior-Gravity](https://github.com/cvg/VP-Estimation-with-Prior-Gravity) under ```third_party/VP-Estimation-with-Prior-Gravity```. Then run:
253
+
254
+ ```bash
255
+ python -m siclib.eval.megadepth2k --conf uvp --tag uvp --overwrite data.preprocessing.edge_divisible_by=null
256
+ ```
257
+
258
+ </details>
259
+
260
+ <details>
261
+ <summary>[Evaluate your own model]</summary>
262
+
263
+ If you have trained your own model, you can evaluate it by running:
264
+
265
+ ```bash
266
+ python -m siclib.eval.megadepth2k --checkpoint <experiment name> --tag <eval name> --overwrite
267
+ ```
268
+
269
+ </details>
270
+
271
+ <details>
272
+ <summary>[Results]</summary>
273
+
274
+ Here are the results for the Area Under the Curve (AUC) for the roll, pitch and field of view (FoV) errors at 1/5/10 degrees for the different methods:
275
+
276
+ | Approach | Roll | Pitch | FoV |
277
+ | --------- | ------------------ | ------------------ | ------------------ |
278
+ | DeepCalib | 34.6 / 65.4 / 79.4 | 11.9 / 27.8 / 44.8 | 5.6 / 12.1 / 22.9 |
279
+ | ParamNet | 43.4 / 70.7 / 82.2 | 15.4 / 34.5 / 53.3 | 3.2 / 10.1 / 21.3 |
280
+ | UVP | 69.2 / 81.6 / 86.9 | 21.6 / 36.2 / 47.4 | 8.2 / 18.7 / 29.8 |
281
+ | GeoCalib | 82.6 / 90.6 / 94.0 | 32.4 / 53.3 / 67.5 | 13.6 / 31.7 / 48.2 |
282
+ </details>
283
+
284
+ ### TartanAir
285
+
286
+ Running the evaluation commands will download the dataset to ```data/tartanair``` which will take around 1.85 GB of disk space.
287
+
288
+ <details>
289
+ <summary>[Evaluate GeoCalib]</summary>
290
+
291
+ To evaluate GeoCalib trained on the OpenPano dataset, run:
292
+
293
+ ```bash
294
+ python -m siclib.eval.tartanair --conf geocalib-pinhole --tag geocalib --overwrite
295
+ ```
296
+
297
+ </details>
298
+
299
+ <details>
300
+ <summary>[Evaluate DeepCalib]</summary>
301
+
302
+ To evaluate DeepCalib trained on the OpenPano dataset, run:
303
+
304
+ ```bash
305
+ python -m siclib.eval.tartanair --conf deepcalib --tag deepcalib --overwrite
306
+ ```
307
+
308
+ </details>
309
+
310
+ <details>
311
+ <summary>[Evaluate Perspective Fields]</summary>
312
+
313
+ Coming soon!
314
+
315
+ </details>
316
+
317
+ <details>
318
+ <summary>[Evaluate UVP]</summary>
319
+
320
+ To evaluate UVP, install the [VP-Estimation-with-Prior-Gravity](https://github.com/cvg/VP-Estimation-with-Prior-Gravity) under ```third_party/VP-Estimation-with-Prior-Gravity```. Then run:
321
+
322
+ ```bash
323
+ python -m siclib.eval.tartanair --conf uvp --tag uvp --overwrite data.preprocessing.edge_divisible_by=null
324
+ ```
325
+
326
+ </details>
327
+
328
+ <details>
329
+ <summary>[Evaluate your own model]</summary>
330
+
331
+ If you have trained your own model, you can evaluate it by running:
332
+
333
+ ```bash
334
+ python -m siclib.eval.tartanair --checkpoint <experiment name> --tag <eval name> --overwrite
335
+ ```
336
+
337
+ </details>
338
+
339
+ <details>
340
+ <summary>[Results]</summary>
341
+
342
+ Here are the results for the Area Under the Curve (AUC) for the roll, pitch and field of view (FoV) errors at 1/5/10 degrees for the different methods:
343
+
344
+ | Approach | Roll | Pitch | FoV |
345
+ | --------- | ------------------ | ------------------ | ------------------ |
346
+ | DeepCalib | 24.7 / 55.4 / 71.5 | 16.3 / 38.8 / 58.5 | 1.5 / 8.8 / 27.2 |
347
+ | ParamNet | 34.5 / 59.2 / 73.9 | 19.4 / 42.0 / 60.3 | 6.0 / 16.8 / 31.6 |
348
+ | UVP | 52.1 / 64.8 / 71.9 | 36.2 / 48.8 / 58.6 | 15.8 / 25.8 / 35.7 |
349
+ | GeoCalib | 71.3 / 83.8 / 89.8 | 38.2 / 62.9 / 76.6 | 14.1 / 30.4 / 47.6 |
350
+ </details>
351
+
352
+ ### Stanford2D3D
353
+
354
+ Before downloading and running the evaluation, you will need to agree to the [terms of use](https://docs.google.com/forms/d/e/1FAIpQLScFR0U8WEUtb7tgjOhhnl31OrkEs73-Y8bQwPeXgebqVKNMpQ/viewform?c=0&w=1) for the Stanford2D3D dataset.
355
+ Running the evaluation commands will download the dataset to ```data/stanford2d3d``` which will take around 885 MB of disk space.
356
+
357
+ <details>
358
+ <summary>[Evaluate GeoCalib]</summary>
359
+
360
+ To evaluate GeoCalib trained on the OpenPano dataset, run:
361
+
362
+ ```bash
363
+ python -m siclib.eval.stanford2d3d --conf geocalib-pinhole --tag geocalib --overwrite
364
+ ```
365
+
366
+ </details>
367
+
368
+ <details>
369
+ <summary>[Evaluate DeepCalib]</summary>
370
+
371
+ To evaluate DeepCalib trained on the OpenPano dataset, run:
372
+
373
+ ```bash
374
+ python -m siclib.eval.stanford2d3d --conf deepcalib --tag deepcalib --overwrite
375
+ ```
376
+
377
+ </details>
378
+
379
+ <details>
380
+ <summary>[Evaluate Perspective Fields]</summary>
381
+
382
+ Coming soon!
383
+
384
+ </details>
385
+
386
+ <details>
387
+ <summary>[Evaluate UVP]</summary>
388
+
389
+ To evaluate UVP, install the [VP-Estimation-with-Prior-Gravity](https://github.com/cvg/VP-Estimation-with-Prior-Gravity) under ```third_party/VP-Estimation-with-Prior-Gravity```. Then run:
390
+
391
+ ```bash
392
+ python -m siclib.eval.stanford2d3d --conf uvp --tag uvp --overwrite data.preprocessing.edge_divisible_by=null
393
+ ```
394
+
395
+ </details>
396
+
397
+ <details>
398
+ <summary>[Evaluate your own model]</summary>
399
+
400
+ If you have trained your own model, you can evaluate it by running:
401
+
402
+ ```bash
403
+ python -m siclib.eval.stanford2d3d --checkpoint <experiment name> --tag <eval name> --overwrite
404
+ ```
405
+
406
+ </details>
407
+
408
+ <details>
409
+ <summary>[Results]</summary>
410
+
411
+ Here are the results for the Area Under the Curve (AUC) for the roll, pitch and field of view (FoV) errors at 1/5/10 degrees for the different methods:
412
+
413
+ | Approach | Roll | Pitch | FoV |
414
+ | --------- | ------------------ | ------------------ | ------------------ |
415
+ | DeepCalib | 33.8 / 63.9 / 79.2 | 21.6 / 46.9 / 65.7 | 8.1 / 20.6 / 37.6 |
416
+ | ParamNet | 44.6 / 73.9 / 84.8 | 29.2 / 56.7 / 73.1 | 5.8 / 14.3 / 27.8 |
417
+ | UVP | 65.3 / 74.6 / 79.1 | 51.2 / 63.0 / 69.2 | 22.2 / 39.5 / 51.3 |
418
+ | GeoCalib | 83.1 / 91.8 / 94.8 | 52.3 / 74.8 / 84.6 | 17.4 / 40.0 / 59.4 |
419
+
420
+ </details>
421
+
422
+ ### Evaluation options
423
+
424
+ If you want to provide priors during the evaluation, you can add one or multiple of the following flags:
425
+
426
+ ```bash
427
+ python -m siclib.eval.<benchmark> --conf <config> \
428
+ --tag <tag> \
429
+ data.use_prior_focal=true \
430
+ data.use_prior_gravity=true \
431
+ data.use_prior_k1=true
432
+ ```
433
+
434
+ <details>
435
+ <summary>[Visual inspection]</summary>
436
+
437
+ To visually inspect the results of the evaluation, you can run the following command:
438
+
439
+ ```bash
440
+ python -m siclib.eval.inspect <benchmark> <one or multiple tags>
441
+
442
+ ```
443
+ For example, to inspect the results of the evaluation of the GeoCalib model on the LaMAR dataset, you can run:
444
+ ```bash
445
+ python -m siclib.eval.inspect lamar2k geocalib
446
+ ```
447
+ </details>
448
+
449
+ ## OpenPano Dataset
450
+
451
+ The OpenPano dataset is a new dataset for single-image calibration which contains about 2.8k panoramas from various sources, namely [HDRMAPS](https://hdrmaps.com/hdris/), [PolyHaven](https://polyhaven.com/hdris), and the [Laval Indoor HDR dataset](http://hdrdb.com/indoor/#presentation). While this dataset is smaller than previous ones, it is publicly available and it provides a better balance between indoor and outdoor scenes.
452
+
453
+ <details>
454
+ <summary>[Downloading and preparing the dataset]</summary>
455
+
456
+ In order to assemble the training set, first download the Laval dataset following the instructions on [the corresponding project page](http://hdrdb.com/indoor/#presentation) and place the panoramas in ```data/indoorDatasetCalibrated```. Then, tonemap the HDR images using the following command:
457
+
458
+ ```bash
459
+ python -m siclib.datasets.utils.tonemapping --hdr_dir data/indoorDatasetCalibrated --out_dir data/laval-tonemap
460
+ ```
461
+
462
+ We provide a script to download the PolyHaven and HDRMAPS panos. The script will create folders ```data/openpano/panoramas/{split}``` containing the panoramas specified by the ```{split}_panos.txt``` files. To run the script, execute the following commands:
463
+
464
+ ```bash
465
+ python -m siclib.datasets.utils.download_openpano --name openpano --laval_dir data/laval-tonemap
466
+ ```
467
+ Alternatively, you can download the PolyHaven and HDRMAPS panos from [here](https://cvg-data.inf.ethz.ch/GeoCalib_ECCV2024/).
468
+
469
+
470
+ After downloading the panoramas, you can create the training set by running the following command:
471
+
472
+ ```bash
473
+ python -m siclib.datasets.create_dataset_from_pano --config-name openpano
474
+ ```
475
+
476
+ The dataset creation can be sped up by using multiple workers and a GPU. To do so, add the following arguments to the command:
477
+
478
+ ```bash
479
+ python -m siclib.datasets.create_dataset_from_pano --config-name openpano n_workers=10 device=cuda
480
+ ```
481
+
482
+ This will create the training set in ```data/openpano/openpano``` with about 37k images for training, 2.1k for validation, and 2.1k for testing.
483
+
484
+ <details>
485
+ <summary>[Distorted OpenPano]</summary>
486
+
487
+ To create the OpenPano dataset with radial distortion, run the following command:
488
+
489
+ ```bash
490
+ python -m siclib.datasets.create_dataset_from_pano --config-name openpano_radial
491
+ ```
492
+
493
+ </details>
494
+
495
+ </details>
496
+
497
+ ## Training
498
+
499
+ As for the evaluation, the training code is provided in the single-image calibration library [`siclib`](siclib), which can be installed by:
500
+
501
+ ```bash
502
+ python -m pip install -e siclib
503
+ ```
504
+
505
+ Once the [OpenPano Dataset](#openpano-dataset) has been downloaded and prepared, we can train GeoCalib with it:
506
+
507
+ First download the pre-trained weights for the [MSCAN-B](https://cloud.tsinghua.edu.cn/d/c15b25a6745946618462/) backbone:
508
+
509
+ ```bash
510
+ mkdir weights
511
+ wget "https://cloud.tsinghua.edu.cn/d/c15b25a6745946618462/files/?p=%2Fmscan_b.pth&dl=1" -O weights/mscan_b.pth
512
+ ```
513
+
514
+ Then, start the training with the following command:
515
+
516
+ ```bash
517
+ python -m siclib.train geocalib-pinhole-openpano --conf geocalib --distributed
518
+ ```
519
+
520
+ Feel free to use any other experiment name. By default, the checkpoints will be written to ```outputs/training/```. The default batch size is 24 which requires 2x 4090 GPUs with 24GB of VRAM each. Configurations are managed by [Hydra](https://hydra.cc/) and can be overwritten from the command line.
521
+ For example, to train GeoCalib on a single GPU with a batch size of 5, run:
522
+
523
+ ```bash
524
+ python -m siclib.train geocalib-pinhole-openpano \
525
+ --conf geocalib \
526
+ data.train_batch_size=5 # for 1x 2080 GPU
527
+ ```
528
+
529
+ Be aware that this can impact the overall performance. You might need to adjust the learning rate and number of training steps accordingly.
530
+
531
+ If you want to log the training progress to [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://wandb.ai/), you can set the ```train.writer``` option:
532
+
533
+ ```bash
534
+ python -m siclib.train geocalib-pinhole-openpano \
535
+ --conf geocalib \
536
+ --distributed \
537
+ train.writer=tensorboard
538
+ ```
539
+
540
+ The model can then be evaluated using its experiment name:
541
+
542
+ ```bash
543
+ python -m siclib.eval.<benchmark> --checkpoint geocalib-pinhole-openpano \
544
+ --tag geocalib-retrained
545
+ ```
546
+
547
+ <details>
548
+ <summary>[Training DeepCalib]</summary>
549
+
550
+ To train DeepCalib on the OpenPano dataset, run:
551
+
552
+ ```bash
553
+ python -m siclib.train deepcalib-openpano --conf deepcalib --distributed
554
+ ```
555
+
556
+ Make sure that you have generated the [OpenPano Dataset](#openpano-dataset) with radial distortion or add
557
+ the flag ```data=openpano``` to the command to train on the pinhole images.
558
+
559
+ </details>
560
+
561
+ <details>
562
+ <summary>[Training Perspective Fields]</summary>
563
+
564
+ Coming soon!
565
+
566
+ </details>
567
+
568
+ ## BibTeX citation
569
+
570
+ If you use any ideas from the paper or code from this repo, please consider citing:
571
+
572
+ ```bibtex
573
+ @inproceedings{veicht2024geocalib,
574
+ author = {Alexander Veicht and
575
+ Paul-Edouard Sarlin and
576
+ Philipp Lindenberger and
577
+ Marc Pollefeys},
578
+ title = {{GeoCalib: Single-image Calibration with Geometric Optimization}},
579
+ booktitle = {ECCV},
580
+ year = {2024}
581
+ }
582
+ ```
583
+
584
+ ## License
585
+
586
+ The code is provided under the [Apache-2.0 License](LICENSE) while the weights of the trained model are provided under the [Creative Commons Attribution 4.0 International Public License](https://creativecommons.org/licenses/by/4.0/legalcode). Thanks to the authors of the [Laval Indoor HDR dataset](http://hdrdb.com/indoor/#presentation) for allowing this.
587
 
 
assets/fisheye-dog-pool.jpg ADDED
assets/fisheye-skyline.jpg ADDED

Git LFS Details

  • SHA256: 1f32ee048e0e9f931f9d42b19269419201834322f40a172367ebbca4752826a2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.35 MB
assets/pinhole-church.jpg ADDED
assets/pinhole-garden.jpg ADDED
assets/teaser.gif ADDED

Git LFS Details

  • SHA256: 2944bad746c92415438f65adb4362bfd4b150db50729f118266d86f638ac3d21
  • Pointer size: 133 Bytes
  • Size of remote file: 12 MB
demo.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:115886d06182eca375251eab9e301180e86465fd0ed152e917d43d7eb4cbd722
3
+ size 13275966
geocalib/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from geocalib.extractor import GeoCalib # noqa
4
+
5
+ formatter = logging.Formatter(
6
+ fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
7
+ )
8
+ handler = logging.StreamHandler()
9
+ handler.setFormatter(formatter)
10
+ handler.setLevel(logging.INFO)
11
+
12
+ logger = logging.getLogger(__name__)
13
+ logger.setLevel(logging.INFO)
14
+ logger.addHandler(handler)
15
+ logger.propagate = False
16
+
17
+ __module_name__ = __name__
geocalib/camera.py ADDED
@@ -0,0 +1,774 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of the pinhole, simple radial, and simple divisional camera models."""
2
+
3
+ from abc import abstractmethod
4
+ from typing import Dict, Optional, Tuple, Union
5
+
6
+ import torch
7
+ from torch.func import jacfwd, vmap
8
+ from torch.nn import functional as F
9
+
10
+ from geocalib.gravity import Gravity
11
+ from geocalib.misc import TensorWrapper, autocast
12
+ from geocalib.utils import deg2rad, focal2fov, fov2focal, rad2rotmat
13
+
14
+ # flake8: noqa: E741
15
+ # mypy: ignore-errors
16
+
17
+
18
+ class BaseCamera(TensorWrapper):
19
+ """Camera tensor class."""
20
+
21
+ eps = 1e-3
22
+
23
+ @autocast
24
+ def __init__(self, data: torch.Tensor):
25
+ """Camera parameters with shape (..., {w, h, fx, fy, cx, cy, *dist}).
26
+
27
+ Tensor convention: (..., {w, h, fx, fy, cx, cy, pitch, roll, *dist}) where
28
+ - w, h: image size in pixels
29
+ - fx, fy: focal lengths in pixels
30
+ - cx, cy: principal points in normalized image coordinates
31
+ - dist: distortion parameters
32
+
33
+ Args:
34
+ data (torch.Tensor): Camera parameters with shape (..., {6, 7, 8}).
35
+ """
36
+ # w, h, fx, fy, cx, cy, dist
37
+ assert data.shape[-1] in {6, 7, 8}, data.shape
38
+
39
+ pad = data.new_zeros(data.shape[:-1] + (8 - data.shape[-1],))
40
+ data = torch.cat([data, pad], -1) if data.shape[-1] != 8 else data
41
+ super().__init__(data)
42
+
43
+ @classmethod
44
+ def from_dict(cls, param_dict: Dict[str, torch.Tensor]) -> "BaseCamera":
45
+ """Create a Camera object from a dictionary of parameters.
46
+
47
+ Args:
48
+ param_dict (Dict[str, torch.Tensor]): Dictionary of parameters.
49
+
50
+ Returns:
51
+ Camera: Camera object.
52
+ """
53
+ for key, value in param_dict.items():
54
+ if not isinstance(value, torch.Tensor):
55
+ param_dict[key] = torch.tensor(value)
56
+
57
+ h, w = param_dict["height"], param_dict["width"]
58
+ cx, cy = param_dict.get("cx", w / 2), param_dict.get("cy", h / 2)
59
+
60
+ if "f" in param_dict:
61
+ f = param_dict["f"]
62
+ elif "vfov" in param_dict:
63
+ vfov = param_dict["vfov"]
64
+ f = fov2focal(vfov, h)
65
+ else:
66
+ raise ValueError("Focal length or vertical field of view must be provided.")
67
+
68
+ if "dist" in param_dict:
69
+ k1, k2 = param_dict["dist"][..., 0], param_dict["dist"][..., 1]
70
+ elif "k1_hat" in param_dict:
71
+ k1 = param_dict["k1_hat"] * (f / h) ** 2
72
+
73
+ k2 = param_dict.get("k2", torch.zeros_like(k1))
74
+ else:
75
+ k1 = param_dict.get("k1", torch.zeros_like(f))
76
+ k2 = param_dict.get("k2", torch.zeros_like(f))
77
+
78
+ fx, fy = f, f
79
+ if "scales" in param_dict:
80
+ fx = fx * param_dict["scales"][..., 0] / param_dict["scales"][..., 1]
81
+
82
+ params = torch.stack([w, h, fx, fy, cx, cy, k1, k2], dim=-1)
83
+ return cls(params)
84
+
85
+ def pinhole(self):
86
+ """Return the pinhole camera model."""
87
+ return self.__class__(self._data[..., :6])
88
+
89
+ @property
90
+ def size(self) -> torch.Tensor:
91
+ """Size (width height) of the images, with shape (..., 2)."""
92
+ return self._data[..., :2]
93
+
94
+ @property
95
+ def f(self) -> torch.Tensor:
96
+ """Focal lengths (fx, fy) with shape (..., 2)."""
97
+ return self._data[..., 2:4]
98
+
99
+ @property
100
+ def vfov(self) -> torch.Tensor:
101
+ """Vertical field of view in radians."""
102
+ return focal2fov(self.f[..., 1], self.size[..., 1])
103
+
104
+ @property
105
+ def hfov(self) -> torch.Tensor:
106
+ """Horizontal field of view in radians."""
107
+ return focal2fov(self.f[..., 0], self.size[..., 0])
108
+
109
+ @property
110
+ def c(self) -> torch.Tensor:
111
+ """Principal points (cx, cy) with shape (..., 2)."""
112
+ return self._data[..., 4:6]
113
+
114
+ @property
115
+ def K(self) -> torch.Tensor:
116
+ """Returns the self intrinsic matrix with shape (..., 3, 3)."""
117
+ shape = self.shape + (3, 3)
118
+ K = self._data.new_zeros(shape)
119
+ K[..., 0, 0] = self.f[..., 0]
120
+ K[..., 1, 1] = self.f[..., 1]
121
+ K[..., 0, 2] = self.c[..., 0]
122
+ K[..., 1, 2] = self.c[..., 1]
123
+ K[..., 2, 2] = 1
124
+ return K
125
+
126
+ def update_focal(self, delta: torch.Tensor, as_log: bool = False):
127
+ """Update the self parameters after changing the focal length."""
128
+ f = torch.exp(torch.log(self.f) + delta) if as_log else self.f + delta
129
+
130
+ # clamp focal length to a reasonable range for stability during training
131
+ min_f = fov2focal(self.new_ones(self.shape[0]) * deg2rad(150), self.size[..., 1])
132
+ max_f = fov2focal(self.new_ones(self.shape[0]) * deg2rad(5), self.size[..., 1])
133
+ min_f = min_f.unsqueeze(-1).expand(-1, 2)
134
+ max_f = max_f.unsqueeze(-1).expand(-1, 2)
135
+ f = f.clamp(min=min_f, max=max_f)
136
+
137
+ # make sure focal ration stays the same (avoid inplace operations)
138
+ fx = f[..., 1] * self.f[..., 0] / self.f[..., 1]
139
+ f = torch.stack([fx, f[..., 1]], -1)
140
+
141
+ dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape)
142
+ return self.__class__(torch.cat([self.size, f, self.c, dist], -1))
143
+
144
+ def scale(self, scales: Union[float, int, Tuple[Union[float, int]]]):
145
+ """Update the self parameters after resizing an image."""
146
+ scales = (scales, scales) if isinstance(scales, (int, float)) else scales
147
+ s = scales if isinstance(scales, torch.Tensor) else self.new_tensor(scales)
148
+
149
+ dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape)
150
+ return self.__class__(torch.cat([self.size * s, self.f * s, self.c * s, dist], -1))
151
+
152
+ def crop(self, pad: Tuple[float]):
153
+ """Update the self parameters after cropping an image."""
154
+ pad = pad if isinstance(pad, torch.Tensor) else self.new_tensor(pad)
155
+ size = self.size + pad.to(self.size)
156
+ c = self.c + pad.to(self.c) / 2
157
+
158
+ dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape)
159
+ return self.__class__(torch.cat([size, self.f, c, dist], -1))
160
+
161
+ @autocast
162
+ def in_image(self, p2d: torch.Tensor):
163
+ """Check if 2D points are within the image boundaries."""
164
+ assert p2d.shape[-1] == 2
165
+ size = self.size.unsqueeze(-2)
166
+ return torch.all((p2d >= 0) & (p2d <= (size - 1)), -1)
167
+
168
+ @autocast
169
+ def project(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
170
+ """Project 3D points into the self plane and check for visibility."""
171
+ z = p3d[..., -1]
172
+ valid = z > self.eps
173
+ z = z.clamp(min=self.eps)
174
+ p2d = p3d[..., :-1] / z.unsqueeze(-1)
175
+ return p2d, valid
176
+
177
+ def J_project(self, p3d: torch.Tensor):
178
+ """Jacobian of the projection function."""
179
+ x, y, z = p3d[..., 0], p3d[..., 1], p3d[..., 2]
180
+ zero = torch.zeros_like(z)
181
+ z = z.clamp(min=self.eps)
182
+ J = torch.stack([1 / z, zero, -x / z**2, zero, 1 / z, -y / z**2], dim=-1)
183
+ J = J.reshape(p3d.shape[:-1] + (2, 3))
184
+ return J # N x 2 x 3
185
+
186
+ def undo_scale_crop(self, data: Dict[str, torch.Tensor]):
187
+ """Undo transforms done during scaling and cropping."""
188
+ camera = self.crop(-data["crop_pad"]) if "crop_pad" in data else self
189
+ return camera.scale(1.0 / data["scales"])
190
+
191
+ @abstractmethod
192
+ def distort(self, pts: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
193
+ """Distort normalized 2D coordinates and check for validity of the distortion model."""
194
+ raise NotImplementedError("distort() must be implemented.")
195
+
196
+ def J_distort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
197
+ """Jacobian of the distortion function."""
198
+ if wrt == "scale2pts": # (..., 2)
199
+ J = [
200
+ vmap(jacfwd(lambda x: self[idx].distort(x, return_scale=True)[0]))(p2d[idx])[None]
201
+ for idx in range(p2d.shape[0])
202
+ ]
203
+
204
+ return torch.cat(J, dim=0).squeeze(-3, -2)
205
+
206
+ elif wrt == "scale2dist": # (..., 1)
207
+ J = []
208
+ for idx in range(p2d.shape[0]): # loop to batch pts dimension
209
+
210
+ def func(x):
211
+ params = torch.cat([self._data[idx, :6], x[None]], -1)
212
+ return self.__class__(params).distort(p2d[idx], return_scale=True)[0]
213
+
214
+ J.append(vmap(jacfwd(func))(self[idx].dist))
215
+
216
+ return torch.cat(J, dim=0)
217
+
218
+ else:
219
+ raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
220
+
221
+ @abstractmethod
222
+ def undistort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]:
223
+ """Undistort normalized 2D coordinates and check for validity of the distortion model."""
224
+ raise NotImplementedError("undistort() must be implemented.")
225
+
226
+ def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
227
+ """Jacobian of the undistortion function."""
228
+ if wrt == "pts": # (..., 2, 2)
229
+ J = [
230
+ vmap(jacfwd(lambda x: self[idx].undistort(x)[0]))(p2d[idx])[None]
231
+ for idx in range(p2d.shape[0])
232
+ ]
233
+
234
+ return torch.cat(J, dim=0).squeeze(-3)
235
+
236
+ elif wrt == "dist": # (..., 1)
237
+ J = []
238
+ for batch_idx in range(p2d.shape[0]): # loop to batch pts dimension
239
+
240
+ def func(x):
241
+ params = torch.cat([self._data[batch_idx, :6], x[None]], -1)
242
+ return self.__class__(params).undistort(p2d[batch_idx])[0]
243
+
244
+ J.append(vmap(jacfwd(func))(self[batch_idx].dist))
245
+
246
+ return torch.cat(J, dim=0)
247
+ else:
248
+ raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
249
+
250
+ @autocast
251
+ def up_projection_offset(self, p2d: torch.Tensor) -> torch.Tensor:
252
+ """Compute the offset for the up-projection."""
253
+ return self.J_distort(p2d, wrt="scale2pts") # (B, N, 2)
254
+
255
+ def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
256
+ """Jacobian of the distortion offset for up-projection."""
257
+ if wrt == "uv": # (B, N, 2, 2)
258
+ J = [
259
+ vmap(jacfwd(lambda x: self[idx].up_projection_offset(x)[0, 0]))(p2d[idx])[None]
260
+ for idx in range(p2d.shape[0])
261
+ ]
262
+
263
+ return torch.cat(J, dim=0)
264
+
265
+ elif wrt == "dist": # (B, N, 2)
266
+ J = []
267
+ for batch_idx in range(p2d.shape[0]): # loop to batch pts dimension
268
+
269
+ def func(x):
270
+ params = torch.cat([self._data[batch_idx, :6], x[None]], -1)[None]
271
+ return self.__class__(params).up_projection_offset(p2d[batch_idx][None])
272
+
273
+ J.append(vmap(jacfwd(func))(self[batch_idx].dist))
274
+
275
+ return torch.cat(J, dim=0).squeeze(1)
276
+ else:
277
+ raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
278
+
279
+ @autocast
280
+ def denormalize(self, p2d: torch.Tensor) -> torch.Tensor:
281
+ """Convert normalized 2D coordinates into pixel coordinates."""
282
+ return p2d * self.f.unsqueeze(-2) + self.c.unsqueeze(-2)
283
+
284
+ def J_denormalize(self):
285
+ """Jacobian of the denormalization function."""
286
+ return torch.diag_embed(self.f) # ..., 2 x 2
287
+
288
+ @autocast
289
+ def normalize(self, p2d: torch.Tensor) -> torch.Tensor:
290
+ """Convert pixel coordinates into normalized 2D coordinates."""
291
+ return (p2d - self.c.unsqueeze(-2)) / (self.f.unsqueeze(-2))
292
+
293
+ def J_normalize(self, p2d: torch.Tensor, wrt: str = "f"):
294
+ """Jacobian of the normalization function."""
295
+ # ... x N x 2 x 2
296
+ if wrt == "f":
297
+ J_f = -(p2d - self.c.unsqueeze(-2)) / ((self.f.unsqueeze(-2)) ** 2)
298
+ return torch.diag_embed(J_f)
299
+ elif wrt == "pts":
300
+ J_pts = 1 / self.f
301
+ return torch.diag_embed(J_pts)
302
+ else:
303
+ raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
304
+
305
+ def pixel_coordinates(self) -> torch.Tensor:
306
+ """Pixel coordinates in self frame.
307
+
308
+ Returns:
309
+ torch.Tensor: Pixel coordinates as a tensor of shape (B, h * w, 2).
310
+ """
311
+ w, h = self.size[0].unbind(-1)
312
+ h, w = h.round().to(int), w.round().to(int)
313
+
314
+ # create grid
315
+ x = torch.arange(0, w, dtype=self.dtype, device=self.device)
316
+ y = torch.arange(0, h, dtype=self.dtype, device=self.device)
317
+ x, y = torch.meshgrid(x, y, indexing="xy")
318
+ xy = torch.stack((x, y), dim=-1).reshape(-1, 2) # shape (h * w, 2)
319
+
320
+ # add batch dimension (normalize() would broadcast but we make it explicit)
321
+ B = self.shape[0]
322
+ xy = xy.unsqueeze(0).expand(B, -1, -1) # if B > 0 else xy
323
+
324
+ return xy.to(self.device).to(self.dtype)
325
+
326
+ @autocast
327
+ def pixel_bearing_many(self, p3d: torch.Tensor) -> torch.Tensor:
328
+ """Get the bearing vectors of pixel coordinates by normalizing them."""
329
+ return F.normalize(p3d, dim=-1)
330
+
331
+ @autocast
332
+ def world2image(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
333
+ """Transform 3D points into 2D pixel coordinates."""
334
+ p2d, visible = self.project(p3d)
335
+ p2d, mask = self.distort(p2d)
336
+ p2d = self.denormalize(p2d)
337
+ valid = visible & mask & self.in_image(p2d)
338
+ return p2d, valid
339
+
340
+ @autocast
341
+ def J_world2image(self, p3d: torch.Tensor):
342
+ """Jacobian of the world2image function."""
343
+ p2d_proj, valid = self.project(p3d)
344
+
345
+ J_dnorm = self.J_denormalize()
346
+ J_dist = self.J_distort(p2d_proj)
347
+ J_proj = self.J_project(p3d)
348
+
349
+ J = torch.einsum("...ij,...jk,...kl->...il", J_dnorm, J_dist, J_proj)
350
+ return J, valid
351
+
352
+ @autocast
353
+ def image2world(self, p2d: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
354
+ """Transform point in the image plane to 3D world coordinates."""
355
+ p2d = self.normalize(p2d)
356
+ p2d, valid = self.undistort(p2d)
357
+ ones = p2d.new_ones(p2d.shape[:-1] + (1,))
358
+ p3d = torch.cat([p2d, ones], -1)
359
+ return p3d, valid
360
+
361
+ @autocast
362
+ def J_image2world(self, p2d: torch.Tensor, wrt: str = "f") -> Tuple[torch.Tensor, torch.Tensor]:
363
+ """Jacobian of the image2world function."""
364
+ if wrt == "dist":
365
+ p2d_norm = self.normalize(p2d)
366
+ return self.J_undistort(p2d_norm, wrt)
367
+ elif wrt == "f":
368
+ J_norm2f = self.J_normalize(p2d, wrt)
369
+ p2d_norm = self.normalize(p2d)
370
+ J_dist2norm = self.J_undistort(p2d_norm, "pts")
371
+
372
+ return torch.einsum("...ij,...jk->...ik", J_dist2norm, J_norm2f)
373
+ else:
374
+ raise ValueError(f"Unknown wrt: {wrt}")
375
+
376
+ @autocast
377
+ def undistort_image(self, img: torch.Tensor) -> torch.Tensor:
378
+ """Undistort an image using the distortion model."""
379
+ assert self.shape[0] == 1, "Batch size must be 1."
380
+ W, H = self.size.unbind(-1)
381
+ H, W = H.int().item(), W.int().item()
382
+
383
+ x, y = torch.meshgrid(torch.arange(0, W), torch.arange(0, H), indexing="xy")
384
+ coords = torch.stack((x, y), dim=-1).reshape(-1, 2)
385
+
386
+ p3d, _ = self.pinhole().image2world(coords.to(self.device).to(self.dtype))
387
+ p2d, _ = self.world2image(p3d)
388
+
389
+ mapx, mapy = p2d[..., 0].reshape((1, H, W)), p2d[..., 1].reshape((1, H, W))
390
+ grid = torch.stack((mapx, mapy), dim=-1)
391
+ grid = 2.0 * grid / torch.tensor([W - 1, H - 1]).to(grid) - 1
392
+ return F.grid_sample(img, grid, align_corners=True)
393
+
394
+ def get_img_from_pano(
395
+ self,
396
+ pano_img: torch.Tensor,
397
+ gravity: Gravity,
398
+ yaws: torch.Tensor = 0.0,
399
+ resize_factor: Optional[torch.Tensor] = None,
400
+ ) -> torch.Tensor:
401
+ """Render an image from a panorama.
402
+
403
+ Args:
404
+ pano_img (torch.Tensor): Panorama image of shape (3, H, W) in [0, 1].
405
+ gravity (Gravity): Gravity direction of the camera.
406
+ yaws (torch.Tensor | list, optional): Yaw angle in radians. Defaults to 0.0.
407
+ resize_factor (torch.Tensor, optional): Resize the panorama to be a multiple of the
408
+ field of view. Defaults to 1.
409
+
410
+ Returns:
411
+ torch.Tensor: Image rendered from the panorama.
412
+ """
413
+ B = self.shape[0]
414
+ if B > 0:
415
+ assert self.size[..., 0].unique().shape[0] == 1, "All images must have the same width."
416
+ assert self.size[..., 1].unique().shape[0] == 1, "All images must have the same height."
417
+
418
+ w, h = self.size[0].unbind(-1)
419
+ h, w = h.round().to(int), w.round().to(int)
420
+
421
+ if isinstance(yaws, (int, float)):
422
+ yaws = [yaws]
423
+ if isinstance(resize_factor, (int, float)):
424
+ resize_factor = [resize_factor]
425
+
426
+ yaws = (
427
+ yaws.to(self.dtype).to(self.device)
428
+ if isinstance(yaws, torch.Tensor)
429
+ else self.new_tensor(yaws)
430
+ )
431
+
432
+ if isinstance(resize_factor, torch.Tensor):
433
+ resize_factor = resize_factor.to(self.dtype).to(self.device)
434
+ elif resize_factor is not None:
435
+ resize_factor = self.new_tensor(resize_factor)
436
+
437
+ assert isinstance(pano_img, torch.Tensor), "Panorama image must be a torch.Tensor."
438
+ pano_img = pano_img if pano_img.dim() == 4 else pano_img.unsqueeze(0) # B x H x W x 3
439
+
440
+ pano_imgs = []
441
+ for i, yaw in enumerate(yaws):
442
+ if resize_factor is not None:
443
+ # resize the panorama such that the fov of the panorama has the same height as the
444
+ # image
445
+ vfov = self.vfov[i] if B != 0 else self.vfov
446
+ scale = torch.pi / float(vfov) * float(h) / pano_img.shape[0] * resize_factor[i]
447
+ pano_shape = (int(pano_img.shape[0] * scale), int(pano_img.shape[1] * scale))
448
+
449
+ mode = "bicubic" if scale >= 1 else "area"
450
+ resized_pano = F.interpolate(pano_img, size=pano_shape, mode=mode)
451
+ else:
452
+ # make sure to copy: resized_pano = pano_img
453
+ resized_pano = pano_img
454
+ pano_shape = pano_img.shape[-2:][::-1]
455
+
456
+ pano_imgs.append((resized_pano, pano_shape))
457
+
458
+ xy = self.pixel_coordinates()
459
+ uv1, _ = self.image2world(xy)
460
+ bearings = self.pixel_bearing_many(uv1)
461
+
462
+ # rotate bearings
463
+ R_yaw = rad2rotmat(self.new_zeros(yaw.shape), self.new_zeros(yaw.shape), yaws)
464
+ rotated_bearings = bearings @ gravity.R @ R_yaw
465
+
466
+ # spherical coordinates
467
+ lon = torch.atan2(rotated_bearings[..., 0], rotated_bearings[..., 2])
468
+ lat = torch.atan2(
469
+ rotated_bearings[..., 1], torch.norm(rotated_bearings[..., [0, 2]], dim=-1)
470
+ )
471
+
472
+ images = []
473
+ for idx, (resized_pano, pano_shape) in enumerate(pano_imgs):
474
+ min_lon, max_lon = -torch.pi, torch.pi
475
+ min_lat, max_lat = -torch.pi / 2.0, torch.pi / 2.0
476
+ min_x, max_x = 0, pano_shape[0] - 1.0
477
+ min_y, max_y = 0, pano_shape[1] - 1.0
478
+
479
+ # map Spherical Coordinates to Panoramic Coordinates
480
+ nx = (lon[idx] - min_lon) / (max_lon - min_lon) * (max_x - min_x) + min_x
481
+ ny = (lat[idx] - min_lat) / (max_lat - min_lat) * (max_y - min_y) + min_y
482
+
483
+ # reshape and cast to numpy for remap
484
+ mapx, mapy = nx.reshape((1, h, w)), ny.reshape((1, h, w))
485
+
486
+ grid = torch.stack((mapx, mapy), dim=-1) # Add batch dimension
487
+ # Normalize to [-1, 1]
488
+ grid = 2.0 * grid / torch.tensor([pano_shape[-2] - 1, pano_shape[-1] - 1]).to(grid) - 1
489
+ # Apply grid sample
490
+ image = F.grid_sample(resized_pano, grid, align_corners=True)
491
+ images.append(image)
492
+
493
+ return torch.concatenate(images, 0) if B > 0 else images[0]
494
+
495
+ def __repr__(self):
496
+ """Print the Camera object."""
497
+ return f"{self.__class__.__name__} {self.shape} {self.dtype} {self.device}"
498
+
499
+
500
+ class Pinhole(BaseCamera):
501
+ """Implementation of the pinhole camera model.
502
+
503
+ Use this model for undistorted images.
504
+ """
505
+
506
+ def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
507
+ """Distort normalized 2D coordinates."""
508
+ if return_scale:
509
+ return p2d.new_ones(p2d.shape[:-1] + (1,))
510
+
511
+ return p2d, p2d.new_ones((p2d.shape[0], 1)).bool()
512
+
513
+ def J_distort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
514
+ """Jacobian of the distortion function."""
515
+ if wrt == "pts":
516
+ return torch.eye(2, device=p2d.device, dtype=p2d.dtype).expand(p2d.shape[:-1] + (2, 2))
517
+
518
+ raise ValueError(f"Unknown wrt: {wrt}")
519
+
520
+ def undistort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]:
521
+ """Undistort normalized 2D coordinates."""
522
+ return pts, pts.new_ones((pts.shape[0], 1)).bool()
523
+
524
+ def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
525
+ """Jacobian of the undistortion function."""
526
+ if wrt == "pts":
527
+ return torch.eye(2, device=p2d.device, dtype=p2d.dtype).expand(p2d.shape[:-1] + (2, 2))
528
+
529
+ raise ValueError(f"Unknown wrt: {wrt}")
530
+
531
+ def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
532
+ """Jacobian of the up-projection offset."""
533
+ if wrt == "uv":
534
+ return torch.zeros(p2d.shape[:-1] + (2, 2), device=p2d.device, dtype=p2d.dtype)
535
+
536
+ raise ValueError(f"Unknown wrt: {wrt}")
537
+
538
+
539
+ class SimpleRadial(BaseCamera):
540
+ """Implementation of the simple radial camera model.
541
+
542
+ Use this model for weakly distorted images.
543
+
544
+ The distortion model is 1 + k1 * r^2 where r^2 = x^2 + y^2.
545
+ The undistortion model is 1 - k1 * r^2 estimated as in
546
+ "An Exact Formula for Calculating Inverse Radial Lens Distortions" by Pierre Drap.
547
+ """
548
+
549
+ @property
550
+ def dist(self) -> torch.Tensor:
551
+ """Distortion parameters, with shape (..., 1)."""
552
+ return self._data[..., 6:]
553
+
554
+ @property
555
+ def k1(self) -> torch.Tensor:
556
+ """Distortion parameters, with shape (...)."""
557
+ return self._data[..., 6]
558
+
559
+ def update_dist(self, delta: torch.Tensor, dist_range: Tuple[float, float] = (-0.7, 0.7)):
560
+ """Update the self parameters after changing the k1 distortion parameter."""
561
+ delta_dist = self.new_ones(self.dist.shape) * delta
562
+ dist = (self.dist + delta_dist).clamp(*dist_range)
563
+ data = torch.cat([self.size, self.f, self.c, dist], -1)
564
+ return self.__class__(data)
565
+
566
+ @autocast
567
+ def check_valid(self, p2d: torch.Tensor) -> torch.Tensor:
568
+ """Check if the distorted points are valid."""
569
+ return p2d.new_ones(p2d.shape[:-1]).bool()
570
+
571
+ def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
572
+ """Distort normalized 2D coordinates and check for validity of the distortion model."""
573
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
574
+ radial = 1 + self.k1[..., None, None] * r2
575
+
576
+ if return_scale:
577
+ return radial, None
578
+
579
+ return p2d * radial, self.check_valid(p2d)
580
+
581
+ def J_distort(self, p2d: torch.Tensor, wrt: str = "pts"):
582
+ """Jacobian of the distortion function."""
583
+ if wrt == "scale2dist": # (..., 1)
584
+ return torch.sum(p2d**2, -1, keepdim=True)
585
+ elif wrt == "scale2pts": # (..., 2)
586
+ return 2 * self.k1[..., None, None] * p2d
587
+ else:
588
+ return super().J_distort(p2d, wrt)
589
+
590
+ @autocast
591
+ def undistort(self, p2d: torch.Tensor) -> Tuple[torch.Tensor]:
592
+ """Undistort normalized 2D coordinates and check for validity of the distortion model."""
593
+ b1 = -self.k1[..., None, None]
594
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
595
+ radial = 1 + b1 * r2
596
+ return p2d * radial, self.check_valid(p2d)
597
+
598
+ @autocast
599
+ def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
600
+ """Jacobian of the undistortion function."""
601
+ b1 = -self.k1[..., None, None]
602
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
603
+ if wrt == "dist":
604
+ return -r2 * p2d
605
+ elif wrt == "pts":
606
+ radial = 1 + b1 * r2
607
+ radial_diag = torch.diag_embed(radial.expand(radial.shape[:-1] + (2,)))
608
+ ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2)
609
+ return (2 * b1[..., None] * ppT) + radial_diag
610
+ else:
611
+ return super().J_undistort(p2d, wrt)
612
+
613
+ def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
614
+ """Jacobian of the up-projection offset."""
615
+ if wrt == "uv": # (..., 2, 2)
616
+ return torch.diag_embed((2 * self.k1[..., None, None]).expand(p2d.shape[:-1] + (2,)))
617
+ elif wrt == "dist":
618
+ return 2 * p2d # (..., 2)
619
+ else:
620
+ return super().J_up_projection_offset(p2d, wrt)
621
+
622
+
623
+ class SimpleDivisional(BaseCamera):
624
+ """Implementation of the simple divisional camera model.
625
+
626
+ Use this model for strongly distorted images.
627
+
628
+ The distortion model is (1 - sqrt(1 - 4 * k1 * r^2)) / (2 * k1 * r^2) where r^2 = x^2 + y^2.
629
+ The undistortion model is 1 / (1 + k1 * r^2).
630
+ """
631
+
632
+ @property
633
+ def dist(self) -> torch.Tensor:
634
+ """Distortion parameters, with shape (..., 1)."""
635
+ return self._data[..., 6:]
636
+
637
+ @property
638
+ def k1(self) -> torch.Tensor:
639
+ """Distortion parameters, with shape (...)."""
640
+ return self._data[..., 6]
641
+
642
+ def update_dist(self, delta: torch.Tensor, dist_range: Tuple[float, float] = (-3.0, 3.0)):
643
+ """Update the self parameters after changing the k1 distortion parameter."""
644
+ delta_dist = self.new_ones(self.dist.shape) * delta
645
+ dist = (self.dist + delta_dist).clamp(*dist_range)
646
+ data = torch.cat([self.size, self.f, self.c, dist], -1)
647
+ return self.__class__(data)
648
+
649
+ @autocast
650
+ def check_valid(self, p2d: torch.Tensor) -> torch.Tensor:
651
+ """Check if the distorted points are valid."""
652
+ return p2d.new_ones(p2d.shape[:-1]).bool()
653
+
654
+ def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
655
+ """Distort normalized 2D coordinates and check for validity of the distortion model."""
656
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
657
+ radial = 1 - torch.sqrt((1 - 4 * self.k1[..., None, None] * r2).clamp(min=0))
658
+ denom = 2 * self.k1[..., None, None] * r2
659
+
660
+ ones = radial.new_ones(radial.shape)
661
+ radial = torch.where(denom == 0, ones, radial / denom.masked_fill(denom == 0, 1e6))
662
+
663
+ if return_scale:
664
+ return radial, None
665
+
666
+ return p2d * radial, self.check_valid(p2d)
667
+
668
+ def J_distort(self, p2d: torch.Tensor, wrt: str = "pts"):
669
+ """Jacobian of the distortion function."""
670
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
671
+ t0 = torch.sqrt((1 - 4 * self.k1[..., None, None] * r2).clamp(min=1e-6))
672
+ if wrt == "scale2pts": # (B, N, 2)
673
+ d1 = t0 * 2 * r2
674
+ d2 = self.k1[..., None, None] * r2**2
675
+ denom = d1 * d2
676
+ return p2d * (4 * d2 - (1 - t0) * d1) / denom.masked_fill(denom == 0, 1e6)
677
+
678
+ elif wrt == "scale2dist":
679
+ d1 = 2 * self.k1[..., None, None] * t0
680
+ d2 = 2 * r2 * self.k1[..., None, None] ** 2
681
+ denom = d1 * d2
682
+ return (2 * d2 - (1 - t0) * d1) / denom.masked_fill(denom == 0, 1e6)
683
+
684
+ else:
685
+ return super().J_distort(p2d, wrt)
686
+
687
+ @autocast
688
+ def undistort(self, p2d: torch.Tensor) -> Tuple[torch.Tensor]:
689
+ """Undistort normalized 2D coordinates and check for validity of the distortion model."""
690
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
691
+ denom = 1 + self.k1[..., None, None] * r2
692
+ radial = 1 / denom.masked_fill(denom == 0, 1e6)
693
+ return p2d * radial, self.check_valid(p2d)
694
+
695
+ def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
696
+ """Jacobian of the undistortion function."""
697
+ # return super().J_undistort(p2d, wrt)
698
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
699
+ k1 = self.k1[..., None, None]
700
+ if wrt == "dist":
701
+ denom = (1 + k1 * r2) ** 2
702
+ return -r2 / denom.masked_fill(denom == 0, 1e6) * p2d
703
+ elif wrt == "pts":
704
+ t0 = 1 + k1 * r2
705
+ t0 = t0.masked_fill(t0 == 0, 1e6)
706
+ ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2)
707
+ J = torch.diag_embed((1 / t0).expand(p2d.shape[:-1] + (2,)))
708
+ return J - 2 * k1[..., None] * ppT / t0[..., None] ** 2 # (..., N, 2, 2)
709
+
710
+ else:
711
+ return super().J_undistort(p2d, wrt)
712
+
713
+ def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
714
+ """Jacobian of the up-projection offset.
715
+
716
+ func(uv, dist) = 4 / (2 * norm2(uv)^2 * (1-4*k1*norm2(uv)^2)^0.5) * uv
717
+ - (1-(1-4*k1*norm2(uv)^2)^0.5) / (k1 * norm2(uv)^4) * uv
718
+ """
719
+ k1 = self.k1[..., None, None]
720
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
721
+ t0 = (1 - 4 * k1 * r2).clamp(min=1e-6)
722
+ t1 = torch.sqrt(t0)
723
+ if wrt == "dist":
724
+ denom = 4 * t0 ** (3 / 2)
725
+ denom = denom.masked_fill(denom == 0, 1e6)
726
+ J = 16 / denom
727
+
728
+ denom = r2 * t1 * k1
729
+ denom = denom.masked_fill(denom == 0, 1e6)
730
+ J = J - 2 / denom
731
+
732
+ denom = (r2 * k1) ** 2
733
+ denom = denom.masked_fill(denom == 0, 1e6)
734
+ J = J + (1 - t1) / denom
735
+
736
+ return J * p2d
737
+ elif wrt == "uv":
738
+ # ! unstable (gradient checker might fail), rewrite to use single division (by denom)
739
+ ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2)
740
+
741
+ denom = 2 * r2 * t1
742
+ denom = denom.masked_fill(denom == 0, 1e6)
743
+ J = torch.diag_embed((4 / denom).expand(p2d.shape[:-1] + (2,)))
744
+
745
+ denom = 4 * t1 * r2**2
746
+ denom = denom.masked_fill(denom == 0, 1e6)
747
+ J = J - 16 / denom[..., None] * ppT
748
+
749
+ denom = 4 * r2 * t0 ** (3 / 2)
750
+ denom = denom.masked_fill(denom == 0, 1e6)
751
+ J = J + (32 * k1[..., None]) / denom[..., None] * ppT
752
+
753
+ denom = r2**2 * t1
754
+ denom = denom.masked_fill(denom == 0, 1e6)
755
+ J = J - 4 / denom[..., None] * ppT
756
+
757
+ denom = k1 * r2**3
758
+ denom = denom.masked_fill(denom == 0, 1e6)
759
+ J = J + (4 * (1 - t1) / denom)[..., None] * ppT
760
+
761
+ denom = k1 * r2**2
762
+ denom = denom.masked_fill(denom == 0, 1e6)
763
+ J = J - torch.diag_embed(((1 - t1) / denom).expand(p2d.shape[:-1] + (2,)))
764
+
765
+ return J
766
+ else:
767
+ return super().J_up_projection_offset(p2d, wrt)
768
+
769
+
770
+ camera_models = {
771
+ "pinhole": Pinhole,
772
+ "simple_radial": SimpleRadial,
773
+ "simple_divisional": SimpleDivisional,
774
+ }
geocalib/extractor.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Simple interface for GeoCalib model."""
2
+
3
+ from pathlib import Path
4
+ from typing import Dict, Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn.functional import interpolate
9
+
10
+ from geocalib.camera import BaseCamera
11
+ from geocalib.geocalib import GeoCalib as Model
12
+ from geocalib.utils import ImagePreprocessor, load_image
13
+
14
+
15
+ class GeoCalib(nn.Module):
16
+ """Simple interface for GeoCalib model."""
17
+
18
+ def __init__(self, weights: str = "pinhole"):
19
+ """Initialize the model with optional config overrides.
20
+
21
+ Args:
22
+ weights (str): trained variant, "pinhole" (default) or "distorted".
23
+ """
24
+ super().__init__()
25
+ if weights == "pinhole":
26
+ url = "https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-pinhole.tar"
27
+ elif weights == "distorted":
28
+ url = (
29
+ "https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-simple_radial.tar"
30
+ )
31
+ else:
32
+ raise ValueError(f"Unknown weights: {weights}")
33
+
34
+ # load checkpoint
35
+ model_dir = f"{torch.hub.get_dir()}/geocalib"
36
+ state_dict = torch.hub.load_state_dict_from_url(
37
+ url, model_dir, map_location="cpu", file_name=f"{weights}.tar"
38
+ )
39
+
40
+ self.model = Model()
41
+ self.model.flexible_load(state_dict["model"])
42
+ self.model.eval()
43
+
44
+ self.image_processor = ImagePreprocessor({"resize": 320, "edge_divisible_by": 32})
45
+
46
+ def load_image(self, path: Path) -> torch.Tensor:
47
+ """Load image from path."""
48
+ return load_image(path)
49
+
50
+ def _post_process(
51
+ self, camera: BaseCamera, img_data: dict[str, torch.Tensor], out: dict[str, torch.Tensor]
52
+ ) -> tuple[BaseCamera, dict[str, torch.Tensor]]:
53
+ """Post-process model output by undoing scaling and cropping."""
54
+ camera = camera.undo_scale_crop(img_data)
55
+
56
+ w, h = camera.size.unbind(-1)
57
+ h = h[0].round().int().item()
58
+ w = w[0].round().int().item()
59
+
60
+ for k in ["latitude_field", "up_field"]:
61
+ out[k] = interpolate(out[k], size=(h, w), mode="bilinear")
62
+ for k in ["up_confidence", "latitude_confidence"]:
63
+ out[k] = interpolate(out[k][:, None], size=(h, w), mode="bilinear")[:, 0]
64
+
65
+ inverse_scales = 1.0 / img_data["scales"]
66
+ zero = camera.new_zeros(camera.f.shape[0])
67
+ out["focal_uncertainty"] = out.get("focal_uncertainty", zero) * inverse_scales[1]
68
+ return camera, out
69
+
70
+ @torch.no_grad()
71
+ def calibrate(
72
+ self,
73
+ img: torch.Tensor,
74
+ camera_model: str = "pinhole",
75
+ priors: Optional[Dict[str, torch.Tensor]] = None,
76
+ shared_intrinsics: bool = False,
77
+ ) -> Dict[str, torch.Tensor]:
78
+ """Perform calibration with online resizing.
79
+
80
+ Assumes input image is in range [0, 1] and in RGB format.
81
+
82
+ Args:
83
+ img (torch.Tensor): Input image, shape (C, H, W) or (1, C, H, W)
84
+ camera_model (str, optional): Camera model. Defaults to "pinhole".
85
+ priors (Dict[str, torch.Tensor], optional): Prior parameters. Defaults to {}.
86
+ shared_intrinsics (bool, optional): Whether to share intrinsics. Defaults to False.
87
+
88
+ Returns:
89
+ Dict[str, torch.Tensor]: camera and gravity vectors and uncertainties.
90
+ """
91
+ if len(img.shape) == 3:
92
+ img = img[None] # add batch dim
93
+ if not shared_intrinsics:
94
+ assert len(img.shape) == 4 and img.shape[0] == 1
95
+
96
+ img_data = self.image_processor(img)
97
+
98
+ if priors is None:
99
+ priors = {}
100
+
101
+ prior_values = {}
102
+ if prior_focal := priors.get("focal"):
103
+ prior_focal = prior_focal[None] if len(prior_focal.shape) == 0 else prior_focal
104
+ prior_values["prior_focal"] = prior_focal * img_data["scales"][1]
105
+
106
+ if "gravity" in priors:
107
+ prior_gravity = priors["gravity"]
108
+ prior_gravity = prior_gravity[None] if len(prior_gravity.shape) == 0 else prior_gravity
109
+ prior_values["prior_gravity"] = prior_gravity
110
+
111
+ self.model.optimizer.set_camera_model(camera_model)
112
+ self.model.optimizer.shared_intrinsics = shared_intrinsics
113
+
114
+ out = self.model(img_data | prior_values)
115
+
116
+ camera, gravity = out["camera"], out["gravity"]
117
+ camera, out = self._post_process(camera, img_data, out)
118
+
119
+ return {
120
+ "camera": camera,
121
+ "gravity": gravity,
122
+ "covariance": out["covariance"],
123
+ **{k: out[k] for k in out.keys() if "field" in k},
124
+ **{k: out[k] for k in out.keys() if "confidence" in k},
125
+ **{k: out[k] for k in out.keys() if "uncertainty" in k},
126
+ }
geocalib/geocalib.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GeoCalib model definition."""
2
+
3
+ import logging
4
+ from typing import Dict
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from geocalib.lm_optimizer import LMOptimizer
11
+ from geocalib.modules import MSCAN, ConvModule, LightHamHead
12
+
13
+ # mypy: ignore-errors
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class LowLevelEncoder(nn.Module):
19
+ """Very simple low-level encoder."""
20
+
21
+ def __init__(self):
22
+ """Simple low-level encoder."""
23
+ super().__init__()
24
+ self.in_channel = 3
25
+ self.feat_dim = 64
26
+
27
+ self.conv1 = ConvModule(self.in_channel, self.feat_dim, kernel_size=3, padding=1)
28
+ self.conv2 = ConvModule(self.feat_dim, self.feat_dim, kernel_size=3, padding=1)
29
+
30
+ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
31
+ """Forward pass."""
32
+ x = data["image"]
33
+
34
+ assert (
35
+ x.shape[-1] % 32 == 0 and x.shape[-2] % 32 == 0
36
+ ), "Image size must be multiple of 32 if not using single image input."
37
+
38
+ c1 = self.conv1(x)
39
+ c2 = self.conv2(c1)
40
+
41
+ return {"features": c2}
42
+
43
+
44
+ class UpDecoder(nn.Module):
45
+ """Minimal implementation of UpDecoder."""
46
+
47
+ def __init__(self):
48
+ """Up decoder."""
49
+ super().__init__()
50
+ self.decoder = LightHamHead()
51
+ self.linear_pred_up = nn.Conv2d(self.decoder.out_channels, 2, kernel_size=1)
52
+
53
+ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
54
+ """Forward pass."""
55
+ x, log_confidence = self.decoder(data["features"])
56
+ up = self.linear_pred_up(x)
57
+ return {"up_field": F.normalize(up, dim=1), "up_confidence": torch.sigmoid(log_confidence)}
58
+
59
+
60
+ class LatitudeDecoder(nn.Module):
61
+ """Minimal implementation of LatitudeDecoder."""
62
+
63
+ def __init__(self):
64
+ """Latitude decoder."""
65
+ super().__init__()
66
+ self.decoder = LightHamHead()
67
+ self.linear_pred_latitude = nn.Conv2d(self.decoder.out_channels, 1, kernel_size=1)
68
+
69
+ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
70
+ """Forward pass."""
71
+ x, log_confidence = self.decoder(data["features"])
72
+ eps = 1e-5 # avoid nan in backward of asin
73
+ lat = torch.tanh(self.linear_pred_latitude(x))
74
+ lat = torch.asin(torch.clamp(lat, -1 + eps, 1 - eps))
75
+ return {"latitude_field": lat, "latitude_confidence": torch.sigmoid(log_confidence)}
76
+
77
+
78
+ class PerspectiveDecoder(nn.Module):
79
+ """Minimal implementation of PerspectiveDecoder."""
80
+
81
+ def __init__(self):
82
+ """Perspective decoder wrapping up and latitude decoders."""
83
+ super().__init__()
84
+ self.up_head = UpDecoder()
85
+ self.latitude_head = LatitudeDecoder()
86
+
87
+ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
88
+ """Forward pass."""
89
+ return self.up_head(data) | self.latitude_head(data)
90
+
91
+
92
+ class GeoCalib(nn.Module):
93
+ """GeoCalib inference model."""
94
+
95
+ def __init__(self, **optimizer_options):
96
+ """Initialize the GeoCalib inference model.
97
+
98
+ Args:
99
+ optimizer_options: Options for the lm optimizer.
100
+ """
101
+ super().__init__()
102
+ self.backbone = MSCAN()
103
+ self.ll_enc = LowLevelEncoder()
104
+ self.perspective_decoder = PerspectiveDecoder()
105
+
106
+ self.optimizer = LMOptimizer({**optimizer_options})
107
+
108
+ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
109
+ """Forward pass."""
110
+ features = {"hl": self.backbone(data)["features"], "ll": self.ll_enc(data)["features"]}
111
+ out = self.perspective_decoder({"features": features})
112
+
113
+ out |= {
114
+ k: data[k]
115
+ for k in ["image", "scales", "prior_gravity", "prior_focal", "prior_k1"]
116
+ if k in data
117
+ }
118
+
119
+ out |= self.optimizer(out)
120
+
121
+ return out
122
+
123
+ def flexible_load(self, state_dict: Dict[str, torch.Tensor]) -> None:
124
+ """Load a checkpoint with flexible key names."""
125
+ dict_params = set(state_dict.keys())
126
+ model_params = set(map(lambda n: n[0], self.named_parameters()))
127
+
128
+ if dict_params == model_params: # perfect fit
129
+ logger.info("Loading all parameters of the checkpoint.")
130
+ self.load_state_dict(state_dict, strict=True)
131
+ return
132
+ elif len(dict_params & model_params) == 0: # perfect mismatch
133
+ strip_prefix = lambda x: ".".join(x.split(".")[:1] + x.split(".")[2:])
134
+ state_dict = {strip_prefix(n): p for n, p in state_dict.items()}
135
+ dict_params = set(state_dict.keys())
136
+ if len(dict_params & model_params) == 0:
137
+ raise ValueError(
138
+ "Could not manage to load the checkpoint with"
139
+ "parameters:" + "\n\t".join(sorted(dict_params))
140
+ )
141
+ common_params = dict_params & model_params
142
+ left_params = dict_params - model_params
143
+ left_params = [
144
+ p for p in left_params if "running" not in p and "num_batches_tracked" not in p
145
+ ]
146
+ logger.debug("Loading parameters:\n\t" + "\n\t".join(sorted(common_params)))
147
+ if left_params:
148
+ # ignore running stats of batchnorm
149
+ logger.warning("Could not load parameters:\n\t" + "\n\t".join(sorted(left_params)))
150
+ self.load_state_dict(state_dict, strict=False)
geocalib/gravity.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tensor class for gravity vector in camera frame."""
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+ from geocalib.misc import EuclideanManifold, SphericalManifold, TensorWrapper, autocast
7
+ from geocalib.utils import rad2rotmat
8
+
9
+ # mypy: ignore-errors
10
+
11
+
12
+ class Gravity(TensorWrapper):
13
+ """Gravity vector in camera frame."""
14
+
15
+ eps = 1e-4
16
+
17
+ @autocast
18
+ def __init__(self, data: torch.Tensor) -> None:
19
+ """Create gravity vector from data.
20
+
21
+ Args:
22
+ data (torch.Tensor): gravity vector as 3D vector in camera frame.
23
+ """
24
+ assert data.shape[-1] == 3, data.shape
25
+
26
+ data = F.normalize(data, dim=-1)
27
+
28
+ super().__init__(data)
29
+
30
+ @classmethod
31
+ def from_rp(cls, roll: torch.Tensor, pitch: torch.Tensor) -> "Gravity":
32
+ """Create gravity vector from roll and pitch angles."""
33
+ if not isinstance(roll, torch.Tensor):
34
+ roll = torch.tensor(roll)
35
+ if not isinstance(pitch, torch.Tensor):
36
+ pitch = torch.tensor(pitch)
37
+
38
+ sr, cr = torch.sin(roll), torch.cos(roll)
39
+ sp, cp = torch.sin(pitch), torch.cos(pitch)
40
+ return cls(torch.stack([-sr * cp, -cr * cp, sp], dim=-1))
41
+
42
+ @property
43
+ def vec3d(self) -> torch.Tensor:
44
+ """Return the gravity vector in the representation."""
45
+ return self._data
46
+
47
+ @property
48
+ def x(self) -> torch.Tensor:
49
+ """Return first component of the gravity vector."""
50
+ return self._data[..., 0]
51
+
52
+ @property
53
+ def y(self) -> torch.Tensor:
54
+ """Return second component of the gravity vector."""
55
+ return self._data[..., 1]
56
+
57
+ @property
58
+ def z(self) -> torch.Tensor:
59
+ """Return third component of the gravity vector."""
60
+ return self._data[..., 2]
61
+
62
+ @property
63
+ def roll(self) -> torch.Tensor:
64
+ """Return the roll angle of the gravity vector."""
65
+ roll = torch.asin(-self.x / (torch.sqrt(1 - self.z**2) + self.eps))
66
+ offset = -torch.pi * torch.sign(self.x)
67
+ return torch.where(self.y < 0, roll, -roll + offset)
68
+
69
+ def J_roll(self) -> torch.Tensor:
70
+ """Return the Jacobian of the roll angle of the gravity vector."""
71
+ cp, _ = torch.cos(self.pitch), torch.sin(self.pitch)
72
+ cr, sr = torch.cos(self.roll), torch.sin(self.roll)
73
+ Jr = self.new_zeros(self.shape + (3,))
74
+ Jr[..., 0] = -cr * cp
75
+ Jr[..., 1] = sr * cp
76
+ return Jr
77
+
78
+ @property
79
+ def pitch(self) -> torch.Tensor:
80
+ """Return the pitch angle of the gravity vector."""
81
+ return torch.asin(self.z)
82
+
83
+ def J_pitch(self) -> torch.Tensor:
84
+ """Return the Jacobian of the pitch angle of the gravity vector."""
85
+ cp, sp = torch.cos(self.pitch), torch.sin(self.pitch)
86
+ cr, sr = torch.cos(self.roll), torch.sin(self.roll)
87
+
88
+ Jp = self.new_zeros(self.shape + (3,))
89
+ Jp[..., 0] = sr * sp
90
+ Jp[..., 1] = cr * sp
91
+ Jp[..., 2] = cp
92
+ return Jp
93
+
94
+ @property
95
+ def rp(self) -> torch.Tensor:
96
+ """Return the roll and pitch angles of the gravity vector."""
97
+ return torch.stack([self.roll, self.pitch], dim=-1)
98
+
99
+ def J_rp(self) -> torch.Tensor:
100
+ """Return the Jacobian of the roll and pitch angles of the gravity vector."""
101
+ return torch.stack([self.J_roll(), self.J_pitch()], dim=-1)
102
+
103
+ @property
104
+ def R(self) -> torch.Tensor:
105
+ """Return the rotation matrix from the gravity vector."""
106
+ return rad2rotmat(roll=self.roll, pitch=self.pitch)
107
+
108
+ def J_R(self) -> torch.Tensor:
109
+ """Return the Jacobian of the rotation matrix from the gravity vector."""
110
+ raise NotImplementedError
111
+
112
+ def update(self, delta: torch.Tensor, spherical: bool = False) -> "Gravity":
113
+ """Update the gravity vector by adding a delta."""
114
+ if spherical:
115
+ data = SphericalManifold.plus(self.vec3d, delta)
116
+ return self.__class__(data)
117
+
118
+ data = EuclideanManifold.plus(self.rp, delta)
119
+ return self.from_rp(data[..., 0], data[..., 1])
120
+
121
+ def J_update(self, spherical: bool = False) -> torch.Tensor:
122
+ """Return the Jacobian of the update."""
123
+ return (
124
+ SphericalManifold.J_plus(self.vec3d)
125
+ if spherical
126
+ else EuclideanManifold.J_plus(self.vec3d)
127
+ )
128
+
129
+ def __repr__(self):
130
+ """Print the Camera object."""
131
+ return f"{self.__class__.__name__} {self.shape} {self.dtype} {self.device}"
geocalib/interactive_demo.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import queue
4
+ import threading
5
+ import time
6
+ from time import time
7
+
8
+ import cv2
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import torch
12
+
13
+ from geocalib.extractor import GeoCalib
14
+ from geocalib.perspective_fields import get_perspective_field
15
+ from geocalib.utils import get_device, rad2deg
16
+
17
+ # flake8: noqa
18
+ # mypy: ignore-errors
19
+
20
+
21
+ description = """
22
+ -------------------------
23
+ GeoCalib Interactive Demo
24
+ -------------------------
25
+
26
+ This script is an interactive demo for GeoCalib. It will open a window showing the camera feed and
27
+ the calibration results.
28
+
29
+ Arguments:
30
+ - '--camera_id': Camera ID to use. If none, will ask for ip of droidcam (https://droidcam.app)
31
+
32
+ You can toggle different features using the following keys:
33
+
34
+ - 'h': Toggle horizon line
35
+ - 'u': Toggle up vector field
36
+ - 'l': Toggle latitude heatmap
37
+ - 'c': Toggle confidence heatmap
38
+ - 'd': Toggle undistorted image
39
+ - 'g': Toggle grid of points
40
+ - 'b': Toggle box object
41
+
42
+ You can also change the camera model using the following keys:
43
+
44
+ - '1': Pinhole
45
+ - '2': Simple Radial
46
+ - '3': Simple Divisional
47
+
48
+ Press 'q' to quit the demo.
49
+ """
50
+
51
+
52
+ # Custom VideoCapture class to get the most recent frame instead FIFO
53
+ class VideoCapture:
54
+ def __init__(self, name):
55
+ self.cap = cv2.VideoCapture(name)
56
+ self.q = queue.Queue()
57
+ t = threading.Thread(target=self._reader)
58
+ t.daemon = True
59
+ t.start()
60
+
61
+ # read frames as soon as they are available, keeping only most recent one
62
+ def _reader(self):
63
+ while True:
64
+ ret, frame = self.cap.read()
65
+ if not ret:
66
+ break
67
+ if not self.q.empty():
68
+ try:
69
+ self.q.get_nowait() # discard previous (unprocessed) frame
70
+ except queue.Empty:
71
+ pass
72
+ self.q.put(frame)
73
+
74
+ def read(self):
75
+ return 1, self.q.get()
76
+
77
+ def isOpened(self):
78
+ return self.cap.isOpened()
79
+
80
+
81
+ def add_text(frame, text, align_left=True, align_top=True):
82
+ """Add text to a plot."""
83
+ h, w = frame.shape[:2]
84
+ sc = min(h / 640.0, 2.0)
85
+ Ht = int(40 * sc) # text height
86
+
87
+ for i, l in enumerate(text.split("\n")):
88
+ max_line = len(max([l for l in text.split("\n")], key=len))
89
+ x = int(8 * sc if align_left else w - (max_line) * sc * 18)
90
+ y = Ht * (i + 1) if align_top else h - Ht * (len(text.split("\n")) - i - 1) - int(8 * sc)
91
+
92
+ c_back, c_front = (0, 0, 0), (255, 255, 255)
93
+ font, style = cv2.FONT_HERSHEY_DUPLEX, cv2.LINE_AA
94
+ cv2.putText(frame, l, (x, y), font, 1.0 * sc, c_back, int(6 * sc), style)
95
+ cv2.putText(frame, l, (x, y), font, 1.0 * sc, c_front, int(1 * sc), style)
96
+ return frame
97
+
98
+
99
+ def is_corner(p, h, w):
100
+ """Check if a point is a corner."""
101
+ return p in [(0, 0), (0, h - 1), (w - 1, 0), (w - 1, h - 1)]
102
+
103
+
104
+ def plot_latitude(frame, latitude):
105
+ """Plot latitude heatmap."""
106
+ if not isinstance(latitude, np.ndarray):
107
+ latitude = latitude.cpu().numpy()
108
+
109
+ cmap = plt.get_cmap("seismic")
110
+ h, w = frame.shape[0], frame.shape[1]
111
+ sc = min(h / 640.0, 2.0)
112
+
113
+ vmin, vmax = -90, 90
114
+ latitude = (latitude - vmin) / (vmax - vmin)
115
+
116
+ colors = (cmap(latitude)[..., :3] * 255).astype(np.uint8)[..., ::-1]
117
+ frame = cv2.addWeighted(frame, 1 - 0.4, colors, 0.4, 0)
118
+
119
+ for contour_line in np.linspace(vmin, vmax, 15):
120
+ contour_line = (contour_line - vmin) / (vmax - vmin)
121
+
122
+ mask = (latitude > contour_line).astype(np.uint8)
123
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
124
+
125
+ for contour in contours:
126
+ color = (np.array(cmap(contour_line))[:3] * 255).astype(np.uint8)[::-1]
127
+
128
+ # remove corners
129
+ contour = [p for p in contour if not is_corner(tuple(p[0]), h, w)]
130
+ for index, item in enumerate(contour[:-1]):
131
+ cv2.line(frame, item[0], contour[index + 1][0], color.tolist(), int(5 * sc))
132
+
133
+ return frame
134
+
135
+
136
+ def draw_horizon_line(frame, heatmap):
137
+ """Draw a horizon line."""
138
+ if not isinstance(heatmap, np.ndarray):
139
+ heatmap = heatmap.cpu().numpy()
140
+
141
+ h, w = frame.shape[0], frame.shape[1]
142
+ sc = min(h / 640.0, 2.0)
143
+
144
+ color = (0, 255, 255)
145
+ vmin, vmax = -90, 90
146
+ heatmap = (heatmap - vmin) / (vmax - vmin)
147
+
148
+ contours, _ = cv2.findContours(
149
+ (heatmap > 0.5).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
150
+ )
151
+ if contours:
152
+ contour = [p for p in contours[0] if not is_corner(tuple(p[0]), h, w)]
153
+ for index, item in enumerate(contour[:-1]):
154
+ cv2.line(frame, item[0], contour[index + 1][0], color, int(5 * sc))
155
+ return frame
156
+
157
+
158
+ def plot_confidence(frame, confidence):
159
+ """Plot confidence heatmap."""
160
+ if not isinstance(confidence, np.ndarray):
161
+ confidence = confidence.cpu().numpy()
162
+
163
+ confidence = np.log10(confidence.clip(1e-6)).clip(-4)
164
+ confidence = (confidence - confidence.min()) / (confidence.max() - confidence.min())
165
+
166
+ cmap = plt.get_cmap("turbo")
167
+ colors = (cmap(confidence)[..., :3] * 255).astype(np.uint8)[..., ::-1]
168
+ return cv2.addWeighted(frame, 1 - 0.4, colors, 0.4, 0)
169
+
170
+
171
+ def plot_vector_field(frame, vector_field):
172
+ """Plot a vector field."""
173
+ if not isinstance(vector_field, np.ndarray):
174
+ vector_field = vector_field.cpu().numpy()
175
+
176
+ H, W = frame.shape[:2]
177
+ sc = min(H / 640.0, 2.0)
178
+
179
+ subsample = min(W, H) // 10
180
+ offset_x = ((W % subsample) + subsample) // 2
181
+ samples_x = np.arange(offset_x, W, subsample)
182
+ samples_y = np.arange(int(subsample * 0.9), H, subsample)
183
+
184
+ vec_len = 40 * sc
185
+ x_grid, y_grid = np.meshgrid(samples_x, samples_y)
186
+ x, y = vector_field[:, samples_y][:, :, samples_x]
187
+ for xi, yi, xi_dir, yi_dir in zip(x_grid.ravel(), y_grid.ravel(), x.ravel(), y.ravel()):
188
+ start = (xi, yi)
189
+ end = (int(xi + xi_dir * vec_len), int(yi + yi_dir * vec_len))
190
+ cv2.arrowedLine(
191
+ frame, start, end, (0, 255, 0), int(5 * sc), line_type=cv2.LINE_AA, tipLength=0.3
192
+ )
193
+
194
+ return frame
195
+
196
+
197
+ def plot_box(frame, gravity, camera):
198
+ """Plot a box object."""
199
+ pts = np.array(
200
+ [[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]]
201
+ )
202
+ pts = pts - np.array([0.5, 1, 0.5])
203
+ rotation_vec = cv2.Rodrigues(gravity.R.numpy()[0])[0]
204
+ t = np.array([0, 0, 1], dtype=float)
205
+ K = camera.K[0].cpu().numpy().astype(float)
206
+ dist = np.zeros(4, dtype=float)
207
+ axis_points, _ = cv2.projectPoints(
208
+ 0.1 * pts.reshape(-1, 3).astype(float), rotation_vec, t, K, dist
209
+ )
210
+
211
+ h = frame.shape[0]
212
+ sc = min(h / 640.0, 2.0)
213
+
214
+ color = (85, 108, 228)
215
+ for p in axis_points:
216
+ center = tuple((int(p[0][0]), int(p[0][1])))
217
+ frame = cv2.circle(frame, center, 10, color, -1, cv2.LINE_AA)
218
+
219
+ for i in range(0, 4):
220
+ p1 = axis_points[i].astype(int)
221
+ p2 = axis_points[i + 4].astype(int)
222
+ frame = cv2.line(frame, tuple(p1[0]), tuple(p2[0]), color, int(5 * sc), cv2.LINE_AA)
223
+
224
+ p1 = axis_points[i].astype(int)
225
+ p2 = axis_points[(i + 1) % 4].astype(int)
226
+ frame = cv2.line(frame, tuple(p1[0]), tuple(p2[0]), color, int(5 * sc), cv2.LINE_AA)
227
+
228
+ p1 = axis_points[i + 4].astype(int)
229
+ p2 = axis_points[(i + 1) % 4 + 4].astype(int)
230
+ frame = cv2.line(frame, tuple(p1[0]), tuple(p2[0]), color, int(5 * sc), cv2.LINE_AA)
231
+
232
+ return frame
233
+
234
+
235
+ def plot_grid(frame, gravity, camera, grid_size=0.2, num_points=5):
236
+ """Plot a grid of points."""
237
+ h = frame.shape[0]
238
+ sc = min(h / 640.0, 2.0)
239
+
240
+ samples = np.linspace(-grid_size, grid_size, num_points)
241
+ xz = np.meshgrid(samples, samples)
242
+ pts = np.stack((xz[0].ravel(), np.zeros_like(xz[0].ravel()), xz[1].ravel()), axis=-1)
243
+
244
+ # project points
245
+ rotation_vec = cv2.Rodrigues(gravity.R.numpy()[0])[0]
246
+ t = np.array([0, 0, 1], dtype=float)
247
+ K = camera.K[0].cpu().numpy().astype(float)
248
+ dist = np.zeros(4, dtype=float)
249
+ axis_points, _ = cv2.projectPoints(pts.reshape(-1, 3).astype(float), rotation_vec, t, K, dist)
250
+
251
+ color = (192, 77, 58)
252
+ # draw points
253
+ for p in axis_points:
254
+ center = tuple((int(p[0][0]), int(p[0][1])))
255
+ frame = cv2.circle(frame, center, 10, color, -1, cv2.LINE_AA)
256
+
257
+ # draw lines
258
+ for i in range(num_points):
259
+ for j in range(num_points - 1):
260
+ p1 = axis_points[i * num_points + j].astype(int)
261
+ p2 = axis_points[i * num_points + j + 1].astype(int)
262
+ frame = cv2.line(frame, tuple(p1[0]), tuple(p2[0]), color, int(5 * sc), cv2.LINE_AA)
263
+
264
+ p1 = axis_points[j * num_points + i].astype(int)
265
+ p2 = axis_points[(j + 1) * num_points + i].astype(int)
266
+ frame = cv2.line(frame, tuple(p1[0]), tuple(p2[0]), color, int(5 * sc), cv2.LINE_AA)
267
+
268
+ return frame
269
+
270
+
271
+ def undistort_image(img, camera, padding=0.3):
272
+ """Undistort an image."""
273
+ W, H = camera.size.unbind(-1)
274
+ H, W = H.int().item(), W.int().item()
275
+
276
+ pad_h, pad_w = int(H * padding), int(W * padding)
277
+ x, y = torch.meshgrid(torch.arange(0, W + pad_w), torch.arange(0, H + pad_h), indexing="xy")
278
+ coords = torch.stack((x, y), dim=-1).reshape(-1, 2) - torch.tensor([pad_w / 2, pad_h / 2])
279
+
280
+ p3d, _ = camera.pinhole().image2world(coords.to(camera.device).to(camera.dtype))
281
+ p2d, _ = camera.world2image(p3d)
282
+
283
+ p2d = p2d.float().numpy().reshape(H + pad_h, W + pad_w, 2)
284
+ img = cv2.remap(img, p2d[..., 0], p2d[..., 1], cv2.INTER_LINEAR, borderValue=(254, 254, 254))
285
+ return cv2.resize(img, (W, H))
286
+
287
+
288
+ class InteractiveDemo:
289
+ def __init__(self, capture: VideoCapture, device: str) -> None:
290
+ self.cap = capture
291
+
292
+ self.device = torch.device(device)
293
+ self.model = GeoCalib().to(device)
294
+
295
+ self.up_toggle = False
296
+ self.lat_toggle = False
297
+ self.conf_toggle = False
298
+
299
+ self.hl_toggle = False
300
+ self.grid_toggle = False
301
+ self.box_toggle = False
302
+
303
+ self.undist_toggle = False
304
+
305
+ self.camera_model = "pinhole"
306
+
307
+ def render_frame(self, frame, calibration):
308
+ """Render the frame with the calibration results."""
309
+ camera, gravity = calibration["camera"].cpu(), calibration["gravity"].cpu()
310
+
311
+ if self.undist_toggle:
312
+ return undistort_image(frame, camera)
313
+
314
+ up, lat = get_perspective_field(camera, gravity)
315
+
316
+ if gravity.pitch[0] > 0:
317
+ frame = plot_box(frame, gravity, camera) if self.box_toggle else frame
318
+ frame = plot_grid(frame, gravity, camera) if self.grid_toggle else frame
319
+ else:
320
+ frame = plot_grid(frame, gravity, camera) if self.grid_toggle else frame
321
+ frame = plot_box(frame, gravity, camera) if self.box_toggle else frame
322
+
323
+ frame = draw_horizon_line(frame, lat[0, 0]) if self.hl_toggle else frame
324
+
325
+ if self.conf_toggle and self.up_toggle:
326
+ frame = plot_confidence(frame, calibration["up_confidence"][0])
327
+ frame = plot_vector_field(frame, up[0]) if self.up_toggle else frame
328
+
329
+ if self.conf_toggle and self.lat_toggle:
330
+ frame = plot_confidence(frame, calibration["latitude_confidence"][0])
331
+ frame = plot_latitude(frame, rad2deg(lat)[0, 0]) if self.lat_toggle else frame
332
+
333
+ return frame
334
+
335
+ def format_results(self, calibration):
336
+ """Format the calibration results."""
337
+ camera, gravity = calibration["camera"].cpu(), calibration["gravity"].cpu()
338
+
339
+ vfov, focal = camera.vfov[0].item(), camera.f[0, 0].item()
340
+ fov_unc = rad2deg(calibration["vfov_uncertainty"].item())
341
+ f_unc = calibration["focal_uncertainty"].item()
342
+
343
+ roll, pitch = gravity.rp[0].unbind(-1)
344
+ roll, pitch, vfov = rad2deg(roll), rad2deg(pitch), rad2deg(vfov)
345
+ roll_unc = rad2deg(calibration["roll_uncertainty"].item())
346
+ pitch_unc = rad2deg(calibration["pitch_uncertainty"].item())
347
+
348
+ text = f"{self.camera_model.replace('_', ' ').title()}\n"
349
+ text += f"Roll: {roll:.2f} (+- {roll_unc:.2f})\n"
350
+ text += f"Pitch: {pitch:.2f} (+- {pitch_unc:.2f})\n"
351
+ text += f"vFoV: {vfov:.2f} (+- {fov_unc:.2f})\n"
352
+ text += f"Focal: {focal:.2f} (+- {f_unc:.2f})"
353
+
354
+ if hasattr(camera, "k1"):
355
+ text += f"\nK1: {camera.k1[0].item():.2f}"
356
+
357
+ return text
358
+
359
+ def update_toggles(self):
360
+ """Update the toggles."""
361
+ key = cv2.waitKey(100) & 0xFF
362
+ if key == ord("h"):
363
+ self.hl_toggle = not self.hl_toggle
364
+ elif key == ord("u"):
365
+ self.up_toggle = not self.up_toggle
366
+ elif key == ord("l"):
367
+ self.lat_toggle = not self.lat_toggle
368
+ elif key == ord("c"):
369
+ self.conf_toggle = not self.conf_toggle
370
+ elif key == ord("d"):
371
+ self.undist_toggle = not self.undist_toggle
372
+ elif key == ord("g"):
373
+ self.grid_toggle = not self.grid_toggle
374
+ elif key == ord("b"):
375
+ self.box_toggle = not self.box_toggle
376
+
377
+ elif key == ord("1"):
378
+ self.camera_model = "pinhole"
379
+ elif key == ord("2"):
380
+ self.camera_model = "simple_radial"
381
+ elif key == ord("3"):
382
+ self.camera_model = "simple_divisional"
383
+
384
+ elif key == ord("q"):
385
+ return True
386
+
387
+ return False
388
+
389
+ def run(self):
390
+ """Run the interactive demo."""
391
+ while True:
392
+ start = time()
393
+ ret, frame = self.cap.read()
394
+
395
+ if not ret:
396
+ print("Error: Failed to retrieve frame.")
397
+ break
398
+
399
+ # create tensor from frame
400
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
401
+ img = torch.tensor(img).permute(2, 0, 1) / 255.0
402
+
403
+ calibration = self.model.calibrate(img.to(self.device), camera_model=self.camera_model)
404
+
405
+ # render results to the frame
406
+ frame = self.render_frame(frame, calibration)
407
+ frame = add_text(frame, self.format_results(calibration))
408
+
409
+ end = time()
410
+ frame = add_text(
411
+ frame, f"FPS: {1 / (end - start):04.1f}", align_left=False, align_top=False
412
+ )
413
+
414
+ cv2.imshow("GeoCalib Demo", frame)
415
+
416
+ if self.update_toggles():
417
+ break
418
+
419
+
420
+ def main():
421
+ parser = argparse.ArgumentParser()
422
+ parser.add_argument(
423
+ "--camera_id",
424
+ type=int,
425
+ default=None,
426
+ help="Camera ID to use. If none, will ask for ip of droidcam.",
427
+ )
428
+ args = parser.parse_args()
429
+
430
+ print(description)
431
+
432
+ device = get_device()
433
+ print(f"Running on: {device}")
434
+
435
+ # setup video capture
436
+ if args.camera_id is not None:
437
+ cap = VideoCapture(args.camera_id)
438
+ else:
439
+ ip = input("Enter the IP address of the camera: ")
440
+ cap = VideoCapture(f"http://{ip}:4747/video/force/1920x1080")
441
+
442
+ if not cap.isOpened():
443
+ raise ValueError("Error: Could not open camera.")
444
+
445
+ demo = InteractiveDemo(cap, device)
446
+ demo.run()
447
+
448
+
449
+ if __name__ == "__main__":
450
+ main()
geocalib/lm_optimizer.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of the Levenberg-Marquardt optimizer for camera calibration."""
2
+
3
+ import logging
4
+ import time
5
+ from types import SimpleNamespace
6
+ from typing import Any, Callable, Dict, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from geocalib.camera import BaseCamera, camera_models
12
+ from geocalib.gravity import Gravity
13
+ from geocalib.misc import J_focal2fov
14
+ from geocalib.perspective_fields import J_perspective_field, get_perspective_field
15
+ from geocalib.utils import focal2fov, rad2deg
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def get_trivial_estimation(data: Dict[str, torch.Tensor], camera_model: BaseCamera) -> BaseCamera:
21
+ """Get initial camera for optimization with roll=0, pitch=0, vfov=0.7 * max(h, w).
22
+
23
+ Args:
24
+ data (Dict[str, torch.Tensor]): Input data dictionary.
25
+ camera_model (BaseCamera): Camera model to use.
26
+
27
+ Returns:
28
+ BaseCamera: Initial camera for optimization.
29
+ """
30
+ """Get initial camera for optimization with roll=0, pitch=0, vfov=0.7 * max(h, w)."""
31
+ ref = data.get("up_field", data["latitude_field"])
32
+ ref = ref.detach()
33
+
34
+ h, w = ref.shape[-2:]
35
+ batch_h, batch_w = (
36
+ ref.new_ones((ref.shape[0],)) * h,
37
+ ref.new_ones((ref.shape[0],)) * w,
38
+ )
39
+
40
+ init_r = ref.new_zeros((ref.shape[0],))
41
+ init_p = ref.new_zeros((ref.shape[0],))
42
+
43
+ focal = data.get("prior_focal", 0.7 * torch.max(batch_h, batch_w))
44
+ init_vfov = focal2fov(focal, h)
45
+
46
+ params = {"width": batch_w, "height": batch_h, "vfov": init_vfov}
47
+ params |= {"scales": data["scales"]} if "scales" in data else {}
48
+ params |= {"k1": data["prior_k1"]} if "prior_k1" in data else {}
49
+ camera = camera_model.from_dict(params)
50
+ camera = camera.float().to(ref.device)
51
+
52
+ gravity = Gravity.from_rp(init_r, init_p).float().to(ref.device)
53
+
54
+ if "prior_gravity" in data:
55
+ gravity = data["prior_gravity"].float().to(ref.device)
56
+ gravity = Gravity(gravity) if isinstance(gravity, torch.Tensor) else gravity
57
+
58
+ return camera, gravity
59
+
60
+
61
+ def scaled_loss(
62
+ x: torch.Tensor, fn: Callable, a: float
63
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
64
+ """Apply a loss function to a tensor and pre- and post-scale it.
65
+
66
+ Args:
67
+ x: the data tensor, should already be squared: `x = y**2`.
68
+ fn: the loss function, with signature `fn(x) -> y`.
69
+ a: the scale parameter.
70
+
71
+ Returns:
72
+ The value of the loss, and its first and second derivatives.
73
+ """
74
+ a2 = a**2
75
+ loss, loss_d1, loss_d2 = fn(x / a2)
76
+ return loss * a2, loss_d1, loss_d2 / a2
77
+
78
+
79
+ def huber_loss(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
80
+ """The classical robust Huber loss, with first and second derivatives."""
81
+ mask = x <= 1
82
+ sx = torch.sqrt(x + 1e-8) # avoid nan in backward pass
83
+ isx = torch.max(sx.new_tensor(torch.finfo(torch.float).eps), 1 / sx)
84
+ loss = torch.where(mask, x, 2 * sx - 1)
85
+ loss_d1 = torch.where(mask, torch.ones_like(x), isx)
86
+ loss_d2 = torch.where(mask, torch.zeros_like(x), -isx / (2 * x))
87
+ return loss, loss_d1, loss_d2
88
+
89
+
90
+ def early_stop(new_cost: torch.Tensor, prev_cost: torch.Tensor, atol: float, rtol: float) -> bool:
91
+ """Early stopping criterion based on cost convergence."""
92
+ return torch.allclose(new_cost, prev_cost, atol=atol, rtol=rtol)
93
+
94
+
95
+ def update_lambda(
96
+ lamb: torch.Tensor,
97
+ prev_cost: torch.Tensor,
98
+ new_cost: torch.Tensor,
99
+ lambda_min: float = 1e-6,
100
+ lambda_max: float = 1e2,
101
+ ) -> torch.Tensor:
102
+ """Update damping factor for Levenberg-Marquardt optimization."""
103
+ new_lamb = lamb.new_zeros(lamb.shape)
104
+ new_lamb = lamb * torch.where(new_cost > prev_cost, 10, 0.1)
105
+ lamb = torch.clamp(new_lamb, lambda_min, lambda_max)
106
+ return lamb
107
+
108
+
109
+ def optimizer_step(
110
+ G: torch.Tensor, H: torch.Tensor, lambda_: torch.Tensor, eps: float = 1e-6
111
+ ) -> torch.Tensor:
112
+ """One optimization step with Gauss-Newton or Levenberg-Marquardt.
113
+
114
+ Args:
115
+ G (torch.Tensor): Batched gradient tensor of size (..., N).
116
+ H (torch.Tensor): Batched hessian tensor of size (..., N, N).
117
+ lambda_ (torch.Tensor): Damping factor for LM (use GN if lambda_=0) with shape (B,).
118
+ eps (float, optional): Epsilon for damping. Defaults to 1e-6.
119
+
120
+ Returns:
121
+ torch.Tensor: Batched update tensor of size (..., N).
122
+ """
123
+ diag = H.diagonal(dim1=-2, dim2=-1)
124
+ diag = diag * lambda_.unsqueeze(-1) # (B, 3)
125
+
126
+ H = H + diag.clamp(min=eps).diag_embed()
127
+
128
+ H_, G_ = H.cpu(), G.cpu()
129
+ try:
130
+ U = torch.linalg.cholesky(H_)
131
+ except RuntimeError:
132
+ logger.warning("Cholesky decomposition failed. Stopping.")
133
+ delta = H.new_zeros((H.shape[0], H.shape[-1])) # (B, 3)
134
+ else:
135
+ delta = torch.cholesky_solve(G_[..., None], U)[..., 0]
136
+
137
+ return delta.to(H.device)
138
+
139
+
140
+ # mypy: ignore-errors
141
+ class LMOptimizer(nn.Module):
142
+ """Levenberg-Marquardt optimizer for camera calibration."""
143
+
144
+ default_conf = {
145
+ # Camera model parameters
146
+ "camera_model": "pinhole", # {"pinhole", "simple_radial", "simple_spherical"}
147
+ "shared_intrinsics": False, # share focal length across all images in batch
148
+ # LM optimizer parameters
149
+ "num_steps": 30,
150
+ "lambda_": 0.1,
151
+ "fix_lambda": False,
152
+ "early_stop": True,
153
+ "atol": 1e-8,
154
+ "rtol": 1e-8,
155
+ "use_spherical_manifold": True, # use spherical manifold for gravity optimization
156
+ "use_log_focal": True, # use log focal length for optimization
157
+ # Loss function parameters
158
+ "up_loss_fn_scale": 1e-2,
159
+ "lat_loss_fn_scale": 1e-2,
160
+ # Misc
161
+ "verbose": False,
162
+ }
163
+
164
+ def __init__(self, conf: Dict[str, Any]):
165
+ """Initialize the LM optimizer."""
166
+ super().__init__()
167
+ self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
168
+ self.num_steps = conf.num_steps
169
+
170
+ self.set_camera_model(conf.camera_model)
171
+ self.setup_optimization_and_priors(shared_intrinsics=conf.shared_intrinsics)
172
+
173
+ def set_camera_model(self, camera_model: str) -> None:
174
+ """Set the camera model to use for the optimization.
175
+
176
+ Args:
177
+ camera_model (str): Camera model to use.
178
+ """
179
+ assert (
180
+ camera_model in camera_models.keys()
181
+ ), f"Unknown camera model: {camera_model} not in {camera_models.keys()}"
182
+ self.camera_model = camera_models[camera_model]
183
+ self.camera_has_distortion = hasattr(self.camera_model, "dist")
184
+
185
+ logger.debug(
186
+ f"Using camera model: {camera_model} (with distortion: {self.camera_has_distortion})"
187
+ )
188
+
189
+ def setup_optimization_and_priors(
190
+ self, data: Dict[str, torch.Tensor] = None, shared_intrinsics: bool = False
191
+ ) -> None:
192
+ """Setup the optimization and priors for the LM optimizer.
193
+
194
+ Args:
195
+ data (Dict[str, torch.Tensor], optional): Dict potentially containing priors. Defaults
196
+ to None.
197
+ shared_intrinsics (bool, optional): Whether to share the intrinsics across the batch.
198
+ Defaults to False.
199
+ """
200
+ if data is None:
201
+ data = {}
202
+ self.shared_intrinsics = shared_intrinsics
203
+
204
+ if shared_intrinsics: # si => must use pinhole
205
+ assert (
206
+ self.camera_model == camera_models["pinhole"]
207
+ ), f"Shared intrinsics only supported with pinhole camera model: {self.camera_model}"
208
+
209
+ self.estimate_gravity = True
210
+ if "prior_gravity" in data:
211
+ self.estimate_gravity = False
212
+ logger.debug("Using provided gravity as prior.")
213
+
214
+ self.estimate_focal = True
215
+ if "prior_focal" in data:
216
+ self.estimate_focal = False
217
+ logger.debug("Using provided focal as prior.")
218
+
219
+ self.estimate_k1 = True
220
+ if "prior_k1" in data:
221
+ self.estimate_k1 = False
222
+ logger.debug("Using provided k1 as prior.")
223
+
224
+ self.gravity_delta_dims = (0, 1) if self.estimate_gravity else (-1,)
225
+ self.focal_delta_dims = (
226
+ (max(self.gravity_delta_dims) + 1,) if self.estimate_focal else (-1,)
227
+ )
228
+ self.k1_delta_dims = (max(self.focal_delta_dims) + 1,) if self.estimate_k1 else (-1,)
229
+
230
+ logger.debug(f"Camera Model: {self.camera_model}")
231
+ logger.debug(f"Optimizing gravity: {self.estimate_gravity} ({self.gravity_delta_dims})")
232
+ logger.debug(f"Optimizing focal: {self.estimate_focal} ({self.focal_delta_dims})")
233
+ logger.debug(f"Optimizing k1: {self.estimate_k1} ({self.k1_delta_dims})")
234
+
235
+ logger.debug(f"Shared intrinsics: {self.shared_intrinsics}")
236
+
237
+ def calculate_residuals(
238
+ self, camera: BaseCamera, gravity: Gravity, data: Dict[str, torch.Tensor]
239
+ ) -> Dict[str, torch.Tensor]:
240
+ """Calculate the residuals for the optimization.
241
+
242
+ Args:
243
+ camera (BaseCamera): Optimized camera.
244
+ gravity (Gravity): Optimized gravity.
245
+ data (Dict[str, torch.Tensor]): Input data containing the up and latitude fields.
246
+
247
+ Returns:
248
+ Dict[str, torch.Tensor]: Residuals for the optimization.
249
+ """
250
+ perspective_up, perspective_lat = get_perspective_field(camera, gravity)
251
+ perspective_lat = torch.sin(perspective_lat)
252
+
253
+ residuals = {}
254
+ if "up_field" in data:
255
+ up_residual = (data["up_field"] - perspective_up).permute(0, 2, 3, 1)
256
+ residuals["up_residual"] = up_residual.reshape(up_residual.shape[0], -1, 2)
257
+
258
+ if "latitude_field" in data:
259
+ target_lat = torch.sin(data["latitude_field"])
260
+ lat_residual = (target_lat - perspective_lat).permute(0, 2, 3, 1)
261
+ residuals["latitude_residual"] = lat_residual.reshape(lat_residual.shape[0], -1, 1)
262
+
263
+ return residuals
264
+
265
+ def calculate_costs(
266
+ self, residuals: torch.Tensor, data: Dict[str, torch.Tensor]
267
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
268
+ """Calculate the costs and weights for the optimization.
269
+
270
+ Args:
271
+ residuals (torch.Tensor): Residuals for the optimization.
272
+ data (Dict[str, torch.Tensor]): Input data containing the up and latitude confidence.
273
+
274
+ Returns:
275
+ Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: Costs and weights for the
276
+ optimization.
277
+ """
278
+ costs, weights = {}, {}
279
+
280
+ if "up_residual" in residuals:
281
+ up_cost = (residuals["up_residual"] ** 2).sum(dim=-1)
282
+ up_cost, up_weight, _ = scaled_loss(up_cost, huber_loss, self.conf.up_loss_fn_scale)
283
+
284
+ if "up_confidence" in data:
285
+ up_conf = data["up_confidence"].reshape(up_weight.shape[0], -1)
286
+ up_weight = up_weight * up_conf
287
+ up_cost = up_cost * up_conf
288
+
289
+ costs["up_cost"] = up_cost
290
+ weights["up_weights"] = up_weight
291
+
292
+ if "latitude_residual" in residuals:
293
+ lat_cost = (residuals["latitude_residual"] ** 2).sum(dim=-1)
294
+ lat_cost, lat_weight, _ = scaled_loss(lat_cost, huber_loss, self.conf.lat_loss_fn_scale)
295
+
296
+ if "latitude_confidence" in data:
297
+ lat_conf = data["latitude_confidence"].reshape(lat_weight.shape[0], -1)
298
+ lat_weight = lat_weight * lat_conf
299
+ lat_cost = lat_cost * lat_conf
300
+
301
+ costs["latitude_cost"] = lat_cost
302
+ weights["latitude_weights"] = lat_weight
303
+
304
+ return costs, weights
305
+
306
+ def calculate_gradient_and_hessian(
307
+ self,
308
+ J: torch.Tensor,
309
+ residuals: torch.Tensor,
310
+ weights: torch.Tensor,
311
+ shared_intrinsics: bool,
312
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
313
+ """Calculate the gradient and Hessian for given the Jacobian, residuals, and weights.
314
+
315
+ Args:
316
+ J (torch.Tensor): Jacobian.
317
+ residuals (torch.Tensor): Residuals.
318
+ weights (torch.Tensor): Weights.
319
+ shared_intrinsics (bool): Whether to share the intrinsics across the batch.
320
+
321
+ Returns:
322
+ Tuple[torch.Tensor, torch.Tensor]: Gradient and Hessian.
323
+ """
324
+ dims = ()
325
+ if self.estimate_gravity:
326
+ dims = (0, 1)
327
+ if self.estimate_focal:
328
+ dims += (2,)
329
+ if self.camera_has_distortion and self.estimate_k1:
330
+ dims += (3,)
331
+ assert dims, "No parameters to optimize"
332
+
333
+ J = J[..., dims]
334
+
335
+ Grad = torch.einsum("...Njk,...Nj->...Nk", J, residuals)
336
+ Grad = weights[..., None] * Grad
337
+ Grad = Grad.sum(-2) # (B, N_params)
338
+
339
+ if shared_intrinsics:
340
+ # reshape to (1, B * (N_params-1) + 1)
341
+ Grad_g = Grad[..., :2].reshape(1, -1)
342
+ Grad_f = Grad[..., 2].reshape(1, -1).sum(-1, keepdim=True)
343
+ Grad = torch.cat([Grad_g, Grad_f], dim=-1)
344
+
345
+ Hess = torch.einsum("...Njk,...Njl->...Nkl", J, J)
346
+ Hess = weights[..., None, None] * Hess
347
+ Hess = Hess.sum(-3)
348
+
349
+ if shared_intrinsics:
350
+ H_g = torch.block_diag(*list(Hess[..., :2, :2]))
351
+ J_fg = Hess[..., :2, 2].flatten()
352
+ J_gf = Hess[..., 2, :2].flatten()
353
+ J_f = Hess[..., 2, 2].sum()
354
+ dims = H_g.shape[-1] + 1
355
+ Hess = Hess.new_zeros((dims, dims), dtype=torch.float32)
356
+ Hess[:-1, :-1] = H_g
357
+ Hess[-1, :-1] = J_gf
358
+ Hess[:-1, -1] = J_fg
359
+ Hess[-1, -1] = J_f
360
+ Hess = Hess.unsqueeze(0)
361
+
362
+ return Grad, Hess
363
+
364
+ def setup_system(
365
+ self,
366
+ camera: BaseCamera,
367
+ gravity: Gravity,
368
+ residuals: Dict[str, torch.Tensor],
369
+ weights: Dict[str, torch.Tensor],
370
+ as_rpf: bool = False,
371
+ shared_intrinsics: bool = False,
372
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
373
+ """Calculate the gradient and Hessian for the optimization.
374
+
375
+ Args:
376
+ camera (BaseCamera): Optimized camera.
377
+ gravity (Gravity): Optimized gravity.
378
+ residuals (Dict[str, torch.Tensor]): Residuals for the optimization.
379
+ weights (Dict[str, torch.Tensor]): Weights for the optimization.
380
+ as_rpf (bool, optional): Wether to calculate the gradient and Hessian with respect to
381
+ roll, pitch, and focal length. Defaults to False.
382
+ shared_intrinsics (bool, optional): Whether to share the intrinsics across the batch.
383
+ Defaults to False.
384
+
385
+ Returns:
386
+ Tuple[torch.Tensor, torch.Tensor]: Gradient and Hessian for the optimization.
387
+ """
388
+ J_up, J_lat = J_perspective_field(
389
+ camera,
390
+ gravity,
391
+ spherical=self.conf.use_spherical_manifold and not as_rpf,
392
+ log_focal=self.conf.use_log_focal and not as_rpf,
393
+ )
394
+
395
+ J_up = J_up.reshape(J_up.shape[0], -1, J_up.shape[-2], J_up.shape[-1]) # (B, N, 2, 3)
396
+ J_lat = J_lat.reshape(J_lat.shape[0], -1, J_lat.shape[-2], J_lat.shape[-1]) # (B, N, 1, 3)
397
+
398
+ n_params = (
399
+ 2 * self.estimate_gravity
400
+ + self.estimate_focal
401
+ + (self.camera_has_distortion and self.estimate_k1)
402
+ )
403
+ Grad = J_up.new_zeros(J_up.shape[0], n_params)
404
+ Hess = J_up.new_zeros(J_up.shape[0], n_params, n_params)
405
+
406
+ if shared_intrinsics:
407
+ N_params = Grad.shape[0] * (n_params - 1) + 1
408
+ Grad = Grad.new_zeros(1, N_params)
409
+ Hess = Hess.new_zeros(1, N_params, N_params)
410
+
411
+ if "up_residual" in residuals:
412
+ Up_Grad, Up_Hess = self.calculate_gradient_and_hessian(
413
+ J_up, residuals["up_residual"], weights["up_weights"], shared_intrinsics
414
+ )
415
+
416
+ if self.conf.verbose:
417
+ logger.info(f"Up J:\n{Up_Grad.mean(0)}")
418
+
419
+ Grad = Grad + Up_Grad
420
+ Hess = Hess + Up_Hess
421
+
422
+ if "latitude_residual" in residuals:
423
+ Lat_Grad, Lat_Hess = self.calculate_gradient_and_hessian(
424
+ J_lat,
425
+ residuals["latitude_residual"],
426
+ weights["latitude_weights"],
427
+ shared_intrinsics,
428
+ )
429
+
430
+ if self.conf.verbose:
431
+ logger.info(f"Lat J:\n{Lat_Grad.mean(0)}")
432
+
433
+ Grad = Grad + Lat_Grad
434
+ Hess = Hess + Lat_Hess
435
+
436
+ return Grad, Hess
437
+
438
+ def estimate_uncertainty(
439
+ self,
440
+ camera_opt: BaseCamera,
441
+ gravity_opt: Gravity,
442
+ errors: Dict[str, torch.Tensor],
443
+ weights: Dict[str, torch.Tensor],
444
+ ) -> Dict[str, torch.Tensor]:
445
+ """Estimate the uncertainty of the optimized camera and gravity at the final step.
446
+
447
+ Args:
448
+ camera_opt (BaseCamera): Final optimized camera.
449
+ gravity_opt (Gravity): Final optimized gravity.
450
+ errors (Dict[str, torch.Tensor]): Costs for the optimization.
451
+ weights (Dict[str, torch.Tensor]): Weights for the optimization.
452
+
453
+ Returns:
454
+ Dict[str, torch.Tensor]: Uncertainty estimates for the optimized camera and gravity.
455
+ """
456
+ _, Hess = self.setup_system(
457
+ camera_opt, gravity_opt, errors, weights, as_rpf=True, shared_intrinsics=False
458
+ )
459
+ Cov = torch.inverse(Hess)
460
+
461
+ roll_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
462
+ pitch_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
463
+ gravity_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
464
+ if self.estimate_gravity:
465
+ roll_uncertainty = Cov[..., 0, 0]
466
+ pitch_uncertainty = Cov[..., 1, 1]
467
+
468
+ try:
469
+ delta_uncertainty = Cov[..., :2, :2]
470
+ eigenvalues = torch.linalg.eigvalsh(delta_uncertainty.cpu())
471
+ gravity_uncertainty = torch.max(eigenvalues, dim=-1).values.to(Cov.device)
472
+ except RuntimeError:
473
+ logger.warning("Could not calculate gravity uncertainty")
474
+ gravity_uncertainty = Cov.new_zeros(Cov.shape[0])
475
+
476
+ focal_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
477
+ fov_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
478
+ if self.estimate_focal:
479
+ focal_uncertainty = Cov[..., self.focal_delta_dims[0], self.focal_delta_dims[0]]
480
+ fov_uncertainty = (
481
+ J_focal2fov(camera_opt.f[..., 1], camera_opt.size[..., 1]) ** 2 * focal_uncertainty
482
+ )
483
+
484
+ return {
485
+ "covariance": Cov,
486
+ "roll_uncertainty": torch.sqrt(roll_uncertainty),
487
+ "pitch_uncertainty": torch.sqrt(pitch_uncertainty),
488
+ "gravity_uncertainty": torch.sqrt(gravity_uncertainty),
489
+ "focal_uncertainty": torch.sqrt(focal_uncertainty) / 2,
490
+ "vfov_uncertainty": torch.sqrt(fov_uncertainty / 2),
491
+ }
492
+
493
+ def update_estimate(
494
+ self, camera: BaseCamera, gravity: Gravity, delta: torch.Tensor
495
+ ) -> Tuple[BaseCamera, Gravity]:
496
+ """Update the camera and gravity estimates with the given delta.
497
+
498
+ Args:
499
+ camera (BaseCamera): Optimized camera.
500
+ gravity (Gravity): Optimized gravity.
501
+ delta (torch.Tensor): Delta to update the camera and gravity estimates.
502
+
503
+ Returns:
504
+ Tuple[BaseCamera, Gravity]: Updated camera and gravity estimates.
505
+ """
506
+ delta_gravity = (
507
+ delta[..., self.gravity_delta_dims]
508
+ if self.estimate_gravity
509
+ else delta.new_zeros(delta.shape[:-1] + (2,))
510
+ )
511
+ new_gravity = gravity.update(delta_gravity, spherical=self.conf.use_spherical_manifold)
512
+
513
+ delta_f = (
514
+ delta[..., self.focal_delta_dims]
515
+ if self.estimate_focal
516
+ else delta.new_zeros(delta.shape[:-1] + (1,))
517
+ )
518
+ new_camera = camera.update_focal(delta_f, as_log=self.conf.use_log_focal)
519
+
520
+ delta_dist = (
521
+ delta[..., self.k1_delta_dims]
522
+ if self.camera_has_distortion and self.estimate_k1
523
+ else delta.new_zeros(delta.shape[:-1] + (1,))
524
+ )
525
+ if self.camera_has_distortion:
526
+ new_camera = new_camera.update_dist(delta_dist)
527
+
528
+ return new_camera, new_gravity
529
+
530
+ def optimize(
531
+ self,
532
+ data: Dict[str, torch.Tensor],
533
+ camera_opt: BaseCamera,
534
+ gravity_opt: Gravity,
535
+ ) -> Tuple[BaseCamera, Gravity, Dict[str, torch.Tensor]]:
536
+ """Optimize the camera and gravity estimates.
537
+
538
+ Args:
539
+ data (Dict[str, torch.Tensor]): Input data.
540
+ camera_opt (BaseCamera): Optimized camera.
541
+ gravity_opt (Gravity): Optimized gravity.
542
+
543
+ Returns:
544
+ Tuple[BaseCamera, Gravity, Dict[str, torch.Tensor]]: Optimized camera, gravity
545
+ estimates and optimization information.
546
+ """
547
+ key = list(data.keys())[0]
548
+ B = data[key].shape[0]
549
+
550
+ lamb = data[key].new_ones(B) * self.conf.lambda_
551
+ if self.shared_intrinsics:
552
+ lamb = data[key].new_ones(1) * self.conf.lambda_
553
+
554
+ infos = {"stop_at": self.num_steps}
555
+ for i in range(self.num_steps):
556
+ if self.conf.verbose:
557
+ logger.info(f"Step {i+1}/{self.num_steps}")
558
+
559
+ errors = self.calculate_residuals(camera_opt, gravity_opt, data)
560
+ costs, weights = self.calculate_costs(errors, data)
561
+
562
+ if i == 0:
563
+ prev_cost = sum(c.mean(-1) for c in costs.values())
564
+ for k, c in costs.items():
565
+ infos[f"initial_{k}"] = c.mean(-1)
566
+
567
+ infos["initial_cost"] = prev_cost
568
+
569
+ Grad, Hess = self.setup_system(
570
+ camera_opt,
571
+ gravity_opt,
572
+ errors,
573
+ weights,
574
+ shared_intrinsics=self.shared_intrinsics,
575
+ )
576
+ delta = optimizer_step(Grad, Hess, lamb) # (B, N_params)
577
+
578
+ if self.shared_intrinsics:
579
+ delta_g = delta[..., :-1].reshape(B, 2)
580
+ delta_f = delta[..., -1].expand(B, 1)
581
+ delta = torch.cat([delta_g, delta_f], dim=-1)
582
+
583
+ # calculate new cost
584
+ camera_opt, gravity_opt = self.update_estimate(camera_opt, gravity_opt, delta)
585
+ new_cost, _ = self.calculate_costs(
586
+ self.calculate_residuals(camera_opt, gravity_opt, data), data
587
+ )
588
+ new_cost = sum(c.mean(-1) for c in new_cost.values())
589
+
590
+ if not self.conf.fix_lambda and not self.shared_intrinsics:
591
+ lamb = update_lambda(lamb, prev_cost, new_cost)
592
+
593
+ if self.conf.verbose:
594
+ logger.info(f"Cost:\nPrev: {prev_cost}\nNew: {new_cost}")
595
+ logger.info(f"Camera:\n{camera_opt._data}")
596
+
597
+ if early_stop(new_cost, prev_cost, atol=self.conf.atol, rtol=self.conf.rtol):
598
+ infos["stop_at"] = min(i + 1, infos["stop_at"])
599
+
600
+ if self.conf.early_stop:
601
+ if self.conf.verbose:
602
+ logger.info(f"Early stopping at step {i+1}")
603
+ break
604
+
605
+ prev_cost = new_cost
606
+
607
+ if i == self.num_steps - 1 and self.conf.early_stop:
608
+ logger.warning("Reached maximum number of steps without convergence.")
609
+
610
+ final_errors = self.calculate_residuals(camera_opt, gravity_opt, data) # (B, N, 3)
611
+ final_cost, weights = self.calculate_costs(final_errors, data) # (B, N)
612
+
613
+ if not self.training:
614
+ infos |= self.estimate_uncertainty(camera_opt, gravity_opt, final_errors, weights)
615
+
616
+ infos["stop_at"] = camera_opt.new_ones(camera_opt.shape[0]) * infos["stop_at"]
617
+ for k, c in final_cost.items():
618
+ infos[f"final_{k}"] = c.mean(-1)
619
+
620
+ infos["final_cost"] = sum(c.mean(-1) for c in final_cost.values())
621
+
622
+ return camera_opt, gravity_opt, infos
623
+
624
+ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
625
+ """Run the LM optimization."""
626
+ camera_init, gravity_init = get_trivial_estimation(data, self.camera_model)
627
+
628
+ self.setup_optimization_and_priors(data, shared_intrinsics=self.shared_intrinsics)
629
+
630
+ start = time.time()
631
+ camera_opt, gravity_opt, infos = self.optimize(data, camera_init, gravity_init)
632
+
633
+ if self.conf.verbose:
634
+ logger.info(f"Optimization took {(time.time() - start)*1000:.2f} ms")
635
+
636
+ logger.info(f"Initial camera:\n{rad2deg(camera_init.vfov)}")
637
+ logger.info(f"Optimized camera:\n{rad2deg(camera_opt.vfov)}")
638
+
639
+ logger.info(f"Initial gravity:\n{rad2deg(gravity_init.rp)}")
640
+ logger.info(f"Optimized gravity:\n{rad2deg(gravity_opt.rp)}")
641
+
642
+ return {"camera": camera_opt, "gravity": gravity_opt, **infos}
geocalib/misc.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Miscellaneous functions and classes for the geocalib_inference package."""
2
+
3
+ import functools
4
+ import inspect
5
+ import logging
6
+ from typing import Callable, List
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # mypy: ignore-errors
14
+
15
+
16
+ def autocast(func: Callable) -> Callable:
17
+ """Cast the inputs of a TensorWrapper method to PyTorch tensors if they are numpy arrays.
18
+
19
+ Use the device and dtype of the wrapper.
20
+
21
+ Args:
22
+ func (Callable): Method of a TensorWrapper class.
23
+
24
+ Returns:
25
+ Callable: Wrapped method.
26
+ """
27
+
28
+ @functools.wraps(func)
29
+ def wrap(self, *args):
30
+ device = torch.device("cpu")
31
+ dtype = None
32
+ if isinstance(self, TensorWrapper):
33
+ if self._data is not None:
34
+ device = self.device
35
+ dtype = self.dtype
36
+ elif not inspect.isclass(self) or not issubclass(self, TensorWrapper):
37
+ raise ValueError(self)
38
+
39
+ cast_args = []
40
+ for arg in args:
41
+ if isinstance(arg, np.ndarray):
42
+ arg = torch.from_numpy(arg)
43
+ arg = arg.to(device=device, dtype=dtype)
44
+ cast_args.append(arg)
45
+ return func(self, *cast_args)
46
+
47
+ return wrap
48
+
49
+
50
+ class TensorWrapper:
51
+ """Wrapper for PyTorch tensors."""
52
+
53
+ _data = None
54
+
55
+ @autocast
56
+ def __init__(self, data: torch.Tensor):
57
+ """Wrapper for PyTorch tensors."""
58
+ self._data = data
59
+
60
+ @property
61
+ def shape(self) -> torch.Size:
62
+ """Shape of the underlying tensor."""
63
+ return self._data.shape[:-1]
64
+
65
+ @property
66
+ def device(self) -> torch.device:
67
+ """Get the device of the underlying tensor."""
68
+ return self._data.device
69
+
70
+ @property
71
+ def dtype(self) -> torch.dtype:
72
+ """Get the dtype of the underlying tensor."""
73
+ return self._data.dtype
74
+
75
+ def __getitem__(self, index) -> torch.Tensor:
76
+ """Get the underlying tensor."""
77
+ return self.__class__(self._data[index])
78
+
79
+ def __setitem__(self, index, item):
80
+ """Set the underlying tensor."""
81
+ self._data[index] = item.data
82
+
83
+ def to(self, *args, **kwargs):
84
+ """Move the underlying tensor to a new device."""
85
+ return self.__class__(self._data.to(*args, **kwargs))
86
+
87
+ def cpu(self):
88
+ """Move the underlying tensor to the CPU."""
89
+ return self.__class__(self._data.cpu())
90
+
91
+ def cuda(self):
92
+ """Move the underlying tensor to the GPU."""
93
+ return self.__class__(self._data.cuda())
94
+
95
+ def pin_memory(self):
96
+ """Pin the underlying tensor to memory."""
97
+ return self.__class__(self._data.pin_memory())
98
+
99
+ def float(self):
100
+ """Cast the underlying tensor to float."""
101
+ return self.__class__(self._data.float())
102
+
103
+ def double(self):
104
+ """Cast the underlying tensor to double."""
105
+ return self.__class__(self._data.double())
106
+
107
+ def detach(self):
108
+ """Detach the underlying tensor."""
109
+ return self.__class__(self._data.detach())
110
+
111
+ def numpy(self):
112
+ """Convert the underlying tensor to a numpy array."""
113
+ return self._data.detach().cpu().numpy()
114
+
115
+ def new_tensor(self, *args, **kwargs):
116
+ """Create a new tensor of the same type and device."""
117
+ return self._data.new_tensor(*args, **kwargs)
118
+
119
+ def new_zeros(self, *args, **kwargs):
120
+ """Create a new tensor of the same type and device."""
121
+ return self._data.new_zeros(*args, **kwargs)
122
+
123
+ def new_ones(self, *args, **kwargs):
124
+ """Create a new tensor of the same type and device."""
125
+ return self._data.new_ones(*args, **kwargs)
126
+
127
+ def new_full(self, *args, **kwargs):
128
+ """Create a new tensor of the same type and device."""
129
+ return self._data.new_full(*args, **kwargs)
130
+
131
+ def new_empty(self, *args, **kwargs):
132
+ """Create a new tensor of the same type and device."""
133
+ return self._data.new_empty(*args, **kwargs)
134
+
135
+ def unsqueeze(self, *args, **kwargs):
136
+ """Create a new tensor of the same type and device."""
137
+ return self.__class__(self._data.unsqueeze(*args, **kwargs))
138
+
139
+ def squeeze(self, *args, **kwargs):
140
+ """Create a new tensor of the same type and device."""
141
+ return self.__class__(self._data.squeeze(*args, **kwargs))
142
+
143
+ @classmethod
144
+ def stack(cls, objects: List, dim=0, *, out=None):
145
+ """Stack a list of objects with the same type and shape."""
146
+ data = torch.stack([obj._data for obj in objects], dim=dim, out=out)
147
+ return cls(data)
148
+
149
+ @classmethod
150
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
151
+ """Support torch functions."""
152
+ if kwargs is None:
153
+ kwargs = {}
154
+ return cls.stack(*args, **kwargs) if func is torch.stack else NotImplemented
155
+
156
+
157
+ class EuclideanManifold:
158
+ """Simple euclidean manifold."""
159
+
160
+ @staticmethod
161
+ def J_plus(x: torch.Tensor) -> torch.Tensor:
162
+ """Plus operator Jacobian."""
163
+ return torch.eye(x.shape[-1]).to(x)
164
+
165
+ @staticmethod
166
+ def plus(x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
167
+ """Plus operator."""
168
+ return x + delta
169
+
170
+
171
+ class SphericalManifold:
172
+ """Implementation of the spherical manifold.
173
+
174
+ Following the derivation from 'Integrating Generic Sensor Fusion Algorithms with Sound State
175
+ Representations through Encapsulation of Manifolds' by Hertzberg et al. (B.2, p. 25).
176
+
177
+ Householder transformation following Algorithm 5.1.1 (p. 210) from 'Matrix Computations' by
178
+ Golub et al.
179
+ """
180
+
181
+ @staticmethod
182
+ def householder_vector(x: torch.Tensor) -> torch.Tensor:
183
+ """Return the Householder vector and beta.
184
+
185
+ Algorithm 5.1.1 (p. 210) from 'Matrix Computations' by Golub et al. (Johns Hopkins Studies
186
+ in Mathematical Sciences) but using the nth element of the input vector as pivot instead of
187
+ first.
188
+
189
+ This computes the vector v with v(n) = 1 and beta such that H = I - beta * v * v^T is
190
+ orthogonal and H * x = ||x||_2 * e_n.
191
+
192
+ Args:
193
+ x (torch.Tensor): [..., n] tensor.
194
+
195
+ Returns:
196
+ torch.Tensor: v of shape [..., n]
197
+ torch.Tensor: beta of shape [...]
198
+ """
199
+ sigma = torch.sum(x[..., :-1] ** 2, -1)
200
+ xpiv = x[..., -1]
201
+ norm = torch.norm(x, dim=-1)
202
+ if torch.any(sigma < 1e-7):
203
+ sigma = torch.where(sigma < 1e-7, sigma + 1e-7, sigma)
204
+ logger.warning("sigma < 1e-7")
205
+
206
+ vpiv = torch.where(xpiv < 0, xpiv - norm, -sigma / (xpiv + norm))
207
+ beta = 2 * vpiv**2 / (sigma + vpiv**2)
208
+ v = torch.cat([x[..., :-1] / vpiv[..., None], torch.ones_like(vpiv)[..., None]], -1)
209
+ return v, beta
210
+
211
+ @staticmethod
212
+ def apply_householder(y: torch.Tensor, v: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:
213
+ """Apply Householder transformation.
214
+
215
+ Args:
216
+ y (torch.Tensor): Vector to transform of shape [..., n].
217
+ v (torch.Tensor): Householder vector of shape [..., n].
218
+ beta (torch.Tensor): Householder beta of shape [...].
219
+
220
+ Returns:
221
+ torch.Tensor: Transformed vector of shape [..., n].
222
+ """
223
+ return y - v * (beta * torch.einsum("...i,...i->...", v, y))[..., None]
224
+
225
+ @classmethod
226
+ def J_plus(cls, x: torch.Tensor) -> torch.Tensor:
227
+ """Plus operator Jacobian."""
228
+ v, beta = cls.householder_vector(x)
229
+ H = -torch.einsum("..., ...k, ...l->...kl", beta, v, v)
230
+ H = H + torch.eye(H.shape[-1]).to(H)
231
+ return H[..., :-1] # J
232
+
233
+ @classmethod
234
+ def plus(cls, x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
235
+ """Plus operator.
236
+
237
+ Equation 109 (p. 25) from 'Integrating Generic Sensor Fusion Algorithms with Sound State
238
+ Representations through Encapsulation of Manifolds' by Hertzberg et al. but using the nth
239
+ element of the input vector as pivot instead of first.
240
+
241
+ Args:
242
+ x: point on the manifold
243
+ delta: tangent vector
244
+ """
245
+ eps = 1e-7
246
+ # keep norm is not equal to 1
247
+ nx = torch.norm(x, dim=-1, keepdim=True)
248
+ nd = torch.norm(delta, dim=-1, keepdim=True)
249
+
250
+ # make sure we don't divide by zero in backward as torch.where computes grad for both
251
+ # branches
252
+ nd_ = torch.where(nd < eps, nd + eps, nd)
253
+ sinc = torch.where(nd < eps, nd.new_ones(nd.shape), torch.sin(nd_) / nd_)
254
+
255
+ # cos is applied to last dim instead of first
256
+ exp_delta = torch.cat([sinc * delta, torch.cos(nd)], -1)
257
+
258
+ v, beta = cls.householder_vector(x)
259
+ return nx * cls.apply_householder(exp_delta, v, beta)
260
+
261
+
262
+ @torch.jit.script
263
+ def J_vecnorm(vec: torch.Tensor) -> torch.Tensor:
264
+ """Compute the jacobian of vec / norm2(vec).
265
+
266
+ Args:
267
+ vec (torch.Tensor): [..., D] tensor.
268
+
269
+ Returns:
270
+ torch.Tensor: [..., D, D] Jacobian.
271
+ """
272
+ D = vec.shape[-1]
273
+ norm_x = torch.norm(vec, dim=-1, keepdim=True).unsqueeze(-1) # (..., 1, 1)
274
+
275
+ if (norm_x == 0).any():
276
+ norm_x = norm_x + 1e-6
277
+
278
+ xxT = torch.einsum("...i,...j->...ij", vec, vec) # (..., D, D)
279
+ identity = torch.eye(D, device=vec.device, dtype=vec.dtype) # (D, D)
280
+
281
+ return identity / norm_x - (xxT / norm_x**3) # (..., D, D)
282
+
283
+
284
+ @torch.jit.script
285
+ def J_focal2fov(focal: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
286
+ """Compute the jacobian of the focal2fov function."""
287
+ return -4 * h / (4 * focal**2 + h**2)
288
+
289
+
290
+ @torch.jit.script
291
+ def J_up_projection(uv: torch.Tensor, abc: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
292
+ """Compute the jacobian of the up-vector projection.
293
+
294
+ Args:
295
+ uv (torch.Tensor): Normalized image coordinates of shape (..., 2).
296
+ abc (torch.Tensor): Gravity vector of shape (..., 3).
297
+ wrt (str, optional): Parameter to differentiate with respect to. Defaults to "uv".
298
+
299
+ Raises:
300
+ ValueError: If the wrt parameter is unknown.
301
+
302
+ Returns:
303
+ torch.Tensor: Jacobian with respect to the parameter.
304
+ """
305
+ if wrt == "uv":
306
+ c = abc[..., 2][..., None, None, None]
307
+ return -c * torch.eye(2, device=uv.device, dtype=uv.dtype).expand(uv.shape[:-1] + (2, 2))
308
+
309
+ elif wrt == "abc":
310
+ J = uv.new_zeros(uv.shape[:-1] + (2, 3))
311
+ J[..., 0, 0] = 1
312
+ J[..., 1, 1] = 1
313
+ J[..., 0, 2] = -uv[..., 0]
314
+ J[..., 1, 2] = -uv[..., 1]
315
+ return J
316
+
317
+ else:
318
+ raise ValueError(f"Unknown wrt: {wrt}")
geocalib/modules.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of MSCAN from SegNeXt: Rethinking Convolutional Attention Design for Semantic
2
+ Segmentation (NeurIPS 2022) adapted from
3
+
4
+ https://github.com/Visual-Attention-Network/SegNeXt/blob/main/mmseg/models/backbones/mscan.py
5
+
6
+
7
+ Light Hamburger Decoder adapted from:
8
+
9
+ https://github.com/Visual-Attention-Network/SegNeXt/blob/main/mmseg/models/decode_heads/ham_head.py
10
+ """
11
+
12
+ from typing import Dict, Tuple
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch.nn.modules.utils import _pair as to_2tuple
18
+
19
+ # flake8: noqa: E266
20
+ # mypy: ignore-errors
21
+
22
+
23
+ class ConvModule(nn.Module):
24
+ """Replacement for mmcv.cnn.ConvModule to avoid mmcv dependency."""
25
+
26
+ def __init__(
27
+ self,
28
+ in_channels: int,
29
+ out_channels: int,
30
+ kernel_size: int,
31
+ padding: int = 0,
32
+ use_norm: bool = False,
33
+ bias: bool = True,
34
+ ):
35
+ """Simple convolution block.
36
+
37
+ Args:
38
+ in_channels (int): Input channels.
39
+ out_channels (int): Output channels.
40
+ kernel_size (int): Kernel size.
41
+ padding (int, optional): Padding. Defaults to 0.
42
+ use_norm (bool, optional): Whether to use normalization. Defaults to False.
43
+ bias (bool, optional): Whether to use bias. Defaults to True.
44
+ """
45
+ super().__init__()
46
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)
47
+ self.bn = nn.BatchNorm2d(out_channels) if use_norm else nn.Identity()
48
+ self.activate = nn.ReLU(inplace=True)
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ """Forward pass."""
52
+ x = self.conv(x)
53
+ x = self.bn(x)
54
+ return self.activate(x)
55
+
56
+
57
+ class ResidualConvUnit(nn.Module):
58
+ """Residual convolution module."""
59
+
60
+ def __init__(self, features):
61
+ """Simple residual convolution block.
62
+
63
+ Args:
64
+ features (int): number of features
65
+ """
66
+ super().__init__()
67
+
68
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True)
69
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True)
70
+
71
+ self.relu = torch.nn.ReLU(inplace=True)
72
+
73
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
74
+ """Forward pass."""
75
+ out = self.relu(x)
76
+ out = self.conv1(out)
77
+ out = self.relu(out)
78
+ out = self.conv2(out)
79
+ return out + x
80
+
81
+
82
+ class FeatureFusionBlock(nn.Module):
83
+ """Feature fusion block."""
84
+
85
+ def __init__(self, features: int, unit2only=False, upsample=True):
86
+ """Feature fusion block.
87
+
88
+ Args:
89
+ features (int): Number of features.
90
+ unit2only (bool, optional): Whether to use only the second unit. Defaults to False.
91
+ upsample (bool, optional): Whether to upsample. Defaults to True.
92
+ """
93
+ super().__init__()
94
+ self.upsample = upsample
95
+
96
+ if not unit2only:
97
+ self.resConfUnit1 = ResidualConvUnit(features)
98
+ self.resConfUnit2 = ResidualConvUnit(features)
99
+
100
+ def forward(self, *xs: torch.Tensor) -> torch.Tensor:
101
+ """Forward pass."""
102
+ output = xs[0]
103
+
104
+ if len(xs) == 2:
105
+ output = output + self.resConfUnit1(xs[1])
106
+
107
+ output = self.resConfUnit2(output)
108
+
109
+ if self.upsample:
110
+ output = F.interpolate(output, scale_factor=2, mode="bilinear", align_corners=False)
111
+
112
+ return output
113
+
114
+
115
+ ###################################################
116
+ ########### Light Hamburger Decoder ###############
117
+ ###################################################
118
+
119
+
120
+ class NMF2D(nn.Module):
121
+ """Non-negative Matrix Factorization (NMF) for 2D data."""
122
+
123
+ def __init__(self):
124
+ """Non-negative Matrix Factorization (NMF) for 2D data."""
125
+ super().__init__()
126
+ self.S, self.D, self.R = 1, 512, 64
127
+ self.train_steps = 6
128
+ self.eval_steps = 7
129
+ self.inv_t = 1
130
+
131
+ def _build_bases(self, B: int, S: int, D: int, R: int, device: str = "cpu") -> torch.Tensor:
132
+ bases = torch.rand((B * S, D, R)).to(device)
133
+ return F.normalize(bases, dim=1)
134
+
135
+ def local_step(
136
+ self, x: torch.Tensor, bases: torch.Tensor, coef: torch.Tensor
137
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
138
+ """Update bases and coefficient."""
139
+ # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
140
+ numerator = torch.bmm(x.transpose(1, 2), bases)
141
+ # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
142
+ denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
143
+ # Multiplicative Update
144
+ coef = coef * numerator / (denominator + 1e-6)
145
+ # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
146
+ numerator = torch.bmm(x, coef)
147
+ # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
148
+ denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
149
+ # Multiplicative Update
150
+ bases = bases * numerator / (denominator + 1e-6)
151
+ return bases, coef
152
+
153
+ def compute_coef(
154
+ self, x: torch.Tensor, bases: torch.Tensor, coef: torch.Tensor
155
+ ) -> torch.Tensor:
156
+ """Compute coefficient."""
157
+ # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
158
+ numerator = torch.bmm(x.transpose(1, 2), bases)
159
+ # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
160
+ denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
161
+ # multiplication update
162
+ return coef * numerator / (denominator + 1e-6)
163
+
164
+ def local_inference(
165
+ self, x: torch.Tensor, bases: torch.Tensor
166
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
167
+ """Local inference."""
168
+ # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
169
+ coef = torch.bmm(x.transpose(1, 2), bases)
170
+ coef = F.softmax(self.inv_t * coef, dim=-1)
171
+
172
+ steps = self.train_steps if self.training else self.eval_steps
173
+ for _ in range(steps):
174
+ bases, coef = self.local_step(x, bases, coef)
175
+
176
+ return bases, coef
177
+
178
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
179
+ """Forward pass."""
180
+ B, C, H, W = x.shape
181
+
182
+ # (B, C, H, W) -> (B * S, D, N)
183
+ D = C // self.S
184
+ N = H * W
185
+ x = x.view(B * self.S, D, N)
186
+
187
+ # (S, D, R) -> (B * S, D, R)
188
+ bases = self._build_bases(B, self.S, D, self.R, device=x.device)
189
+ bases, coef = self.local_inference(x, bases)
190
+ # (B * S, N, R)
191
+ coef = self.compute_coef(x, bases, coef)
192
+ # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
193
+ x = torch.bmm(bases, coef.transpose(1, 2))
194
+ # (B * S, D, N) -> (B, C, H, W)
195
+ x = x.view(B, C, H, W)
196
+ # (B * H, D, R) -> (B, H, N, D)
197
+ bases = bases.view(B, self.S, D, self.R)
198
+
199
+ return x
200
+
201
+
202
+ class Hamburger(nn.Module):
203
+ """Hamburger Module."""
204
+
205
+ def __init__(self, ham_channels: int = 512):
206
+ """Hambuger Module.
207
+
208
+ Args:
209
+ ham_channels (int, optional): Number of channels in the hamburger module. Defaults to
210
+ 512.
211
+ """
212
+ super().__init__()
213
+ self.ham_in = ConvModule(ham_channels, ham_channels, 1)
214
+ self.ham = NMF2D()
215
+ self.ham_out = ConvModule(ham_channels, ham_channels, 1)
216
+
217
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
218
+ """Forward pass."""
219
+ enjoy = self.ham_in(x)
220
+ enjoy = F.relu(enjoy, inplace=False)
221
+ enjoy = self.ham(enjoy)
222
+ enjoy = self.ham_out(enjoy)
223
+ ham = F.relu(x + enjoy, inplace=False)
224
+ return ham
225
+
226
+
227
+ class LightHamHead(nn.Module):
228
+ """Is Attention Better Than Matrix Decomposition?
229
+
230
+ This head is the implementation of `HamNet <https://arxiv.org/abs/2109.04553>`.
231
+ """
232
+
233
+ def __init__(self):
234
+ """Light hamburger decoder head."""
235
+ super().__init__()
236
+ self.in_index = [0, 1, 2, 3]
237
+ self.in_channels = [64, 128, 320, 512]
238
+ self.out_channels = 64
239
+ self.ham_channels = 512
240
+ self.align_corners = False
241
+
242
+ self.squeeze = ConvModule(sum(self.in_channels), self.ham_channels, 1)
243
+
244
+ self.hamburger = Hamburger(self.ham_channels)
245
+
246
+ self.align = ConvModule(self.ham_channels, self.out_channels, 1)
247
+
248
+ self.linear_pred_uncertainty = nn.Sequential(
249
+ ConvModule(
250
+ in_channels=self.out_channels,
251
+ out_channels=self.out_channels,
252
+ kernel_size=3,
253
+ padding=1,
254
+ bias=False,
255
+ ),
256
+ nn.Conv2d(in_channels=self.out_channels, out_channels=1, kernel_size=1),
257
+ )
258
+
259
+ self.out_conv = ConvModule(self.out_channels, self.out_channels, 3, padding=1, bias=False)
260
+ self.ll_fusion = FeatureFusionBlock(self.out_channels, upsample=False)
261
+
262
+ def forward(self, features: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
263
+ """Forward pass."""
264
+ inputs = [features["hl"][i] for i in self.in_index]
265
+
266
+ inputs = [
267
+ F.interpolate(
268
+ level, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
269
+ )
270
+ for level in inputs
271
+ ]
272
+
273
+ inputs = torch.cat(inputs, dim=1)
274
+ x = self.squeeze(inputs)
275
+
276
+ x = self.hamburger(x)
277
+
278
+ feats = self.align(x)
279
+
280
+ assert "ll" in features, "Low-level features are required for this model"
281
+ feats = F.interpolate(feats, scale_factor=2, mode="bilinear", align_corners=False)
282
+ feats = self.out_conv(feats)
283
+ feats = F.interpolate(feats, scale_factor=2, mode="bilinear", align_corners=False)
284
+ feats = self.ll_fusion(feats, features["ll"].clone())
285
+
286
+ uncertainty = self.linear_pred_uncertainty(feats).squeeze(1)
287
+
288
+ return feats, uncertainty
289
+
290
+
291
+ ###################################################
292
+ ########### MSCAN ################
293
+ ###################################################
294
+
295
+
296
+ class DWConv(nn.Module):
297
+ """Depthwise convolution."""
298
+
299
+ def __init__(self, dim: int = 768):
300
+ """Depthwise convolution.
301
+
302
+ Args:
303
+ dim (int, optional): Number of features. Defaults to 768.
304
+ """
305
+ super().__init__()
306
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
307
+
308
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
309
+ """Forward pass."""
310
+ return self.dwconv(x)
311
+
312
+
313
+ class Mlp(nn.Module):
314
+ """MLP module."""
315
+
316
+ def __init__(
317
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
318
+ ):
319
+ """Initialize the MLP."""
320
+ super().__init__()
321
+ out_features = out_features or in_features
322
+ hidden_features = hidden_features or in_features
323
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
324
+ self.dwconv = DWConv(hidden_features)
325
+ self.act = act_layer()
326
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
327
+ self.drop = nn.Dropout(drop)
328
+
329
+ def forward(self, x):
330
+ """Forward pass."""
331
+ x = self.fc1(x)
332
+
333
+ x = self.dwconv(x)
334
+ x = self.act(x)
335
+ x = self.drop(x)
336
+ x = self.fc2(x)
337
+ x = self.drop(x)
338
+
339
+ return x
340
+
341
+
342
+ class StemConv(nn.Module):
343
+ """Simple stem convolution module."""
344
+
345
+ def __init__(self, in_channels: int, out_channels: int):
346
+ """Simple stem convolution module.
347
+
348
+ Args:
349
+ in_channels (int): Input channels.
350
+ out_channels (int): Output channels.
351
+ """
352
+ super().__init__()
353
+
354
+ self.proj = nn.Sequential(
355
+ nn.Conv2d(
356
+ in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
357
+ ),
358
+ nn.BatchNorm2d(out_channels // 2),
359
+ nn.GELU(),
360
+ nn.Conv2d(
361
+ out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
362
+ ),
363
+ nn.BatchNorm2d(out_channels),
364
+ )
365
+
366
+ def forward(self, x):
367
+ """Forward pass."""
368
+ x = self.proj(x)
369
+ _, _, H, W = x.size()
370
+ x = x.flatten(2).transpose(1, 2)
371
+ return x, H, W
372
+
373
+
374
+ class AttentionModule(nn.Module):
375
+ """Attention module."""
376
+
377
+ def __init__(self, dim: int):
378
+ """Attention module.
379
+
380
+ Args:
381
+ dim (int): Number of features.
382
+ """
383
+ super().__init__()
384
+ self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
385
+ self.conv0_1 = nn.Conv2d(dim, dim, (1, 7), padding=(0, 3), groups=dim)
386
+ self.conv0_2 = nn.Conv2d(dim, dim, (7, 1), padding=(3, 0), groups=dim)
387
+
388
+ self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)
389
+ self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)
390
+
391
+ self.conv2_1 = nn.Conv2d(dim, dim, (1, 21), padding=(0, 10), groups=dim)
392
+ self.conv2_2 = nn.Conv2d(dim, dim, (21, 1), padding=(10, 0), groups=dim)
393
+ self.conv3 = nn.Conv2d(dim, dim, 1)
394
+
395
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
396
+ """Forward pass."""
397
+ u = x.clone()
398
+ attn = self.conv0(x)
399
+
400
+ attn_0 = self.conv0_1(attn)
401
+ attn_0 = self.conv0_2(attn_0)
402
+
403
+ attn_1 = self.conv1_1(attn)
404
+ attn_1 = self.conv1_2(attn_1)
405
+
406
+ attn_2 = self.conv2_1(attn)
407
+ attn_2 = self.conv2_2(attn_2)
408
+ attn = attn + attn_0 + attn_1 + attn_2
409
+
410
+ attn = self.conv3(attn)
411
+ return attn * u
412
+
413
+
414
+ class SpatialAttention(nn.Module):
415
+ """Spatial attention module."""
416
+
417
+ def __init__(self, dim: int):
418
+ """Spatial attention module.
419
+
420
+ Args:
421
+ dim (int): Number of features.
422
+ """
423
+ super().__init__()
424
+ self.d_model = dim
425
+ self.proj_1 = nn.Conv2d(dim, dim, 1)
426
+ self.activation = nn.GELU()
427
+ self.spatial_gating_unit = AttentionModule(dim)
428
+ self.proj_2 = nn.Conv2d(dim, dim, 1)
429
+
430
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
431
+ """Forward pass."""
432
+ shorcut = x.clone()
433
+ x = self.proj_1(x)
434
+ x = self.activation(x)
435
+ x = self.spatial_gating_unit(x)
436
+ x = self.proj_2(x)
437
+ x = x + shorcut
438
+ return x
439
+
440
+
441
+ class Block(nn.Module):
442
+ """MSCAN block."""
443
+
444
+ def __init__(
445
+ self, dim: int, mlp_ratio: float = 4.0, drop: float = 0.0, act_layer: nn.Module = nn.GELU
446
+ ):
447
+ """MSCAN block.
448
+
449
+ Args:
450
+ dim (int): Number of features.
451
+ mlp_ratio (float, optional): Ratio of the hidden features in the MLP. Defaults to 4.0.
452
+ drop (float, optional): Dropout rate. Defaults to 0.0.
453
+ act_layer (nn.Module, optional): Activation layer. Defaults to nn.GELU.
454
+ """
455
+ super().__init__()
456
+ self.norm1 = nn.BatchNorm2d(dim)
457
+ self.attn = SpatialAttention(dim)
458
+ self.drop_path = nn.Identity() # only used in training
459
+ self.norm2 = nn.BatchNorm2d(dim)
460
+ mlp_hidden_dim = int(dim * mlp_ratio)
461
+ self.mlp = Mlp(
462
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
463
+ )
464
+ layer_scale_init_value = 1e-2
465
+ self.layer_scale_1 = nn.Parameter(
466
+ layer_scale_init_value * torch.ones((dim)), requires_grad=True
467
+ )
468
+ self.layer_scale_2 = nn.Parameter(
469
+ layer_scale_init_value * torch.ones((dim)), requires_grad=True
470
+ )
471
+
472
+ def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
473
+ """Forward pass."""
474
+ B, N, C = x.shape
475
+ x = x.permute(0, 2, 1).view(B, C, H, W)
476
+ x = x + self.drop_path(self.layer_scale_1[..., None, None] * self.attn(self.norm1(x)))
477
+ x = x + self.drop_path(self.layer_scale_2[..., None, None] * self.mlp(self.norm2(x)))
478
+ return x.view(B, C, N).permute(0, 2, 1)
479
+
480
+
481
+ class OverlapPatchEmbed(nn.Module):
482
+ """Image to Patch Embedding"""
483
+
484
+ def __init__(
485
+ self, patch_size: int = 7, stride: int = 4, in_chans: int = 3, embed_dim: int = 768
486
+ ):
487
+ """Image to Patch Embedding.
488
+
489
+ Args:
490
+ patch_size (int, optional): Image patch size. Defaults to 7.
491
+ stride (int, optional): Stride. Defaults to 4.
492
+ in_chans (int, optional): Number of input channels. Defaults to 3.
493
+ embed_dim (int, optional): Embedding dimension. Defaults to 768.
494
+ """
495
+ super().__init__()
496
+ patch_size = to_2tuple(patch_size)
497
+
498
+ self.proj = nn.Conv2d(
499
+ in_chans,
500
+ embed_dim,
501
+ kernel_size=patch_size,
502
+ stride=stride,
503
+ padding=(patch_size[0] // 2, patch_size[1] // 2),
504
+ )
505
+ self.norm = nn.BatchNorm2d(embed_dim)
506
+
507
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
508
+ """Forward pass."""
509
+ x = self.proj(x)
510
+ _, _, H, W = x.shape
511
+ x = self.norm(x)
512
+ x = x.flatten(2).transpose(1, 2)
513
+ return x, H, W
514
+
515
+
516
+ class MSCAN(nn.Module):
517
+ """Multi-scale convolutional attention network."""
518
+
519
+ def __init__(self):
520
+ """Multi-scale convolutional attention network."""
521
+ super().__init__()
522
+ self.in_channels = 3
523
+ self.embed_dims = [64, 128, 320, 512]
524
+ self.mlp_ratios = [8, 8, 4, 4]
525
+ self.drop_rate = 0.0
526
+ self.drop_path_rate = 0.1
527
+ self.depths = [3, 3, 12, 3]
528
+ self.num_stages = 4
529
+
530
+ for i in range(self.num_stages):
531
+ if i == 0:
532
+ patch_embed = StemConv(3, self.embed_dims[0])
533
+ else:
534
+ patch_embed = OverlapPatchEmbed(
535
+ patch_size=7 if i == 0 else 3,
536
+ stride=4 if i == 0 else 2,
537
+ in_chans=self.in_chans if i == 0 else self.embed_dims[i - 1],
538
+ embed_dim=self.embed_dims[i],
539
+ )
540
+
541
+ block = nn.ModuleList(
542
+ [
543
+ Block(
544
+ dim=self.embed_dims[i],
545
+ mlp_ratio=self.mlp_ratios[i],
546
+ drop=self.drop_rate,
547
+ )
548
+ for _ in range(self.depths[i])
549
+ ]
550
+ )
551
+ norm = nn.LayerNorm(self.embed_dims[i])
552
+
553
+ setattr(self, f"patch_embed{i + 1}", patch_embed)
554
+ setattr(self, f"block{i + 1}", block)
555
+ setattr(self, f"norm{i + 1}", norm)
556
+
557
+ def forward(self, data):
558
+ """Forward pass."""
559
+ # rgb -> bgr and from [0, 1] to [0, 255]
560
+ x = data["image"][:, [2, 1, 0], :, :] * 255.0
561
+ B = x.shape[0]
562
+
563
+ outs = []
564
+ for i in range(self.num_stages):
565
+ patch_embed = getattr(self, f"patch_embed{i + 1}")
566
+ block = getattr(self, f"block{i + 1}")
567
+ norm = getattr(self, f"norm{i + 1}")
568
+ x, H, W = patch_embed(x)
569
+ for blk in block:
570
+ x = blk(x, H, W)
571
+ x = norm(x)
572
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
573
+ outs.append(x)
574
+
575
+ return {"features": outs}
geocalib/perspective_fields.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of perspective fields.
2
+
3
+ Adapted from https://github.com/jinlinyi/PerspectiveFields/blob/main/perspective2d/utils/panocam.py
4
+ """
5
+
6
+ from typing import Tuple
7
+
8
+ import torch
9
+ from torch.nn import functional as F
10
+
11
+ from geocalib.camera import BaseCamera
12
+ from geocalib.gravity import Gravity
13
+ from geocalib.misc import J_up_projection, J_vecnorm, SphericalManifold
14
+
15
+ # flake8: noqa: E266
16
+
17
+
18
+ def get_horizon_line(camera: BaseCamera, gravity: Gravity, relative: bool = True) -> torch.Tensor:
19
+ """Get the horizon line from the camera parameters.
20
+
21
+ Args:
22
+ camera (Camera): Camera parameters.
23
+ gravity (Gravity): Gravity vector.
24
+ relative (bool, optional): Whether to normalize horizon line by img_h. Defaults to True.
25
+
26
+ Returns:
27
+ torch.Tensor: In image frame, fraction of image left/right border intersection with
28
+ respect to image height.
29
+ """
30
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
31
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
32
+
33
+ # project horizon midpoint to image plane
34
+ horizon_midpoint = camera.new_tensor([0, 0, 1])
35
+ horizon_midpoint = camera.K @ gravity.R @ horizon_midpoint
36
+ midpoint = horizon_midpoint[:2] / horizon_midpoint[2]
37
+
38
+ # compute left and right offset to borders
39
+ left_offset = midpoint[0] * torch.tan(gravity.roll)
40
+ right_offset = (camera.size[0] - midpoint[0]) * torch.tan(gravity.roll)
41
+ left, right = midpoint[1] + left_offset, midpoint[1] - right_offset
42
+
43
+ horizon = camera.new_tensor([left, right])
44
+ return horizon / camera.size[1] if relative else horizon
45
+
46
+
47
+ def get_up_field(camera: BaseCamera, gravity: Gravity, normalize: bool = True) -> torch.Tensor:
48
+ """Get the up vector field from the camera parameters.
49
+
50
+ Args:
51
+ camera (Camera): Camera parameters.
52
+ normalize (bool, optional): Whether to normalize the up vector. Defaults to True.
53
+
54
+ Returns:
55
+ torch.Tensor: up vector field as tensor of shape (..., h, w, 2).
56
+ """
57
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
58
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
59
+
60
+ w, h = camera.size[0].unbind(-1)
61
+ h, w = h.round().to(int), w.round().to(int)
62
+
63
+ uv = camera.normalize(camera.pixel_coordinates())
64
+
65
+ # projected up is (a, b) - c * (u, v)
66
+ abc = gravity.vec3d
67
+ projected_up2d = abc[..., None, :2] - abc[..., 2, None, None] * uv # (..., N, 2)
68
+
69
+ if hasattr(camera, "dist"):
70
+ d_uv = camera.distort(uv, return_scale=True)[0] # (..., N, 1)
71
+ d_uv = torch.diag_embed(d_uv.expand(d_uv.shape[:-1] + (2,))) # (..., N, 2, 2)
72
+ offset = camera.up_projection_offset(uv) # (..., N, 2)
73
+ offset = torch.einsum("...i,...j->...ij", offset, uv) # (..., N, 2, 2)
74
+
75
+ # (..., N, 2)
76
+ projected_up2d = torch.einsum("...Nij,...Nj->...Ni", d_uv + offset, projected_up2d)
77
+
78
+ if normalize:
79
+ projected_up2d = F.normalize(projected_up2d, dim=-1) # (..., N, 2)
80
+
81
+ return projected_up2d.reshape(camera.shape[0], h, w, 2)
82
+
83
+
84
+ def J_up_field(
85
+ camera: BaseCamera, gravity: Gravity, spherical: bool = False, log_focal: bool = False
86
+ ) -> torch.Tensor:
87
+ """Get the jacobian of the up field.
88
+
89
+ Args:
90
+ camera (Camera): Camera parameters.
91
+ gravity (Gravity): Gravity vector.
92
+ spherical (bool, optional): Whether to use spherical coordinates. Defaults to False.
93
+ log_focal (bool, optional): Whether to use log-focal length. Defaults to False.
94
+
95
+ Returns:
96
+ torch.Tensor: Jacobian of the up field as a tensor of shape (..., h, w, 2, 2, 3).
97
+ """
98
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
99
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
100
+
101
+ w, h = camera.size[0].unbind(-1)
102
+ h, w = h.round().to(int), w.round().to(int)
103
+
104
+ # Forward
105
+ xy = camera.pixel_coordinates()
106
+ uv = camera.normalize(xy)
107
+
108
+ projected_up2d = gravity.vec3d[..., None, :2] - gravity.vec3d[..., 2, None, None] * uv
109
+
110
+ # Backward
111
+ J = []
112
+
113
+ # (..., N, 2, 2)
114
+ J_norm2proj = J_vecnorm(
115
+ get_up_field(camera, gravity, normalize=False).reshape(camera.shape[0], -1, 2)
116
+ )
117
+
118
+ # distortion values
119
+ if hasattr(camera, "dist"):
120
+ d_uv = camera.distort(uv, return_scale=True)[0] # (..., N, 1)
121
+ d_uv = torch.diag_embed(d_uv.expand(d_uv.shape[:-1] + (2,))) # (..., N, 2, 2)
122
+ offset = camera.up_projection_offset(uv) # (..., N, 2)
123
+ offset_uv = torch.einsum("...i,...j->...ij", offset, uv) # (..., N, 2, 2)
124
+
125
+ ######################
126
+ ## Gravity Jacobian ##
127
+ ######################
128
+
129
+ J_proj2abc = J_up_projection(uv, gravity.vec3d, wrt="abc") # (..., N, 2, 3)
130
+
131
+ if hasattr(camera, "dist"):
132
+ # (..., N, 2, 3)
133
+ J_proj2abc = torch.einsum("...Nij,...Njk->...Nik", d_uv + offset_uv, J_proj2abc)
134
+
135
+ J_abc2delta = SphericalManifold.J_plus(gravity.vec3d) if spherical else gravity.J_rp()
136
+ J_proj2delta = torch.einsum("...Nij,...jk->...Nik", J_proj2abc, J_abc2delta)
137
+ J_up2delta = torch.einsum("...Nij,...Njk->...Nik", J_norm2proj, J_proj2delta)
138
+ J.append(J_up2delta)
139
+
140
+ ######################
141
+ ### Focal Jacobian ###
142
+ ######################
143
+
144
+ J_proj2uv = J_up_projection(uv, gravity.vec3d, wrt="uv") # (..., N, 2, 2)
145
+
146
+ if hasattr(camera, "dist"):
147
+ J_proj2up = torch.einsum("...Nij,...Njk->...Nik", d_uv + offset_uv, J_proj2uv)
148
+ J_proj2duv = torch.einsum("...i,...j->...ji", offset, projected_up2d)
149
+
150
+ inner = (uv * projected_up2d).sum(-1)[..., None, None]
151
+ J_proj2offset1 = inner * camera.J_up_projection_offset(uv, wrt="uv")
152
+ J_proj2offset2 = torch.einsum("...i,...j->...ij", offset, projected_up2d) # (..., N, 2, 2)
153
+ J_proj2uv = (J_proj2duv + J_proj2offset1 + J_proj2offset2) + J_proj2up
154
+
155
+ J_uv2f = camera.J_normalize(xy) # (..., N, 2, 2)
156
+
157
+ if log_focal:
158
+ J_uv2f = J_uv2f * camera.f[..., None, None, :] # (..., N, 2, 2)
159
+
160
+ J_uv2f = J_uv2f.sum(-1) # (..., N, 2)
161
+
162
+ J_proj2f = torch.einsum("...ij,...j->...i", J_proj2uv, J_uv2f) # (..., N, 2)
163
+ J_up2f = torch.einsum("...Nij,...Nj->...Ni", J_norm2proj, J_proj2f)[..., None] # (..., N, 2, 1)
164
+ J.append(J_up2f)
165
+
166
+ ######################
167
+ ##### K1 Jacobian ####
168
+ ######################
169
+
170
+ if hasattr(camera, "dist"):
171
+ J_duv = camera.J_distort(uv, wrt="scale2dist")
172
+ J_duv = torch.diag_embed(J_duv.expand(J_duv.shape[:-1] + (2,))) # (..., N, 2, 2)
173
+ J_offset = torch.einsum(
174
+ "...i,...j->...ij", camera.J_up_projection_offset(uv, wrt="dist"), uv
175
+ )
176
+ J_proj2k1 = torch.einsum("...Nij,...Nj->...Ni", J_duv + J_offset, projected_up2d)
177
+ J_k1 = torch.einsum("...Nij,...Nj->...Ni", J_norm2proj, J_proj2k1)[..., None]
178
+ J.append(J_k1)
179
+
180
+ n_params = sum(j.shape[-1] for j in J)
181
+ return torch.cat(J, axis=-1).reshape(camera.shape[0], h, w, 2, n_params)
182
+
183
+
184
+ def get_latitude_field(camera: BaseCamera, gravity: Gravity) -> torch.Tensor:
185
+ """Get the latitudes of the camera pixels in radians.
186
+
187
+ Latitudes are defined as the angle between the ray and the up vector.
188
+
189
+ Args:
190
+ camera (Camera): Camera parameters.
191
+ gravity (Gravity): Gravity vector.
192
+
193
+ Returns:
194
+ torch.Tensor: Latitudes in radians as a tensor of shape (..., h, w, 1).
195
+ """
196
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
197
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
198
+
199
+ w, h = camera.size[0].unbind(-1)
200
+ h, w = h.round().to(int), w.round().to(int)
201
+
202
+ uv1, _ = camera.image2world(camera.pixel_coordinates())
203
+ rays = camera.pixel_bearing_many(uv1)
204
+
205
+ lat = torch.einsum("...Nj,...j->...N", rays, gravity.vec3d)
206
+
207
+ eps = 1e-6
208
+ lat_asin = torch.asin(lat.clamp(min=-1 + eps, max=1 - eps))
209
+
210
+ return lat_asin.reshape(camera.shape[0], h, w, 1)
211
+
212
+
213
+ def J_latitude_field(
214
+ camera: BaseCamera, gravity: Gravity, spherical: bool = False, log_focal: bool = False
215
+ ) -> torch.Tensor:
216
+ """Get the jacobian of the latitude field.
217
+
218
+ Args:
219
+ camera (Camera): Camera parameters.
220
+ gravity (Gravity): Gravity vector.
221
+ spherical (bool, optional): Whether to use spherical coordinates. Defaults to False.
222
+ log_focal (bool, optional): Whether to use log-focal length. Defaults to False.
223
+
224
+ Returns:
225
+ torch.Tensor: Jacobian of the latitude field as a tensor of shape (..., h, w, 1, 3).
226
+ """
227
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
228
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
229
+
230
+ w, h = camera.size[0].unbind(-1)
231
+ h, w = h.round().to(int), w.round().to(int)
232
+
233
+ # Forward
234
+ xy = camera.pixel_coordinates()
235
+ uv1, _ = camera.image2world(xy)
236
+ uv1_norm = camera.pixel_bearing_many(uv1) # (..., N, 3)
237
+
238
+ # Backward
239
+ J = []
240
+ J_norm2w_to_img = J_vecnorm(uv1)[..., :2] # (..., N, 2)
241
+
242
+ ######################
243
+ ## Gravity Jacobian ##
244
+ ######################
245
+
246
+ J_delta = SphericalManifold.J_plus(gravity.vec3d) if spherical else gravity.J_rp()
247
+ J_delta = torch.einsum("...Ni,...ij->...Nj", uv1_norm, J_delta) # (..., N, 2)
248
+ J.append(J_delta)
249
+
250
+ ######################
251
+ ### Focal Jacobian ###
252
+ ######################
253
+
254
+ J_w_to_img2f = camera.J_image2world(xy, "f") # (..., N, 2, 2)
255
+ if log_focal:
256
+ J_w_to_img2f = J_w_to_img2f * camera.f[..., None, None, :]
257
+ J_w_to_img2f = J_w_to_img2f.sum(-1) # (..., N, 2)
258
+
259
+ J_norm2f = torch.einsum("...Nij,...Nj->...Ni", J_norm2w_to_img, J_w_to_img2f) # (..., N, 3)
260
+ J_f = torch.einsum("...Ni,...i->...N", J_norm2f, gravity.vec3d).unsqueeze(-1) # (..., N, 1)
261
+ J.append(J_f)
262
+
263
+ ######################
264
+ ##### K1 Jacobian ####
265
+ ######################
266
+
267
+ if hasattr(camera, "dist"):
268
+ J_w_to_img2k1 = camera.J_image2world(xy, "dist") # (..., N, 2)
269
+ # (..., N, 2)
270
+ J_norm2k1 = torch.einsum("...Nij,...Nj->...Ni", J_norm2w_to_img, J_w_to_img2k1)
271
+ # (..., N, 1)
272
+ J_k1 = torch.einsum("...Ni,...i->...N", J_norm2k1, gravity.vec3d).unsqueeze(-1)
273
+ J.append(J_k1)
274
+
275
+ n_params = sum(j.shape[-1] for j in J)
276
+ return torch.cat(J, axis=-1).reshape(camera.shape[0], h, w, 1, n_params)
277
+
278
+
279
+ def get_perspective_field(
280
+ camera: BaseCamera,
281
+ gravity: Gravity,
282
+ use_up: bool = True,
283
+ use_latitude: bool = True,
284
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
285
+ """Get the perspective field from the camera parameters.
286
+
287
+ Args:
288
+ camera (Camera): Camera parameters.
289
+ gravity (Gravity): Gravity vector.
290
+ use_up (bool, optional): Whether to include the up vector field. Defaults to True.
291
+ use_latitude (bool, optional): Whether to include the latitude field. Defaults to True.
292
+
293
+ Returns:
294
+ Tuple[torch.Tensor, torch.Tensor]: Up and latitude fields as tensors of shape
295
+ (..., 2, h, w) and (..., 1, h, w).
296
+ """
297
+ assert use_up or use_latitude, "At least one of use_up or use_latitude must be True."
298
+
299
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
300
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
301
+
302
+ w, h = camera.size[0].unbind(-1)
303
+ h, w = h.round().to(int), w.round().to(int)
304
+
305
+ if use_up:
306
+ permute = (0, 3, 1, 2)
307
+ # (..., 2, h, w)
308
+ up = get_up_field(camera, gravity).permute(permute)
309
+ else:
310
+ shape = (camera.shape[0], 2, h, w)
311
+ up = camera.new_zeros(shape)
312
+
313
+ if use_latitude:
314
+ permute = (0, 3, 1, 2)
315
+ # (..., 1, h, w)
316
+ lat = get_latitude_field(camera, gravity).permute(permute)
317
+ else:
318
+ shape = (camera.shape[0], 1, h, w)
319
+ lat = camera.new_zeros(shape)
320
+
321
+ return up, lat
322
+
323
+
324
+ def J_perspective_field(
325
+ camera: BaseCamera,
326
+ gravity: Gravity,
327
+ use_up: bool = True,
328
+ use_latitude: bool = True,
329
+ spherical: bool = False,
330
+ log_focal: bool = False,
331
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
332
+ """Get the jacobian of the perspective field.
333
+
334
+ Args:
335
+ camera (Camera): Camera parameters.
336
+ gravity (Gravity): Gravity vector.
337
+ use_up (bool, optional): Whether to include the up vector field. Defaults to True.
338
+ use_latitude (bool, optional): Whether to include the latitude field. Defaults to True.
339
+ spherical (bool, optional): Whether to use spherical coordinates. Defaults to False.
340
+ log_focal (bool, optional): Whether to use log-focal length. Defaults to False.
341
+
342
+ Returns:
343
+ Tuple[torch.Tensor, torch.Tensor]: Up and latitude jacobians as tensors of shape
344
+ (..., h, w, 2, 4) and (..., h, w, 1, 4).
345
+ """
346
+ assert use_up or use_latitude, "At least one of use_up or use_latitude must be True."
347
+
348
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
349
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
350
+
351
+ w, h = camera.size[0].unbind(-1)
352
+ h, w = h.round().to(int), w.round().to(int)
353
+
354
+ if use_up:
355
+ J_up = J_up_field(camera, gravity, spherical, log_focal) # (..., h, w, 2, 4)
356
+ else:
357
+ shape = (camera.shape[0], h, w, 2, 4)
358
+ J_up = camera.new_zeros(shape)
359
+
360
+ if use_latitude:
361
+ J_lat = J_latitude_field(camera, gravity, spherical, log_focal) # (..., h, w, 1, 4)
362
+ else:
363
+ shape = (camera.shape[0], h, w, 1, 4)
364
+ J_lat = camera.new_zeros(shape)
365
+
366
+ return J_up, J_lat
geocalib/utils.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image loading and general conversion utilities."""
2
+
3
+ import collections.abc as collections
4
+ from pathlib import Path
5
+ from types import SimpleNamespace
6
+ from typing import Dict, Optional, Tuple
7
+
8
+ import cv2
9
+ import kornia
10
+ import numpy as np
11
+ import torch
12
+ import torchvision
13
+
14
+ # mypy: ignore-errors
15
+
16
+
17
+ def fit_to_multiple(x: torch.Tensor, multiple: int, mode: str = "center", crop: bool = False):
18
+ """Get padding to make the image size a multiple of the given number.
19
+
20
+ Args:
21
+ x (torch.Tensor): Input tensor.
22
+ multiple (int, optional): Multiple to fit to.
23
+ crop (bool, optional): Whether to crop or pad. Defaults to False.
24
+
25
+ Returns:
26
+ torch.Tensor: Padding.
27
+ """
28
+ h, w = x.shape[-2:]
29
+
30
+ if crop:
31
+ pad_w = (w // multiple) * multiple - w
32
+ pad_h = (h // multiple) * multiple - h
33
+ else:
34
+ pad_w = (multiple - w % multiple) % multiple
35
+ pad_h = (multiple - h % multiple) % multiple
36
+
37
+ if mode == "center":
38
+ pad_l = pad_w // 2
39
+ pad_r = pad_w - pad_l
40
+ pad_t = pad_h // 2
41
+ pad_b = pad_h - pad_t
42
+ elif mode == "left":
43
+ pad_l, pad_r = 0, pad_w
44
+ pad_t, pad_b = 0, pad_h
45
+ else:
46
+ raise ValueError(f"Unknown mode {mode}")
47
+
48
+ return (pad_l, pad_r, pad_t, pad_b)
49
+
50
+
51
+ def fit_features_to_multiple(
52
+ features: torch.Tensor, multiple: int = 32, crop: bool = False
53
+ ) -> Tuple[torch.Tensor, Tuple[int, int]]:
54
+ """Pad or crop image to a multiple of the given number.
55
+
56
+ Args:
57
+ features (torch.Tensor): Input features.
58
+ multiple (int, optional): Multiple. Defaults to 32.
59
+ crop (bool, optional): Whether to crop or pad. Defaults to False.
60
+
61
+ Returns:
62
+ Tuple[torch.Tensor, Tuple[int, int]]: Padded features and padding.
63
+ """
64
+ pad = fit_to_multiple(features, multiple, crop=crop)
65
+ return torch.nn.functional.pad(features, pad, mode="reflect"), pad
66
+
67
+
68
+ class ImagePreprocessor:
69
+ """Preprocess images for calibration."""
70
+
71
+ default_conf = {
72
+ "resize": 320, # target edge length, None for no resizing
73
+ "edge_divisible_by": None,
74
+ "side": "short",
75
+ "interpolation": "bilinear",
76
+ "align_corners": None,
77
+ "antialias": True,
78
+ "square_crop": False,
79
+ "add_padding_mask": False,
80
+ "resize_backend": "kornia", # torchvision, kornia
81
+ }
82
+
83
+ def __init__(self, conf) -> None:
84
+ """Initialize the image preprocessor."""
85
+ self.conf = {**self.default_conf, **conf}
86
+ self.conf = SimpleNamespace(**self.conf)
87
+
88
+ def __call__(self, img: torch.Tensor, interpolation: Optional[str] = None) -> dict:
89
+ """Resize and preprocess an image, return image and resize scale."""
90
+ h, w = img.shape[-2:]
91
+ size = h, w
92
+
93
+ if self.conf.square_crop:
94
+ min_size = min(h, w)
95
+ offset = (h - min_size) // 2, (w - min_size) // 2
96
+ img = img[:, offset[0] : offset[0] + min_size, offset[1] : offset[1] + min_size]
97
+ size = img.shape[-2:]
98
+
99
+ if self.conf.resize is not None:
100
+ if interpolation is None:
101
+ interpolation = self.conf.interpolation
102
+ size = self.get_new_image_size(h, w)
103
+ img = self.resize(img, size, interpolation)
104
+
105
+ scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
106
+ T = np.diag([scale[0].cpu(), scale[1].cpu(), 1])
107
+
108
+ data = {
109
+ "scales": scale,
110
+ "image_size": np.array(size[::-1]),
111
+ "transform": T,
112
+ "original_image_size": np.array([w, h]),
113
+ }
114
+
115
+ if self.conf.edge_divisible_by is not None:
116
+ # crop to make the edge divisible by a number
117
+ w_, h_ = img.shape[-1], img.shape[-2]
118
+ img, _ = fit_features_to_multiple(img, self.conf.edge_divisible_by, crop=True)
119
+ crop_pad = torch.Tensor([img.shape[-1] - w_, img.shape[-2] - h_]).to(img)
120
+ data["crop_pad"] = crop_pad
121
+ data["image_size"] = np.array([img.shape[-1], img.shape[-2]])
122
+
123
+ data["image"] = img
124
+ return data
125
+
126
+ def resize(self, img: torch.Tensor, size: Tuple[int, int], interpolation: str) -> torch.Tensor:
127
+ """Resize an image using the specified backend."""
128
+ if self.conf.resize_backend == "kornia":
129
+ return kornia.geometry.transform.resize(
130
+ img,
131
+ size,
132
+ side=self.conf.side,
133
+ antialias=self.conf.antialias,
134
+ align_corners=self.conf.align_corners,
135
+ interpolation=interpolation,
136
+ )
137
+ elif self.conf.resize_backend == "torchvision":
138
+ return torchvision.transforms.Resize(size, antialias=self.conf.antialias)(img)
139
+ else:
140
+ raise ValueError(f"{self.conf.resize_backend} not implemented.")
141
+
142
+ def load_image(self, image_path: Path) -> dict:
143
+ """Load an image from a path and preprocess it."""
144
+ return self(load_image(image_path))
145
+
146
+ def get_new_image_size(self, h: int, w: int) -> Tuple[int, int]:
147
+ """Get the new image size after resizing."""
148
+ side = self.conf.side
149
+ if isinstance(self.conf.resize, collections.Iterable):
150
+ assert len(self.conf.resize) == 2
151
+ return tuple(self.conf.resize)
152
+ side_size = self.conf.resize
153
+ aspect_ratio = w / h
154
+ if side not in ("short", "long", "vert", "horz"):
155
+ raise ValueError(
156
+ f"side can be one of 'short', 'long', 'vert', and 'horz'. Got '{side}'"
157
+ )
158
+ return (
159
+ (side_size, int(side_size * aspect_ratio))
160
+ if side == "vert" or (side != "horz" and (side == "short") ^ (aspect_ratio < 1.0))
161
+ else (int(side_size / aspect_ratio), side_size)
162
+ )
163
+
164
+
165
+ def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
166
+ """Normalize the image tensor and reorder the dimensions."""
167
+ if image.ndim == 3:
168
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
169
+ elif image.ndim == 2:
170
+ image = image[None] # add channel axis
171
+ else:
172
+ raise ValueError(f"Not an image: {image.shape}")
173
+ return torch.tensor(image / 255.0, dtype=torch.float)
174
+
175
+
176
+ def torch_image_to_numpy(image: torch.Tensor) -> np.ndarray:
177
+ """Normalize and reorder the dimensions of an image tensor."""
178
+ if image.ndim == 3:
179
+ image = image.permute((1, 2, 0)) # CxHxW to HxWxC
180
+ elif image.ndim == 2:
181
+ image = image[None] # add channel axis
182
+ else:
183
+ raise ValueError(f"Not an image: {image.shape}")
184
+ return (image.cpu().detach().numpy() * 255).astype(np.uint8)
185
+
186
+
187
+ def read_image(path: Path, grayscale: bool = False) -> np.ndarray:
188
+ """Read an image from path as RGB or grayscale."""
189
+ if not Path(path).exists():
190
+ raise FileNotFoundError(f"No image at path {path}.")
191
+ mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
192
+ image = cv2.imread(str(path), mode)
193
+ if image is None:
194
+ raise IOError(f"Could not read image at {path}.")
195
+ if not grayscale:
196
+ image = image[..., ::-1]
197
+ return image
198
+
199
+
200
+ def write_image(img: torch.Tensor, path: Path):
201
+ """Write an image tensor to a file."""
202
+ img = torch_image_to_numpy(img) if isinstance(img, torch.Tensor) else img
203
+ cv2.imwrite(str(path), img[..., ::-1])
204
+
205
+
206
+ def load_image(path: Path, grayscale: bool = False, return_tensor: bool = True) -> torch.Tensor:
207
+ """Load an image from a path and return as a tensor."""
208
+ image = read_image(path, grayscale=grayscale)
209
+ if return_tensor:
210
+ return numpy_image_to_torch(image)
211
+
212
+ assert image.ndim in [2, 3], f"Not an image: {image.shape}"
213
+ image = image[None] if image.ndim == 2 else image
214
+ return torch.tensor(image.copy(), dtype=torch.uint8)
215
+
216
+
217
+ def skew_symmetric(v: torch.Tensor) -> torch.Tensor:
218
+ """Create a skew-symmetric matrix from a (batched) vector of size (..., 3).
219
+
220
+ Args:
221
+ (torch.Tensor): Vector of size (..., 3).
222
+
223
+ Returns:
224
+ (torch.Tensor): Skew-symmetric matrix of size (..., 3, 3).
225
+ """
226
+ z = torch.zeros_like(v[..., 0])
227
+ return torch.stack(
228
+ [z, -v[..., 2], v[..., 1], v[..., 2], z, -v[..., 0], -v[..., 1], v[..., 0], z], dim=-1
229
+ ).reshape(v.shape[:-1] + (3, 3))
230
+
231
+
232
+ def rad2rotmat(
233
+ roll: torch.Tensor, pitch: torch.Tensor, yaw: Optional[torch.Tensor] = None
234
+ ) -> torch.Tensor:
235
+ """Convert (batched) roll, pitch, yaw angles (in radians) to rotation matrix.
236
+
237
+ Args:
238
+ roll (torch.Tensor): Roll angle in radians.
239
+ pitch (torch.Tensor): Pitch angle in radians.
240
+ yaw (torch.Tensor, optional): Yaw angle in radians. Defaults to None.
241
+
242
+ Returns:
243
+ torch.Tensor: Rotation matrix of shape (..., 3, 3).
244
+ """
245
+ if yaw is None:
246
+ yaw = roll.new_zeros(roll.shape)
247
+
248
+ Rx = pitch.new_zeros(pitch.shape + (3, 3))
249
+ Rx[..., 0, 0] = 1
250
+ Rx[..., 1, 1] = torch.cos(pitch)
251
+ Rx[..., 1, 2] = torch.sin(pitch)
252
+ Rx[..., 2, 1] = -torch.sin(pitch)
253
+ Rx[..., 2, 2] = torch.cos(pitch)
254
+
255
+ Ry = yaw.new_zeros(yaw.shape + (3, 3))
256
+ Ry[..., 0, 0] = torch.cos(yaw)
257
+ Ry[..., 0, 2] = -torch.sin(yaw)
258
+ Ry[..., 1, 1] = 1
259
+ Ry[..., 2, 0] = torch.sin(yaw)
260
+ Ry[..., 2, 2] = torch.cos(yaw)
261
+
262
+ Rz = roll.new_zeros(roll.shape + (3, 3))
263
+ Rz[..., 0, 0] = torch.cos(roll)
264
+ Rz[..., 0, 1] = torch.sin(roll)
265
+ Rz[..., 1, 0] = -torch.sin(roll)
266
+ Rz[..., 1, 1] = torch.cos(roll)
267
+ Rz[..., 2, 2] = 1
268
+
269
+ return Rz @ Rx @ Ry
270
+
271
+
272
+ def fov2focal(fov: torch.Tensor, size: torch.Tensor) -> torch.Tensor:
273
+ """Compute focal length from (vertical/horizontal) field of view."""
274
+ return size / 2 / torch.tan(fov / 2)
275
+
276
+
277
+ def focal2fov(focal: torch.Tensor, size: torch.Tensor) -> torch.Tensor:
278
+ """Compute (vertical/horizontal) field of view from focal length."""
279
+ return 2 * torch.arctan(size / (2 * focal))
280
+
281
+
282
+ def pitch2rho(pitch: torch.Tensor, f: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
283
+ """Compute the distance from principal point to the horizon."""
284
+ return torch.tan(pitch) * f / h
285
+
286
+
287
+ def rho2pitch(rho: torch.Tensor, f: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
288
+ """Compute the pitch angle from the distance to the horizon."""
289
+ return torch.atan(rho * h / f)
290
+
291
+
292
+ def rad2deg(rad: torch.Tensor) -> torch.Tensor:
293
+ """Convert radians to degrees."""
294
+ return rad / torch.pi * 180
295
+
296
+
297
+ def deg2rad(deg: torch.Tensor) -> torch.Tensor:
298
+ """Convert degrees to radians."""
299
+ return deg / 180 * torch.pi
300
+
301
+
302
+ def get_device() -> str:
303
+ """Get the device (cpu, cuda, mps) available."""
304
+ device = "cpu"
305
+ if torch.cuda.is_available():
306
+ device = "cuda"
307
+ elif torch.backends.mps.is_available():
308
+ device = "mps"
309
+ return device
310
+
311
+
312
+ def print_calibration(results: Dict[str, torch.Tensor]) -> None:
313
+ """Print the calibration results."""
314
+ camera, gravity = results["camera"], results["gravity"]
315
+ vfov = rad2deg(camera.vfov)
316
+ roll, pitch = rad2deg(gravity.rp).unbind(-1)
317
+
318
+ print("\nEstimated parameters (Pred):")
319
+ print(f"Roll: {roll.item():.1f}° (± {rad2deg(results['roll_uncertainty']).item():.1f})°")
320
+ print(f"Pitch: {pitch.item():.1f}° (± {rad2deg(results['pitch_uncertainty']).item():.1f})°")
321
+ print(f"vFoV: {vfov.item():.1f}° (± {rad2deg(results['vfov_uncertainty']).item():.1f})°")
322
+ print(f"Focal: {camera.f[0, 1].item():.1f} px (± {results['focal_uncertainty'].item():.1f} px)")
323
+
324
+ if hasattr(camera, "k1"):
325
+ print(f"K1: {camera.k1.item():.1f}")
geocalib/viz2d.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """2D visualization primitives based on Matplotlib.
2
+
3
+ 1) Plot images with `plot_images`.
4
+ 2) Call functions to plot heatmaps, vector fields, and horizon lines.
5
+ 3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`.
6
+ """
7
+
8
+ import matplotlib.patheffects as path_effects
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import torch
12
+
13
+ from geocalib.perspective_fields import get_perspective_field
14
+ from geocalib.utils import rad2deg
15
+
16
+ # mypy: ignore-errors
17
+
18
+
19
+ def plot_images(imgs, titles=None, cmaps="gray", dpi=200, pad=0.5, adaptive=True):
20
+ """Plot a list of images.
21
+
22
+ Args:
23
+ imgs (List[np.ndarray]): List of images to plot.
24
+ titles (List[str], optional): Titles. Defaults to None.
25
+ cmaps (str, optional): Colormaps. Defaults to "gray".
26
+ dpi (int, optional): Dots per inch. Defaults to 200.
27
+ pad (float, optional): Padding. Defaults to 0.5.
28
+ adaptive (bool, optional): Whether to adapt the aspect ratio. Defaults to True.
29
+
30
+ Returns:
31
+ plt.Figure: Figure of the images.
32
+ """
33
+ n = len(imgs)
34
+ if not isinstance(cmaps, (list, tuple)):
35
+ cmaps = [cmaps] * n
36
+
37
+ ratios = [i.shape[1] / i.shape[0] for i in imgs] if adaptive else [4 / 3] * n
38
+ figsize = [sum(ratios) * 4.5, 4.5]
39
+ fig, axs = plt.subplots(1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios})
40
+ if n == 1:
41
+ axs = [axs]
42
+ for i, (img, ax) in enumerate(zip(imgs, axs)):
43
+ ax.imshow(img, cmap=plt.get_cmap(cmaps[i]))
44
+ ax.set_axis_off()
45
+ if titles:
46
+ ax.set_title(titles[i])
47
+ fig.tight_layout(pad=pad)
48
+
49
+ return fig
50
+
51
+
52
+ def plot_image_grid(
53
+ imgs,
54
+ titles=None,
55
+ cmaps="gray",
56
+ dpi=100,
57
+ pad=0.5,
58
+ fig=None,
59
+ adaptive=True,
60
+ figs=3.0,
61
+ return_fig=False,
62
+ set_lim=False,
63
+ ) -> plt.Figure:
64
+ """Plot a grid of images.
65
+
66
+ Args:
67
+ imgs (List[np.ndarray]): List of images to plot.
68
+ titles (List[str], optional): Titles. Defaults to None.
69
+ cmaps (str, optional): Colormaps. Defaults to "gray".
70
+ dpi (int, optional): Dots per inch. Defaults to 100.
71
+ pad (float, optional): Padding. Defaults to 0.5.
72
+ fig (_type_, optional): Figure to plot on. Defaults to None.
73
+ adaptive (bool, optional): Whether to adapt the aspect ratio. Defaults to True.
74
+ figs (float, optional): Figure size. Defaults to 3.0.
75
+ return_fig (bool, optional): Whether to return the figure. Defaults to False.
76
+ set_lim (bool, optional): Whether to set the limits. Defaults to False.
77
+
78
+ Returns:
79
+ plt.Figure: Figure and axes or just axes.
80
+ """
81
+ nr, n = len(imgs), len(imgs[0])
82
+ if not isinstance(cmaps, (list, tuple)):
83
+ cmaps = [cmaps] * n
84
+
85
+ if adaptive:
86
+ ratios = [i.shape[1] / i.shape[0] for i in imgs[0]] # W / H
87
+ else:
88
+ ratios = [4 / 3] * n
89
+
90
+ figsize = [sum(ratios) * figs, nr * figs]
91
+ if fig is None:
92
+ fig, axs = plt.subplots(
93
+ nr, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
94
+ )
95
+ else:
96
+ axs = fig.subplots(nr, n, gridspec_kw={"width_ratios": ratios})
97
+ fig.figure.set_size_inches(figsize)
98
+
99
+ if nr == 1 and n == 1:
100
+ axs = [[axs]]
101
+ elif n == 1:
102
+ axs = axs[:, None]
103
+ elif nr == 1:
104
+ axs = [axs]
105
+
106
+ for j in range(nr):
107
+ for i in range(n):
108
+ ax = axs[j][i]
109
+ ax.imshow(imgs[j][i], cmap=plt.get_cmap(cmaps[i]))
110
+ ax.set_axis_off()
111
+ if set_lim:
112
+ ax.set_xlim([0, imgs[j][i].shape[1]])
113
+ ax.set_ylim([imgs[j][i].shape[0], 0])
114
+ if titles:
115
+ ax.set_title(titles[j][i])
116
+ if isinstance(fig, plt.Figure):
117
+ fig.tight_layout(pad=pad)
118
+ return (fig, axs) if return_fig else axs
119
+
120
+
121
+ def add_text(
122
+ idx,
123
+ text,
124
+ pos=(0.01, 0.99),
125
+ fs=15,
126
+ color="w",
127
+ lcolor="k",
128
+ lwidth=4,
129
+ ha="left",
130
+ va="top",
131
+ axes=None,
132
+ **kwargs,
133
+ ):
134
+ """Add text to a plot.
135
+
136
+ Args:
137
+ idx (int): Index of the axes.
138
+ text (str): Text to add.
139
+ pos (tuple, optional): Text position. Defaults to (0.01, 0.99).
140
+ fs (int, optional): Font size. Defaults to 15.
141
+ color (str, optional): Text color. Defaults to "w".
142
+ lcolor (str, optional): Line color. Defaults to "k".
143
+ lwidth (int, optional): Line width. Defaults to 4.
144
+ ha (str, optional): Horizontal alignment. Defaults to "left".
145
+ va (str, optional): Vertical alignment. Defaults to "top".
146
+ axes (List[plt.Axes], optional): Axes to put text on. Defaults to None.
147
+
148
+ Returns:
149
+ plt.Text: Text object.
150
+ """
151
+ if axes is None:
152
+ axes = plt.gcf().axes
153
+
154
+ ax = axes[idx]
155
+
156
+ t = ax.text(
157
+ *pos,
158
+ text,
159
+ fontsize=fs,
160
+ ha=ha,
161
+ va=va,
162
+ color=color,
163
+ transform=ax.transAxes,
164
+ zorder=5,
165
+ **kwargs,
166
+ )
167
+ if lcolor is not None:
168
+ t.set_path_effects(
169
+ [
170
+ path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
171
+ path_effects.Normal(),
172
+ ]
173
+ )
174
+ return t
175
+
176
+
177
+ def plot_heatmaps(
178
+ heatmaps,
179
+ vmin=-1e-6, # include negative zero
180
+ vmax=None,
181
+ cmap="Spectral",
182
+ a=0.5,
183
+ axes=None,
184
+ contours_every=None,
185
+ contour_style="solid",
186
+ colorbar=False,
187
+ ):
188
+ """Plot heatmaps with optional contours.
189
+
190
+ To plot latitude field, set vmin=-90, vmax=90 and contours_every=15.
191
+
192
+ Args:
193
+ heatmaps (List[np.ndarray | torch.Tensor]): List of 2D heatmaps.
194
+ vmin (float, optional): Min Value. Defaults to -1e-6.
195
+ vmax (float, optional): Max Value. Defaults to None.
196
+ cmap (str, optional): Colormap. Defaults to "Spectral".
197
+ a (float, optional): Alpha value. Defaults to 0.5.
198
+ axes (List[plt.Axes], optional): Axes to plot on. Defaults to None.
199
+ contours_every (int, optional): If not none, will draw contours. Defaults to None.
200
+ contour_style (str, optional): Style of the contours. Defaults to "solid".
201
+ colorbar (bool, optional): Whether to show colorbar. Defaults to False.
202
+
203
+ Returns:
204
+ List[plt.Artist]: List of artists.
205
+ """
206
+ if axes is None:
207
+ axes = plt.gcf().axes
208
+ artists = []
209
+
210
+ for i in range(len(axes)):
211
+ a_ = a if isinstance(a, float) else a[i]
212
+
213
+ if isinstance(heatmaps[i], torch.Tensor):
214
+ heatmaps[i] = heatmaps[i].cpu().numpy()
215
+
216
+ alpha = a_
217
+ # Plot the heatmap
218
+ art = axes[i].imshow(
219
+ heatmaps[i],
220
+ alpha=alpha,
221
+ vmin=vmin,
222
+ vmax=vmax,
223
+ cmap=cmap,
224
+ )
225
+ if colorbar:
226
+ cmax = vmax or np.percentile(heatmaps[i], 99)
227
+ art.set_clim(vmin, cmax)
228
+ cbar = plt.colorbar(art, ax=axes[i])
229
+ artists.append(cbar)
230
+
231
+ artists.append(art)
232
+
233
+ if contours_every is not None:
234
+ # Add contour lines to the heatmap
235
+ contour_data = np.arange(vmin, vmax + contours_every, contours_every)
236
+
237
+ # Get the colormap colors for contour lines
238
+ contour_colors = [
239
+ plt.colormaps.get_cmap(cmap)(plt.Normalize(vmin=vmin, vmax=vmax)(level))
240
+ for level in contour_data
241
+ ]
242
+ contours = axes[i].contour(
243
+ heatmaps[i],
244
+ levels=contour_data,
245
+ linewidths=2,
246
+ colors=contour_colors,
247
+ linestyles=contour_style,
248
+ )
249
+
250
+ contours.set_clim(vmin, vmax)
251
+
252
+ fmt = {
253
+ level: f"{label}°"
254
+ for level, label in zip(contour_data, contour_data.astype(int).astype(str))
255
+ }
256
+ t = axes[i].clabel(contours, inline=True, fmt=fmt, fontsize=16, colors="white")
257
+
258
+ for label in t:
259
+ label.set_path_effects(
260
+ [
261
+ path_effects.Stroke(linewidth=1, foreground="k"),
262
+ path_effects.Normal(),
263
+ ]
264
+ )
265
+ artists.append(contours)
266
+
267
+ return artists
268
+
269
+
270
+ def plot_horizon_lines(
271
+ cameras, gravities, line_colors="orange", lw=2, styles="solid", alpha=1.0, ax=None
272
+ ):
273
+ """Plot horizon lines on the perspective field.
274
+
275
+ Args:
276
+ cameras (List[Camera]): List of cameras.
277
+ gravities (List[Gravity]): Gravities.
278
+ line_colors (str, optional): Line Colors. Defaults to "orange".
279
+ lw (int, optional): Line width. Defaults to 2.
280
+ styles (str, optional): Line styles. Defaults to "solid".
281
+ alpha (float, optional): Alphas. Defaults to 1.0.
282
+ ax (List[plt.Axes], optional): Axes to draw horizon line on. Defaults to None.
283
+ """
284
+ if not isinstance(line_colors, list):
285
+ line_colors = [line_colors] * len(cameras)
286
+
287
+ if not isinstance(styles, list):
288
+ styles = [styles] * len(cameras)
289
+
290
+ fig = plt.gcf()
291
+ ax = fig.gca() if ax is None else ax
292
+
293
+ if isinstance(ax, plt.Axes):
294
+ ax = [ax] * len(cameras)
295
+
296
+ assert len(ax) == len(cameras), f"{len(ax)}, {len(cameras)}"
297
+
298
+ for i in range(len(cameras)):
299
+ _, lat = get_perspective_field(cameras[i], gravities[i])
300
+ # horizon line is zero level of the latitude field
301
+ lat = lat[0, 0].cpu().numpy()
302
+ contours = ax[i].contour(lat, levels=[0], linewidths=lw, colors=line_colors[i])
303
+ for contour_line in contours.collections:
304
+ contour_line.set_linestyle(styles[i])
305
+
306
+
307
+ def plot_vector_fields(
308
+ vector_fields,
309
+ cmap="lime",
310
+ subsample=15,
311
+ scale=None,
312
+ lw=None,
313
+ alphas=0.8,
314
+ axes=None,
315
+ ):
316
+ """Plot vector fields.
317
+
318
+ Args:
319
+ vector_fields (List[torch.Tensor]): List of vector fields of shape (2, H, W).
320
+ cmap (str, optional): Color of the vectors. Defaults to "lime".
321
+ subsample (int, optional): Subsample the vector field. Defaults to 15.
322
+ scale (float, optional): Scale of the vectors. Defaults to None.
323
+ lw (float, optional): Line width of the vectors. Defaults to None.
324
+ alphas (float | np.ndarray, optional): Alpha per vector or global. Defaults to 0.8.
325
+ axes (List[plt.Axes], optional): List of axes to draw on. Defaults to None.
326
+
327
+ Returns:
328
+ List[plt.Artist]: List of artists.
329
+ """
330
+ if axes is None:
331
+ axes = plt.gcf().axes
332
+
333
+ vector_fields = [v.cpu().numpy() if isinstance(v, torch.Tensor) else v for v in vector_fields]
334
+
335
+ artists = []
336
+
337
+ H, W = vector_fields[0].shape[-2:]
338
+ if scale is None:
339
+ scale = subsample / min(H, W)
340
+
341
+ if lw is None:
342
+ lw = 0.1 / subsample
343
+
344
+ if alphas is None:
345
+ alphas = np.ones_like(vector_fields[0][0])
346
+ alphas = np.stack([alphas] * len(vector_fields), 0)
347
+ elif isinstance(alphas, float):
348
+ alphas = np.ones_like(vector_fields[0][0]) * alphas
349
+ alphas = np.stack([alphas] * len(vector_fields), 0)
350
+ else:
351
+ alphas = np.array(alphas)
352
+
353
+ subsample = min(W, H) // subsample
354
+ offset_x = ((W % subsample) + subsample) // 2
355
+
356
+ samples_x = np.arange(offset_x, W, subsample)
357
+ samples_y = np.arange(int(subsample * 0.9), H, subsample)
358
+
359
+ x_grid, y_grid = np.meshgrid(samples_x, samples_y)
360
+
361
+ for i in range(len(axes)):
362
+ # vector field of shape (2, H, W) with vectors of norm == 1
363
+ vector_field = vector_fields[i]
364
+
365
+ a = alphas[i][samples_y][:, samples_x]
366
+ x, y = vector_field[:, samples_y][:, :, samples_x]
367
+
368
+ c = cmap
369
+ if not isinstance(cmap, str):
370
+ c = cmap[i][samples_y][:, samples_x].reshape(-1, 3)
371
+
372
+ s = scale * min(H, W)
373
+ arrows = axes[i].quiver(
374
+ x_grid,
375
+ y_grid,
376
+ x,
377
+ y,
378
+ scale=s,
379
+ scale_units="width" if H > W else "height",
380
+ units="width" if H > W else "height",
381
+ alpha=a,
382
+ color=c,
383
+ angles="xy",
384
+ antialiased=True,
385
+ width=lw,
386
+ headaxislength=3.5,
387
+ zorder=5,
388
+ )
389
+
390
+ artists.append(arrows)
391
+
392
+ return artists
393
+
394
+
395
+ def plot_latitudes(
396
+ latitude,
397
+ is_radians=True,
398
+ vmin=-90,
399
+ vmax=90,
400
+ cmap="seismic",
401
+ contours_every=15,
402
+ alpha=0.4,
403
+ axes=None,
404
+ **kwargs,
405
+ ):
406
+ """Plot latitudes.
407
+
408
+ Args:
409
+ latitude (List[torch.Tensor]): List of latitudes.
410
+ is_radians (bool, optional): Whether the latitudes are in radians. Defaults to True.
411
+ vmin (int, optional): Min value to clip to. Defaults to -90.
412
+ vmax (int, optional): Max value to clip to. Defaults to 90.
413
+ cmap (str, optional): Colormap. Defaults to "seismic".
414
+ contours_every (int, optional): Contours every. Defaults to 15.
415
+ alpha (float, optional): Alpha value. Defaults to 0.4.
416
+ axes (List[plt.Axes], optional): Axes to plot on. Defaults to None.
417
+
418
+ Returns:
419
+ List[plt.Artist]: List of artists.
420
+ """
421
+ if axes is None:
422
+ axes = plt.gcf().axes
423
+
424
+ assert len(axes) == len(latitude), f"{len(axes)}, {len(latitude)}"
425
+ lat = [rad2deg(lat) for lat in latitude] if is_radians else latitude
426
+ return plot_heatmaps(
427
+ lat,
428
+ vmin=vmin,
429
+ vmax=vmax,
430
+ cmap=cmap,
431
+ a=alpha,
432
+ axes=axes,
433
+ contours_every=contours_every,
434
+ **kwargs,
435
+ )
436
+
437
+
438
+ def plot_perspective_fields(cameras, gravities, axes=None, **kwargs):
439
+ """Plot perspective fields.
440
+
441
+ Args:
442
+ cameras (List[Camera]): List of cameras.
443
+ gravities (List[Gravity]): List of gravities.
444
+ axes (List[plt.Axes], optional): Axes to plot on. Defaults to None.
445
+
446
+ Returns:
447
+ List[plt.Artist]: List of artists.
448
+ """
449
+ if axes is None:
450
+ axes = plt.gcf().axes
451
+
452
+ assert len(axes) == len(cameras), f"{len(axes)}, {len(cameras)}"
453
+
454
+ artists = []
455
+ for i in range(len(axes)):
456
+ up, lat = get_perspective_field(cameras[i], gravities[i])
457
+ artists += plot_vector_fields([up[0]], axes=[axes[i]], **kwargs)
458
+ artists += plot_latitudes([lat[0, 0]], axes=[axes[i]], **kwargs)
459
+
460
+ return artists
461
+
462
+
463
+ def plot_confidences(
464
+ confidence,
465
+ as_log=True,
466
+ vmin=-4,
467
+ vmax=0,
468
+ cmap="turbo",
469
+ alpha=0.4,
470
+ axes=None,
471
+ **kwargs,
472
+ ):
473
+ """Plot confidences.
474
+
475
+ Args:
476
+ confidence (List[torch.Tensor]): Confidence maps.
477
+ as_log (bool, optional): Whether to plot in log scale. Defaults to True.
478
+ vmin (int, optional): Min value to clip to. Defaults to -4.
479
+ vmax (int, optional): Max value to clip to. Defaults to 0.
480
+ cmap (str, optional): Colormap. Defaults to "turbo".
481
+ alpha (float, optional): Alpha value. Defaults to 0.4.
482
+ axes (List[plt.Axes], optional): Axes to plot on. Defaults to None.
483
+
484
+ Returns:
485
+ List[plt.Artist]: List of artists.
486
+ """
487
+ if axes is None:
488
+ axes = plt.gcf().axes
489
+
490
+ assert len(axes) == len(confidence), f"{len(axes)}, {len(confidence)}"
491
+
492
+ if as_log:
493
+ confidence = [torch.log10(c.clip(1e-5)).clip(vmin, vmax) for c in confidence]
494
+
495
+ # normalize to [0, 1]
496
+ confidence = [(c - c.min()) / (c.max() - c.min()) for c in confidence]
497
+ return plot_heatmaps(confidence, vmin=0, vmax=1, cmap=cmap, a=alpha, axes=axes, **kwargs)
498
+
499
+
500
+ def save_plot(path, **kw):
501
+ """Save the current figure without any white margin."""
502
+ plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)
gradio_app.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio app for GeoCalib inference."""
2
+
3
+ from copy import deepcopy
4
+ from time import time
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import spaces
9
+ import torch
10
+
11
+ from geocalib import viz2d
12
+ from geocalib.camera import camera_models
13
+ from geocalib.extractor import GeoCalib
14
+ from geocalib.perspective_fields import get_perspective_field
15
+ from geocalib.utils import rad2deg
16
+
17
+ # flake8: noqa
18
+ # mypy: ignore-errors
19
+
20
+ description = """
21
+ <p align="center">
22
+ <h1 align="center"><ins>GeoCalib</ins> 📸<br>Single-image Calibration with Geometric Optimization</h1>
23
+ <p align="center">
24
+ <a href="https://www.linkedin.com/in/alexander-veicht/">Alexander Veicht</a>
25
+ ·
26
+ <a href="https://psarlin.com/">Paul-Edouard&nbsp;Sarlin</a>
27
+ ·
28
+ <a href="https://www.linkedin.com/in/philipplindenberger/">Philipp Lindenberger</a>
29
+ ·
30
+ <a href="https://www.microsoft.com/en-us/research/people/mapoll/">Marc&nbsp;Pollefeys</a>
31
+ </p>
32
+ <h2 align="center">
33
+ <p>ECCV 2024</p>
34
+ <a href="" align="center">Paper</a> | <!--TODO: update link-->
35
+ <a href="https://github.com/cvg/GeoCalib" align="center">Code</a> |
36
+ <a href="https://colab.research.google.com/drive/1oMzgPGppAPAIQxe-s7SRd_q8r7dVfnqo#scrollTo=etdzQZQzoo-K" align="center">Colab</a>
37
+ </h2>
38
+ </p>
39
+
40
+ ## Getting Started
41
+ GeoCalib accurately estimates the camera intrinsics and gravity direction from a single image by
42
+ combining geometric optimization with deep learning.
43
+
44
+ To get started, upload an image or select one of the examples below.
45
+ You can choose between different camera models and visualize the calibration results.
46
+
47
+ """
48
+
49
+ example_images = [
50
+ ["assets/pinhole-church.jpg"],
51
+ ["assets/pinhole-garden.jpg"],
52
+ ["assets/fisheye-skyline.jpg"],
53
+ ["assets/fisheye-dog-pool.jpg"],
54
+ ]
55
+
56
+ device = "cuda" if torch.cuda.is_available() else "cpu"
57
+ model = GeoCalib().to(device)
58
+
59
+
60
+ def format_output(results):
61
+ camera, gravity = results["camera"], results["gravity"]
62
+ vfov = rad2deg(camera.vfov)
63
+ roll, pitch = rad2deg(gravity.rp).unbind(-1)
64
+
65
+ txt = "Estimated parameters:\n"
66
+ txt += f"Roll: {roll.item():.2f}° (± {rad2deg(results['roll_uncertainty']).item():.2f})°\n"
67
+ txt += f"Pitch: {pitch.item():.2f}° (± {rad2deg(results['pitch_uncertainty']).item():.2f})°\n"
68
+ txt += f"vFoV: {vfov.item():.2f}° (± {rad2deg(results['vfov_uncertainty']).item():.2f})°\n"
69
+ txt += (
70
+ f"Focal: {camera.f[0, 1].item():.2f} px (± {results['focal_uncertainty'].item():.2f} px)\n"
71
+ )
72
+ if hasattr(camera, "k1"):
73
+ txt += f"K1: {camera.k1[0].item():.2f}\n"
74
+ return txt
75
+
76
+
77
+ @spaces.GPU(duration=10)
78
+ def inference(img, camera_model):
79
+ out = model.calibrate(img.to(device), camera_model=camera_model)
80
+ save_keys = ["camera", "gravity"] + [f"{k}_uncertainty" for k in ["roll", "pitch", "vfov"]]
81
+ res = {k: v.cpu() for k, v in out.items() if k in save_keys}
82
+ # not converting to numpy results in gpu abort
83
+ res["up_confidence"] = out["up_confidence"].cpu().numpy()
84
+ res["latitude_confidence"] = out["latitude_confidence"].cpu().numpy()
85
+ return res
86
+
87
+
88
+ def process_results(
89
+ image_path,
90
+ camera_model,
91
+ plot_up,
92
+ plot_up_confidence,
93
+ plot_latitude,
94
+ plot_latitude_confidence,
95
+ plot_undistort,
96
+ ):
97
+ """Process the image and return the calibration results."""
98
+
99
+ if image_path is None:
100
+ raise gr.Error("Please upload an image first.")
101
+
102
+ img = model.load_image(image_path)
103
+ print("Running inference...")
104
+ start = time()
105
+ inference_result = inference(img, camera_model)
106
+ print(f"Done ({time() - start:.2f}s)")
107
+ inference_result["image"] = img.cpu()
108
+
109
+ if inference_result is None:
110
+ return ("", np.ones((128, 256, 3)), None)
111
+
112
+ plot_img = update_plot(
113
+ inference_result,
114
+ plot_up,
115
+ plot_up_confidence,
116
+ plot_latitude,
117
+ plot_latitude_confidence,
118
+ plot_undistort,
119
+ )
120
+
121
+ return format_output(inference_result), plot_img, inference_result
122
+
123
+
124
+ def update_plot(
125
+ inference_result,
126
+ plot_up,
127
+ plot_up_confidence,
128
+ plot_latitude,
129
+ plot_latitude_confidence,
130
+ plot_undistort,
131
+ ):
132
+ """Update the plot based on the selected options."""
133
+ if inference_result is None:
134
+ gr.Error("Please calibrate an image first.")
135
+ return np.ones((128, 256, 3))
136
+
137
+ camera, gravity = inference_result["camera"], inference_result["gravity"]
138
+ img = inference_result["image"].permute(1, 2, 0).numpy()
139
+
140
+ if plot_undistort:
141
+ if not hasattr(camera, "k1"):
142
+ return img
143
+
144
+ return camera.undistort_image(inference_result["image"][None])[0].permute(1, 2, 0).numpy()
145
+
146
+ up, lat = get_perspective_field(camera, gravity)
147
+
148
+ fig = viz2d.plot_images([img], pad=0)
149
+ ax = fig.get_axes()
150
+
151
+ if plot_up:
152
+ viz2d.plot_vector_fields([up[0]], axes=[ax[0]])
153
+
154
+ if plot_latitude:
155
+ viz2d.plot_latitudes([lat[0, 0]], axes=[ax[0]])
156
+
157
+ if plot_up_confidence:
158
+ viz2d.plot_confidences([inference_result["up_confidence"][0]], axes=[ax[0]])
159
+
160
+ if plot_latitude_confidence:
161
+ viz2d.plot_confidences([inference_result["latitude_confidence"][0]], axes=[ax[0]])
162
+
163
+ fig.canvas.draw()
164
+ img = np.array(fig.canvas.renderer.buffer_rgba())
165
+
166
+ return img
167
+
168
+
169
+ # Create the Gradio interface
170
+ with gr.Blocks() as demo:
171
+ gr.Markdown(description)
172
+ with gr.Row():
173
+ with gr.Column():
174
+ gr.Markdown("""## Input Image""")
175
+ image_path = gr.Image(label="Upload image to calibrate", type="filepath")
176
+ choice_input = gr.Dropdown(
177
+ choices=list(camera_models.keys()), label="Choose a camera model.", value="pinhole"
178
+ )
179
+ submit_btn = gr.Button("Calibrate 📸")
180
+ gr.Examples(examples=example_images, inputs=[image_path, choice_input])
181
+
182
+ with gr.Column():
183
+ gr.Markdown("""## Results""")
184
+ image_output = gr.Image(label="Calibration Results")
185
+ gr.Markdown("### Plot Options")
186
+ plot_undistort = gr.Checkbox(
187
+ label="undistort",
188
+ value=False,
189
+ info="Undistorted image "
190
+ + "(this is only available for models with distortion "
191
+ + "parameters and will overwrite other options).",
192
+ )
193
+
194
+ with gr.Row():
195
+ plot_up = gr.Checkbox(label="up-vectors", value=True)
196
+ plot_up_confidence = gr.Checkbox(label="up confidence", value=False)
197
+ plot_latitude = gr.Checkbox(label="latitude", value=True)
198
+ plot_latitude_confidence = gr.Checkbox(label="latitude confidence", value=False)
199
+
200
+ gr.Markdown("### Calibration Results")
201
+ text_output = gr.Textbox(label="Estimated parameters", type="text", lines=5)
202
+
203
+ # Define the action when the button is clicked
204
+ inference_state = gr.State()
205
+ plot_inputs = [
206
+ inference_state,
207
+ plot_up,
208
+ plot_up_confidence,
209
+ plot_latitude,
210
+ plot_latitude_confidence,
211
+ plot_undistort,
212
+ ]
213
+ submit_btn.click(
214
+ fn=process_results,
215
+ inputs=[image_path, choice_input] + plot_inputs[1:],
216
+ outputs=[text_output, image_output, inference_state],
217
+ )
218
+
219
+ # Define the action when the plot checkboxes are clicked
220
+ plot_up.change(fn=update_plot, inputs=plot_inputs, outputs=image_output)
221
+ plot_up_confidence.change(fn=update_plot, inputs=plot_inputs, outputs=image_output)
222
+ plot_latitude.change(fn=update_plot, inputs=plot_inputs, outputs=image_output)
223
+ plot_latitude_confidence.change(fn=update_plot, inputs=plot_inputs, outputs=image_output)
224
+ plot_undistort.change(fn=update_plot, inputs=plot_inputs, outputs=image_output)
225
+
226
+
227
+ # Launch the app
228
+ demo.launch()
hubconf.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Entrypoint for torch hub."""
2
+
3
+ dependencies = ["torch", "torchvision", "opencv-python", "kornia", "matplotlib"]
4
+
5
+ from geocalib import GeoCalib
6
+
7
+
8
+ def model(*args, **kwargs):
9
+ """Pre-trained Geocalib model.
10
+
11
+ Args:
12
+ weights (str): trained variant, "pinhole" (default) or "distorted".
13
+ """
14
+ return GeoCalib(*args, **kwargs)
pyproject.toml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "geocalib"
7
+ version = "1.0"
8
+ description = "GeoCalib Inference Package"
9
+ authors = [
10
+ { name = "Alexander Veicht" },
11
+ { name = "Paul-Edouard Sarlin" },
12
+ { name = "Philipp Lindenberger" },
13
+ ]
14
+ readme = "README.md"
15
+ requires-python = ">=3.9"
16
+ license = { file = "LICENSE" }
17
+ classifiers = [
18
+ "Programming Language :: Python :: 3",
19
+ "License :: OSI Approved :: Apache Software License",
20
+ "Operating System :: OS Independent",
21
+ ]
22
+ urls = { Repository = "https://github.com/cvg/GeoCalib" }
23
+
24
+ dynamic = ["dependencies"]
25
+
26
+ [project.optional-dependencies]
27
+ dev = ["black==23.9.1", "flake8", "isort==5.12.0"]
28
+
29
+ [tool.setuptools]
30
+ packages = ["geocalib"]
31
+
32
+ [tool.setuptools.dynamic]
33
+ dependencies = { file = ["requirements.txt"] }
34
+
35
+
36
+ [tool.black]
37
+ line-length = 100
38
+ exclude = "(venv/|docs/|third_party/)"
39
+
40
+ [tool.isort]
41
+ profile = "black"
42
+ line_length = 100
43
+ atomic = true
44
+
45
+ [tool.flake8]
46
+ max-line-length = 100
47
+ docstring-convention = "google"
48
+ ignore = ["E203", "W503", "E402"]
49
+ exclude = [".git", "__pycache__", "venv", "docs", "third_party", "scripts"]
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ opencv-python
4
+ kornia
5
+ matplotlib
siclib/LICENSE ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ Copyright 2024 ETH Zurich
179
+
180
+ Licensed under the Apache License, Version 2.0 (the "License");
181
+ you may not use this file except in compliance with the License.
182
+ You may obtain a copy of the License at
183
+
184
+ http://www.apache.org/licenses/LICENSE-2.0
185
+
186
+ Unless required by applicable law or agreed to in writing, software
187
+ distributed under the License is distributed on an "AS IS" BASIS,
188
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
189
+ See the License for the specific language governing permissions and
190
+ limitations under the License.
siclib/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ formatter = logging.Formatter(
4
+ fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
5
+ )
6
+ handler = logging.StreamHandler()
7
+ handler.setFormatter(formatter)
8
+ handler.setLevel(logging.INFO)
9
+
10
+ logger = logging.getLogger(__name__)
11
+ logger.setLevel(logging.INFO)
12
+ logger.addHandler(handler)
13
+ logger.propagate = False
14
+
15
+ __module_name__ = __name__
siclib/configs/deepcalib.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - data: openpano-radial
3
+ - train: deepcalib
4
+ - model: deepcalib
5
+ - _self_
6
+
7
+ data:
8
+ train_batch_size: 32
9
+ val_batch_size: 32
10
+ test_batch_size: 32
11
+ augmentations:
12
+ name: "deepcalib"
siclib/configs/geocalib-radial.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - data: openpano-radial
3
+ - train: geocalib
4
+ - model: geocalib
5
+ - _self_
6
+
7
+ data:
8
+ # smaller batch size since lm takes more memory
9
+ train_batch_size: 18
10
+ val_batch_size: 18
11
+ test_batch_size: 18
12
+
13
+ model:
14
+ optimizer:
15
+ camera_model: simple_radial
16
+
17
+ weights: weights/geocalib.tar
18
+
19
+ train:
20
+ lr: 1e-5 # smaller lr since we are fine-tuning
21
+ num_steps: 200_000 # adapt to see same number of samples as previous training
22
+
23
+ lr_schedule:
24
+ type: SequentialLR
25
+ on_epoch: false
26
+ options:
27
+ # adapt to see same number of samples as previous training
28
+ milestones: [5_000]
29
+ schedulers:
30
+ - type: LinearLR
31
+ options:
32
+ start_factor: 1e-3
33
+ total_iters: 5_000
34
+ - type: MultiStepLR
35
+ options:
36
+ gamma: 0.1
37
+ # adapt to see same number of samples as previous training
38
+ milestones: [110_000, 170_000]
siclib/configs/geocalib.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - data: openpano
3
+ - train: geocalib
4
+ - model: geocalib
5
+ - _self_
6
+
7
+ data:
8
+ train_batch_size: 24
9
+ val_batch_size: 24
10
+ test_batch_size: 24
siclib/configs/model/deepcalib.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ name: networks.deepcalib
2
+ bounds:
3
+ roll: [-45, 45]
4
+ # rho = torch.tan(pitch) / torch.tan(vfov / 2) / 2 -> rho in [-1/0.3526, 1/0.0872]
5
+ rho: [-2.83607487, 2.83607487]
6
+ vfov: [20, 105]
7
+ k1_hat: [-0.7, 0.7]
siclib/configs/model/geocalib.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: networks.geocalib
2
+
3
+ ll_enc:
4
+ name: encoders.low_level_encoder
5
+
6
+ backbone:
7
+ name: encoders.mscan
8
+ weights: weights/mscan_b.pth
9
+
10
+ perspective_decoder:
11
+ name: decoders.perspective_decoder
12
+
13
+ up_decoder:
14
+ name: decoders.up_decoder
15
+ loss_type: l1
16
+ use_uncertainty_loss: true
17
+ decoder:
18
+ name: decoders.light_hamburger
19
+ predict_uncertainty: true
20
+
21
+ latitude_decoder:
22
+ name: decoders.latitude_decoder
23
+ loss_type: l1
24
+ use_uncertainty_loss: true
25
+ decoder:
26
+ name: decoders.light_hamburger
27
+ predict_uncertainty: true
28
+
29
+ optimizer:
30
+ name: optimization.lm_optimizer
31
+ camera_model: pinhole
siclib/configs/train/deepcalib.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 0
2
+ num_steps: 20_000
3
+ log_every_iter: 500
4
+ eval_every_iter: 3000
5
+ test_every_epoch: 1
6
+ writer: null
7
+ lr: 1.0e-4
8
+ clip_grad: 1.0
9
+ lr_schedule:
10
+ type: null
11
+ optimizer: adam
12
+ submodules: []
13
+ median_metrics:
14
+ - roll_error
15
+ - pitch_error
16
+ - vfov_error
17
+ recall_metrics:
18
+ roll_error: [1, 5, 10]
19
+ pitch_error: [1, 5, 10]
20
+ vfov_error: [1, 5, 10]
21
+
22
+ plot: [3, "siclib.visualization.visualize_batch.make_perspective_figures"]
siclib/configs/train/geocalib.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 0
2
+ num_steps: 150_000
3
+
4
+ writer: null
5
+ log_every_iter: 500
6
+ eval_every_iter: 1000
7
+
8
+ lr: 1e-4
9
+ optimizer: adamw
10
+ clip_grad: 1.0
11
+ best_key: loss/param_total
12
+
13
+ lr_schedule:
14
+ type: SequentialLR
15
+ on_epoch: false
16
+ options:
17
+ milestones: [4_000]
18
+ schedulers:
19
+ - type: LinearLR
20
+ options:
21
+ start_factor: 1e-3
22
+ total_iters: 4_000
23
+ - type: MultiStepLR
24
+ options:
25
+ gamma: 0.1
26
+ milestones: [80_000, 130_000]
27
+
28
+ submodules: []
29
+
30
+ median_metrics:
31
+ - roll_error
32
+ - pitch_error
33
+ - gravity_error
34
+ - vfov_error
35
+ - up_angle_error
36
+ - latitude_angle_error
37
+ - up_angle_recall@1
38
+ - up_angle_recall@5
39
+ - up_angle_recall@10
40
+ - latitude_angle_recall@1
41
+ - latitude_angle_recall@5
42
+ - latitude_angle_recall@10
43
+
44
+ recall_metrics:
45
+ roll_error: [1, 3, 5, 10]
46
+ pitch_error: [1, 3, 5, 10]
47
+ gravity_error: [1, 3, 5, 10]
48
+ vfov_error: [1, 3, 5, 10]
49
+
50
+ plot: [3, "siclib.visualization.visualize_batch.make_perspective_figures"]
siclib/datasets/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+
3
+ from siclib.datasets.base_dataset import BaseDataset
4
+ from siclib.utils.tools import get_class
5
+
6
+
7
+ def get_dataset(name):
8
+ import_paths = [name, f"{__name__}.{name}"]
9
+ for path in import_paths:
10
+ try:
11
+ spec = importlib.util.find_spec(path)
12
+ except ModuleNotFoundError:
13
+ spec = None
14
+ if spec is not None:
15
+ try:
16
+ return get_class(path, BaseDataset)
17
+ except AssertionError:
18
+ mod = __import__(path, fromlist=[""])
19
+ try:
20
+ return mod.__main_dataset__
21
+ except AttributeError as exc:
22
+ print(exc)
23
+ continue
24
+
25
+ raise RuntimeError(f'Dataset {name} not found in any of [{" ".join(import_paths)}]')
siclib/datasets/augmentations.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import albumentations as A
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from albumentations.pytorch.transforms import ToTensorV2
8
+ from omegaconf import OmegaConf
9
+
10
+
11
+ # flake8: noqa
12
+ # mypy: ignore-errors
13
+ class IdentityTransform(A.ImageOnlyTransform):
14
+ def apply(self, img, **params):
15
+ return img
16
+
17
+ def get_transform_init_args_names(self):
18
+ return ()
19
+
20
+
21
+ class RandomAdditiveShade(A.ImageOnlyTransform):
22
+ def __init__(
23
+ self,
24
+ nb_ellipses=10,
25
+ transparency_limit=[-0.5, 0.8],
26
+ kernel_size_limit=[150, 350],
27
+ always_apply=False,
28
+ p=0.5,
29
+ ):
30
+ super().__init__(always_apply, p)
31
+ self.nb_ellipses = nb_ellipses
32
+ self.transparency_limit = transparency_limit
33
+ self.kernel_size_limit = kernel_size_limit
34
+
35
+ def apply(self, img, **params):
36
+ if img.dtype == np.float32:
37
+ shaded = self._py_additive_shade(img * 255.0)
38
+ shaded /= 255.0
39
+ elif img.dtype == np.uint8:
40
+ shaded = self._py_additive_shade(img.astype(np.float32))
41
+ shaded = shaded.astype(np.uint8)
42
+ else:
43
+ raise NotImplementedError(f"Data augmentation not available for type: {img.dtype}")
44
+ return shaded
45
+
46
+ def _py_additive_shade(self, img):
47
+ grayscale = len(img.shape) == 2
48
+ if grayscale:
49
+ img = img[None]
50
+ min_dim = min(img.shape[:2]) / 4
51
+ mask = np.zeros(img.shape[:2], img.dtype)
52
+ for i in range(self.nb_ellipses):
53
+ ax = int(max(np.random.rand() * min_dim, min_dim / 5))
54
+ ay = int(max(np.random.rand() * min_dim, min_dim / 5))
55
+ max_rad = max(ax, ay)
56
+ x = np.random.randint(max_rad, img.shape[1] - max_rad) # center
57
+ y = np.random.randint(max_rad, img.shape[0] - max_rad)
58
+ angle = np.random.rand() * 90
59
+ cv2.ellipse(mask, (x, y), (ax, ay), angle, 0, 360, 255, -1)
60
+
61
+ transparency = np.random.uniform(*self.transparency_limit)
62
+ ks = np.random.randint(*self.kernel_size_limit)
63
+ if (ks % 2) == 0: # kernel_size has to be odd
64
+ ks += 1
65
+ mask = cv2.GaussianBlur(mask.astype(np.float32), (ks, ks), 0)
66
+ shaded = img * (1 - transparency * mask[..., np.newaxis] / 255.0)
67
+ out = np.clip(shaded, 0, 255)
68
+ if grayscale:
69
+ out = out.squeeze(0)
70
+ return out
71
+
72
+ def get_transform_init_args_names(self):
73
+ return "transparency_limit", "kernel_size_limit", "nb_ellipses"
74
+
75
+
76
+ def kw(entry: Union[float, dict], n=None, **default):
77
+ if not isinstance(entry, dict):
78
+ entry = {"p": entry}
79
+ entry = OmegaConf.create(entry)
80
+ if n is not None:
81
+ entry = default.get(n, entry)
82
+ return OmegaConf.merge(default, entry)
83
+
84
+
85
+ def kwi(entry: Union[float, dict], n=None, **default):
86
+ conf = kw(entry, n=n, **default)
87
+ return {k: conf[k] for k in set(default.keys()).union(set(["p"]))}
88
+
89
+
90
+ def replay_str(transforms, s="Replay:\n", log_inactive=True):
91
+ for t in transforms:
92
+ if "transforms" in t.keys():
93
+ s = replay_str(t["transforms"], s=s)
94
+ elif t["applied"] or log_inactive:
95
+ s += t["__class_fullname__"] + " " + str(t["applied"]) + "\n"
96
+ return s
97
+
98
+
99
+ class BaseAugmentation(object):
100
+ base_default_conf = {
101
+ "name": "???",
102
+ "shuffle": False,
103
+ "p": 1.0,
104
+ "verbose": False,
105
+ "dtype": "uint8", # (byte, float)
106
+ }
107
+
108
+ default_conf = {}
109
+
110
+ def __init__(self, conf={}):
111
+ """Perform some logic and call the _init method of the child model."""
112
+ default_conf = OmegaConf.merge(
113
+ OmegaConf.create(self.base_default_conf),
114
+ OmegaConf.create(self.default_conf),
115
+ )
116
+ OmegaConf.set_struct(default_conf, True)
117
+ if isinstance(conf, dict):
118
+ conf = OmegaConf.create(conf)
119
+ self.conf = OmegaConf.merge(default_conf, conf)
120
+ OmegaConf.set_readonly(self.conf, True)
121
+ self._init(self.conf)
122
+
123
+ self.conf = OmegaConf.merge(self.conf, conf)
124
+ if self.conf.verbose:
125
+ self.compose = A.ReplayCompose
126
+ else:
127
+ self.compose = A.Compose
128
+ if self.conf.dtype == "uint8":
129
+ self.dtype = np.uint8
130
+ self.preprocess = A.FromFloat(always_apply=True, dtype="uint8")
131
+ self.postprocess = A.ToFloat(always_apply=True)
132
+ elif self.conf.dtype == "float32":
133
+ self.dtype = np.float32
134
+ self.preprocess = A.ToFloat(always_apply=True)
135
+ self.postprocess = IdentityTransform()
136
+ else:
137
+ raise ValueError(f"Unsupported dtype {self.conf.dtype}")
138
+ self.to_tensor = ToTensorV2()
139
+
140
+ def _init(self, conf):
141
+ """Child class overwrites this, setting up a list of transforms"""
142
+ self.transforms = []
143
+
144
+ def __call__(self, image, return_tensor=False):
145
+ """image as HW or HWC"""
146
+ if isinstance(image, torch.Tensor):
147
+ image = image.cpu().numpy()
148
+ data = {"image": image}
149
+ if image.dtype != self.dtype:
150
+ data = self.preprocess(**data)
151
+ transforms = self.transforms
152
+ if self.conf.shuffle:
153
+ order = [i for i, _ in enumerate(transforms)]
154
+ np.random.shuffle(order)
155
+ transforms = [transforms[i] for i in order]
156
+ transformed = self.compose(transforms, p=self.conf.p)(**data)
157
+ if self.conf.verbose:
158
+ print(replay_str(transformed["replay"]["transforms"]))
159
+ transformed = self.postprocess(**transformed)
160
+ if return_tensor:
161
+ return self.to_tensor(**transformed)["image"]
162
+ else:
163
+ return transformed["image"]
164
+
165
+
166
+ class IdentityAugmentation(BaseAugmentation):
167
+ default_conf = {}
168
+
169
+ def _init(self, conf):
170
+ self.transforms = [IdentityTransform(p=1.0)]
171
+
172
+
173
+ class DarkAugmentation(BaseAugmentation):
174
+ default_conf = {"p": 0.75}
175
+
176
+ def _init(self, conf):
177
+ bright_contr = 0.5
178
+ blur = 0.1
179
+ random_gamma = 0.1
180
+ hue = 0.1
181
+ self.transforms = [
182
+ A.RandomRain(p=0.2),
183
+ A.RandomBrightnessContrast(
184
+ **kw(
185
+ bright_contr,
186
+ brightness_limit=(-0.4, 0.0),
187
+ contrast_limit=(-0.3, 0.0),
188
+ )
189
+ ),
190
+ A.OneOf(
191
+ [
192
+ A.Blur(**kwi(blur, p=0.1, blur_limit=(3, 9), n="blur")),
193
+ A.MotionBlur(**kwi(blur, p=0.2, blur_limit=(3, 25), n="motion_blur")),
194
+ A.ISONoise(),
195
+ A.ImageCompression(),
196
+ ],
197
+ **kwi(blur, p=0.1),
198
+ ),
199
+ A.RandomGamma(**kw(random_gamma, gamma_limit=(15, 65))),
200
+ A.OneOf(
201
+ [
202
+ A.Equalize(),
203
+ A.CLAHE(p=0.2),
204
+ A.ToGray(),
205
+ A.ToSepia(p=0.1),
206
+ A.HueSaturationValue(**kw(hue, val_shift_limit=(-100, -40))),
207
+ ],
208
+ p=0.5,
209
+ ),
210
+ ]
211
+
212
+
213
+ class DefaultAugmentation(BaseAugmentation):
214
+ default_conf = {"p": 1.0}
215
+
216
+ def _init(self, conf):
217
+ self.transforms = [
218
+ A.RandomBrightnessContrast(p=0.2),
219
+ A.HueSaturationValue(p=0.2),
220
+ A.ToGray(p=0.2),
221
+ A.ImageCompression(quality_lower=30, quality_upper=100, p=0.5),
222
+ A.OneOf(
223
+ [
224
+ A.MotionBlur(p=0.2),
225
+ A.MedianBlur(blur_limit=3, p=0.1),
226
+ A.Blur(blur_limit=3, p=0.1),
227
+ ],
228
+ p=0.2,
229
+ ),
230
+ ]
231
+
232
+
233
+ class PerspectiveAugmentation(BaseAugmentation):
234
+ default_conf = {"p": 1.0}
235
+
236
+ def _init(self, conf):
237
+ self.transforms = [
238
+ A.RandomBrightnessContrast(p=0.2),
239
+ A.HueSaturationValue(p=0.2),
240
+ A.ToGray(p=0.2),
241
+ A.ImageCompression(quality_lower=30, quality_upper=100, p=0.5),
242
+ A.OneOf(
243
+ [
244
+ A.MotionBlur(p=0.2),
245
+ A.MedianBlur(blur_limit=3, p=0.1),
246
+ A.Blur(blur_limit=3, p=0.1),
247
+ ],
248
+ p=0.2,
249
+ ),
250
+ ]
251
+
252
+
253
+ class DeepCalibAugmentations(BaseAugmentation):
254
+ default_conf = {"p": 1.0}
255
+
256
+ def _init(self, conf):
257
+ self.transforms = [
258
+ A.RandomBrightnessContrast(p=0.5),
259
+ A.GaussNoise(var_limit=(5.0, 112.0), mean=0, per_channel=True, p=0.75),
260
+ A.Downscale(
261
+ scale_min=0.5,
262
+ scale_max=0.95,
263
+ interpolation=dict(downscale=cv2.INTER_AREA, upscale=cv2.INTER_LINEAR),
264
+ p=0.5,
265
+ ),
266
+ A.Downscale(scale_min=0.5, scale_max=0.95, interpolation=cv2.INTER_LINEAR, p=0.5),
267
+ A.ImageCompression(quality_lower=20, quality_upper=85, p=1, always_apply=True),
268
+ A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.4),
269
+ A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=0.5),
270
+ A.ToGray(always_apply=False, p=0.2),
271
+ A.GaussianBlur(blur_limit=(3, 5), sigma_limit=0, p=0.25),
272
+ A.MotionBlur(blur_limit=5, allow_shifted=True, p=0.25),
273
+ A.MultiplicativeNoise(multiplier=[0.85, 1.15], elementwise=True, p=0.5),
274
+ ]
275
+
276
+
277
+ class GeoCalibAugmentations(BaseAugmentation):
278
+ default_conf = {"p": 1.0}
279
+
280
+ def _init(self, conf):
281
+ self.color_transforms = [
282
+ A.RandomGamma(gamma_limit=(80, 180), p=0.8),
283
+ A.RandomToneCurve(scale=0.1, p=0.5),
284
+ A.RandomBrightnessContrast(p=0.5),
285
+ A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.4),
286
+ A.OneOf([A.ToGray(p=0.1), A.ToSepia(p=0.1), IdentityTransform(p=0.8)], p=1),
287
+ ]
288
+
289
+ self.noise_transforms = [
290
+ A.GaussNoise(var_limit=(5.0, 112.0), mean=0, per_channel=True, p=0.75),
291
+ A.ImageCompression(quality_lower=20, quality_upper=100, p=1),
292
+ A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), p=0.5),
293
+ A.OneOrOther(
294
+ first=A.Compose(
295
+ [
296
+ A.AdvancedBlur(
297
+ p=1,
298
+ blur_limit=(3, 7),
299
+ sigmaX_limit=(0.2, 1.0),
300
+ sigmaY_limit=(0.2, 1.0),
301
+ rotate_limit=(-90, 90),
302
+ beta_limit=(0.5, 8.0),
303
+ noise_limit=(0.9, 1.1),
304
+ ),
305
+ A.Sharpen(p=0.5, alpha=(0.2, 0.5), lightness=(0.5, 1.0)),
306
+ ]
307
+ ),
308
+ second=A.Compose(
309
+ [
310
+ A.Sharpen(p=0.5, alpha=(0.2, 0.5), lightness=(0.5, 1.0)),
311
+ A.AdvancedBlur(
312
+ p=1,
313
+ blur_limit=(3, 7),
314
+ sigmaX_limit=(0.2, 1.0),
315
+ sigmaY_limit=(0.2, 1.0),
316
+ rotate_limit=(-90, 90),
317
+ beta_limit=(0.5, 8.0),
318
+ noise_limit=(0.9, 1.1),
319
+ ),
320
+ ]
321
+ ),
322
+ ),
323
+ ]
324
+
325
+ self.image_transforms = [
326
+ A.OneOf(
327
+ [
328
+ A.Downscale(
329
+ scale_min=0.5,
330
+ scale_max=0.99,
331
+ interpolation=dict(downscale=down, upscale=up),
332
+ p=1,
333
+ )
334
+ for down, up in [
335
+ (cv2.INTER_AREA, cv2.INTER_LINEAR),
336
+ (cv2.INTER_LINEAR, cv2.INTER_CUBIC),
337
+ (cv2.INTER_CUBIC, cv2.INTER_LINEAR),
338
+ (cv2.INTER_LINEAR, cv2.INTER_AREA),
339
+ ]
340
+ ],
341
+ p=1,
342
+ )
343
+ ]
344
+
345
+ self.transforms = [
346
+ *self.color_transforms,
347
+ *self.noise_transforms,
348
+ *self.image_transforms,
349
+ ]
350
+
351
+
352
+ augmentations = {
353
+ "default": DefaultAugmentation,
354
+ "dark": DarkAugmentation,
355
+ "perspective": PerspectiveAugmentation,
356
+ "deepcalib": DeepCalibAugmentations,
357
+ "geocalib": GeoCalibAugmentations,
358
+ "identity": IdentityAugmentation,
359
+ }
siclib/datasets/base_dataset.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base class for dataset.
2
+
3
+ See mnist.py for an example of dataset.
4
+ """
5
+
6
+ import collections
7
+ import logging
8
+ from abc import ABCMeta, abstractmethod
9
+
10
+ import omegaconf
11
+ import torch
12
+ from omegaconf import OmegaConf
13
+ from torch.utils.data import DataLoader, Sampler, get_worker_info
14
+ from torch.utils.data._utils.collate import default_collate_err_msg_format, np_str_obj_array_pattern
15
+
16
+ from siclib.utils.tensor import string_classes
17
+ from siclib.utils.tools import set_num_threads, set_seed
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # mypy: ignore-errors
22
+
23
+
24
+ class LoopSampler(Sampler):
25
+ """Infinite sampler that loops over a given number of elements."""
26
+
27
+ def __init__(self, loop_size: int, total_size: int = None):
28
+ """Initialize the sampler.
29
+
30
+ Args:
31
+ loop_size (int): Number of elements to loop over.
32
+ total_size (int, optional): Total number of elements. Defaults to None.
33
+ """
34
+ self.loop_size = loop_size
35
+ self.total_size = total_size - (total_size % loop_size)
36
+
37
+ def __iter__(self):
38
+ """Return an iterator over the elements."""
39
+ return (i % self.loop_size for i in range(self.total_size))
40
+
41
+ def __len__(self):
42
+ """Return the number of elements."""
43
+ return self.total_size
44
+
45
+
46
+ def worker_init_fn(i):
47
+ """Initialize the workers with a different seed."""
48
+ info = get_worker_info()
49
+ if hasattr(info.dataset, "conf"):
50
+ conf = info.dataset.conf
51
+ set_seed(info.id + conf.seed)
52
+ set_num_threads(conf.num_threads)
53
+ else:
54
+ set_num_threads(1)
55
+
56
+
57
+ def collate(batch):
58
+ """Difference with PyTorch default_collate: it can stack of other objects."""
59
+ if not isinstance(batch, list): # no batching
60
+ return batch
61
+ elem = batch[0]
62
+ elem_type = type(elem)
63
+ if isinstance(elem, torch.Tensor):
64
+ # out = None
65
+ if torch.utils.data.get_worker_info() is not None:
66
+ # If we're in a background process, concatenate directly into a
67
+ # shared memory tensor to avoid an extra copy
68
+ numel = sum([x.numel() for x in batch])
69
+ try:
70
+ _ = elem.untyped_storage()._new_shared(numel)
71
+ except AttributeError:
72
+ _ = elem.storage()._new_shared(numel)
73
+ return torch.stack(batch, dim=0)
74
+ elif (
75
+ elem_type.__module__ == "numpy"
76
+ and elem_type.__name__ != "str_"
77
+ and elem_type.__name__ != "string_"
78
+ ):
79
+ if elem_type.__name__ in ["ndarray", "memmap"]:
80
+ # array of string classes and object
81
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
82
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
83
+ return collate([torch.as_tensor(b) for b in batch])
84
+ elif elem.shape == (): # scalars
85
+ return torch.as_tensor(batch)
86
+ elif isinstance(elem, float):
87
+ return torch.tensor(batch, dtype=torch.float64)
88
+ elif isinstance(elem, int):
89
+ return torch.tensor(batch)
90
+ elif isinstance(elem, string_classes):
91
+ return batch
92
+ elif isinstance(elem, collections.abc.Mapping):
93
+ return {key: collate([d[key] for d in batch]) for key in elem}
94
+ elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
95
+ return elem_type(*(collate(samples) for samples in zip(*batch)))
96
+ elif isinstance(elem, collections.abc.Sequence):
97
+ # check to make sure that the elements in batch have consistent size
98
+ it = iter(batch)
99
+ elem_size = len(next(it))
100
+ if any(len(elem) != elem_size for elem in it):
101
+ raise RuntimeError("each element in list of batch should be of equal size")
102
+ transposed = zip(*batch)
103
+ return [collate(samples) for samples in transposed]
104
+ elif elem is None:
105
+ return elem
106
+ else:
107
+ # try to stack anyway in case the object implements stacking.
108
+ return torch.stack(batch, 0)
109
+
110
+
111
+ class BaseDataset(metaclass=ABCMeta):
112
+ """Base class for dataset.
113
+
114
+ What the dataset model is expect to declare:
115
+ default_conf: dictionary of the default configuration of the dataset.
116
+ It overwrites base_default_conf in BaseModel, and it is overwritten by
117
+ the user-provided configuration passed to __init__.
118
+ Configurations can be nested.
119
+
120
+ _init(self, conf): initialization method, where conf is the final
121
+ configuration object (also accessible with `self.conf`). Accessing
122
+ unknown configuration entries will raise an error.
123
+
124
+ get_dataset(self, split): method that returns an instance of
125
+ torch.utils.data.Dataset corresponding to the requested split string,
126
+ which can be `'train'`, `'val'`, or `'test'`.
127
+ """
128
+
129
+ base_default_conf = {
130
+ "name": "???",
131
+ "num_workers": "???",
132
+ "train_batch_size": "???",
133
+ "val_batch_size": "???",
134
+ "test_batch_size": "???",
135
+ "shuffle_training": True,
136
+ "batch_size": 1,
137
+ "num_threads": 1,
138
+ "seed": 0,
139
+ "prefetch_factor": 2,
140
+ }
141
+ default_conf = {}
142
+
143
+ def __init__(self, conf):
144
+ """Perform some logic and call the _init method of the child model."""
145
+ default_conf = OmegaConf.merge(
146
+ OmegaConf.create(self.base_default_conf),
147
+ OmegaConf.create(self.default_conf),
148
+ )
149
+ OmegaConf.set_struct(default_conf, True)
150
+ if isinstance(conf, dict):
151
+ conf = OmegaConf.create(conf)
152
+ self.conf = OmegaConf.merge(default_conf, conf)
153
+ OmegaConf.set_readonly(self.conf, True)
154
+ logger.info(f"Creating dataset {self.__class__.__name__}")
155
+ self._init(self.conf)
156
+
157
+ @abstractmethod
158
+ def _init(self, conf):
159
+ """To be implemented by the child class."""
160
+ raise NotImplementedError
161
+
162
+ @abstractmethod
163
+ def get_dataset(self, split):
164
+ """To be implemented by the child class."""
165
+ raise NotImplementedError
166
+
167
+ def get_data_loader(self, split, shuffle=None, pinned=False, distributed=False):
168
+ """Return a data loader for a given split."""
169
+ assert split in ["train", "val", "test"]
170
+ dataset = self.get_dataset(split)
171
+ try:
172
+ batch_size = self.conf[f"{split}_batch_size"]
173
+ except omegaconf.MissingMandatoryValue:
174
+ batch_size = self.conf.batch_size
175
+ num_workers = self.conf.get("num_workers", batch_size)
176
+ if distributed:
177
+ shuffle = False
178
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset)
179
+ else:
180
+ sampler = None
181
+ if shuffle is None:
182
+ shuffle = split == "train" and self.conf.shuffle_training
183
+ return DataLoader(
184
+ dataset,
185
+ batch_size=batch_size,
186
+ shuffle=shuffle,
187
+ sampler=sampler,
188
+ pin_memory=pinned,
189
+ collate_fn=collate,
190
+ num_workers=num_workers,
191
+ worker_init_fn=worker_init_fn,
192
+ prefetch_factor=self.conf.prefetch_factor,
193
+ )
194
+
195
+ def get_overfit_loader(self, split: str):
196
+ """Return an overfit data loader.
197
+
198
+ The training set is composed of a single duplicated batch, while
199
+ the validation and test sets contain a single copy of this same batch.
200
+ This is useful to debug a model and make sure that losses and metrics
201
+ correlate well.
202
+ """
203
+ assert split in {"train", "val", "test"}
204
+ dataset = self.get_dataset("train")
205
+ sampler = LoopSampler(
206
+ self.conf.batch_size,
207
+ len(dataset) if split == "train" else self.conf.batch_size,
208
+ )
209
+ num_workers = self.conf.get("num_workers", self.conf.batch_size)
210
+ return DataLoader(
211
+ dataset,
212
+ batch_size=self.conf.batch_size,
213
+ pin_memory=True,
214
+ num_workers=num_workers,
215
+ sampler=sampler,
216
+ worker_init_fn=worker_init_fn,
217
+ collate_fn=collate,
218
+ )
siclib/datasets/configs/openpano-radial.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: openpano_radial
2
+ base_dir: data/openpano
3
+ pano_dir: "${.base_dir}/panoramas"
4
+ images_per_pano: 16
5
+ resize_factor: null
6
+ n_workers: 1
7
+ device: cpu
8
+ overwrite: true
9
+ parameter_dists:
10
+ roll:
11
+ type: uniform # uni[-45, 45]
12
+ options:
13
+ loc: -0.7853981633974483 # -45 degrees
14
+ scale: 1.5707963267948966 # 90 degrees
15
+ pitch:
16
+ type: uniform # uni[-45, 45]
17
+ options:
18
+ loc: -0.7853981633974483 # -45 degrees
19
+ scale: 1.5707963267948966 # 90 degrees
20
+ vfov:
21
+ type: uniform # uni[20, 105]
22
+ options:
23
+ loc: 0.3490658503988659 # 20 degrees
24
+ scale: 1.48352986419518 # 85 degrees
25
+ k1_hat:
26
+ type: truncnorm
27
+ options:
28
+ a: -4.285714285714286 # corresponds to -0.3
29
+ b: 4.285714285714286 # corresponds to 0.3
30
+ loc: 0
31
+ scale: 0.07
32
+ resize_factor:
33
+ type: uniform
34
+ options:
35
+ loc: 1.2
36
+ scale: 0.5
37
+ shape:
38
+ type: fix
39
+ value:
40
+ - 640
41
+ - 640
siclib/datasets/configs/openpano.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: openpano
2
+ base_dir: data/openpano
3
+ pano_dir: "${.base_dir}/panoramas"
4
+ images_per_pano: 16
5
+ resize_factor: null
6
+ n_workers: 1
7
+ device: cpu
8
+ overwrite: true
9
+ parameter_dists:
10
+ roll:
11
+ type: uniform # uni[-45, 45]
12
+ options:
13
+ loc: -0.7853981633974483 # -45 degrees
14
+ scale: 1.5707963267948966 # 90 degrees
15
+ pitch:
16
+ type: uniform # uni[-45, 45]
17
+ options:
18
+ loc: -0.7853981633974483 # -45 degrees
19
+ scale: 1.5707963267948966 # 90 degrees
20
+ vfov:
21
+ type: uniform # uni[20, 105]
22
+ options:
23
+ loc: 0.3490658503988659 # 20 degrees
24
+ scale: 1.48352986419518 # 85 degrees
25
+ resize_factor:
26
+ type: uniform
27
+ options:
28
+ loc: 1.2
29
+ scale: 0.5
30
+ shape:
31
+ type: fix
32
+ value:
33
+ - 640
34
+ - 640
siclib/datasets/create_dataset_from_pano.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Script to create a dataset from panorama images."""
2
+
3
+ import hashlib
4
+ import logging
5
+ from concurrent import futures
6
+ from pathlib import Path
7
+
8
+ import hydra
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import pandas as pd
12
+ import scipy
13
+ import torch
14
+ from omegaconf import DictConfig, OmegaConf
15
+ from tqdm import tqdm
16
+
17
+ from siclib.geometry.camera import camera_models
18
+ from siclib.geometry.gravity import Gravity
19
+ from siclib.utils.conversions import deg2rad, focal2fov, fov2focal, rad2deg
20
+ from siclib.utils.image import load_image, write_image
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ # mypy: ignore-errors
26
+
27
+
28
+ def max_radius(a, b):
29
+ """Compute the maximum radius of a Brown distortion model."""
30
+ discrim = a * a - 4 * b
31
+ # if torch.isfinite(discrim) and discrim >= 0.0:
32
+ # discrim = np.sqrt(discrim) - a
33
+ # if discrim > 0.0:
34
+ # return 2.0 / discrim
35
+
36
+ valid = torch.isfinite(discrim) & (discrim >= 0.0)
37
+ discrim = torch.sqrt(discrim) - a
38
+ valid &= discrim > 0.0
39
+ return 2.0 / torch.where(valid, discrim, 0)
40
+
41
+
42
+ def brown_max_radius(k1, k2):
43
+ """Compute the maximum radius of a Brown distortion model."""
44
+ # fold the constants from the derivative into a and b
45
+ a = k1 * 3
46
+ b = k2 * 5
47
+ return torch.sqrt(max_radius(a, b))
48
+
49
+
50
+ class ParallelProcessor:
51
+ """Generic parallel processor class."""
52
+
53
+ def __init__(self, max_workers):
54
+ """Init processor and pbars."""
55
+ self.max_workers = max_workers
56
+ self.executor = futures.ProcessPoolExecutor(max_workers=self.max_workers)
57
+ self.pbars = {}
58
+
59
+ def update_pbar(self, pbar_key):
60
+ """Update progressbar."""
61
+ pbar = self.pbars.get(pbar_key)
62
+ pbar.update(1)
63
+
64
+ def submit_tasks(self, task_func, task_args, pbar_key):
65
+ """Submit tasks."""
66
+ pbar = tqdm(total=len(task_args), desc=f"Processing {pbar_key}", ncols=80)
67
+ self.pbars[pbar_key] = pbar
68
+
69
+ def update_pbar(future):
70
+ self.update_pbar(pbar_key)
71
+
72
+ futures = []
73
+ for args in task_args:
74
+ future = self.executor.submit(task_func, *args)
75
+ future.add_done_callback(update_pbar)
76
+ futures.append(future)
77
+
78
+ return futures
79
+
80
+ def wait_for_completion(self, futures):
81
+ """Wait for completion and return results."""
82
+ results = []
83
+ for f in futures:
84
+ results += f.result()
85
+
86
+ for key in self.pbars.keys():
87
+ self.pbars[key].close()
88
+
89
+ return results
90
+
91
+ def shutdown(self):
92
+ """Close the executer."""
93
+ self.executor.shutdown()
94
+
95
+
96
+ class DatasetGenerator:
97
+ """Dataset generator class to create perspective datasets from panoramas."""
98
+
99
+ default_conf = {
100
+ "name": "???",
101
+ # paths
102
+ "base_dir": "???",
103
+ "pano_dir": "${.base_dir}/panoramas",
104
+ "pano_train": "${.pano_dir}/train",
105
+ "pano_val": "${.pano_dir}/val",
106
+ "pano_test": "${.pano_dir}/test",
107
+ "perspective_dir": "${.base_dir}/${.name}",
108
+ "perspective_train": "${.perspective_dir}/train",
109
+ "perspective_val": "${.perspective_dir}/val",
110
+ "perspective_test": "${.perspective_dir}/test",
111
+ "train_csv": "${.perspective_dir}/train.csv",
112
+ "val_csv": "${.perspective_dir}/val.csv",
113
+ "test_csv": "${.perspective_dir}/test.csv",
114
+ # data options
115
+ "camera_model": "pinhole",
116
+ "parameter_dists": {
117
+ "roll": {
118
+ "type": "uniform",
119
+ "options": {"loc": deg2rad(-45), "scale": deg2rad(90)}, # in [-45, 45]
120
+ },
121
+ "pitch": {
122
+ "type": "uniform",
123
+ "options": {"loc": deg2rad(-45), "scale": deg2rad(90)}, # in [-45, 45]
124
+ },
125
+ "vfov": {
126
+ "type": "uniform",
127
+ "options": {"loc": deg2rad(20), "scale": deg2rad(85)}, # in [20, 105]
128
+ },
129
+ "resize_factor": {
130
+ "type": "uniform",
131
+ "options": {"loc": 1.0, "scale": 1.0}, # factor in [1.0, 2.0]
132
+ },
133
+ "shape": {"type": "fix", "value": (640, 640)},
134
+ },
135
+ "images_per_pano": 16,
136
+ "n_workers": 10,
137
+ "device": "cpu",
138
+ "overwrite": False,
139
+ }
140
+
141
+ def __init__(self, conf):
142
+ """Init the class by merging and storing the config."""
143
+ self.conf = OmegaConf.merge(
144
+ OmegaConf.create(self.default_conf),
145
+ OmegaConf.create(conf),
146
+ )
147
+ logger.info(f"Config:\n{OmegaConf.to_yaml(self.conf)}")
148
+
149
+ self.infos = {}
150
+ self.device = self.conf.device
151
+
152
+ self.camera_model = camera_models[self.conf.camera_model]
153
+
154
+ def sample_value(self, parameter_name, seed=None):
155
+ """Sample a value from the specified distribution."""
156
+ param_conf = self.conf["parameter_dists"][parameter_name]
157
+
158
+ if param_conf.type == "fix":
159
+ return torch.tensor(param_conf.value)
160
+
161
+ # fix seed for reproducibility
162
+ generator = None
163
+ if seed:
164
+ if not isinstance(seed, (int, float)):
165
+ seed = int(hashlib.sha256(seed.encode()).hexdigest(), 16) % (2**32)
166
+ generator = np.random.default_rng(seed)
167
+
168
+ sampler = getattr(scipy.stats, param_conf.type)
169
+ return torch.tensor(sampler.rvs(random_state=generator, **param_conf.options))
170
+
171
+ def plot_distributions(self):
172
+ """Plot parameter distributions."""
173
+ fig, ax = plt.subplots(3, 3, figsize=(15, 10))
174
+ for i, split in enumerate(["train", "val", "test"]):
175
+ roll_vals = [rad2deg(row["roll"]) for row in self.infos[split]]
176
+ ax[i, 0].hist(roll_vals, bins=100)
177
+ ax[i, 0].set_xlabel("Roll (°)")
178
+ ax[i, 0].set_ylabel(f"Count {split}")
179
+
180
+ pitch_vals = [rad2deg(row["pitch"]) for row in self.infos[split]]
181
+ ax[i, 1].hist(pitch_vals, bins=100)
182
+ ax[i, 1].set_xlabel("Pitch (°)")
183
+ ax[i, 1].set_ylabel(f"Count {split}")
184
+
185
+ vfov_vals = [rad2deg(row["vfov"]) for row in self.infos[split]]
186
+ ax[i, 2].hist(vfov_vals, bins=100)
187
+ ax[i, 2].set_xlabel("vFoV (°)")
188
+ ax[i, 2].set_ylabel(f"Count {split}")
189
+
190
+ plt.tight_layout()
191
+ plt.savefig(Path(self.conf.perspective_dir) / "distributions.pdf")
192
+
193
+ fig, ax = plt.subplots(3, 3, figsize=(15, 10))
194
+ for i, k1 in enumerate(["roll", "pitch", "vfov"]):
195
+ for j, k2 in enumerate(["roll", "pitch", "vfov"]):
196
+ ax[i, j].scatter(
197
+ [rad2deg(row[k1]) for row in self.infos["train"]],
198
+ [rad2deg(row[k2]) for row in self.infos["train"]],
199
+ s=1,
200
+ label="train",
201
+ )
202
+
203
+ ax[i, j].scatter(
204
+ [rad2deg(row[k1]) for row in self.infos["val"]],
205
+ [rad2deg(row[k2]) for row in self.infos["val"]],
206
+ s=1,
207
+ label="val",
208
+ )
209
+
210
+ ax[i, j].scatter(
211
+ [rad2deg(row[k1]) for row in self.infos["test"]],
212
+ [rad2deg(row[k2]) for row in self.infos["test"]],
213
+ s=1,
214
+ label="test",
215
+ )
216
+
217
+ ax[i, j].set_xlabel(k1)
218
+ ax[i, j].set_ylabel(k2)
219
+ ax[i, j].legend()
220
+
221
+ plt.tight_layout()
222
+ plt.savefig(Path(self.conf.perspective_dir) / "distributions_scatter.pdf")
223
+
224
+ def generate_images_from_pano(self, pano_path: Path, out_dir: Path):
225
+ """Generate perspective images from a single panorama."""
226
+ infos = []
227
+
228
+ pano = load_image(pano_path).to(self.device)
229
+
230
+ yaws = np.linspace(0, 2 * np.pi, self.conf.images_per_pano, endpoint=False)
231
+ params = {
232
+ k: [self.sample_value(k, pano_path.stem + k + str(i)) for i in yaws]
233
+ for k in self.conf.parameter_dists
234
+ if k != "shape"
235
+ }
236
+ shapes = [self.sample_value("shape", pano_path.stem + "shape") for _ in yaws]
237
+ params |= {
238
+ "height": [shape[0] for shape in shapes],
239
+ "width": [shape[1] for shape in shapes],
240
+ }
241
+
242
+ if "k1_hat" in params:
243
+ height = torch.tensor(params["height"])
244
+ width = torch.tensor(params["width"])
245
+ k1_hat = torch.tensor(params["k1_hat"])
246
+ vfov = torch.tensor(params["vfov"])
247
+ focal = fov2focal(vfov, height)
248
+ focal = focal
249
+ rel_focal = focal / height
250
+ k1 = k1_hat * rel_focal
251
+
252
+ # distance to image corner
253
+ # r_max_im = f_px * r_max * (1 + k1*r_max**2)
254
+ # function of r_max_im: f_px = r_max_im / (r_max * (1 + k1*r_max**2))
255
+ min_permissible_rmax = torch.sqrt((height / 2) ** 2 + (width / 2) ** 2)
256
+ r_max = brown_max_radius(k1=k1, k2=0)
257
+ lowest_possible_f_px = min_permissible_rmax / (r_max * (1 + k1 * r_max**2))
258
+ valid = lowest_possible_f_px <= focal
259
+
260
+ f = torch.where(valid, focal, lowest_possible_f_px)
261
+ vfov = focal2fov(f, height)
262
+
263
+ params["vfov"] = vfov
264
+ params |= {"k1": k1}
265
+
266
+ cam = self.camera_model.from_dict(params).float().to(self.device)
267
+ gravity = Gravity.from_rp(params["roll"], params["pitch"]).float().to(self.device)
268
+
269
+ if (out_dir / f"{pano_path.stem}_0.jpg").exists() and not self.conf.overwrite:
270
+ for i in range(self.conf.images_per_pano):
271
+ perspective_name = f"{pano_path.stem}_{i}.jpg"
272
+ info = {"fname": perspective_name} | {k: v[i].item() for k, v in params.items()}
273
+ infos.append(info)
274
+
275
+ logger.info(f"Perspectives for {pano_path.stem} already exist.")
276
+
277
+ return infos
278
+
279
+ perspective_images = cam.get_img_from_pano(
280
+ pano_img=pano, gravity=gravity, yaws=yaws, resize_factor=params["resize_factor"]
281
+ )
282
+
283
+ for i, perspective_image in enumerate(perspective_images):
284
+ perspective_name = f"{pano_path.stem}_{i}.jpg"
285
+
286
+ n_pixels = perspective_image.shape[-2] * perspective_image.shape[-1]
287
+ valid = (torch.sum(perspective_image.sum(0) == 0) / n_pixels) < 0.01
288
+ if not valid:
289
+ logger.debug(f"Perspective {perspective_name} has too many black pixels.")
290
+ continue
291
+
292
+ write_image(perspective_image, out_dir / perspective_name)
293
+
294
+ info = {"fname": perspective_name} | {k: v[i].item() for k, v in params.items()}
295
+ infos.append(info)
296
+
297
+ return infos
298
+
299
+ def generate_split(self, split: str, parallel_processor: ParallelProcessor):
300
+ """Generate a single split of a dataset."""
301
+ self.infos[split] = []
302
+ panorama_paths = [
303
+ path
304
+ for path in Path(self.conf[f"pano_{split}"]).glob("*")
305
+ if not path.name.startswith(".")
306
+ ]
307
+
308
+ out_dir = Path(self.conf[f"perspective_{split}"])
309
+ logger.info(f"Writing perspective images to {str(out_dir)}")
310
+ if not out_dir.exists():
311
+ out_dir.mkdir(parents=True)
312
+
313
+ futures = parallel_processor.submit_tasks(
314
+ self.generate_images_from_pano, [(f, out_dir) for f in panorama_paths], split
315
+ )
316
+ self.infos[split] = parallel_processor.wait_for_completion(futures)
317
+ # parallel_processor.shutdown()
318
+
319
+ metadata = pd.DataFrame(data=self.infos[split])
320
+ metadata.to_csv(self.conf[f"{split}_csv"])
321
+
322
+ def generate_dataset(self):
323
+ """Generate all splits of a dataset."""
324
+ out_dir = Path(self.conf.perspective_dir)
325
+ if not out_dir.exists():
326
+ out_dir.mkdir(parents=True)
327
+
328
+ OmegaConf.save(self.conf, out_dir / "config.yaml")
329
+
330
+ processor = ParallelProcessor(self.conf.n_workers)
331
+ for split in ["train", "val", "test"]:
332
+ self.generate_split(split=split, parallel_processor=processor)
333
+
334
+ processor.shutdown()
335
+
336
+ for split in ["train", "val", "test"]:
337
+ logger.info(f"Generated {len(self.infos[split])} {split} images.")
338
+
339
+ self.plot_distributions()
340
+
341
+
342
+ @hydra.main(version_base=None, config_path="configs", config_name="SUN360")
343
+ def main(cfg: DictConfig) -> None:
344
+ """Run dataset generation."""
345
+ generator = DatasetGenerator(conf=cfg)
346
+ generator.generate_dataset()
347
+
348
+
349
+ if __name__ == "__main__":
350
+ main()
siclib/datasets/simple_dataset.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset for images created with 'create_dataset_from_pano.py'."""
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Tuple
6
+
7
+ import pandas as pd
8
+ import torch
9
+ from omegaconf import DictConfig
10
+
11
+ from siclib.datasets.augmentations import IdentityAugmentation, augmentations
12
+ from siclib.datasets.base_dataset import BaseDataset
13
+ from siclib.geometry.camera import SimpleRadial
14
+ from siclib.geometry.gravity import Gravity
15
+ from siclib.geometry.perspective_fields import get_perspective_field
16
+ from siclib.utils.conversions import fov2focal
17
+ from siclib.utils.image import ImagePreprocessor, load_image
18
+ from siclib.utils.tools import fork_rng
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # mypy: ignore-errors
23
+
24
+
25
+ def load_csv(
26
+ csv_file: Path, img_root: Path
27
+ ) -> Tuple[List[Dict[str, Any]], torch.Tensor, torch.Tensor]:
28
+ """Load a CSV file containing image information.
29
+
30
+ Args:
31
+ csv_file (str): Path to the CSV file.
32
+ img_root (str): Path to the root directory containing the images.
33
+
34
+ Returns:
35
+ list: List of dictionaries containing the image paths and camera parameters.
36
+ """
37
+ df = pd.read_csv(csv_file)
38
+
39
+ infos, params, gravity = [], [], []
40
+ for _, row in df.iterrows():
41
+ h = row["height"]
42
+ w = row["width"]
43
+ px = row.get("px", w / 2)
44
+ py = row.get("py", h / 2)
45
+ vfov = row["vfov"]
46
+ f = fov2focal(torch.tensor(vfov), h)
47
+ k1 = row.get("k1", 0)
48
+ k2 = row.get("k2", 0)
49
+ params.append(torch.tensor([w, h, f, f, px, py, k1, k2]))
50
+
51
+ roll = row["roll"]
52
+ pitch = row["pitch"]
53
+ gravity.append(torch.tensor([roll, pitch]))
54
+
55
+ infos.append({"name": row["fname"], "file_name": str(img_root / row["fname"])})
56
+
57
+ params = torch.stack(params).float()
58
+ gravity = torch.stack(gravity).float()
59
+ return infos, params, gravity
60
+
61
+
62
+ class SimpleDataset(BaseDataset):
63
+ """Dataset for images created with 'create_dataset_from_pano.py'."""
64
+
65
+ default_conf = {
66
+ # paths
67
+ "dataset_dir": "???",
68
+ "train_img_dir": "${.dataset_dir}/train",
69
+ "val_img_dir": "${.dataset_dir}/val",
70
+ "test_img_dir": "${.dataset_dir}/test",
71
+ "train_csv": "${.dataset_dir}/train.csv",
72
+ "val_csv": "${.dataset_dir}/val.csv",
73
+ "test_csv": "${.dataset_dir}/test.csv",
74
+ # data options
75
+ "use_up": True,
76
+ "use_latitude": True,
77
+ "use_prior_focal": False,
78
+ "use_prior_gravity": False,
79
+ "use_prior_k1": False,
80
+ # image options
81
+ "grayscale": False,
82
+ "preprocessing": ImagePreprocessor.default_conf,
83
+ "augmentations": {"name": "geocalib", "verbose": False},
84
+ "p_rotate": 0.0, # probability to rotate image by +/- 90°
85
+ "reseed": False,
86
+ "seed": 0,
87
+ # data loader options
88
+ "num_workers": 8,
89
+ "prefetch_factor": 2,
90
+ "train_batch_size": 32,
91
+ "val_batch_size": 32,
92
+ "test_batch_size": 32,
93
+ }
94
+
95
+ def _init(self, conf):
96
+ pass
97
+
98
+ def get_dataset(self, split: str) -> torch.utils.data.Dataset:
99
+ """Return a dataset for a given split."""
100
+ return _SimpleDataset(self.conf, split)
101
+
102
+
103
+ class _SimpleDataset(torch.utils.data.Dataset):
104
+ """Dataset for dataset for images created with 'create_dataset_from_pano.py'."""
105
+
106
+ def __init__(self, conf: DictConfig, split: str):
107
+ """Initialize the dataset."""
108
+ self.conf = conf
109
+ self.split = split
110
+ self.img_dir = Path(conf.get(f"{split}_img_dir"))
111
+
112
+ self.preprocessor = ImagePreprocessor(conf.preprocessing)
113
+
114
+ # load image information
115
+ assert f"{split}_csv" in conf, f"Missing {split}_csv in conf"
116
+ infos_path = self.conf.get(f"{split}_csv")
117
+ self.infos, self.parameters, self.gravity = load_csv(infos_path, self.img_dir)
118
+
119
+ # define augmentations
120
+ aug_name = conf.augmentations.name
121
+ assert (
122
+ aug_name in augmentations.keys()
123
+ ), f'{aug_name} not in {" ".join(augmentations.keys())}'
124
+
125
+ if self.split == "train":
126
+ self.augmentation = augmentations[aug_name](conf.augmentations)
127
+ else:
128
+ self.augmentation = IdentityAugmentation()
129
+
130
+ def __len__(self):
131
+ return len(self.infos)
132
+
133
+ def __getitem__(self, idx):
134
+ if not self.conf.reseed:
135
+ return self.getitem(idx)
136
+ with fork_rng(self.conf.seed + idx, False):
137
+ return self.getitem(idx)
138
+
139
+ def _read_image(
140
+ self, infos: Dict[str, Any], parameters: torch.Tensor, gravity: torch.Tensor
141
+ ) -> Dict[str, Any]:
142
+ path = Path(str(infos["file_name"]))
143
+
144
+ # load image as uint8 and HWC for augmentation
145
+ image = load_image(path, self.conf.grayscale, return_tensor=False)
146
+ image = self.augmentation(image, return_tensor=True)
147
+
148
+ # create radial camera -> same as pinhole if k1 = 0
149
+ camera = SimpleRadial(parameters[None]).float()
150
+
151
+ roll, pitch = gravity[None].unbind(-1)
152
+ gravity = Gravity.from_rp(roll, pitch)
153
+
154
+ # preprocess
155
+ data = self.preprocessor(image)
156
+ camera = camera.scale(data["scales"])
157
+ camera = camera.crop(data["crop_pad"]) if "crop_pad" in data else camera
158
+
159
+ priors = {"prior_gravity": gravity} if self.conf.use_prior_gravity else {}
160
+ priors |= {"prior_focal": camera.f[..., 1]} if self.conf.use_prior_focal else {}
161
+ priors |= {"prior_k1": camera.k1} if self.conf.use_prior_k1 else {}
162
+ return {
163
+ "name": infos["name"],
164
+ "path": str(path),
165
+ "camera": camera[0],
166
+ "gravity": gravity[0],
167
+ **priors,
168
+ **data,
169
+ }
170
+
171
+ def _get_perspective(self, data):
172
+ """Get perspective field."""
173
+ camera = data["camera"]
174
+ gravity = data["gravity"]
175
+
176
+ up_field, lat_field = get_perspective_field(
177
+ camera, gravity, use_up=self.conf.use_up, use_latitude=self.conf.use_latitude
178
+ )
179
+
180
+ out = {}
181
+ if self.conf.use_up:
182
+ out["up_field"] = up_field[0]
183
+ if self.conf.use_latitude:
184
+ out["latitude_field"] = lat_field[0]
185
+
186
+ return out
187
+
188
+ def getitem(self, idx: int):
189
+ """Return a sample from the dataset."""
190
+ infos = self.infos[idx]
191
+ parameters = self.parameters[idx]
192
+ gravity = self.gravity[idx]
193
+ data = self._read_image(infos, parameters, gravity)
194
+
195
+ if self.conf.use_up or self.conf.use_latitude:
196
+ data |= self._get_perspective(data)
197
+
198
+ return data
199
+
200
+
201
+ if __name__ == "__main__":
202
+ # Create a dump of the dataset
203
+ import argparse
204
+
205
+ import matplotlib.pyplot as plt
206
+
207
+ from siclib.visualization.visualize_batch import make_perspective_figures
208
+
209
+ parser = argparse.ArgumentParser()
210
+ parser.add_argument("--name", type=str, required=True)
211
+ parser.add_argument("--data_dir", type=str)
212
+ parser.add_argument("--split", type=str, default="train")
213
+ parser.add_argument("--shuffle", action="store_true")
214
+ parser.add_argument("--n_rows", type=int, default=4)
215
+ parser.add_argument("--dpi", type=int, default=100)
216
+ args = parser.parse_intermixed_args()
217
+
218
+ dconf = SimpleDataset.default_conf
219
+ dconf["name"] = args.name
220
+ dconf["num_workers"] = 0
221
+ dconf["prefetch_factor"] = None
222
+
223
+ dconf["dataset_dir"] = args.data_dir
224
+ dconf[f"{args.split}_batch_size"] = args.n_rows
225
+
226
+ torch.set_grad_enabled(False)
227
+
228
+ dataset = SimpleDataset(dconf)
229
+ loader = dataset.get_data_loader(args.split, args.shuffle)
230
+
231
+ with fork_rng(seed=42):
232
+ for data in loader:
233
+ pred = data
234
+ break
235
+ fig = make_perspective_figures(pred, data, n_pairs=args.n_rows)
236
+
237
+ plt.show()
siclib/datasets/utils/__init__.py ADDED
File without changes
siclib/datasets/utils/align_megadepth.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import subprocess
3
+ from pathlib import Path
4
+
5
+ # flake8: noqa
6
+ # mypy: ignore-errors
7
+
8
+ parser = argparse.ArgumentParser(description="Aligns a COLMAP model and plots the horizon lines")
9
+ parser.add_argument(
10
+ "--base_dir", type=str, help="Path to the base directory of the MegaDepth dataset"
11
+ )
12
+ parser.add_argument("--out_dir", type=str, help="Path to the output directory")
13
+ args = parser.parse_args()
14
+
15
+ base_dir = Path(args.base_dir)
16
+ out_dir = Path(args.out_dir)
17
+
18
+ scenes = [d.name for d in base_dir.iterdir() if d.is_dir()]
19
+ print(scenes[:3], len(scenes))
20
+
21
+ # exit()
22
+
23
+ for scene in scenes:
24
+ image_dir = base_dir / scene / "images"
25
+ sfm_dir = base_dir / scene / "sparse" / "manhattan" / "0"
26
+
27
+ # Align model
28
+ align_dir = out_dir / scene / "sparse" / "align"
29
+ align_dir.mkdir(exist_ok=True, parents=True)
30
+
31
+ print(f"image_dir ({image_dir.exists()}): {image_dir}")
32
+ print(f"sfm_dir ({sfm_dir.exists()}): {sfm_dir}")
33
+ print(f"align_dir ({align_dir.exists()}): {align_dir}")
34
+
35
+ cmd = (
36
+ "colmap model_orientation_aligner "
37
+ + f"--image_path {image_dir} "
38
+ + f"--input_path {sfm_dir} "
39
+ + f"--output_path {str(align_dir)}"
40
+ )
41
+ subprocess.run(cmd, shell=True)
siclib/datasets/utils/download_openpano.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helper script to download and extract OpenPano dataset."""
2
+
3
+ import argparse
4
+ import shutil
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+
10
+ from siclib import logger
11
+
12
+ PANO_URL = "https://cvg-data.inf.ethz.ch/GeoCalib_ECCV2024/openpano.zip"
13
+
14
+
15
+ def download_and_extract_dataset(name: str, url: Path, output: Path) -> None:
16
+ """Download and extract a dataset from a URL."""
17
+ dataset_dir = output / name
18
+ if not output.exists():
19
+ output.mkdir(parents=True)
20
+
21
+ if dataset_dir.exists():
22
+ logger.info(f"Dataset {name} already exists at {dataset_dir}, skipping download.")
23
+ return
24
+
25
+ zip_file = output / f"{name}.zip"
26
+
27
+ if not zip_file.exists():
28
+ logger.info(f"Downloading dataset {name} to {zip_file} from {url}.")
29
+ torch.hub.download_url_to_file(url, zip_file)
30
+
31
+ logger.info(f"Extracting dataset {name} in {output}.")
32
+ shutil.unpack_archive(zip_file, output, format="zip")
33
+ zip_file.unlink()
34
+
35
+
36
+ def main():
37
+ """Prepare the OpenPano dataset."""
38
+ parser = argparse.ArgumentParser(description="Download and extract OpenPano dataset.")
39
+ parser.add_argument("--name", type=str, default="openpano", help="Name of the dataset.")
40
+ parser.add_argument(
41
+ "--laval_dir", type=str, default="data/laval-tonemap", help="Path the Laval dataset."
42
+ )
43
+
44
+ args = parser.parse_args()
45
+
46
+ out_dir = Path("data")
47
+ download_and_extract_dataset(args.name, PANO_URL, out_dir)
48
+
49
+ pano_dir = out_dir / args.name / "panoramas"
50
+ for split in ["train", "test", "val"]:
51
+ with open(pano_dir / f"{split}_panos.txt", "r") as f:
52
+ pano_list = f.readlines()
53
+ pano_list = [fname.strip() for fname in pano_list]
54
+
55
+ for fname in tqdm(pano_list, ncols=80, desc=f"Copying {split} panoramas"):
56
+ laval_path = Path(args.laval_dir) / fname
57
+ target_path = pano_dir / split / fname
58
+
59
+ # pano either exists in laval or is in split
60
+ if target_path.exists():
61
+ continue
62
+
63
+ if laval_path.exists():
64
+ shutil.copy(laval_path, target_path)
65
+ else: # not in laval and not in split
66
+ logger.warning(f"Panorama {fname} not found in {args.laval_dir} or {split} split.")
67
+
68
+ n_train = len(list(pano_dir.glob("train/*.jpg")))
69
+ n_test = len(list(pano_dir.glob("test/*.jpg")))
70
+ n_val = len(list(pano_dir.glob("val/*.jpg")))
71
+ logger.info(f"{args.name} contains {n_train}/{n_test}/{n_val} train/test/val panoramas.")
72
+
73
+
74
+ if __name__ == "__main__":
75
+ main()
siclib/datasets/utils/tonemapping.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
4
+ import argparse
5
+
6
+ import cv2
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+ from matplotlib import colors
11
+ from tqdm import tqdm
12
+
13
+ # flake8: noqa
14
+ # mypy: ignore-errors
15
+
16
+
17
+ class tonemap:
18
+ def __init__(self):
19
+ pass
20
+
21
+ def process(self, img):
22
+ return img
23
+
24
+ def inv_process(self, img):
25
+ return img
26
+
27
+
28
+ # Log correction
29
+ class log_tonemap(tonemap):
30
+ # Constructor
31
+ # Base of log
32
+ # Scale of tonemapped
33
+ # Offset
34
+ def __init__(self, base, scale=1, offset=1):
35
+ self.base = base
36
+ self.scale = scale
37
+ self.offset = offset
38
+
39
+ def process(self, img):
40
+ tonemapped = (np.log(img + self.offset) / np.log(self.base)) * self.scale
41
+ return tonemapped
42
+
43
+ def inv_process(self, img):
44
+ inverse_tonemapped = np.power(self.base, (img) / self.scale) - self.offset
45
+ return inverse_tonemapped
46
+
47
+
48
+ class log_tonemap_clip(tonemap):
49
+ # Constructor
50
+ # Base of log
51
+ # Scale of tonemapped
52
+ # Offset
53
+ def __init__(self, base, scale=1, offset=1):
54
+ self.base = base
55
+ self.scale = scale
56
+ self.offset = offset
57
+
58
+ def process(self, img):
59
+ tonemapped = np.clip((np.log(img * self.scale + self.offset) / np.log(self.base)), 0, 2) - 1
60
+ return tonemapped
61
+
62
+ def inv_process(self, img):
63
+ inverse_tonemapped = (np.power(self.base, (img + 1)) - self.offset) / self.scale
64
+ return inverse_tonemapped
65
+
66
+
67
+ # Gamma Tonemap
68
+ class gamma_tonemap(tonemap):
69
+ def __init__(
70
+ self,
71
+ gamma,
72
+ ):
73
+ self.gamma = gamma
74
+
75
+ def process(self, img):
76
+ tonemapped = np.power(img, 1 / self.gamma)
77
+ return tonemapped
78
+
79
+ def inv_process(self, img):
80
+ inverse_tonemapped = np.power(img, self.gamma)
81
+ return inverse_tonemapped
82
+
83
+
84
+ class linear_clip(tonemap):
85
+ def __init__(self, scale, mean):
86
+ self.scale = scale
87
+ self.mean = mean
88
+
89
+ def process(self, img):
90
+ tonemapped = np.clip((img - self.mean) / self.scale, -1, 1)
91
+ return tonemapped
92
+
93
+ def inv_process(self, img):
94
+ inverse_tonemapped = img * self.scale + self.mean
95
+ return inverse_tonemapped
96
+
97
+
98
+ def make_tonemap_HDR(opt):
99
+ if opt.mode == "luminance":
100
+ res_tonemap = log_tonemap_clip(10, 1.0, 1.0)
101
+ else: # temperature
102
+ res_tonemap = linear_clip(5000.0, 5000.0)
103
+ return res_tonemap
104
+
105
+
106
+ class LDRfromHDR:
107
+ def __init__(
108
+ self, tonemap="none", orig_scale=False, clip=True, quantization=0, color_jitter=0, noise=0
109
+ ):
110
+ self.tonemap_str, val = tonemap
111
+ if tonemap[0] == "gamma":
112
+ self.tonemap = gamma_tonemap(val)
113
+ elif tonemap[0] == "log10":
114
+ self.tonemap = log_tonemap(val)
115
+ else:
116
+ print("Warning: No tonemap specified, using linear")
117
+
118
+ self.clip = clip
119
+ self.orig_scale = orig_scale
120
+ self.bits = quantization
121
+ self.jitter = color_jitter
122
+ self.noise = noise
123
+
124
+ self.wbModel = None
125
+
126
+ def process(self, HDR):
127
+ LDR, normalized_scale = self.rescale(HDR)
128
+ LDR = self.apply_clip(LDR)
129
+ LDR = self.apply_scale(LDR, normalized_scale)
130
+ LDR = self.apply_tonemap(LDR)
131
+ LDR = self.colorJitter(LDR)
132
+ LDR = self.gaussianNoise(LDR)
133
+ LDR = self.quantize(LDR)
134
+ LDR = self.apply_white_balance(LDR)
135
+ return LDR, normalized_scale
136
+
137
+ def rescale(self, img, percentile=90, max_mapping=0.8):
138
+ r_percentile = np.percentile(img, percentile)
139
+ alpha = max_mapping / (r_percentile + 1e-10)
140
+
141
+ img_reexposed = img * alpha
142
+
143
+ normalized_scale = normalizeScale(1 / alpha)
144
+
145
+ return img_reexposed, normalized_scale
146
+
147
+ def rescaleAlpha(self, img, percentile=90, max_mapping=0.8):
148
+ r_percentile = np.percentile(img, percentile)
149
+ alpha = max_mapping / (r_percentile + 1e-10)
150
+
151
+ return alpha
152
+
153
+ def apply_clip(self, img):
154
+ if self.clip:
155
+ img = np.clip(img, 0, 1)
156
+ return img
157
+
158
+ def apply_scale(self, img, scale):
159
+ if self.orig_scale:
160
+ scale = unNormalizeScale(scale)
161
+ img = img * scale
162
+ return img
163
+
164
+ def apply_tonemap(self, img):
165
+ if self.tonemap_str == "none":
166
+ return img
167
+ gammaed = self.tonemap.process(img)
168
+ return gammaed
169
+
170
+ def quantize(self, img):
171
+ if self.bits == 0:
172
+ return img
173
+ max_val = np.power(2, self.bits)
174
+ img = img * max_val
175
+ img = np.floor(img)
176
+ img = img / max_val
177
+ return img
178
+
179
+ def colorJitter(self, img):
180
+ if self.jitter == 0:
181
+ return img
182
+ hsv = colors.rgb_to_hsv(img)
183
+ hue_offset = np.random.normal(0, self.jitter, 1)
184
+ hsv[:, :, 0] = (hsv[:, :, 0] + hue_offset) % 1.0
185
+ rgb = colors.hsv_to_rgb(hsv)
186
+ return rgb
187
+
188
+ def gaussianNoise(self, img):
189
+ if self.noise == 0:
190
+ return img
191
+ noise_amount = np.random.uniform(0, self.noise, 1)
192
+ noise_img = np.random.normal(0, noise_amount, img.shape)
193
+ img = img + noise_img
194
+ img = np.clip(img, 0, 1).astype(np.float32)
195
+ return img
196
+
197
+ def apply_white_balance(self, img):
198
+ if self.wbModel is None:
199
+ return img
200
+ img = self.wbModel.correctImage(img)
201
+ return img.copy()
202
+
203
+
204
+ def make_LDRfromHDR(opt):
205
+ LDR_from_HDR = LDRfromHDR(
206
+ opt.tonemap_LDR, opt.orig_scale, opt.clip, opt.quantization, opt.color_jitter, opt.noise
207
+ )
208
+ return LDR_from_HDR
209
+
210
+
211
+ def torchnormalizeEV(EV, mean=5.12, scale=6, clip=True):
212
+ # Normalize based on the computed distribution between -1 1
213
+ EV -= mean
214
+ EV = EV / scale
215
+
216
+ if clip:
217
+ EV = torch.clip(EV, min=-1, max=1)
218
+
219
+ return EV
220
+
221
+
222
+ def torchnormalizeEV0(EV, mean=5.12, scale=6, clip=True):
223
+ # Normalize based on the computed distribution between 0 1
224
+ EV -= mean
225
+ EV = EV / scale
226
+
227
+ if clip:
228
+ EV = torch.clip(EV, min=-1, max=1)
229
+
230
+ EV += 0.5
231
+ EV = EV / 2
232
+
233
+ return EV
234
+
235
+
236
+ def normalizeScale(x, scale=4):
237
+ x = np.log10(x + 1)
238
+
239
+ x = x / (scale / 2)
240
+ x = x - 1
241
+
242
+ return x
243
+
244
+
245
+ def unNormalizeScale(x, scale=4):
246
+ x = x + 1
247
+ x = x * (scale / 2)
248
+
249
+ x = np.power(10, x) - 1
250
+
251
+ return x
252
+
253
+
254
+ def normalizeIlluminance(x, scale=5):
255
+ x = np.log10(x + 1)
256
+
257
+ x = x / (scale / 2)
258
+ x = x - 1
259
+
260
+ return x
261
+
262
+
263
+ def unNormalizeIlluminance(x, scale=5):
264
+ x = x + 1
265
+ x = x * (scale / 2)
266
+
267
+ x = np.power(10, x) - 1
268
+
269
+ return x
270
+
271
+
272
+ def main(args):
273
+ processor = LDRfromHDR(
274
+ # tonemap=("log10", 10),
275
+ tonemap=("gamma", args.gamma),
276
+ orig_scale=False,
277
+ clip=True,
278
+ quantization=0,
279
+ color_jitter=0,
280
+ noise=0,
281
+ )
282
+
283
+ img_list = list(os.listdir(args.hdr_dir))
284
+ img_list = [f for f in img_list if f.endswith(args.extension)]
285
+ img_list = [f for f in img_list if not f.startswith("._")]
286
+
287
+ if not os.path.exists(args.out_dir):
288
+ os.makedirs(args.out_dir)
289
+
290
+ for fname in tqdm(img_list):
291
+ fname_out = ".".join(fname.split(".")[:-1])
292
+ out = os.path.join(args.out_dir, f"{fname_out}.jpg")
293
+ if os.path.exists(out) and not args.overwrite:
294
+ continue
295
+
296
+ fpath = os.path.join(args.hdr_dir, fname)
297
+ img = cv2.imread(fpath, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
298
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
299
+
300
+ ldr, scale = processor.process(img)
301
+
302
+ ldr = (ldr * 255).astype(np.uint8)
303
+ ldr = cv2.cvtColor(ldr, cv2.COLOR_RGB2BGR)
304
+ cv2.imwrite(out, ldr)
305
+
306
+
307
+ if __name__ == "__main__":
308
+ parser = argparse.ArgumentParser()
309
+ parser.add_argument("--hdr_dir", type=str, default="hdr")
310
+ parser.add_argument("--out_dir", type=str, default="ldr")
311
+ parser.add_argument("--extension", type=str, default=".exr")
312
+ parser.add_argument("--overwrite", action="store_true")
313
+ parser.add_argument("--gamma", type=float, default=2)
314
+ args = parser.parse_args()
315
+
316
+ main(args)
siclib/eval/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from siclib.eval.eval_pipeline import EvalPipeline
4
+ from siclib.utils.tools import get_class
5
+
6
+
7
+ def get_benchmark(benchmark):
8
+ return get_class(f"{__name__}.{benchmark}", EvalPipeline)
9
+
10
+
11
+ @torch.no_grad()
12
+ def run_benchmark(benchmark, eval_conf, experiment_dir, model=None):
13
+ """This overwrites existing benchmarks"""
14
+ experiment_dir.mkdir(exist_ok=True, parents=True)
15
+ bm = get_benchmark(benchmark)
16
+
17
+ pipeline = bm(eval_conf)
18
+ return pipeline.run(experiment_dir, model=model, overwrite=True, overwrite_eval=True)
siclib/eval/configs/deepcalib.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ model:
2
+ name: networks.deepcalib
3
+ weights: weights/deepcalib.tar
siclib/eval/configs/geocalib-pinhole.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ model:
2
+ name: networks.geocalib_pretrained