Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- .gitignore +138 -0
- .pre-commit-config.yaml +35 -0
- LICENSE +190 -0
- README.md +582 -7
- assets/fisheye-dog-pool.jpg +0 -0
- assets/fisheye-skyline.jpg +3 -0
- assets/pinhole-church.jpg +0 -0
- assets/pinhole-garden.jpg +0 -0
- assets/teaser.gif +3 -0
- demo.ipynb +3 -0
- geocalib/__init__.py +17 -0
- geocalib/camera.py +774 -0
- geocalib/extractor.py +126 -0
- geocalib/geocalib.py +150 -0
- geocalib/gravity.py +131 -0
- geocalib/interactive_demo.py +450 -0
- geocalib/lm_optimizer.py +642 -0
- geocalib/misc.py +318 -0
- geocalib/modules.py +575 -0
- geocalib/perspective_fields.py +366 -0
- geocalib/utils.py +325 -0
- geocalib/viz2d.py +502 -0
- gradio_app.py +228 -0
- hubconf.py +14 -0
- pyproject.toml +49 -0
- requirements.txt +5 -0
- siclib/LICENSE +190 -0
- siclib/__init__.py +15 -0
- siclib/configs/deepcalib.yaml +12 -0
- siclib/configs/geocalib-radial.yaml +38 -0
- siclib/configs/geocalib.yaml +10 -0
- siclib/configs/model/deepcalib.yaml +7 -0
- siclib/configs/model/geocalib.yaml +31 -0
- siclib/configs/train/deepcalib.yaml +22 -0
- siclib/configs/train/geocalib.yaml +50 -0
- siclib/datasets/__init__.py +25 -0
- siclib/datasets/augmentations.py +359 -0
- siclib/datasets/base_dataset.py +218 -0
- siclib/datasets/configs/openpano-radial.yaml +41 -0
- siclib/datasets/configs/openpano.yaml +34 -0
- siclib/datasets/create_dataset_from_pano.py +350 -0
- siclib/datasets/simple_dataset.py +237 -0
- siclib/datasets/utils/__init__.py +0 -0
- siclib/datasets/utils/align_megadepth.py +41 -0
- siclib/datasets/utils/download_openpano.py +75 -0
- siclib/datasets/utils/tonemapping.py +316 -0
- siclib/eval/__init__.py +18 -0
- siclib/eval/configs/deepcalib.yaml +3 -0
- siclib/eval/configs/geocalib-pinhole.yaml +2 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/fisheye-skyline.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/teaser.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
demo.ipynb filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# folders
|
2 |
+
data/**
|
3 |
+
outputs/**
|
4 |
+
weights/**
|
5 |
+
**.DS_Store
|
6 |
+
.vscode/**
|
7 |
+
wandb/**
|
8 |
+
third_party/**
|
9 |
+
|
10 |
+
# Byte-compiled / optimized / DLL files
|
11 |
+
__pycache__/
|
12 |
+
*.py[cod]
|
13 |
+
*$py.class
|
14 |
+
|
15 |
+
# C extensions
|
16 |
+
*.so
|
17 |
+
|
18 |
+
# Distribution / packaging
|
19 |
+
.Python
|
20 |
+
build/
|
21 |
+
develop-eggs/
|
22 |
+
dist/
|
23 |
+
downloads/
|
24 |
+
eggs/
|
25 |
+
.eggs/
|
26 |
+
lib/
|
27 |
+
lib64/
|
28 |
+
parts/
|
29 |
+
sdist/
|
30 |
+
var/
|
31 |
+
wheels/
|
32 |
+
pip-wheel-metadata/
|
33 |
+
share/python-wheels/
|
34 |
+
*.egg-info/
|
35 |
+
.installed.cfg
|
36 |
+
*.egg
|
37 |
+
MANIFEST
|
38 |
+
|
39 |
+
# PyInstaller
|
40 |
+
# Usually these files are written by a python script from a template
|
41 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
42 |
+
*.manifest
|
43 |
+
*.spec
|
44 |
+
|
45 |
+
# Installer logs
|
46 |
+
pip-log.txt
|
47 |
+
pip-delete-this-directory.txt
|
48 |
+
|
49 |
+
# Unit test / coverage reports
|
50 |
+
htmlcov/
|
51 |
+
.tox/
|
52 |
+
.nox/
|
53 |
+
.coverage
|
54 |
+
.coverage.*
|
55 |
+
.cache
|
56 |
+
nosetests.xml
|
57 |
+
coverage.xml
|
58 |
+
*.cover
|
59 |
+
*.py,cover
|
60 |
+
.hypothesis/
|
61 |
+
.pytest_cache/
|
62 |
+
|
63 |
+
# Translations
|
64 |
+
*.mo
|
65 |
+
*.pot
|
66 |
+
|
67 |
+
# Django stuff:
|
68 |
+
*.log
|
69 |
+
local_settings.py
|
70 |
+
db.sqlite3
|
71 |
+
db.sqlite3-journal
|
72 |
+
|
73 |
+
# Flask stuff:
|
74 |
+
instance/
|
75 |
+
.webassets-cache
|
76 |
+
|
77 |
+
# Scrapy stuff:
|
78 |
+
.scrapy
|
79 |
+
|
80 |
+
# Sphinx documentation
|
81 |
+
docs/_build/
|
82 |
+
|
83 |
+
# PyBuilder
|
84 |
+
target/
|
85 |
+
|
86 |
+
# Jupyter Notebook
|
87 |
+
.ipynb_checkpoints
|
88 |
+
|
89 |
+
# IPython
|
90 |
+
profile_default/
|
91 |
+
ipython_config.py
|
92 |
+
|
93 |
+
# pyenv
|
94 |
+
.python-version
|
95 |
+
|
96 |
+
# pipenv
|
97 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
98 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
99 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
100 |
+
# install all needed dependencies.
|
101 |
+
#Pipfile.lock
|
102 |
+
|
103 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
104 |
+
__pypackages__/
|
105 |
+
|
106 |
+
# Celery stuff
|
107 |
+
celerybeat-schedule
|
108 |
+
celerybeat.pid
|
109 |
+
|
110 |
+
# SageMath parsed files
|
111 |
+
*.sage.py
|
112 |
+
|
113 |
+
# Environments
|
114 |
+
.env
|
115 |
+
.venv
|
116 |
+
env/
|
117 |
+
venv/
|
118 |
+
ENV/
|
119 |
+
env.bak/
|
120 |
+
venv.bak/
|
121 |
+
|
122 |
+
# Spyder project settings
|
123 |
+
.spyderproject
|
124 |
+
.spyproject
|
125 |
+
|
126 |
+
# Rope project settings
|
127 |
+
.ropeproject
|
128 |
+
|
129 |
+
# mkdocs documentation
|
130 |
+
/site
|
131 |
+
|
132 |
+
# mypy
|
133 |
+
.mypy_cache/
|
134 |
+
.dmypy.json
|
135 |
+
dmypy.json
|
136 |
+
|
137 |
+
# Pyre type checker
|
138 |
+
.pyre/
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
default_stages: [commit]
|
2 |
+
default_language_version:
|
3 |
+
python: python3.10
|
4 |
+
repos:
|
5 |
+
- repo: https://github.com/psf/black
|
6 |
+
rev: 23.9.1
|
7 |
+
hooks:
|
8 |
+
- id: black
|
9 |
+
args: [--line-length=100]
|
10 |
+
exclude: ^(venv/|docs/)
|
11 |
+
types: [python]
|
12 |
+
- repo: https://github.com/PyCQA/flake8
|
13 |
+
rev: 6.1.0
|
14 |
+
hooks:
|
15 |
+
- id: flake8
|
16 |
+
additional_dependencies: [flake8-docstrings]
|
17 |
+
args:
|
18 |
+
[
|
19 |
+
--max-line-length=100,
|
20 |
+
--docstring-convention=google,
|
21 |
+
--ignore=E203 W503 E402 E731,
|
22 |
+
]
|
23 |
+
exclude: ^(venv/|docs/|.*__init__.py)
|
24 |
+
types: [python]
|
25 |
+
|
26 |
+
- repo: https://github.com/pycqa/isort
|
27 |
+
rev: 5.12.0
|
28 |
+
hooks:
|
29 |
+
- id: isort
|
30 |
+
args: [--line-length=100, --profile=black, --atomic]
|
31 |
+
|
32 |
+
- repo: https://github.com/pre-commit/mirrors-mypy
|
33 |
+
rev: v1.1.1
|
34 |
+
hooks:
|
35 |
+
- id: mypy
|
LICENSE
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
Copyright 2024 ETH Zurich
|
179 |
+
|
180 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
181 |
+
you may not use this file except in compliance with the License.
|
182 |
+
You may obtain a copy of the License at
|
183 |
+
|
184 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
185 |
+
|
186 |
+
Unless required by applicable law or agreed to in writing, software
|
187 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
188 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
189 |
+
See the License for the specific language governing permissions and
|
190 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,587 @@
|
|
1 |
---
|
2 |
title: GeoCalib
|
3 |
-
|
4 |
-
colorFrom: pink
|
5 |
-
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: GeoCalib
|
3 |
+
app_file: gradio_app.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
+
sdk_version: 4.38.1
|
|
|
|
|
6 |
---
|
7 |
+
<p align="center">
|
8 |
+
<h1 align="center"><ins>GeoCalib</ins> 📸<br>Single-image Calibration with Geometric Optimization</h1>
|
9 |
+
<p align="center">
|
10 |
+
<a href="https://www.linkedin.com/in/alexander-veicht/">Alexander Veicht</a>
|
11 |
+
·
|
12 |
+
<a href="https://psarlin.com/">Paul-Edouard Sarlin</a>
|
13 |
+
·
|
14 |
+
<a href="https://www.linkedin.com/in/philipplindenberger/">Philipp Lindenberger</a>
|
15 |
+
·
|
16 |
+
<a href="https://www.microsoft.com/en-us/research/people/mapoll/">Marc Pollefeys</a>
|
17 |
+
</p>
|
18 |
+
<h2 align="center">
|
19 |
+
<p>ECCV 2024</p>
|
20 |
+
<a href="" align="center">Paper</a> | <!--TODO: update link-->
|
21 |
+
<a href="https://colab.research.google.com/drive/1oMzgPGppAPAIQxe-s7SRd_q8r7dVfnqo#scrollTo=etdzQZQzoo-K" align="center">Colab</a> |
|
22 |
+
<a href="https://huggingface.co/spaces/veichta/GeoCalib" align="center">Demo 🤗</a>
|
23 |
+
</h2>
|
24 |
+
|
25 |
+
</p>
|
26 |
+
<p align="center">
|
27 |
+
<a href=""><img src="assets/teaser.gif" alt="example" width=80%></a> <!--TODO: update link-->
|
28 |
+
<br>
|
29 |
+
<em>
|
30 |
+
GeoCalib accurately estimates the camera intrinsics and gravity direction from a single image
|
31 |
+
<br>
|
32 |
+
by combining geometric optimization with deep learning.
|
33 |
+
</em>
|
34 |
+
</p>
|
35 |
+
|
36 |
+
##
|
37 |
+
|
38 |
+
GeoCalib is a an algoritm for single-image calibration: it estimates the camera intrinsics and gravity direction from a single image only. By combining geometric optimization with deep learning, GeoCalib provides a more flexible and accurate calibration compared to previous approaches. This repository hosts the [inference](#setup-and-demo), [evaluation](#evaluation), and [training](#training) code for GeoCalib and instructions to download our training set [OpenPano](#openpano-dataset).
|
39 |
+
|
40 |
+
|
41 |
+
## Setup and demo
|
42 |
+
|
43 |
+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1oMzgPGppAPAIQxe-s7SRd_q8r7dVfnqo#scrollTo=etdzQZQzoo-K)
|
44 |
+
[![Hugging Face](https://img.shields.io/badge/Gradio-Demo-blue)](https://huggingface.co/spaces/veichta/GeoCalib)
|
45 |
+
|
46 |
+
We provide a small inference package [`geocalib`](geocalib) that requires only minimal dependencies and Python >= 3.9. First clone the repository and install the dependencies:
|
47 |
+
|
48 |
+
```bash
|
49 |
+
git clone https://github.com/cvg/GeoCalib.git && cd GeoCalib
|
50 |
+
python -m pip install -e .
|
51 |
+
# OR
|
52 |
+
python -m pip install -e "git+https://github.com/cvg/GeoCalib#egg=geocalib"
|
53 |
+
```
|
54 |
+
|
55 |
+
Here is a minimal usage example:
|
56 |
+
|
57 |
+
```python
|
58 |
+
from geocalib import GeoCalib
|
59 |
+
|
60 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
61 |
+
model = GeoCalib().to(device)
|
62 |
+
|
63 |
+
# load image as tensor in range [0, 1] with shape [C, H, W]
|
64 |
+
img = model.load_image("path/to/image.jpg").to(device)
|
65 |
+
result = model.calibrate(img)
|
66 |
+
|
67 |
+
print("camera:", result["camera"])
|
68 |
+
print("gravity:", result["gravity"])
|
69 |
+
```
|
70 |
+
|
71 |
+
When either the intrinsics or the gravity are already know, they can be provided:
|
72 |
+
|
73 |
+
```python
|
74 |
+
# known intrinsics:
|
75 |
+
result = model.calibrate(img, priors={"focal": focal_length_tensor})
|
76 |
+
|
77 |
+
# known gravity:
|
78 |
+
result = model.calibrate(img, priors={"gravity": gravity_direction_tensor})
|
79 |
+
```
|
80 |
+
|
81 |
+
The default model is optimized for pinhole images. To handle lens distortion, use the following:
|
82 |
+
|
83 |
+
```python
|
84 |
+
model = GeoCalib(weights="distorted") # default is "pinhole"
|
85 |
+
result = model.calibrate(img, camera_model="simple_radial") # or pinhole, simple_divisional
|
86 |
+
```
|
87 |
+
|
88 |
+
Check out our [demo notebook](demo.ipynb) for a full working example.
|
89 |
+
|
90 |
+
<details>
|
91 |
+
<summary><b>[Interactive demo for your webcam - click to expand]</b></summary>
|
92 |
+
Run the following command:
|
93 |
+
|
94 |
+
```bash
|
95 |
+
python -m geocalib.interactive_demo --camera_id 0
|
96 |
+
```
|
97 |
+
|
98 |
+
The demo will open a window showing the camera feed and the calibration results. If `--camera_id` is not provided, the demo will ask for the IP address of a [droidcam](https://droidcam.app) camera.
|
99 |
+
|
100 |
+
Controls:
|
101 |
+
|
102 |
+
>Toggle the different features using the following keys:
|
103 |
+
>
|
104 |
+
>- ```h```: Show the estimated horizon line
|
105 |
+
>- ```u```: Show the estimated up-vectors
|
106 |
+
>- ```l```: Show the estimated latitude heatmap
|
107 |
+
>- ```c```: Show the confidence heatmap for the up-vectors and latitudes
|
108 |
+
>- ```d```: Show undistorted image, will overwrite the other features
|
109 |
+
>- ```g```: Shows a virtual grid of points
|
110 |
+
>- ```b```: Shows a virtual box object
|
111 |
+
>
|
112 |
+
>Change the camera model using the following keys:
|
113 |
+
>
|
114 |
+
>- ```1```: Pinhole -> Simple and fast
|
115 |
+
>- ```2```: Simple Radial -> For small distortions
|
116 |
+
>- ```3```: Simple Divisional -> For large distortions
|
117 |
+
>
|
118 |
+
>Press ```q``` to quit the demo.
|
119 |
+
|
120 |
+
</details>
|
121 |
+
|
122 |
+
|
123 |
+
<details>
|
124 |
+
<summary><b>[Load GeoCalib with torch hub - click to expand]</b></summary>
|
125 |
+
|
126 |
+
```python
|
127 |
+
model = torch.hub.load("cvg/GeoCalib", "GeoCalib", trust_repo=True)
|
128 |
+
```
|
129 |
+
|
130 |
+
</details>
|
131 |
+
|
132 |
+
## Evaluation
|
133 |
+
|
134 |
+
The full evaluation and training code is provided in the single-image calibration library [`siclib`](siclib), which can be installed as:
|
135 |
+
```bash
|
136 |
+
python -m pip install -e siclib
|
137 |
+
```
|
138 |
+
|
139 |
+
Running the evaluation commands will write the results to `outputs/results/`.
|
140 |
+
|
141 |
+
### LaMAR
|
142 |
+
|
143 |
+
Running the evaluation commands will download the dataset to ```data/lamar2k``` which will take around 400 MB of disk space.
|
144 |
+
|
145 |
+
<details>
|
146 |
+
<summary>[Evaluate GeoCalib]</summary>
|
147 |
+
|
148 |
+
To evaluate GeoCalib trained on the OpenPano dataset, run:
|
149 |
+
|
150 |
+
```bash
|
151 |
+
python -m siclib.eval.lamar2k --conf geocalib-pinhole --tag geocalib --overwrite
|
152 |
+
```
|
153 |
+
|
154 |
+
</details>
|
155 |
+
|
156 |
+
<details>
|
157 |
+
<summary>[Evaluate DeepCalib]</summary>
|
158 |
+
|
159 |
+
To evaluate DeepCalib trained on the OpenPano dataset, run:
|
160 |
+
|
161 |
+
```bash
|
162 |
+
python -m siclib.eval.lamar2k --conf deepcalib --tag deepcalib --overwrite
|
163 |
+
```
|
164 |
+
|
165 |
+
</details>
|
166 |
+
|
167 |
+
<details>
|
168 |
+
<summary>[Evaluate Perspective Fields]</summary>
|
169 |
+
|
170 |
+
Coming soon!
|
171 |
+
|
172 |
+
</details>
|
173 |
+
|
174 |
+
<details>
|
175 |
+
<summary>[Evaluate UVP]</summary>
|
176 |
+
|
177 |
+
To evaluate UVP, install the [VP-Estimation-with-Prior-Gravity](https://github.com/cvg/VP-Estimation-with-Prior-Gravity) under ```third_party/VP-Estimation-with-Prior-Gravity```. Then run:
|
178 |
+
|
179 |
+
```bash
|
180 |
+
python -m siclib.eval.lamar2k --conf uvp --tag uvp --overwrite data.preprocessing.edge_divisible_by=null
|
181 |
+
```
|
182 |
+
|
183 |
+
</details>
|
184 |
+
|
185 |
+
<details>
|
186 |
+
<summary>[Evaluate your own model]</summary>
|
187 |
+
|
188 |
+
If you have trained your own model, you can evaluate it by running:
|
189 |
+
|
190 |
+
```bash
|
191 |
+
python -m siclib.eval.lamar2k --checkpoint <experiment name> --tag <eval name> --overwrite
|
192 |
+
```
|
193 |
+
|
194 |
+
</details>
|
195 |
+
|
196 |
+
|
197 |
+
<details>
|
198 |
+
<summary>[Results]</summary>
|
199 |
+
|
200 |
+
Here are the results for the Area Under the Curve (AUC) for the roll, pitch and field of view (FoV) errors at 1/5/10 degrees for the different methods:
|
201 |
+
|
202 |
+
| Approach | Roll | Pitch | FoV |
|
203 |
+
| --------- | ------------------ | ------------------ | ------------------ |
|
204 |
+
| DeepCalib | 44.1 / 73.9 / 84.8 | 10.8 / 28.3 / 49.8 | 0.7 / 13.0 / 24.0 |
|
205 |
+
| ParamNet | 51.7 / 77.0 / 86.0 | 27.0 / 52.7 / 70.2 | 02.8 / 06.8 / 14.3 |
|
206 |
+
| UVP | 72.7 / 81.8 / 85.7 | 42.3 / 59.9 / 69.4 | 15.6 / 30.6 / 43.5 |
|
207 |
+
| GeoCalib | 86.4 / 92.5 / 95.0 | 55.0 / 76.9 / 86.2 | 19.1 / 41.5 / 60.0 |
|
208 |
+
</details>
|
209 |
+
|
210 |
+
### MegaDepth
|
211 |
+
|
212 |
+
Running the evaluation commands will download the dataset to ```data/megadepth2k``` or ```data/memegadepth2k-radial``` which will take around 2.1 GB and 1.47 GB of disk space respectively.
|
213 |
+
|
214 |
+
<details>
|
215 |
+
<summary>[Evaluate GeoCalib]</summary>
|
216 |
+
|
217 |
+
To evaluate GeoCalib trained on the OpenPano dataset, run:
|
218 |
+
|
219 |
+
```bash
|
220 |
+
python -m siclib.eval.megadepth2k --conf geocalib-pinhole --tag geocalib --overwrite
|
221 |
+
```
|
222 |
+
|
223 |
+
To run the eval on the radial distorted images, run:
|
224 |
+
|
225 |
+
```bash
|
226 |
+
python -m siclib.eval.megadepth2k_radial --conf geocalib-pinhole --tag geocalib --overwrite model.camera_model=simple_radial
|
227 |
+
```
|
228 |
+
|
229 |
+
</details>
|
230 |
+
|
231 |
+
<details>
|
232 |
+
<summary>[Evaluate DeepCalib]</summary>
|
233 |
+
|
234 |
+
To evaluate DeepCalib trained on the OpenPano dataset, run:
|
235 |
+
|
236 |
+
```bash
|
237 |
+
python -m siclib.eval.megadepth2k --conf deepcalib --tag deepcalib --overwrite
|
238 |
+
```
|
239 |
+
|
240 |
+
</details>
|
241 |
+
|
242 |
+
<details>
|
243 |
+
<summary>[Evaluate Perspective Fields]</summary>
|
244 |
+
|
245 |
+
Coming soon!
|
246 |
+
|
247 |
+
</details>
|
248 |
+
|
249 |
+
<details>
|
250 |
+
<summary>[Evaluate UVP]</summary>
|
251 |
+
|
252 |
+
To evaluate UVP, install the [VP-Estimation-with-Prior-Gravity](https://github.com/cvg/VP-Estimation-with-Prior-Gravity) under ```third_party/VP-Estimation-with-Prior-Gravity```. Then run:
|
253 |
+
|
254 |
+
```bash
|
255 |
+
python -m siclib.eval.megadepth2k --conf uvp --tag uvp --overwrite data.preprocessing.edge_divisible_by=null
|
256 |
+
```
|
257 |
+
|
258 |
+
</details>
|
259 |
+
|
260 |
+
<details>
|
261 |
+
<summary>[Evaluate your own model]</summary>
|
262 |
+
|
263 |
+
If you have trained your own model, you can evaluate it by running:
|
264 |
+
|
265 |
+
```bash
|
266 |
+
python -m siclib.eval.megadepth2k --checkpoint <experiment name> --tag <eval name> --overwrite
|
267 |
+
```
|
268 |
+
|
269 |
+
</details>
|
270 |
+
|
271 |
+
<details>
|
272 |
+
<summary>[Results]</summary>
|
273 |
+
|
274 |
+
Here are the results for the Area Under the Curve (AUC) for the roll, pitch and field of view (FoV) errors at 1/5/10 degrees for the different methods:
|
275 |
+
|
276 |
+
| Approach | Roll | Pitch | FoV |
|
277 |
+
| --------- | ------------------ | ------------------ | ------------------ |
|
278 |
+
| DeepCalib | 34.6 / 65.4 / 79.4 | 11.9 / 27.8 / 44.8 | 5.6 / 12.1 / 22.9 |
|
279 |
+
| ParamNet | 43.4 / 70.7 / 82.2 | 15.4 / 34.5 / 53.3 | 3.2 / 10.1 / 21.3 |
|
280 |
+
| UVP | 69.2 / 81.6 / 86.9 | 21.6 / 36.2 / 47.4 | 8.2 / 18.7 / 29.8 |
|
281 |
+
| GeoCalib | 82.6 / 90.6 / 94.0 | 32.4 / 53.3 / 67.5 | 13.6 / 31.7 / 48.2 |
|
282 |
+
</details>
|
283 |
+
|
284 |
+
### TartanAir
|
285 |
+
|
286 |
+
Running the evaluation commands will download the dataset to ```data/tartanair``` which will take around 1.85 GB of disk space.
|
287 |
+
|
288 |
+
<details>
|
289 |
+
<summary>[Evaluate GeoCalib]</summary>
|
290 |
+
|
291 |
+
To evaluate GeoCalib trained on the OpenPano dataset, run:
|
292 |
+
|
293 |
+
```bash
|
294 |
+
python -m siclib.eval.tartanair --conf geocalib-pinhole --tag geocalib --overwrite
|
295 |
+
```
|
296 |
+
|
297 |
+
</details>
|
298 |
+
|
299 |
+
<details>
|
300 |
+
<summary>[Evaluate DeepCalib]</summary>
|
301 |
+
|
302 |
+
To evaluate DeepCalib trained on the OpenPano dataset, run:
|
303 |
+
|
304 |
+
```bash
|
305 |
+
python -m siclib.eval.tartanair --conf deepcalib --tag deepcalib --overwrite
|
306 |
+
```
|
307 |
+
|
308 |
+
</details>
|
309 |
+
|
310 |
+
<details>
|
311 |
+
<summary>[Evaluate Perspective Fields]</summary>
|
312 |
+
|
313 |
+
Coming soon!
|
314 |
+
|
315 |
+
</details>
|
316 |
+
|
317 |
+
<details>
|
318 |
+
<summary>[Evaluate UVP]</summary>
|
319 |
+
|
320 |
+
To evaluate UVP, install the [VP-Estimation-with-Prior-Gravity](https://github.com/cvg/VP-Estimation-with-Prior-Gravity) under ```third_party/VP-Estimation-with-Prior-Gravity```. Then run:
|
321 |
+
|
322 |
+
```bash
|
323 |
+
python -m siclib.eval.tartanair --conf uvp --tag uvp --overwrite data.preprocessing.edge_divisible_by=null
|
324 |
+
```
|
325 |
+
|
326 |
+
</details>
|
327 |
+
|
328 |
+
<details>
|
329 |
+
<summary>[Evaluate your own model]</summary>
|
330 |
+
|
331 |
+
If you have trained your own model, you can evaluate it by running:
|
332 |
+
|
333 |
+
```bash
|
334 |
+
python -m siclib.eval.tartanair --checkpoint <experiment name> --tag <eval name> --overwrite
|
335 |
+
```
|
336 |
+
|
337 |
+
</details>
|
338 |
+
|
339 |
+
<details>
|
340 |
+
<summary>[Results]</summary>
|
341 |
+
|
342 |
+
Here are the results for the Area Under the Curve (AUC) for the roll, pitch and field of view (FoV) errors at 1/5/10 degrees for the different methods:
|
343 |
+
|
344 |
+
| Approach | Roll | Pitch | FoV |
|
345 |
+
| --------- | ------------------ | ------------------ | ------------------ |
|
346 |
+
| DeepCalib | 24.7 / 55.4 / 71.5 | 16.3 / 38.8 / 58.5 | 1.5 / 8.8 / 27.2 |
|
347 |
+
| ParamNet | 34.5 / 59.2 / 73.9 | 19.4 / 42.0 / 60.3 | 6.0 / 16.8 / 31.6 |
|
348 |
+
| UVP | 52.1 / 64.8 / 71.9 | 36.2 / 48.8 / 58.6 | 15.8 / 25.8 / 35.7 |
|
349 |
+
| GeoCalib | 71.3 / 83.8 / 89.8 | 38.2 / 62.9 / 76.6 | 14.1 / 30.4 / 47.6 |
|
350 |
+
</details>
|
351 |
+
|
352 |
+
### Stanford2D3D
|
353 |
+
|
354 |
+
Before downloading and running the evaluation, you will need to agree to the [terms of use](https://docs.google.com/forms/d/e/1FAIpQLScFR0U8WEUtb7tgjOhhnl31OrkEs73-Y8bQwPeXgebqVKNMpQ/viewform?c=0&w=1) for the Stanford2D3D dataset.
|
355 |
+
Running the evaluation commands will download the dataset to ```data/stanford2d3d``` which will take around 885 MB of disk space.
|
356 |
+
|
357 |
+
<details>
|
358 |
+
<summary>[Evaluate GeoCalib]</summary>
|
359 |
+
|
360 |
+
To evaluate GeoCalib trained on the OpenPano dataset, run:
|
361 |
+
|
362 |
+
```bash
|
363 |
+
python -m siclib.eval.stanford2d3d --conf geocalib-pinhole --tag geocalib --overwrite
|
364 |
+
```
|
365 |
+
|
366 |
+
</details>
|
367 |
+
|
368 |
+
<details>
|
369 |
+
<summary>[Evaluate DeepCalib]</summary>
|
370 |
+
|
371 |
+
To evaluate DeepCalib trained on the OpenPano dataset, run:
|
372 |
+
|
373 |
+
```bash
|
374 |
+
python -m siclib.eval.stanford2d3d --conf deepcalib --tag deepcalib --overwrite
|
375 |
+
```
|
376 |
+
|
377 |
+
</details>
|
378 |
+
|
379 |
+
<details>
|
380 |
+
<summary>[Evaluate Perspective Fields]</summary>
|
381 |
+
|
382 |
+
Coming soon!
|
383 |
+
|
384 |
+
</details>
|
385 |
+
|
386 |
+
<details>
|
387 |
+
<summary>[Evaluate UVP]</summary>
|
388 |
+
|
389 |
+
To evaluate UVP, install the [VP-Estimation-with-Prior-Gravity](https://github.com/cvg/VP-Estimation-with-Prior-Gravity) under ```third_party/VP-Estimation-with-Prior-Gravity```. Then run:
|
390 |
+
|
391 |
+
```bash
|
392 |
+
python -m siclib.eval.stanford2d3d --conf uvp --tag uvp --overwrite data.preprocessing.edge_divisible_by=null
|
393 |
+
```
|
394 |
+
|
395 |
+
</details>
|
396 |
+
|
397 |
+
<details>
|
398 |
+
<summary>[Evaluate your own model]</summary>
|
399 |
+
|
400 |
+
If you have trained your own model, you can evaluate it by running:
|
401 |
+
|
402 |
+
```bash
|
403 |
+
python -m siclib.eval.stanford2d3d --checkpoint <experiment name> --tag <eval name> --overwrite
|
404 |
+
```
|
405 |
+
|
406 |
+
</details>
|
407 |
+
|
408 |
+
<details>
|
409 |
+
<summary>[Results]</summary>
|
410 |
+
|
411 |
+
Here are the results for the Area Under the Curve (AUC) for the roll, pitch and field of view (FoV) errors at 1/5/10 degrees for the different methods:
|
412 |
+
|
413 |
+
| Approach | Roll | Pitch | FoV |
|
414 |
+
| --------- | ------------------ | ------------------ | ------------------ |
|
415 |
+
| DeepCalib | 33.8 / 63.9 / 79.2 | 21.6 / 46.9 / 65.7 | 8.1 / 20.6 / 37.6 |
|
416 |
+
| ParamNet | 44.6 / 73.9 / 84.8 | 29.2 / 56.7 / 73.1 | 5.8 / 14.3 / 27.8 |
|
417 |
+
| UVP | 65.3 / 74.6 / 79.1 | 51.2 / 63.0 / 69.2 | 22.2 / 39.5 / 51.3 |
|
418 |
+
| GeoCalib | 83.1 / 91.8 / 94.8 | 52.3 / 74.8 / 84.6 | 17.4 / 40.0 / 59.4 |
|
419 |
+
|
420 |
+
</details>
|
421 |
+
|
422 |
+
### Evaluation options
|
423 |
+
|
424 |
+
If you want to provide priors during the evaluation, you can add one or multiple of the following flags:
|
425 |
+
|
426 |
+
```bash
|
427 |
+
python -m siclib.eval.<benchmark> --conf <config> \
|
428 |
+
--tag <tag> \
|
429 |
+
data.use_prior_focal=true \
|
430 |
+
data.use_prior_gravity=true \
|
431 |
+
data.use_prior_k1=true
|
432 |
+
```
|
433 |
+
|
434 |
+
<details>
|
435 |
+
<summary>[Visual inspection]</summary>
|
436 |
+
|
437 |
+
To visually inspect the results of the evaluation, you can run the following command:
|
438 |
+
|
439 |
+
```bash
|
440 |
+
python -m siclib.eval.inspect <benchmark> <one or multiple tags>
|
441 |
+
|
442 |
+
```
|
443 |
+
For example, to inspect the results of the evaluation of the GeoCalib model on the LaMAR dataset, you can run:
|
444 |
+
```bash
|
445 |
+
python -m siclib.eval.inspect lamar2k geocalib
|
446 |
+
```
|
447 |
+
</details>
|
448 |
+
|
449 |
+
## OpenPano Dataset
|
450 |
+
|
451 |
+
The OpenPano dataset is a new dataset for single-image calibration which contains about 2.8k panoramas from various sources, namely [HDRMAPS](https://hdrmaps.com/hdris/), [PolyHaven](https://polyhaven.com/hdris), and the [Laval Indoor HDR dataset](http://hdrdb.com/indoor/#presentation). While this dataset is smaller than previous ones, it is publicly available and it provides a better balance between indoor and outdoor scenes.
|
452 |
+
|
453 |
+
<details>
|
454 |
+
<summary>[Downloading and preparing the dataset]</summary>
|
455 |
+
|
456 |
+
In order to assemble the training set, first download the Laval dataset following the instructions on [the corresponding project page](http://hdrdb.com/indoor/#presentation) and place the panoramas in ```data/indoorDatasetCalibrated```. Then, tonemap the HDR images using the following command:
|
457 |
+
|
458 |
+
```bash
|
459 |
+
python -m siclib.datasets.utils.tonemapping --hdr_dir data/indoorDatasetCalibrated --out_dir data/laval-tonemap
|
460 |
+
```
|
461 |
+
|
462 |
+
We provide a script to download the PolyHaven and HDRMAPS panos. The script will create folders ```data/openpano/panoramas/{split}``` containing the panoramas specified by the ```{split}_panos.txt``` files. To run the script, execute the following commands:
|
463 |
+
|
464 |
+
```bash
|
465 |
+
python -m siclib.datasets.utils.download_openpano --name openpano --laval_dir data/laval-tonemap
|
466 |
+
```
|
467 |
+
Alternatively, you can download the PolyHaven and HDRMAPS panos from [here](https://cvg-data.inf.ethz.ch/GeoCalib_ECCV2024/).
|
468 |
+
|
469 |
+
|
470 |
+
After downloading the panoramas, you can create the training set by running the following command:
|
471 |
+
|
472 |
+
```bash
|
473 |
+
python -m siclib.datasets.create_dataset_from_pano --config-name openpano
|
474 |
+
```
|
475 |
+
|
476 |
+
The dataset creation can be sped up by using multiple workers and a GPU. To do so, add the following arguments to the command:
|
477 |
+
|
478 |
+
```bash
|
479 |
+
python -m siclib.datasets.create_dataset_from_pano --config-name openpano n_workers=10 device=cuda
|
480 |
+
```
|
481 |
+
|
482 |
+
This will create the training set in ```data/openpano/openpano``` with about 37k images for training, 2.1k for validation, and 2.1k for testing.
|
483 |
+
|
484 |
+
<details>
|
485 |
+
<summary>[Distorted OpenPano]</summary>
|
486 |
+
|
487 |
+
To create the OpenPano dataset with radial distortion, run the following command:
|
488 |
+
|
489 |
+
```bash
|
490 |
+
python -m siclib.datasets.create_dataset_from_pano --config-name openpano_radial
|
491 |
+
```
|
492 |
+
|
493 |
+
</details>
|
494 |
+
|
495 |
+
</details>
|
496 |
+
|
497 |
+
## Training
|
498 |
+
|
499 |
+
As for the evaluation, the training code is provided in the single-image calibration library [`siclib`](siclib), which can be installed by:
|
500 |
+
|
501 |
+
```bash
|
502 |
+
python -m pip install -e siclib
|
503 |
+
```
|
504 |
+
|
505 |
+
Once the [OpenPano Dataset](#openpano-dataset) has been downloaded and prepared, we can train GeoCalib with it:
|
506 |
+
|
507 |
+
First download the pre-trained weights for the [MSCAN-B](https://cloud.tsinghua.edu.cn/d/c15b25a6745946618462/) backbone:
|
508 |
+
|
509 |
+
```bash
|
510 |
+
mkdir weights
|
511 |
+
wget "https://cloud.tsinghua.edu.cn/d/c15b25a6745946618462/files/?p=%2Fmscan_b.pth&dl=1" -O weights/mscan_b.pth
|
512 |
+
```
|
513 |
+
|
514 |
+
Then, start the training with the following command:
|
515 |
+
|
516 |
+
```bash
|
517 |
+
python -m siclib.train geocalib-pinhole-openpano --conf geocalib --distributed
|
518 |
+
```
|
519 |
+
|
520 |
+
Feel free to use any other experiment name. By default, the checkpoints will be written to ```outputs/training/```. The default batch size is 24 which requires 2x 4090 GPUs with 24GB of VRAM each. Configurations are managed by [Hydra](https://hydra.cc/) and can be overwritten from the command line.
|
521 |
+
For example, to train GeoCalib on a single GPU with a batch size of 5, run:
|
522 |
+
|
523 |
+
```bash
|
524 |
+
python -m siclib.train geocalib-pinhole-openpano \
|
525 |
+
--conf geocalib \
|
526 |
+
data.train_batch_size=5 # for 1x 2080 GPU
|
527 |
+
```
|
528 |
+
|
529 |
+
Be aware that this can impact the overall performance. You might need to adjust the learning rate and number of training steps accordingly.
|
530 |
+
|
531 |
+
If you want to log the training progress to [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://wandb.ai/), you can set the ```train.writer``` option:
|
532 |
+
|
533 |
+
```bash
|
534 |
+
python -m siclib.train geocalib-pinhole-openpano \
|
535 |
+
--conf geocalib \
|
536 |
+
--distributed \
|
537 |
+
train.writer=tensorboard
|
538 |
+
```
|
539 |
+
|
540 |
+
The model can then be evaluated using its experiment name:
|
541 |
+
|
542 |
+
```bash
|
543 |
+
python -m siclib.eval.<benchmark> --checkpoint geocalib-pinhole-openpano \
|
544 |
+
--tag geocalib-retrained
|
545 |
+
```
|
546 |
+
|
547 |
+
<details>
|
548 |
+
<summary>[Training DeepCalib]</summary>
|
549 |
+
|
550 |
+
To train DeepCalib on the OpenPano dataset, run:
|
551 |
+
|
552 |
+
```bash
|
553 |
+
python -m siclib.train deepcalib-openpano --conf deepcalib --distributed
|
554 |
+
```
|
555 |
+
|
556 |
+
Make sure that you have generated the [OpenPano Dataset](#openpano-dataset) with radial distortion or add
|
557 |
+
the flag ```data=openpano``` to the command to train on the pinhole images.
|
558 |
+
|
559 |
+
</details>
|
560 |
+
|
561 |
+
<details>
|
562 |
+
<summary>[Training Perspective Fields]</summary>
|
563 |
+
|
564 |
+
Coming soon!
|
565 |
+
|
566 |
+
</details>
|
567 |
+
|
568 |
+
## BibTeX citation
|
569 |
+
|
570 |
+
If you use any ideas from the paper or code from this repo, please consider citing:
|
571 |
+
|
572 |
+
```bibtex
|
573 |
+
@inproceedings{veicht2024geocalib,
|
574 |
+
author = {Alexander Veicht and
|
575 |
+
Paul-Edouard Sarlin and
|
576 |
+
Philipp Lindenberger and
|
577 |
+
Marc Pollefeys},
|
578 |
+
title = {{GeoCalib: Single-image Calibration with Geometric Optimization}},
|
579 |
+
booktitle = {ECCV},
|
580 |
+
year = {2024}
|
581 |
+
}
|
582 |
+
```
|
583 |
+
|
584 |
+
## License
|
585 |
+
|
586 |
+
The code is provided under the [Apache-2.0 License](LICENSE) while the weights of the trained model are provided under the [Creative Commons Attribution 4.0 International Public License](https://creativecommons.org/licenses/by/4.0/legalcode). Thanks to the authors of the [Laval Indoor HDR dataset](http://hdrdb.com/indoor/#presentation) for allowing this.
|
587 |
|
|
assets/fisheye-dog-pool.jpg
ADDED
![]() |
assets/fisheye-skyline.jpg
ADDED
![]() |
Git LFS Details
|
assets/pinhole-church.jpg
ADDED
![]() |
assets/pinhole-garden.jpg
ADDED
![]() |
assets/teaser.gif
ADDED
![]() |
Git LFS Details
|
demo.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:115886d06182eca375251eab9e301180e86465fd0ed152e917d43d7eb4cbd722
|
3 |
+
size 13275966
|
geocalib/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from geocalib.extractor import GeoCalib # noqa
|
4 |
+
|
5 |
+
formatter = logging.Formatter(
|
6 |
+
fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
7 |
+
)
|
8 |
+
handler = logging.StreamHandler()
|
9 |
+
handler.setFormatter(formatter)
|
10 |
+
handler.setLevel(logging.INFO)
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
logger.setLevel(logging.INFO)
|
14 |
+
logger.addHandler(handler)
|
15 |
+
logger.propagate = False
|
16 |
+
|
17 |
+
__module_name__ = __name__
|
geocalib/camera.py
ADDED
@@ -0,0 +1,774 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Implementation of the pinhole, simple radial, and simple divisional camera models."""
|
2 |
+
|
3 |
+
from abc import abstractmethod
|
4 |
+
from typing import Dict, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.func import jacfwd, vmap
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
from geocalib.gravity import Gravity
|
11 |
+
from geocalib.misc import TensorWrapper, autocast
|
12 |
+
from geocalib.utils import deg2rad, focal2fov, fov2focal, rad2rotmat
|
13 |
+
|
14 |
+
# flake8: noqa: E741
|
15 |
+
# mypy: ignore-errors
|
16 |
+
|
17 |
+
|
18 |
+
class BaseCamera(TensorWrapper):
|
19 |
+
"""Camera tensor class."""
|
20 |
+
|
21 |
+
eps = 1e-3
|
22 |
+
|
23 |
+
@autocast
|
24 |
+
def __init__(self, data: torch.Tensor):
|
25 |
+
"""Camera parameters with shape (..., {w, h, fx, fy, cx, cy, *dist}).
|
26 |
+
|
27 |
+
Tensor convention: (..., {w, h, fx, fy, cx, cy, pitch, roll, *dist}) where
|
28 |
+
- w, h: image size in pixels
|
29 |
+
- fx, fy: focal lengths in pixels
|
30 |
+
- cx, cy: principal points in normalized image coordinates
|
31 |
+
- dist: distortion parameters
|
32 |
+
|
33 |
+
Args:
|
34 |
+
data (torch.Tensor): Camera parameters with shape (..., {6, 7, 8}).
|
35 |
+
"""
|
36 |
+
# w, h, fx, fy, cx, cy, dist
|
37 |
+
assert data.shape[-1] in {6, 7, 8}, data.shape
|
38 |
+
|
39 |
+
pad = data.new_zeros(data.shape[:-1] + (8 - data.shape[-1],))
|
40 |
+
data = torch.cat([data, pad], -1) if data.shape[-1] != 8 else data
|
41 |
+
super().__init__(data)
|
42 |
+
|
43 |
+
@classmethod
|
44 |
+
def from_dict(cls, param_dict: Dict[str, torch.Tensor]) -> "BaseCamera":
|
45 |
+
"""Create a Camera object from a dictionary of parameters.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
param_dict (Dict[str, torch.Tensor]): Dictionary of parameters.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
Camera: Camera object.
|
52 |
+
"""
|
53 |
+
for key, value in param_dict.items():
|
54 |
+
if not isinstance(value, torch.Tensor):
|
55 |
+
param_dict[key] = torch.tensor(value)
|
56 |
+
|
57 |
+
h, w = param_dict["height"], param_dict["width"]
|
58 |
+
cx, cy = param_dict.get("cx", w / 2), param_dict.get("cy", h / 2)
|
59 |
+
|
60 |
+
if "f" in param_dict:
|
61 |
+
f = param_dict["f"]
|
62 |
+
elif "vfov" in param_dict:
|
63 |
+
vfov = param_dict["vfov"]
|
64 |
+
f = fov2focal(vfov, h)
|
65 |
+
else:
|
66 |
+
raise ValueError("Focal length or vertical field of view must be provided.")
|
67 |
+
|
68 |
+
if "dist" in param_dict:
|
69 |
+
k1, k2 = param_dict["dist"][..., 0], param_dict["dist"][..., 1]
|
70 |
+
elif "k1_hat" in param_dict:
|
71 |
+
k1 = param_dict["k1_hat"] * (f / h) ** 2
|
72 |
+
|
73 |
+
k2 = param_dict.get("k2", torch.zeros_like(k1))
|
74 |
+
else:
|
75 |
+
k1 = param_dict.get("k1", torch.zeros_like(f))
|
76 |
+
k2 = param_dict.get("k2", torch.zeros_like(f))
|
77 |
+
|
78 |
+
fx, fy = f, f
|
79 |
+
if "scales" in param_dict:
|
80 |
+
fx = fx * param_dict["scales"][..., 0] / param_dict["scales"][..., 1]
|
81 |
+
|
82 |
+
params = torch.stack([w, h, fx, fy, cx, cy, k1, k2], dim=-1)
|
83 |
+
return cls(params)
|
84 |
+
|
85 |
+
def pinhole(self):
|
86 |
+
"""Return the pinhole camera model."""
|
87 |
+
return self.__class__(self._data[..., :6])
|
88 |
+
|
89 |
+
@property
|
90 |
+
def size(self) -> torch.Tensor:
|
91 |
+
"""Size (width height) of the images, with shape (..., 2)."""
|
92 |
+
return self._data[..., :2]
|
93 |
+
|
94 |
+
@property
|
95 |
+
def f(self) -> torch.Tensor:
|
96 |
+
"""Focal lengths (fx, fy) with shape (..., 2)."""
|
97 |
+
return self._data[..., 2:4]
|
98 |
+
|
99 |
+
@property
|
100 |
+
def vfov(self) -> torch.Tensor:
|
101 |
+
"""Vertical field of view in radians."""
|
102 |
+
return focal2fov(self.f[..., 1], self.size[..., 1])
|
103 |
+
|
104 |
+
@property
|
105 |
+
def hfov(self) -> torch.Tensor:
|
106 |
+
"""Horizontal field of view in radians."""
|
107 |
+
return focal2fov(self.f[..., 0], self.size[..., 0])
|
108 |
+
|
109 |
+
@property
|
110 |
+
def c(self) -> torch.Tensor:
|
111 |
+
"""Principal points (cx, cy) with shape (..., 2)."""
|
112 |
+
return self._data[..., 4:6]
|
113 |
+
|
114 |
+
@property
|
115 |
+
def K(self) -> torch.Tensor:
|
116 |
+
"""Returns the self intrinsic matrix with shape (..., 3, 3)."""
|
117 |
+
shape = self.shape + (3, 3)
|
118 |
+
K = self._data.new_zeros(shape)
|
119 |
+
K[..., 0, 0] = self.f[..., 0]
|
120 |
+
K[..., 1, 1] = self.f[..., 1]
|
121 |
+
K[..., 0, 2] = self.c[..., 0]
|
122 |
+
K[..., 1, 2] = self.c[..., 1]
|
123 |
+
K[..., 2, 2] = 1
|
124 |
+
return K
|
125 |
+
|
126 |
+
def update_focal(self, delta: torch.Tensor, as_log: bool = False):
|
127 |
+
"""Update the self parameters after changing the focal length."""
|
128 |
+
f = torch.exp(torch.log(self.f) + delta) if as_log else self.f + delta
|
129 |
+
|
130 |
+
# clamp focal length to a reasonable range for stability during training
|
131 |
+
min_f = fov2focal(self.new_ones(self.shape[0]) * deg2rad(150), self.size[..., 1])
|
132 |
+
max_f = fov2focal(self.new_ones(self.shape[0]) * deg2rad(5), self.size[..., 1])
|
133 |
+
min_f = min_f.unsqueeze(-1).expand(-1, 2)
|
134 |
+
max_f = max_f.unsqueeze(-1).expand(-1, 2)
|
135 |
+
f = f.clamp(min=min_f, max=max_f)
|
136 |
+
|
137 |
+
# make sure focal ration stays the same (avoid inplace operations)
|
138 |
+
fx = f[..., 1] * self.f[..., 0] / self.f[..., 1]
|
139 |
+
f = torch.stack([fx, f[..., 1]], -1)
|
140 |
+
|
141 |
+
dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape)
|
142 |
+
return self.__class__(torch.cat([self.size, f, self.c, dist], -1))
|
143 |
+
|
144 |
+
def scale(self, scales: Union[float, int, Tuple[Union[float, int]]]):
|
145 |
+
"""Update the self parameters after resizing an image."""
|
146 |
+
scales = (scales, scales) if isinstance(scales, (int, float)) else scales
|
147 |
+
s = scales if isinstance(scales, torch.Tensor) else self.new_tensor(scales)
|
148 |
+
|
149 |
+
dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape)
|
150 |
+
return self.__class__(torch.cat([self.size * s, self.f * s, self.c * s, dist], -1))
|
151 |
+
|
152 |
+
def crop(self, pad: Tuple[float]):
|
153 |
+
"""Update the self parameters after cropping an image."""
|
154 |
+
pad = pad if isinstance(pad, torch.Tensor) else self.new_tensor(pad)
|
155 |
+
size = self.size + pad.to(self.size)
|
156 |
+
c = self.c + pad.to(self.c) / 2
|
157 |
+
|
158 |
+
dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape)
|
159 |
+
return self.__class__(torch.cat([size, self.f, c, dist], -1))
|
160 |
+
|
161 |
+
@autocast
|
162 |
+
def in_image(self, p2d: torch.Tensor):
|
163 |
+
"""Check if 2D points are within the image boundaries."""
|
164 |
+
assert p2d.shape[-1] == 2
|
165 |
+
size = self.size.unsqueeze(-2)
|
166 |
+
return torch.all((p2d >= 0) & (p2d <= (size - 1)), -1)
|
167 |
+
|
168 |
+
@autocast
|
169 |
+
def project(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
|
170 |
+
"""Project 3D points into the self plane and check for visibility."""
|
171 |
+
z = p3d[..., -1]
|
172 |
+
valid = z > self.eps
|
173 |
+
z = z.clamp(min=self.eps)
|
174 |
+
p2d = p3d[..., :-1] / z.unsqueeze(-1)
|
175 |
+
return p2d, valid
|
176 |
+
|
177 |
+
def J_project(self, p3d: torch.Tensor):
|
178 |
+
"""Jacobian of the projection function."""
|
179 |
+
x, y, z = p3d[..., 0], p3d[..., 1], p3d[..., 2]
|
180 |
+
zero = torch.zeros_like(z)
|
181 |
+
z = z.clamp(min=self.eps)
|
182 |
+
J = torch.stack([1 / z, zero, -x / z**2, zero, 1 / z, -y / z**2], dim=-1)
|
183 |
+
J = J.reshape(p3d.shape[:-1] + (2, 3))
|
184 |
+
return J # N x 2 x 3
|
185 |
+
|
186 |
+
def undo_scale_crop(self, data: Dict[str, torch.Tensor]):
|
187 |
+
"""Undo transforms done during scaling and cropping."""
|
188 |
+
camera = self.crop(-data["crop_pad"]) if "crop_pad" in data else self
|
189 |
+
return camera.scale(1.0 / data["scales"])
|
190 |
+
|
191 |
+
@abstractmethod
|
192 |
+
def distort(self, pts: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
|
193 |
+
"""Distort normalized 2D coordinates and check for validity of the distortion model."""
|
194 |
+
raise NotImplementedError("distort() must be implemented.")
|
195 |
+
|
196 |
+
def J_distort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
|
197 |
+
"""Jacobian of the distortion function."""
|
198 |
+
if wrt == "scale2pts": # (..., 2)
|
199 |
+
J = [
|
200 |
+
vmap(jacfwd(lambda x: self[idx].distort(x, return_scale=True)[0]))(p2d[idx])[None]
|
201 |
+
for idx in range(p2d.shape[0])
|
202 |
+
]
|
203 |
+
|
204 |
+
return torch.cat(J, dim=0).squeeze(-3, -2)
|
205 |
+
|
206 |
+
elif wrt == "scale2dist": # (..., 1)
|
207 |
+
J = []
|
208 |
+
for idx in range(p2d.shape[0]): # loop to batch pts dimension
|
209 |
+
|
210 |
+
def func(x):
|
211 |
+
params = torch.cat([self._data[idx, :6], x[None]], -1)
|
212 |
+
return self.__class__(params).distort(p2d[idx], return_scale=True)[0]
|
213 |
+
|
214 |
+
J.append(vmap(jacfwd(func))(self[idx].dist))
|
215 |
+
|
216 |
+
return torch.cat(J, dim=0)
|
217 |
+
|
218 |
+
else:
|
219 |
+
raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
|
220 |
+
|
221 |
+
@abstractmethod
|
222 |
+
def undistort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]:
|
223 |
+
"""Undistort normalized 2D coordinates and check for validity of the distortion model."""
|
224 |
+
raise NotImplementedError("undistort() must be implemented.")
|
225 |
+
|
226 |
+
def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
|
227 |
+
"""Jacobian of the undistortion function."""
|
228 |
+
if wrt == "pts": # (..., 2, 2)
|
229 |
+
J = [
|
230 |
+
vmap(jacfwd(lambda x: self[idx].undistort(x)[0]))(p2d[idx])[None]
|
231 |
+
for idx in range(p2d.shape[0])
|
232 |
+
]
|
233 |
+
|
234 |
+
return torch.cat(J, dim=0).squeeze(-3)
|
235 |
+
|
236 |
+
elif wrt == "dist": # (..., 1)
|
237 |
+
J = []
|
238 |
+
for batch_idx in range(p2d.shape[0]): # loop to batch pts dimension
|
239 |
+
|
240 |
+
def func(x):
|
241 |
+
params = torch.cat([self._data[batch_idx, :6], x[None]], -1)
|
242 |
+
return self.__class__(params).undistort(p2d[batch_idx])[0]
|
243 |
+
|
244 |
+
J.append(vmap(jacfwd(func))(self[batch_idx].dist))
|
245 |
+
|
246 |
+
return torch.cat(J, dim=0)
|
247 |
+
else:
|
248 |
+
raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
|
249 |
+
|
250 |
+
@autocast
|
251 |
+
def up_projection_offset(self, p2d: torch.Tensor) -> torch.Tensor:
|
252 |
+
"""Compute the offset for the up-projection."""
|
253 |
+
return self.J_distort(p2d, wrt="scale2pts") # (B, N, 2)
|
254 |
+
|
255 |
+
def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
|
256 |
+
"""Jacobian of the distortion offset for up-projection."""
|
257 |
+
if wrt == "uv": # (B, N, 2, 2)
|
258 |
+
J = [
|
259 |
+
vmap(jacfwd(lambda x: self[idx].up_projection_offset(x)[0, 0]))(p2d[idx])[None]
|
260 |
+
for idx in range(p2d.shape[0])
|
261 |
+
]
|
262 |
+
|
263 |
+
return torch.cat(J, dim=0)
|
264 |
+
|
265 |
+
elif wrt == "dist": # (B, N, 2)
|
266 |
+
J = []
|
267 |
+
for batch_idx in range(p2d.shape[0]): # loop to batch pts dimension
|
268 |
+
|
269 |
+
def func(x):
|
270 |
+
params = torch.cat([self._data[batch_idx, :6], x[None]], -1)[None]
|
271 |
+
return self.__class__(params).up_projection_offset(p2d[batch_idx][None])
|
272 |
+
|
273 |
+
J.append(vmap(jacfwd(func))(self[batch_idx].dist))
|
274 |
+
|
275 |
+
return torch.cat(J, dim=0).squeeze(1)
|
276 |
+
else:
|
277 |
+
raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
|
278 |
+
|
279 |
+
@autocast
|
280 |
+
def denormalize(self, p2d: torch.Tensor) -> torch.Tensor:
|
281 |
+
"""Convert normalized 2D coordinates into pixel coordinates."""
|
282 |
+
return p2d * self.f.unsqueeze(-2) + self.c.unsqueeze(-2)
|
283 |
+
|
284 |
+
def J_denormalize(self):
|
285 |
+
"""Jacobian of the denormalization function."""
|
286 |
+
return torch.diag_embed(self.f) # ..., 2 x 2
|
287 |
+
|
288 |
+
@autocast
|
289 |
+
def normalize(self, p2d: torch.Tensor) -> torch.Tensor:
|
290 |
+
"""Convert pixel coordinates into normalized 2D coordinates."""
|
291 |
+
return (p2d - self.c.unsqueeze(-2)) / (self.f.unsqueeze(-2))
|
292 |
+
|
293 |
+
def J_normalize(self, p2d: torch.Tensor, wrt: str = "f"):
|
294 |
+
"""Jacobian of the normalization function."""
|
295 |
+
# ... x N x 2 x 2
|
296 |
+
if wrt == "f":
|
297 |
+
J_f = -(p2d - self.c.unsqueeze(-2)) / ((self.f.unsqueeze(-2)) ** 2)
|
298 |
+
return torch.diag_embed(J_f)
|
299 |
+
elif wrt == "pts":
|
300 |
+
J_pts = 1 / self.f
|
301 |
+
return torch.diag_embed(J_pts)
|
302 |
+
else:
|
303 |
+
raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
|
304 |
+
|
305 |
+
def pixel_coordinates(self) -> torch.Tensor:
|
306 |
+
"""Pixel coordinates in self frame.
|
307 |
+
|
308 |
+
Returns:
|
309 |
+
torch.Tensor: Pixel coordinates as a tensor of shape (B, h * w, 2).
|
310 |
+
"""
|
311 |
+
w, h = self.size[0].unbind(-1)
|
312 |
+
h, w = h.round().to(int), w.round().to(int)
|
313 |
+
|
314 |
+
# create grid
|
315 |
+
x = torch.arange(0, w, dtype=self.dtype, device=self.device)
|
316 |
+
y = torch.arange(0, h, dtype=self.dtype, device=self.device)
|
317 |
+
x, y = torch.meshgrid(x, y, indexing="xy")
|
318 |
+
xy = torch.stack((x, y), dim=-1).reshape(-1, 2) # shape (h * w, 2)
|
319 |
+
|
320 |
+
# add batch dimension (normalize() would broadcast but we make it explicit)
|
321 |
+
B = self.shape[0]
|
322 |
+
xy = xy.unsqueeze(0).expand(B, -1, -1) # if B > 0 else xy
|
323 |
+
|
324 |
+
return xy.to(self.device).to(self.dtype)
|
325 |
+
|
326 |
+
@autocast
|
327 |
+
def pixel_bearing_many(self, p3d: torch.Tensor) -> torch.Tensor:
|
328 |
+
"""Get the bearing vectors of pixel coordinates by normalizing them."""
|
329 |
+
return F.normalize(p3d, dim=-1)
|
330 |
+
|
331 |
+
@autocast
|
332 |
+
def world2image(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
|
333 |
+
"""Transform 3D points into 2D pixel coordinates."""
|
334 |
+
p2d, visible = self.project(p3d)
|
335 |
+
p2d, mask = self.distort(p2d)
|
336 |
+
p2d = self.denormalize(p2d)
|
337 |
+
valid = visible & mask & self.in_image(p2d)
|
338 |
+
return p2d, valid
|
339 |
+
|
340 |
+
@autocast
|
341 |
+
def J_world2image(self, p3d: torch.Tensor):
|
342 |
+
"""Jacobian of the world2image function."""
|
343 |
+
p2d_proj, valid = self.project(p3d)
|
344 |
+
|
345 |
+
J_dnorm = self.J_denormalize()
|
346 |
+
J_dist = self.J_distort(p2d_proj)
|
347 |
+
J_proj = self.J_project(p3d)
|
348 |
+
|
349 |
+
J = torch.einsum("...ij,...jk,...kl->...il", J_dnorm, J_dist, J_proj)
|
350 |
+
return J, valid
|
351 |
+
|
352 |
+
@autocast
|
353 |
+
def image2world(self, p2d: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
354 |
+
"""Transform point in the image plane to 3D world coordinates."""
|
355 |
+
p2d = self.normalize(p2d)
|
356 |
+
p2d, valid = self.undistort(p2d)
|
357 |
+
ones = p2d.new_ones(p2d.shape[:-1] + (1,))
|
358 |
+
p3d = torch.cat([p2d, ones], -1)
|
359 |
+
return p3d, valid
|
360 |
+
|
361 |
+
@autocast
|
362 |
+
def J_image2world(self, p2d: torch.Tensor, wrt: str = "f") -> Tuple[torch.Tensor, torch.Tensor]:
|
363 |
+
"""Jacobian of the image2world function."""
|
364 |
+
if wrt == "dist":
|
365 |
+
p2d_norm = self.normalize(p2d)
|
366 |
+
return self.J_undistort(p2d_norm, wrt)
|
367 |
+
elif wrt == "f":
|
368 |
+
J_norm2f = self.J_normalize(p2d, wrt)
|
369 |
+
p2d_norm = self.normalize(p2d)
|
370 |
+
J_dist2norm = self.J_undistort(p2d_norm, "pts")
|
371 |
+
|
372 |
+
return torch.einsum("...ij,...jk->...ik", J_dist2norm, J_norm2f)
|
373 |
+
else:
|
374 |
+
raise ValueError(f"Unknown wrt: {wrt}")
|
375 |
+
|
376 |
+
@autocast
|
377 |
+
def undistort_image(self, img: torch.Tensor) -> torch.Tensor:
|
378 |
+
"""Undistort an image using the distortion model."""
|
379 |
+
assert self.shape[0] == 1, "Batch size must be 1."
|
380 |
+
W, H = self.size.unbind(-1)
|
381 |
+
H, W = H.int().item(), W.int().item()
|
382 |
+
|
383 |
+
x, y = torch.meshgrid(torch.arange(0, W), torch.arange(0, H), indexing="xy")
|
384 |
+
coords = torch.stack((x, y), dim=-1).reshape(-1, 2)
|
385 |
+
|
386 |
+
p3d, _ = self.pinhole().image2world(coords.to(self.device).to(self.dtype))
|
387 |
+
p2d, _ = self.world2image(p3d)
|
388 |
+
|
389 |
+
mapx, mapy = p2d[..., 0].reshape((1, H, W)), p2d[..., 1].reshape((1, H, W))
|
390 |
+
grid = torch.stack((mapx, mapy), dim=-1)
|
391 |
+
grid = 2.0 * grid / torch.tensor([W - 1, H - 1]).to(grid) - 1
|
392 |
+
return F.grid_sample(img, grid, align_corners=True)
|
393 |
+
|
394 |
+
def get_img_from_pano(
|
395 |
+
self,
|
396 |
+
pano_img: torch.Tensor,
|
397 |
+
gravity: Gravity,
|
398 |
+
yaws: torch.Tensor = 0.0,
|
399 |
+
resize_factor: Optional[torch.Tensor] = None,
|
400 |
+
) -> torch.Tensor:
|
401 |
+
"""Render an image from a panorama.
|
402 |
+
|
403 |
+
Args:
|
404 |
+
pano_img (torch.Tensor): Panorama image of shape (3, H, W) in [0, 1].
|
405 |
+
gravity (Gravity): Gravity direction of the camera.
|
406 |
+
yaws (torch.Tensor | list, optional): Yaw angle in radians. Defaults to 0.0.
|
407 |
+
resize_factor (torch.Tensor, optional): Resize the panorama to be a multiple of the
|
408 |
+
field of view. Defaults to 1.
|
409 |
+
|
410 |
+
Returns:
|
411 |
+
torch.Tensor: Image rendered from the panorama.
|
412 |
+
"""
|
413 |
+
B = self.shape[0]
|
414 |
+
if B > 0:
|
415 |
+
assert self.size[..., 0].unique().shape[0] == 1, "All images must have the same width."
|
416 |
+
assert self.size[..., 1].unique().shape[0] == 1, "All images must have the same height."
|
417 |
+
|
418 |
+
w, h = self.size[0].unbind(-1)
|
419 |
+
h, w = h.round().to(int), w.round().to(int)
|
420 |
+
|
421 |
+
if isinstance(yaws, (int, float)):
|
422 |
+
yaws = [yaws]
|
423 |
+
if isinstance(resize_factor, (int, float)):
|
424 |
+
resize_factor = [resize_factor]
|
425 |
+
|
426 |
+
yaws = (
|
427 |
+
yaws.to(self.dtype).to(self.device)
|
428 |
+
if isinstance(yaws, torch.Tensor)
|
429 |
+
else self.new_tensor(yaws)
|
430 |
+
)
|
431 |
+
|
432 |
+
if isinstance(resize_factor, torch.Tensor):
|
433 |
+
resize_factor = resize_factor.to(self.dtype).to(self.device)
|
434 |
+
elif resize_factor is not None:
|
435 |
+
resize_factor = self.new_tensor(resize_factor)
|
436 |
+
|
437 |
+
assert isinstance(pano_img, torch.Tensor), "Panorama image must be a torch.Tensor."
|
438 |
+
pano_img = pano_img if pano_img.dim() == 4 else pano_img.unsqueeze(0) # B x H x W x 3
|
439 |
+
|
440 |
+
pano_imgs = []
|
441 |
+
for i, yaw in enumerate(yaws):
|
442 |
+
if resize_factor is not None:
|
443 |
+
# resize the panorama such that the fov of the panorama has the same height as the
|
444 |
+
# image
|
445 |
+
vfov = self.vfov[i] if B != 0 else self.vfov
|
446 |
+
scale = torch.pi / float(vfov) * float(h) / pano_img.shape[0] * resize_factor[i]
|
447 |
+
pano_shape = (int(pano_img.shape[0] * scale), int(pano_img.shape[1] * scale))
|
448 |
+
|
449 |
+
mode = "bicubic" if scale >= 1 else "area"
|
450 |
+
resized_pano = F.interpolate(pano_img, size=pano_shape, mode=mode)
|
451 |
+
else:
|
452 |
+
# make sure to copy: resized_pano = pano_img
|
453 |
+
resized_pano = pano_img
|
454 |
+
pano_shape = pano_img.shape[-2:][::-1]
|
455 |
+
|
456 |
+
pano_imgs.append((resized_pano, pano_shape))
|
457 |
+
|
458 |
+
xy = self.pixel_coordinates()
|
459 |
+
uv1, _ = self.image2world(xy)
|
460 |
+
bearings = self.pixel_bearing_many(uv1)
|
461 |
+
|
462 |
+
# rotate bearings
|
463 |
+
R_yaw = rad2rotmat(self.new_zeros(yaw.shape), self.new_zeros(yaw.shape), yaws)
|
464 |
+
rotated_bearings = bearings @ gravity.R @ R_yaw
|
465 |
+
|
466 |
+
# spherical coordinates
|
467 |
+
lon = torch.atan2(rotated_bearings[..., 0], rotated_bearings[..., 2])
|
468 |
+
lat = torch.atan2(
|
469 |
+
rotated_bearings[..., 1], torch.norm(rotated_bearings[..., [0, 2]], dim=-1)
|
470 |
+
)
|
471 |
+
|
472 |
+
images = []
|
473 |
+
for idx, (resized_pano, pano_shape) in enumerate(pano_imgs):
|
474 |
+
min_lon, max_lon = -torch.pi, torch.pi
|
475 |
+
min_lat, max_lat = -torch.pi / 2.0, torch.pi / 2.0
|
476 |
+
min_x, max_x = 0, pano_shape[0] - 1.0
|
477 |
+
min_y, max_y = 0, pano_shape[1] - 1.0
|
478 |
+
|
479 |
+
# map Spherical Coordinates to Panoramic Coordinates
|
480 |
+
nx = (lon[idx] - min_lon) / (max_lon - min_lon) * (max_x - min_x) + min_x
|
481 |
+
ny = (lat[idx] - min_lat) / (max_lat - min_lat) * (max_y - min_y) + min_y
|
482 |
+
|
483 |
+
# reshape and cast to numpy for remap
|
484 |
+
mapx, mapy = nx.reshape((1, h, w)), ny.reshape((1, h, w))
|
485 |
+
|
486 |
+
grid = torch.stack((mapx, mapy), dim=-1) # Add batch dimension
|
487 |
+
# Normalize to [-1, 1]
|
488 |
+
grid = 2.0 * grid / torch.tensor([pano_shape[-2] - 1, pano_shape[-1] - 1]).to(grid) - 1
|
489 |
+
# Apply grid sample
|
490 |
+
image = F.grid_sample(resized_pano, grid, align_corners=True)
|
491 |
+
images.append(image)
|
492 |
+
|
493 |
+
return torch.concatenate(images, 0) if B > 0 else images[0]
|
494 |
+
|
495 |
+
def __repr__(self):
|
496 |
+
"""Print the Camera object."""
|
497 |
+
return f"{self.__class__.__name__} {self.shape} {self.dtype} {self.device}"
|
498 |
+
|
499 |
+
|
500 |
+
class Pinhole(BaseCamera):
|
501 |
+
"""Implementation of the pinhole camera model.
|
502 |
+
|
503 |
+
Use this model for undistorted images.
|
504 |
+
"""
|
505 |
+
|
506 |
+
def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
|
507 |
+
"""Distort normalized 2D coordinates."""
|
508 |
+
if return_scale:
|
509 |
+
return p2d.new_ones(p2d.shape[:-1] + (1,))
|
510 |
+
|
511 |
+
return p2d, p2d.new_ones((p2d.shape[0], 1)).bool()
|
512 |
+
|
513 |
+
def J_distort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
|
514 |
+
"""Jacobian of the distortion function."""
|
515 |
+
if wrt == "pts":
|
516 |
+
return torch.eye(2, device=p2d.device, dtype=p2d.dtype).expand(p2d.shape[:-1] + (2, 2))
|
517 |
+
|
518 |
+
raise ValueError(f"Unknown wrt: {wrt}")
|
519 |
+
|
520 |
+
def undistort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]:
|
521 |
+
"""Undistort normalized 2D coordinates."""
|
522 |
+
return pts, pts.new_ones((pts.shape[0], 1)).bool()
|
523 |
+
|
524 |
+
def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
|
525 |
+
"""Jacobian of the undistortion function."""
|
526 |
+
if wrt == "pts":
|
527 |
+
return torch.eye(2, device=p2d.device, dtype=p2d.dtype).expand(p2d.shape[:-1] + (2, 2))
|
528 |
+
|
529 |
+
raise ValueError(f"Unknown wrt: {wrt}")
|
530 |
+
|
531 |
+
def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
|
532 |
+
"""Jacobian of the up-projection offset."""
|
533 |
+
if wrt == "uv":
|
534 |
+
return torch.zeros(p2d.shape[:-1] + (2, 2), device=p2d.device, dtype=p2d.dtype)
|
535 |
+
|
536 |
+
raise ValueError(f"Unknown wrt: {wrt}")
|
537 |
+
|
538 |
+
|
539 |
+
class SimpleRadial(BaseCamera):
|
540 |
+
"""Implementation of the simple radial camera model.
|
541 |
+
|
542 |
+
Use this model for weakly distorted images.
|
543 |
+
|
544 |
+
The distortion model is 1 + k1 * r^2 where r^2 = x^2 + y^2.
|
545 |
+
The undistortion model is 1 - k1 * r^2 estimated as in
|
546 |
+
"An Exact Formula for Calculating Inverse Radial Lens Distortions" by Pierre Drap.
|
547 |
+
"""
|
548 |
+
|
549 |
+
@property
|
550 |
+
def dist(self) -> torch.Tensor:
|
551 |
+
"""Distortion parameters, with shape (..., 1)."""
|
552 |
+
return self._data[..., 6:]
|
553 |
+
|
554 |
+
@property
|
555 |
+
def k1(self) -> torch.Tensor:
|
556 |
+
"""Distortion parameters, with shape (...)."""
|
557 |
+
return self._data[..., 6]
|
558 |
+
|
559 |
+
def update_dist(self, delta: torch.Tensor, dist_range: Tuple[float, float] = (-0.7, 0.7)):
|
560 |
+
"""Update the self parameters after changing the k1 distortion parameter."""
|
561 |
+
delta_dist = self.new_ones(self.dist.shape) * delta
|
562 |
+
dist = (self.dist + delta_dist).clamp(*dist_range)
|
563 |
+
data = torch.cat([self.size, self.f, self.c, dist], -1)
|
564 |
+
return self.__class__(data)
|
565 |
+
|
566 |
+
@autocast
|
567 |
+
def check_valid(self, p2d: torch.Tensor) -> torch.Tensor:
|
568 |
+
"""Check if the distorted points are valid."""
|
569 |
+
return p2d.new_ones(p2d.shape[:-1]).bool()
|
570 |
+
|
571 |
+
def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
|
572 |
+
"""Distort normalized 2D coordinates and check for validity of the distortion model."""
|
573 |
+
r2 = torch.sum(p2d**2, -1, keepdim=True)
|
574 |
+
radial = 1 + self.k1[..., None, None] * r2
|
575 |
+
|
576 |
+
if return_scale:
|
577 |
+
return radial, None
|
578 |
+
|
579 |
+
return p2d * radial, self.check_valid(p2d)
|
580 |
+
|
581 |
+
def J_distort(self, p2d: torch.Tensor, wrt: str = "pts"):
|
582 |
+
"""Jacobian of the distortion function."""
|
583 |
+
if wrt == "scale2dist": # (..., 1)
|
584 |
+
return torch.sum(p2d**2, -1, keepdim=True)
|
585 |
+
elif wrt == "scale2pts": # (..., 2)
|
586 |
+
return 2 * self.k1[..., None, None] * p2d
|
587 |
+
else:
|
588 |
+
return super().J_distort(p2d, wrt)
|
589 |
+
|
590 |
+
@autocast
|
591 |
+
def undistort(self, p2d: torch.Tensor) -> Tuple[torch.Tensor]:
|
592 |
+
"""Undistort normalized 2D coordinates and check for validity of the distortion model."""
|
593 |
+
b1 = -self.k1[..., None, None]
|
594 |
+
r2 = torch.sum(p2d**2, -1, keepdim=True)
|
595 |
+
radial = 1 + b1 * r2
|
596 |
+
return p2d * radial, self.check_valid(p2d)
|
597 |
+
|
598 |
+
@autocast
|
599 |
+
def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
|
600 |
+
"""Jacobian of the undistortion function."""
|
601 |
+
b1 = -self.k1[..., None, None]
|
602 |
+
r2 = torch.sum(p2d**2, -1, keepdim=True)
|
603 |
+
if wrt == "dist":
|
604 |
+
return -r2 * p2d
|
605 |
+
elif wrt == "pts":
|
606 |
+
radial = 1 + b1 * r2
|
607 |
+
radial_diag = torch.diag_embed(radial.expand(radial.shape[:-1] + (2,)))
|
608 |
+
ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2)
|
609 |
+
return (2 * b1[..., None] * ppT) + radial_diag
|
610 |
+
else:
|
611 |
+
return super().J_undistort(p2d, wrt)
|
612 |
+
|
613 |
+
def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
|
614 |
+
"""Jacobian of the up-projection offset."""
|
615 |
+
if wrt == "uv": # (..., 2, 2)
|
616 |
+
return torch.diag_embed((2 * self.k1[..., None, None]).expand(p2d.shape[:-1] + (2,)))
|
617 |
+
elif wrt == "dist":
|
618 |
+
return 2 * p2d # (..., 2)
|
619 |
+
else:
|
620 |
+
return super().J_up_projection_offset(p2d, wrt)
|
621 |
+
|
622 |
+
|
623 |
+
class SimpleDivisional(BaseCamera):
|
624 |
+
"""Implementation of the simple divisional camera model.
|
625 |
+
|
626 |
+
Use this model for strongly distorted images.
|
627 |
+
|
628 |
+
The distortion model is (1 - sqrt(1 - 4 * k1 * r^2)) / (2 * k1 * r^2) where r^2 = x^2 + y^2.
|
629 |
+
The undistortion model is 1 / (1 + k1 * r^2).
|
630 |
+
"""
|
631 |
+
|
632 |
+
@property
|
633 |
+
def dist(self) -> torch.Tensor:
|
634 |
+
"""Distortion parameters, with shape (..., 1)."""
|
635 |
+
return self._data[..., 6:]
|
636 |
+
|
637 |
+
@property
|
638 |
+
def k1(self) -> torch.Tensor:
|
639 |
+
"""Distortion parameters, with shape (...)."""
|
640 |
+
return self._data[..., 6]
|
641 |
+
|
642 |
+
def update_dist(self, delta: torch.Tensor, dist_range: Tuple[float, float] = (-3.0, 3.0)):
|
643 |
+
"""Update the self parameters after changing the k1 distortion parameter."""
|
644 |
+
delta_dist = self.new_ones(self.dist.shape) * delta
|
645 |
+
dist = (self.dist + delta_dist).clamp(*dist_range)
|
646 |
+
data = torch.cat([self.size, self.f, self.c, dist], -1)
|
647 |
+
return self.__class__(data)
|
648 |
+
|
649 |
+
@autocast
|
650 |
+
def check_valid(self, p2d: torch.Tensor) -> torch.Tensor:
|
651 |
+
"""Check if the distorted points are valid."""
|
652 |
+
return p2d.new_ones(p2d.shape[:-1]).bool()
|
653 |
+
|
654 |
+
def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
|
655 |
+
"""Distort normalized 2D coordinates and check for validity of the distortion model."""
|
656 |
+
r2 = torch.sum(p2d**2, -1, keepdim=True)
|
657 |
+
radial = 1 - torch.sqrt((1 - 4 * self.k1[..., None, None] * r2).clamp(min=0))
|
658 |
+
denom = 2 * self.k1[..., None, None] * r2
|
659 |
+
|
660 |
+
ones = radial.new_ones(radial.shape)
|
661 |
+
radial = torch.where(denom == 0, ones, radial / denom.masked_fill(denom == 0, 1e6))
|
662 |
+
|
663 |
+
if return_scale:
|
664 |
+
return radial, None
|
665 |
+
|
666 |
+
return p2d * radial, self.check_valid(p2d)
|
667 |
+
|
668 |
+
def J_distort(self, p2d: torch.Tensor, wrt: str = "pts"):
|
669 |
+
"""Jacobian of the distortion function."""
|
670 |
+
r2 = torch.sum(p2d**2, -1, keepdim=True)
|
671 |
+
t0 = torch.sqrt((1 - 4 * self.k1[..., None, None] * r2).clamp(min=1e-6))
|
672 |
+
if wrt == "scale2pts": # (B, N, 2)
|
673 |
+
d1 = t0 * 2 * r2
|
674 |
+
d2 = self.k1[..., None, None] * r2**2
|
675 |
+
denom = d1 * d2
|
676 |
+
return p2d * (4 * d2 - (1 - t0) * d1) / denom.masked_fill(denom == 0, 1e6)
|
677 |
+
|
678 |
+
elif wrt == "scale2dist":
|
679 |
+
d1 = 2 * self.k1[..., None, None] * t0
|
680 |
+
d2 = 2 * r2 * self.k1[..., None, None] ** 2
|
681 |
+
denom = d1 * d2
|
682 |
+
return (2 * d2 - (1 - t0) * d1) / denom.masked_fill(denom == 0, 1e6)
|
683 |
+
|
684 |
+
else:
|
685 |
+
return super().J_distort(p2d, wrt)
|
686 |
+
|
687 |
+
@autocast
|
688 |
+
def undistort(self, p2d: torch.Tensor) -> Tuple[torch.Tensor]:
|
689 |
+
"""Undistort normalized 2D coordinates and check for validity of the distortion model."""
|
690 |
+
r2 = torch.sum(p2d**2, -1, keepdim=True)
|
691 |
+
denom = 1 + self.k1[..., None, None] * r2
|
692 |
+
radial = 1 / denom.masked_fill(denom == 0, 1e6)
|
693 |
+
return p2d * radial, self.check_valid(p2d)
|
694 |
+
|
695 |
+
def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
|
696 |
+
"""Jacobian of the undistortion function."""
|
697 |
+
# return super().J_undistort(p2d, wrt)
|
698 |
+
r2 = torch.sum(p2d**2, -1, keepdim=True)
|
699 |
+
k1 = self.k1[..., None, None]
|
700 |
+
if wrt == "dist":
|
701 |
+
denom = (1 + k1 * r2) ** 2
|
702 |
+
return -r2 / denom.masked_fill(denom == 0, 1e6) * p2d
|
703 |
+
elif wrt == "pts":
|
704 |
+
t0 = 1 + k1 * r2
|
705 |
+
t0 = t0.masked_fill(t0 == 0, 1e6)
|
706 |
+
ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2)
|
707 |
+
J = torch.diag_embed((1 / t0).expand(p2d.shape[:-1] + (2,)))
|
708 |
+
return J - 2 * k1[..., None] * ppT / t0[..., None] ** 2 # (..., N, 2, 2)
|
709 |
+
|
710 |
+
else:
|
711 |
+
return super().J_undistort(p2d, wrt)
|
712 |
+
|
713 |
+
def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
|
714 |
+
"""Jacobian of the up-projection offset.
|
715 |
+
|
716 |
+
func(uv, dist) = 4 / (2 * norm2(uv)^2 * (1-4*k1*norm2(uv)^2)^0.5) * uv
|
717 |
+
- (1-(1-4*k1*norm2(uv)^2)^0.5) / (k1 * norm2(uv)^4) * uv
|
718 |
+
"""
|
719 |
+
k1 = self.k1[..., None, None]
|
720 |
+
r2 = torch.sum(p2d**2, -1, keepdim=True)
|
721 |
+
t0 = (1 - 4 * k1 * r2).clamp(min=1e-6)
|
722 |
+
t1 = torch.sqrt(t0)
|
723 |
+
if wrt == "dist":
|
724 |
+
denom = 4 * t0 ** (3 / 2)
|
725 |
+
denom = denom.masked_fill(denom == 0, 1e6)
|
726 |
+
J = 16 / denom
|
727 |
+
|
728 |
+
denom = r2 * t1 * k1
|
729 |
+
denom = denom.masked_fill(denom == 0, 1e6)
|
730 |
+
J = J - 2 / denom
|
731 |
+
|
732 |
+
denom = (r2 * k1) ** 2
|
733 |
+
denom = denom.masked_fill(denom == 0, 1e6)
|
734 |
+
J = J + (1 - t1) / denom
|
735 |
+
|
736 |
+
return J * p2d
|
737 |
+
elif wrt == "uv":
|
738 |
+
# ! unstable (gradient checker might fail), rewrite to use single division (by denom)
|
739 |
+
ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2)
|
740 |
+
|
741 |
+
denom = 2 * r2 * t1
|
742 |
+
denom = denom.masked_fill(denom == 0, 1e6)
|
743 |
+
J = torch.diag_embed((4 / denom).expand(p2d.shape[:-1] + (2,)))
|
744 |
+
|
745 |
+
denom = 4 * t1 * r2**2
|
746 |
+
denom = denom.masked_fill(denom == 0, 1e6)
|
747 |
+
J = J - 16 / denom[..., None] * ppT
|
748 |
+
|
749 |
+
denom = 4 * r2 * t0 ** (3 / 2)
|
750 |
+
denom = denom.masked_fill(denom == 0, 1e6)
|
751 |
+
J = J + (32 * k1[..., None]) / denom[..., None] * ppT
|
752 |
+
|
753 |
+
denom = r2**2 * t1
|
754 |
+
denom = denom.masked_fill(denom == 0, 1e6)
|
755 |
+
J = J - 4 / denom[..., None] * ppT
|
756 |
+
|
757 |
+
denom = k1 * r2**3
|
758 |
+
denom = denom.masked_fill(denom == 0, 1e6)
|
759 |
+
J = J + (4 * (1 - t1) / denom)[..., None] * ppT
|
760 |
+
|
761 |
+
denom = k1 * r2**2
|
762 |
+
denom = denom.masked_fill(denom == 0, 1e6)
|
763 |
+
J = J - torch.diag_embed(((1 - t1) / denom).expand(p2d.shape[:-1] + (2,)))
|
764 |
+
|
765 |
+
return J
|
766 |
+
else:
|
767 |
+
return super().J_up_projection_offset(p2d, wrt)
|
768 |
+
|
769 |
+
|
770 |
+
camera_models = {
|
771 |
+
"pinhole": Pinhole,
|
772 |
+
"simple_radial": SimpleRadial,
|
773 |
+
"simple_divisional": SimpleDivisional,
|
774 |
+
}
|
geocalib/extractor.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Simple interface for GeoCalib model."""
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Dict, Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.nn.functional import interpolate
|
9 |
+
|
10 |
+
from geocalib.camera import BaseCamera
|
11 |
+
from geocalib.geocalib import GeoCalib as Model
|
12 |
+
from geocalib.utils import ImagePreprocessor, load_image
|
13 |
+
|
14 |
+
|
15 |
+
class GeoCalib(nn.Module):
|
16 |
+
"""Simple interface for GeoCalib model."""
|
17 |
+
|
18 |
+
def __init__(self, weights: str = "pinhole"):
|
19 |
+
"""Initialize the model with optional config overrides.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
weights (str): trained variant, "pinhole" (default) or "distorted".
|
23 |
+
"""
|
24 |
+
super().__init__()
|
25 |
+
if weights == "pinhole":
|
26 |
+
url = "https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-pinhole.tar"
|
27 |
+
elif weights == "distorted":
|
28 |
+
url = (
|
29 |
+
"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-simple_radial.tar"
|
30 |
+
)
|
31 |
+
else:
|
32 |
+
raise ValueError(f"Unknown weights: {weights}")
|
33 |
+
|
34 |
+
# load checkpoint
|
35 |
+
model_dir = f"{torch.hub.get_dir()}/geocalib"
|
36 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
37 |
+
url, model_dir, map_location="cpu", file_name=f"{weights}.tar"
|
38 |
+
)
|
39 |
+
|
40 |
+
self.model = Model()
|
41 |
+
self.model.flexible_load(state_dict["model"])
|
42 |
+
self.model.eval()
|
43 |
+
|
44 |
+
self.image_processor = ImagePreprocessor({"resize": 320, "edge_divisible_by": 32})
|
45 |
+
|
46 |
+
def load_image(self, path: Path) -> torch.Tensor:
|
47 |
+
"""Load image from path."""
|
48 |
+
return load_image(path)
|
49 |
+
|
50 |
+
def _post_process(
|
51 |
+
self, camera: BaseCamera, img_data: dict[str, torch.Tensor], out: dict[str, torch.Tensor]
|
52 |
+
) -> tuple[BaseCamera, dict[str, torch.Tensor]]:
|
53 |
+
"""Post-process model output by undoing scaling and cropping."""
|
54 |
+
camera = camera.undo_scale_crop(img_data)
|
55 |
+
|
56 |
+
w, h = camera.size.unbind(-1)
|
57 |
+
h = h[0].round().int().item()
|
58 |
+
w = w[0].round().int().item()
|
59 |
+
|
60 |
+
for k in ["latitude_field", "up_field"]:
|
61 |
+
out[k] = interpolate(out[k], size=(h, w), mode="bilinear")
|
62 |
+
for k in ["up_confidence", "latitude_confidence"]:
|
63 |
+
out[k] = interpolate(out[k][:, None], size=(h, w), mode="bilinear")[:, 0]
|
64 |
+
|
65 |
+
inverse_scales = 1.0 / img_data["scales"]
|
66 |
+
zero = camera.new_zeros(camera.f.shape[0])
|
67 |
+
out["focal_uncertainty"] = out.get("focal_uncertainty", zero) * inverse_scales[1]
|
68 |
+
return camera, out
|
69 |
+
|
70 |
+
@torch.no_grad()
|
71 |
+
def calibrate(
|
72 |
+
self,
|
73 |
+
img: torch.Tensor,
|
74 |
+
camera_model: str = "pinhole",
|
75 |
+
priors: Optional[Dict[str, torch.Tensor]] = None,
|
76 |
+
shared_intrinsics: bool = False,
|
77 |
+
) -> Dict[str, torch.Tensor]:
|
78 |
+
"""Perform calibration with online resizing.
|
79 |
+
|
80 |
+
Assumes input image is in range [0, 1] and in RGB format.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
img (torch.Tensor): Input image, shape (C, H, W) or (1, C, H, W)
|
84 |
+
camera_model (str, optional): Camera model. Defaults to "pinhole".
|
85 |
+
priors (Dict[str, torch.Tensor], optional): Prior parameters. Defaults to {}.
|
86 |
+
shared_intrinsics (bool, optional): Whether to share intrinsics. Defaults to False.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
Dict[str, torch.Tensor]: camera and gravity vectors and uncertainties.
|
90 |
+
"""
|
91 |
+
if len(img.shape) == 3:
|
92 |
+
img = img[None] # add batch dim
|
93 |
+
if not shared_intrinsics:
|
94 |
+
assert len(img.shape) == 4 and img.shape[0] == 1
|
95 |
+
|
96 |
+
img_data = self.image_processor(img)
|
97 |
+
|
98 |
+
if priors is None:
|
99 |
+
priors = {}
|
100 |
+
|
101 |
+
prior_values = {}
|
102 |
+
if prior_focal := priors.get("focal"):
|
103 |
+
prior_focal = prior_focal[None] if len(prior_focal.shape) == 0 else prior_focal
|
104 |
+
prior_values["prior_focal"] = prior_focal * img_data["scales"][1]
|
105 |
+
|
106 |
+
if "gravity" in priors:
|
107 |
+
prior_gravity = priors["gravity"]
|
108 |
+
prior_gravity = prior_gravity[None] if len(prior_gravity.shape) == 0 else prior_gravity
|
109 |
+
prior_values["prior_gravity"] = prior_gravity
|
110 |
+
|
111 |
+
self.model.optimizer.set_camera_model(camera_model)
|
112 |
+
self.model.optimizer.shared_intrinsics = shared_intrinsics
|
113 |
+
|
114 |
+
out = self.model(img_data | prior_values)
|
115 |
+
|
116 |
+
camera, gravity = out["camera"], out["gravity"]
|
117 |
+
camera, out = self._post_process(camera, img_data, out)
|
118 |
+
|
119 |
+
return {
|
120 |
+
"camera": camera,
|
121 |
+
"gravity": gravity,
|
122 |
+
"covariance": out["covariance"],
|
123 |
+
**{k: out[k] for k in out.keys() if "field" in k},
|
124 |
+
**{k: out[k] for k in out.keys() if "confidence" in k},
|
125 |
+
**{k: out[k] for k in out.keys() if "uncertainty" in k},
|
126 |
+
}
|
geocalib/geocalib.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""GeoCalib model definition."""
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from typing import Dict
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
from geocalib.lm_optimizer import LMOptimizer
|
11 |
+
from geocalib.modules import MSCAN, ConvModule, LightHamHead
|
12 |
+
|
13 |
+
# mypy: ignore-errors
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
class LowLevelEncoder(nn.Module):
|
19 |
+
"""Very simple low-level encoder."""
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
"""Simple low-level encoder."""
|
23 |
+
super().__init__()
|
24 |
+
self.in_channel = 3
|
25 |
+
self.feat_dim = 64
|
26 |
+
|
27 |
+
self.conv1 = ConvModule(self.in_channel, self.feat_dim, kernel_size=3, padding=1)
|
28 |
+
self.conv2 = ConvModule(self.feat_dim, self.feat_dim, kernel_size=3, padding=1)
|
29 |
+
|
30 |
+
def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
31 |
+
"""Forward pass."""
|
32 |
+
x = data["image"]
|
33 |
+
|
34 |
+
assert (
|
35 |
+
x.shape[-1] % 32 == 0 and x.shape[-2] % 32 == 0
|
36 |
+
), "Image size must be multiple of 32 if not using single image input."
|
37 |
+
|
38 |
+
c1 = self.conv1(x)
|
39 |
+
c2 = self.conv2(c1)
|
40 |
+
|
41 |
+
return {"features": c2}
|
42 |
+
|
43 |
+
|
44 |
+
class UpDecoder(nn.Module):
|
45 |
+
"""Minimal implementation of UpDecoder."""
|
46 |
+
|
47 |
+
def __init__(self):
|
48 |
+
"""Up decoder."""
|
49 |
+
super().__init__()
|
50 |
+
self.decoder = LightHamHead()
|
51 |
+
self.linear_pred_up = nn.Conv2d(self.decoder.out_channels, 2, kernel_size=1)
|
52 |
+
|
53 |
+
def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
54 |
+
"""Forward pass."""
|
55 |
+
x, log_confidence = self.decoder(data["features"])
|
56 |
+
up = self.linear_pred_up(x)
|
57 |
+
return {"up_field": F.normalize(up, dim=1), "up_confidence": torch.sigmoid(log_confidence)}
|
58 |
+
|
59 |
+
|
60 |
+
class LatitudeDecoder(nn.Module):
|
61 |
+
"""Minimal implementation of LatitudeDecoder."""
|
62 |
+
|
63 |
+
def __init__(self):
|
64 |
+
"""Latitude decoder."""
|
65 |
+
super().__init__()
|
66 |
+
self.decoder = LightHamHead()
|
67 |
+
self.linear_pred_latitude = nn.Conv2d(self.decoder.out_channels, 1, kernel_size=1)
|
68 |
+
|
69 |
+
def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
70 |
+
"""Forward pass."""
|
71 |
+
x, log_confidence = self.decoder(data["features"])
|
72 |
+
eps = 1e-5 # avoid nan in backward of asin
|
73 |
+
lat = torch.tanh(self.linear_pred_latitude(x))
|
74 |
+
lat = torch.asin(torch.clamp(lat, -1 + eps, 1 - eps))
|
75 |
+
return {"latitude_field": lat, "latitude_confidence": torch.sigmoid(log_confidence)}
|
76 |
+
|
77 |
+
|
78 |
+
class PerspectiveDecoder(nn.Module):
|
79 |
+
"""Minimal implementation of PerspectiveDecoder."""
|
80 |
+
|
81 |
+
def __init__(self):
|
82 |
+
"""Perspective decoder wrapping up and latitude decoders."""
|
83 |
+
super().__init__()
|
84 |
+
self.up_head = UpDecoder()
|
85 |
+
self.latitude_head = LatitudeDecoder()
|
86 |
+
|
87 |
+
def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
88 |
+
"""Forward pass."""
|
89 |
+
return self.up_head(data) | self.latitude_head(data)
|
90 |
+
|
91 |
+
|
92 |
+
class GeoCalib(nn.Module):
|
93 |
+
"""GeoCalib inference model."""
|
94 |
+
|
95 |
+
def __init__(self, **optimizer_options):
|
96 |
+
"""Initialize the GeoCalib inference model.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
optimizer_options: Options for the lm optimizer.
|
100 |
+
"""
|
101 |
+
super().__init__()
|
102 |
+
self.backbone = MSCAN()
|
103 |
+
self.ll_enc = LowLevelEncoder()
|
104 |
+
self.perspective_decoder = PerspectiveDecoder()
|
105 |
+
|
106 |
+
self.optimizer = LMOptimizer({**optimizer_options})
|
107 |
+
|
108 |
+
def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
109 |
+
"""Forward pass."""
|
110 |
+
features = {"hl": self.backbone(data)["features"], "ll": self.ll_enc(data)["features"]}
|
111 |
+
out = self.perspective_decoder({"features": features})
|
112 |
+
|
113 |
+
out |= {
|
114 |
+
k: data[k]
|
115 |
+
for k in ["image", "scales", "prior_gravity", "prior_focal", "prior_k1"]
|
116 |
+
if k in data
|
117 |
+
}
|
118 |
+
|
119 |
+
out |= self.optimizer(out)
|
120 |
+
|
121 |
+
return out
|
122 |
+
|
123 |
+
def flexible_load(self, state_dict: Dict[str, torch.Tensor]) -> None:
|
124 |
+
"""Load a checkpoint with flexible key names."""
|
125 |
+
dict_params = set(state_dict.keys())
|
126 |
+
model_params = set(map(lambda n: n[0], self.named_parameters()))
|
127 |
+
|
128 |
+
if dict_params == model_params: # perfect fit
|
129 |
+
logger.info("Loading all parameters of the checkpoint.")
|
130 |
+
self.load_state_dict(state_dict, strict=True)
|
131 |
+
return
|
132 |
+
elif len(dict_params & model_params) == 0: # perfect mismatch
|
133 |
+
strip_prefix = lambda x: ".".join(x.split(".")[:1] + x.split(".")[2:])
|
134 |
+
state_dict = {strip_prefix(n): p for n, p in state_dict.items()}
|
135 |
+
dict_params = set(state_dict.keys())
|
136 |
+
if len(dict_params & model_params) == 0:
|
137 |
+
raise ValueError(
|
138 |
+
"Could not manage to load the checkpoint with"
|
139 |
+
"parameters:" + "\n\t".join(sorted(dict_params))
|
140 |
+
)
|
141 |
+
common_params = dict_params & model_params
|
142 |
+
left_params = dict_params - model_params
|
143 |
+
left_params = [
|
144 |
+
p for p in left_params if "running" not in p and "num_batches_tracked" not in p
|
145 |
+
]
|
146 |
+
logger.debug("Loading parameters:\n\t" + "\n\t".join(sorted(common_params)))
|
147 |
+
if left_params:
|
148 |
+
# ignore running stats of batchnorm
|
149 |
+
logger.warning("Could not load parameters:\n\t" + "\n\t".join(sorted(left_params)))
|
150 |
+
self.load_state_dict(state_dict, strict=False)
|
geocalib/gravity.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Tensor class for gravity vector in camera frame."""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from geocalib.misc import EuclideanManifold, SphericalManifold, TensorWrapper, autocast
|
7 |
+
from geocalib.utils import rad2rotmat
|
8 |
+
|
9 |
+
# mypy: ignore-errors
|
10 |
+
|
11 |
+
|
12 |
+
class Gravity(TensorWrapper):
|
13 |
+
"""Gravity vector in camera frame."""
|
14 |
+
|
15 |
+
eps = 1e-4
|
16 |
+
|
17 |
+
@autocast
|
18 |
+
def __init__(self, data: torch.Tensor) -> None:
|
19 |
+
"""Create gravity vector from data.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
data (torch.Tensor): gravity vector as 3D vector in camera frame.
|
23 |
+
"""
|
24 |
+
assert data.shape[-1] == 3, data.shape
|
25 |
+
|
26 |
+
data = F.normalize(data, dim=-1)
|
27 |
+
|
28 |
+
super().__init__(data)
|
29 |
+
|
30 |
+
@classmethod
|
31 |
+
def from_rp(cls, roll: torch.Tensor, pitch: torch.Tensor) -> "Gravity":
|
32 |
+
"""Create gravity vector from roll and pitch angles."""
|
33 |
+
if not isinstance(roll, torch.Tensor):
|
34 |
+
roll = torch.tensor(roll)
|
35 |
+
if not isinstance(pitch, torch.Tensor):
|
36 |
+
pitch = torch.tensor(pitch)
|
37 |
+
|
38 |
+
sr, cr = torch.sin(roll), torch.cos(roll)
|
39 |
+
sp, cp = torch.sin(pitch), torch.cos(pitch)
|
40 |
+
return cls(torch.stack([-sr * cp, -cr * cp, sp], dim=-1))
|
41 |
+
|
42 |
+
@property
|
43 |
+
def vec3d(self) -> torch.Tensor:
|
44 |
+
"""Return the gravity vector in the representation."""
|
45 |
+
return self._data
|
46 |
+
|
47 |
+
@property
|
48 |
+
def x(self) -> torch.Tensor:
|
49 |
+
"""Return first component of the gravity vector."""
|
50 |
+
return self._data[..., 0]
|
51 |
+
|
52 |
+
@property
|
53 |
+
def y(self) -> torch.Tensor:
|
54 |
+
"""Return second component of the gravity vector."""
|
55 |
+
return self._data[..., 1]
|
56 |
+
|
57 |
+
@property
|
58 |
+
def z(self) -> torch.Tensor:
|
59 |
+
"""Return third component of the gravity vector."""
|
60 |
+
return self._data[..., 2]
|
61 |
+
|
62 |
+
@property
|
63 |
+
def roll(self) -> torch.Tensor:
|
64 |
+
"""Return the roll angle of the gravity vector."""
|
65 |
+
roll = torch.asin(-self.x / (torch.sqrt(1 - self.z**2) + self.eps))
|
66 |
+
offset = -torch.pi * torch.sign(self.x)
|
67 |
+
return torch.where(self.y < 0, roll, -roll + offset)
|
68 |
+
|
69 |
+
def J_roll(self) -> torch.Tensor:
|
70 |
+
"""Return the Jacobian of the roll angle of the gravity vector."""
|
71 |
+
cp, _ = torch.cos(self.pitch), torch.sin(self.pitch)
|
72 |
+
cr, sr = torch.cos(self.roll), torch.sin(self.roll)
|
73 |
+
Jr = self.new_zeros(self.shape + (3,))
|
74 |
+
Jr[..., 0] = -cr * cp
|
75 |
+
Jr[..., 1] = sr * cp
|
76 |
+
return Jr
|
77 |
+
|
78 |
+
@property
|
79 |
+
def pitch(self) -> torch.Tensor:
|
80 |
+
"""Return the pitch angle of the gravity vector."""
|
81 |
+
return torch.asin(self.z)
|
82 |
+
|
83 |
+
def J_pitch(self) -> torch.Tensor:
|
84 |
+
"""Return the Jacobian of the pitch angle of the gravity vector."""
|
85 |
+
cp, sp = torch.cos(self.pitch), torch.sin(self.pitch)
|
86 |
+
cr, sr = torch.cos(self.roll), torch.sin(self.roll)
|
87 |
+
|
88 |
+
Jp = self.new_zeros(self.shape + (3,))
|
89 |
+
Jp[..., 0] = sr * sp
|
90 |
+
Jp[..., 1] = cr * sp
|
91 |
+
Jp[..., 2] = cp
|
92 |
+
return Jp
|
93 |
+
|
94 |
+
@property
|
95 |
+
def rp(self) -> torch.Tensor:
|
96 |
+
"""Return the roll and pitch angles of the gravity vector."""
|
97 |
+
return torch.stack([self.roll, self.pitch], dim=-1)
|
98 |
+
|
99 |
+
def J_rp(self) -> torch.Tensor:
|
100 |
+
"""Return the Jacobian of the roll and pitch angles of the gravity vector."""
|
101 |
+
return torch.stack([self.J_roll(), self.J_pitch()], dim=-1)
|
102 |
+
|
103 |
+
@property
|
104 |
+
def R(self) -> torch.Tensor:
|
105 |
+
"""Return the rotation matrix from the gravity vector."""
|
106 |
+
return rad2rotmat(roll=self.roll, pitch=self.pitch)
|
107 |
+
|
108 |
+
def J_R(self) -> torch.Tensor:
|
109 |
+
"""Return the Jacobian of the rotation matrix from the gravity vector."""
|
110 |
+
raise NotImplementedError
|
111 |
+
|
112 |
+
def update(self, delta: torch.Tensor, spherical: bool = False) -> "Gravity":
|
113 |
+
"""Update the gravity vector by adding a delta."""
|
114 |
+
if spherical:
|
115 |
+
data = SphericalManifold.plus(self.vec3d, delta)
|
116 |
+
return self.__class__(data)
|
117 |
+
|
118 |
+
data = EuclideanManifold.plus(self.rp, delta)
|
119 |
+
return self.from_rp(data[..., 0], data[..., 1])
|
120 |
+
|
121 |
+
def J_update(self, spherical: bool = False) -> torch.Tensor:
|
122 |
+
"""Return the Jacobian of the update."""
|
123 |
+
return (
|
124 |
+
SphericalManifold.J_plus(self.vec3d)
|
125 |
+
if spherical
|
126 |
+
else EuclideanManifold.J_plus(self.vec3d)
|
127 |
+
)
|
128 |
+
|
129 |
+
def __repr__(self):
|
130 |
+
"""Print the Camera object."""
|
131 |
+
return f"{self.__class__.__name__} {self.shape} {self.dtype} {self.device}"
|
geocalib/interactive_demo.py
ADDED
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import queue
|
4 |
+
import threading
|
5 |
+
import time
|
6 |
+
from time import time
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from geocalib.extractor import GeoCalib
|
14 |
+
from geocalib.perspective_fields import get_perspective_field
|
15 |
+
from geocalib.utils import get_device, rad2deg
|
16 |
+
|
17 |
+
# flake8: noqa
|
18 |
+
# mypy: ignore-errors
|
19 |
+
|
20 |
+
|
21 |
+
description = """
|
22 |
+
-------------------------
|
23 |
+
GeoCalib Interactive Demo
|
24 |
+
-------------------------
|
25 |
+
|
26 |
+
This script is an interactive demo for GeoCalib. It will open a window showing the camera feed and
|
27 |
+
the calibration results.
|
28 |
+
|
29 |
+
Arguments:
|
30 |
+
- '--camera_id': Camera ID to use. If none, will ask for ip of droidcam (https://droidcam.app)
|
31 |
+
|
32 |
+
You can toggle different features using the following keys:
|
33 |
+
|
34 |
+
- 'h': Toggle horizon line
|
35 |
+
- 'u': Toggle up vector field
|
36 |
+
- 'l': Toggle latitude heatmap
|
37 |
+
- 'c': Toggle confidence heatmap
|
38 |
+
- 'd': Toggle undistorted image
|
39 |
+
- 'g': Toggle grid of points
|
40 |
+
- 'b': Toggle box object
|
41 |
+
|
42 |
+
You can also change the camera model using the following keys:
|
43 |
+
|
44 |
+
- '1': Pinhole
|
45 |
+
- '2': Simple Radial
|
46 |
+
- '3': Simple Divisional
|
47 |
+
|
48 |
+
Press 'q' to quit the demo.
|
49 |
+
"""
|
50 |
+
|
51 |
+
|
52 |
+
# Custom VideoCapture class to get the most recent frame instead FIFO
|
53 |
+
class VideoCapture:
|
54 |
+
def __init__(self, name):
|
55 |
+
self.cap = cv2.VideoCapture(name)
|
56 |
+
self.q = queue.Queue()
|
57 |
+
t = threading.Thread(target=self._reader)
|
58 |
+
t.daemon = True
|
59 |
+
t.start()
|
60 |
+
|
61 |
+
# read frames as soon as they are available, keeping only most recent one
|
62 |
+
def _reader(self):
|
63 |
+
while True:
|
64 |
+
ret, frame = self.cap.read()
|
65 |
+
if not ret:
|
66 |
+
break
|
67 |
+
if not self.q.empty():
|
68 |
+
try:
|
69 |
+
self.q.get_nowait() # discard previous (unprocessed) frame
|
70 |
+
except queue.Empty:
|
71 |
+
pass
|
72 |
+
self.q.put(frame)
|
73 |
+
|
74 |
+
def read(self):
|
75 |
+
return 1, self.q.get()
|
76 |
+
|
77 |
+
def isOpened(self):
|
78 |
+
return self.cap.isOpened()
|
79 |
+
|
80 |
+
|
81 |
+
def add_text(frame, text, align_left=True, align_top=True):
|
82 |
+
"""Add text to a plot."""
|
83 |
+
h, w = frame.shape[:2]
|
84 |
+
sc = min(h / 640.0, 2.0)
|
85 |
+
Ht = int(40 * sc) # text height
|
86 |
+
|
87 |
+
for i, l in enumerate(text.split("\n")):
|
88 |
+
max_line = len(max([l for l in text.split("\n")], key=len))
|
89 |
+
x = int(8 * sc if align_left else w - (max_line) * sc * 18)
|
90 |
+
y = Ht * (i + 1) if align_top else h - Ht * (len(text.split("\n")) - i - 1) - int(8 * sc)
|
91 |
+
|
92 |
+
c_back, c_front = (0, 0, 0), (255, 255, 255)
|
93 |
+
font, style = cv2.FONT_HERSHEY_DUPLEX, cv2.LINE_AA
|
94 |
+
cv2.putText(frame, l, (x, y), font, 1.0 * sc, c_back, int(6 * sc), style)
|
95 |
+
cv2.putText(frame, l, (x, y), font, 1.0 * sc, c_front, int(1 * sc), style)
|
96 |
+
return frame
|
97 |
+
|
98 |
+
|
99 |
+
def is_corner(p, h, w):
|
100 |
+
"""Check if a point is a corner."""
|
101 |
+
return p in [(0, 0), (0, h - 1), (w - 1, 0), (w - 1, h - 1)]
|
102 |
+
|
103 |
+
|
104 |
+
def plot_latitude(frame, latitude):
|
105 |
+
"""Plot latitude heatmap."""
|
106 |
+
if not isinstance(latitude, np.ndarray):
|
107 |
+
latitude = latitude.cpu().numpy()
|
108 |
+
|
109 |
+
cmap = plt.get_cmap("seismic")
|
110 |
+
h, w = frame.shape[0], frame.shape[1]
|
111 |
+
sc = min(h / 640.0, 2.0)
|
112 |
+
|
113 |
+
vmin, vmax = -90, 90
|
114 |
+
latitude = (latitude - vmin) / (vmax - vmin)
|
115 |
+
|
116 |
+
colors = (cmap(latitude)[..., :3] * 255).astype(np.uint8)[..., ::-1]
|
117 |
+
frame = cv2.addWeighted(frame, 1 - 0.4, colors, 0.4, 0)
|
118 |
+
|
119 |
+
for contour_line in np.linspace(vmin, vmax, 15):
|
120 |
+
contour_line = (contour_line - vmin) / (vmax - vmin)
|
121 |
+
|
122 |
+
mask = (latitude > contour_line).astype(np.uint8)
|
123 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
124 |
+
|
125 |
+
for contour in contours:
|
126 |
+
color = (np.array(cmap(contour_line))[:3] * 255).astype(np.uint8)[::-1]
|
127 |
+
|
128 |
+
# remove corners
|
129 |
+
contour = [p for p in contour if not is_corner(tuple(p[0]), h, w)]
|
130 |
+
for index, item in enumerate(contour[:-1]):
|
131 |
+
cv2.line(frame, item[0], contour[index + 1][0], color.tolist(), int(5 * sc))
|
132 |
+
|
133 |
+
return frame
|
134 |
+
|
135 |
+
|
136 |
+
def draw_horizon_line(frame, heatmap):
|
137 |
+
"""Draw a horizon line."""
|
138 |
+
if not isinstance(heatmap, np.ndarray):
|
139 |
+
heatmap = heatmap.cpu().numpy()
|
140 |
+
|
141 |
+
h, w = frame.shape[0], frame.shape[1]
|
142 |
+
sc = min(h / 640.0, 2.0)
|
143 |
+
|
144 |
+
color = (0, 255, 255)
|
145 |
+
vmin, vmax = -90, 90
|
146 |
+
heatmap = (heatmap - vmin) / (vmax - vmin)
|
147 |
+
|
148 |
+
contours, _ = cv2.findContours(
|
149 |
+
(heatmap > 0.5).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
150 |
+
)
|
151 |
+
if contours:
|
152 |
+
contour = [p for p in contours[0] if not is_corner(tuple(p[0]), h, w)]
|
153 |
+
for index, item in enumerate(contour[:-1]):
|
154 |
+
cv2.line(frame, item[0], contour[index + 1][0], color, int(5 * sc))
|
155 |
+
return frame
|
156 |
+
|
157 |
+
|
158 |
+
def plot_confidence(frame, confidence):
|
159 |
+
"""Plot confidence heatmap."""
|
160 |
+
if not isinstance(confidence, np.ndarray):
|
161 |
+
confidence = confidence.cpu().numpy()
|
162 |
+
|
163 |
+
confidence = np.log10(confidence.clip(1e-6)).clip(-4)
|
164 |
+
confidence = (confidence - confidence.min()) / (confidence.max() - confidence.min())
|
165 |
+
|
166 |
+
cmap = plt.get_cmap("turbo")
|
167 |
+
colors = (cmap(confidence)[..., :3] * 255).astype(np.uint8)[..., ::-1]
|
168 |
+
return cv2.addWeighted(frame, 1 - 0.4, colors, 0.4, 0)
|
169 |
+
|
170 |
+
|
171 |
+
def plot_vector_field(frame, vector_field):
|
172 |
+
"""Plot a vector field."""
|
173 |
+
if not isinstance(vector_field, np.ndarray):
|
174 |
+
vector_field = vector_field.cpu().numpy()
|
175 |
+
|
176 |
+
H, W = frame.shape[:2]
|
177 |
+
sc = min(H / 640.0, 2.0)
|
178 |
+
|
179 |
+
subsample = min(W, H) // 10
|
180 |
+
offset_x = ((W % subsample) + subsample) // 2
|
181 |
+
samples_x = np.arange(offset_x, W, subsample)
|
182 |
+
samples_y = np.arange(int(subsample * 0.9), H, subsample)
|
183 |
+
|
184 |
+
vec_len = 40 * sc
|
185 |
+
x_grid, y_grid = np.meshgrid(samples_x, samples_y)
|
186 |
+
x, y = vector_field[:, samples_y][:, :, samples_x]
|
187 |
+
for xi, yi, xi_dir, yi_dir in zip(x_grid.ravel(), y_grid.ravel(), x.ravel(), y.ravel()):
|
188 |
+
start = (xi, yi)
|
189 |
+
end = (int(xi + xi_dir * vec_len), int(yi + yi_dir * vec_len))
|
190 |
+
cv2.arrowedLine(
|
191 |
+
frame, start, end, (0, 255, 0), int(5 * sc), line_type=cv2.LINE_AA, tipLength=0.3
|
192 |
+
)
|
193 |
+
|
194 |
+
return frame
|
195 |
+
|
196 |
+
|
197 |
+
def plot_box(frame, gravity, camera):
|
198 |
+
"""Plot a box object."""
|
199 |
+
pts = np.array(
|
200 |
+
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]]
|
201 |
+
)
|
202 |
+
pts = pts - np.array([0.5, 1, 0.5])
|
203 |
+
rotation_vec = cv2.Rodrigues(gravity.R.numpy()[0])[0]
|
204 |
+
t = np.array([0, 0, 1], dtype=float)
|
205 |
+
K = camera.K[0].cpu().numpy().astype(float)
|
206 |
+
dist = np.zeros(4, dtype=float)
|
207 |
+
axis_points, _ = cv2.projectPoints(
|
208 |
+
0.1 * pts.reshape(-1, 3).astype(float), rotation_vec, t, K, dist
|
209 |
+
)
|
210 |
+
|
211 |
+
h = frame.shape[0]
|
212 |
+
sc = min(h / 640.0, 2.0)
|
213 |
+
|
214 |
+
color = (85, 108, 228)
|
215 |
+
for p in axis_points:
|
216 |
+
center = tuple((int(p[0][0]), int(p[0][1])))
|
217 |
+
frame = cv2.circle(frame, center, 10, color, -1, cv2.LINE_AA)
|
218 |
+
|
219 |
+
for i in range(0, 4):
|
220 |
+
p1 = axis_points[i].astype(int)
|
221 |
+
p2 = axis_points[i + 4].astype(int)
|
222 |
+
frame = cv2.line(frame, tuple(p1[0]), tuple(p2[0]), color, int(5 * sc), cv2.LINE_AA)
|
223 |
+
|
224 |
+
p1 = axis_points[i].astype(int)
|
225 |
+
p2 = axis_points[(i + 1) % 4].astype(int)
|
226 |
+
frame = cv2.line(frame, tuple(p1[0]), tuple(p2[0]), color, int(5 * sc), cv2.LINE_AA)
|
227 |
+
|
228 |
+
p1 = axis_points[i + 4].astype(int)
|
229 |
+
p2 = axis_points[(i + 1) % 4 + 4].astype(int)
|
230 |
+
frame = cv2.line(frame, tuple(p1[0]), tuple(p2[0]), color, int(5 * sc), cv2.LINE_AA)
|
231 |
+
|
232 |
+
return frame
|
233 |
+
|
234 |
+
|
235 |
+
def plot_grid(frame, gravity, camera, grid_size=0.2, num_points=5):
|
236 |
+
"""Plot a grid of points."""
|
237 |
+
h = frame.shape[0]
|
238 |
+
sc = min(h / 640.0, 2.0)
|
239 |
+
|
240 |
+
samples = np.linspace(-grid_size, grid_size, num_points)
|
241 |
+
xz = np.meshgrid(samples, samples)
|
242 |
+
pts = np.stack((xz[0].ravel(), np.zeros_like(xz[0].ravel()), xz[1].ravel()), axis=-1)
|
243 |
+
|
244 |
+
# project points
|
245 |
+
rotation_vec = cv2.Rodrigues(gravity.R.numpy()[0])[0]
|
246 |
+
t = np.array([0, 0, 1], dtype=float)
|
247 |
+
K = camera.K[0].cpu().numpy().astype(float)
|
248 |
+
dist = np.zeros(4, dtype=float)
|
249 |
+
axis_points, _ = cv2.projectPoints(pts.reshape(-1, 3).astype(float), rotation_vec, t, K, dist)
|
250 |
+
|
251 |
+
color = (192, 77, 58)
|
252 |
+
# draw points
|
253 |
+
for p in axis_points:
|
254 |
+
center = tuple((int(p[0][0]), int(p[0][1])))
|
255 |
+
frame = cv2.circle(frame, center, 10, color, -1, cv2.LINE_AA)
|
256 |
+
|
257 |
+
# draw lines
|
258 |
+
for i in range(num_points):
|
259 |
+
for j in range(num_points - 1):
|
260 |
+
p1 = axis_points[i * num_points + j].astype(int)
|
261 |
+
p2 = axis_points[i * num_points + j + 1].astype(int)
|
262 |
+
frame = cv2.line(frame, tuple(p1[0]), tuple(p2[0]), color, int(5 * sc), cv2.LINE_AA)
|
263 |
+
|
264 |
+
p1 = axis_points[j * num_points + i].astype(int)
|
265 |
+
p2 = axis_points[(j + 1) * num_points + i].astype(int)
|
266 |
+
frame = cv2.line(frame, tuple(p1[0]), tuple(p2[0]), color, int(5 * sc), cv2.LINE_AA)
|
267 |
+
|
268 |
+
return frame
|
269 |
+
|
270 |
+
|
271 |
+
def undistort_image(img, camera, padding=0.3):
|
272 |
+
"""Undistort an image."""
|
273 |
+
W, H = camera.size.unbind(-1)
|
274 |
+
H, W = H.int().item(), W.int().item()
|
275 |
+
|
276 |
+
pad_h, pad_w = int(H * padding), int(W * padding)
|
277 |
+
x, y = torch.meshgrid(torch.arange(0, W + pad_w), torch.arange(0, H + pad_h), indexing="xy")
|
278 |
+
coords = torch.stack((x, y), dim=-1).reshape(-1, 2) - torch.tensor([pad_w / 2, pad_h / 2])
|
279 |
+
|
280 |
+
p3d, _ = camera.pinhole().image2world(coords.to(camera.device).to(camera.dtype))
|
281 |
+
p2d, _ = camera.world2image(p3d)
|
282 |
+
|
283 |
+
p2d = p2d.float().numpy().reshape(H + pad_h, W + pad_w, 2)
|
284 |
+
img = cv2.remap(img, p2d[..., 0], p2d[..., 1], cv2.INTER_LINEAR, borderValue=(254, 254, 254))
|
285 |
+
return cv2.resize(img, (W, H))
|
286 |
+
|
287 |
+
|
288 |
+
class InteractiveDemo:
|
289 |
+
def __init__(self, capture: VideoCapture, device: str) -> None:
|
290 |
+
self.cap = capture
|
291 |
+
|
292 |
+
self.device = torch.device(device)
|
293 |
+
self.model = GeoCalib().to(device)
|
294 |
+
|
295 |
+
self.up_toggle = False
|
296 |
+
self.lat_toggle = False
|
297 |
+
self.conf_toggle = False
|
298 |
+
|
299 |
+
self.hl_toggle = False
|
300 |
+
self.grid_toggle = False
|
301 |
+
self.box_toggle = False
|
302 |
+
|
303 |
+
self.undist_toggle = False
|
304 |
+
|
305 |
+
self.camera_model = "pinhole"
|
306 |
+
|
307 |
+
def render_frame(self, frame, calibration):
|
308 |
+
"""Render the frame with the calibration results."""
|
309 |
+
camera, gravity = calibration["camera"].cpu(), calibration["gravity"].cpu()
|
310 |
+
|
311 |
+
if self.undist_toggle:
|
312 |
+
return undistort_image(frame, camera)
|
313 |
+
|
314 |
+
up, lat = get_perspective_field(camera, gravity)
|
315 |
+
|
316 |
+
if gravity.pitch[0] > 0:
|
317 |
+
frame = plot_box(frame, gravity, camera) if self.box_toggle else frame
|
318 |
+
frame = plot_grid(frame, gravity, camera) if self.grid_toggle else frame
|
319 |
+
else:
|
320 |
+
frame = plot_grid(frame, gravity, camera) if self.grid_toggle else frame
|
321 |
+
frame = plot_box(frame, gravity, camera) if self.box_toggle else frame
|
322 |
+
|
323 |
+
frame = draw_horizon_line(frame, lat[0, 0]) if self.hl_toggle else frame
|
324 |
+
|
325 |
+
if self.conf_toggle and self.up_toggle:
|
326 |
+
frame = plot_confidence(frame, calibration["up_confidence"][0])
|
327 |
+
frame = plot_vector_field(frame, up[0]) if self.up_toggle else frame
|
328 |
+
|
329 |
+
if self.conf_toggle and self.lat_toggle:
|
330 |
+
frame = plot_confidence(frame, calibration["latitude_confidence"][0])
|
331 |
+
frame = plot_latitude(frame, rad2deg(lat)[0, 0]) if self.lat_toggle else frame
|
332 |
+
|
333 |
+
return frame
|
334 |
+
|
335 |
+
def format_results(self, calibration):
|
336 |
+
"""Format the calibration results."""
|
337 |
+
camera, gravity = calibration["camera"].cpu(), calibration["gravity"].cpu()
|
338 |
+
|
339 |
+
vfov, focal = camera.vfov[0].item(), camera.f[0, 0].item()
|
340 |
+
fov_unc = rad2deg(calibration["vfov_uncertainty"].item())
|
341 |
+
f_unc = calibration["focal_uncertainty"].item()
|
342 |
+
|
343 |
+
roll, pitch = gravity.rp[0].unbind(-1)
|
344 |
+
roll, pitch, vfov = rad2deg(roll), rad2deg(pitch), rad2deg(vfov)
|
345 |
+
roll_unc = rad2deg(calibration["roll_uncertainty"].item())
|
346 |
+
pitch_unc = rad2deg(calibration["pitch_uncertainty"].item())
|
347 |
+
|
348 |
+
text = f"{self.camera_model.replace('_', ' ').title()}\n"
|
349 |
+
text += f"Roll: {roll:.2f} (+- {roll_unc:.2f})\n"
|
350 |
+
text += f"Pitch: {pitch:.2f} (+- {pitch_unc:.2f})\n"
|
351 |
+
text += f"vFoV: {vfov:.2f} (+- {fov_unc:.2f})\n"
|
352 |
+
text += f"Focal: {focal:.2f} (+- {f_unc:.2f})"
|
353 |
+
|
354 |
+
if hasattr(camera, "k1"):
|
355 |
+
text += f"\nK1: {camera.k1[0].item():.2f}"
|
356 |
+
|
357 |
+
return text
|
358 |
+
|
359 |
+
def update_toggles(self):
|
360 |
+
"""Update the toggles."""
|
361 |
+
key = cv2.waitKey(100) & 0xFF
|
362 |
+
if key == ord("h"):
|
363 |
+
self.hl_toggle = not self.hl_toggle
|
364 |
+
elif key == ord("u"):
|
365 |
+
self.up_toggle = not self.up_toggle
|
366 |
+
elif key == ord("l"):
|
367 |
+
self.lat_toggle = not self.lat_toggle
|
368 |
+
elif key == ord("c"):
|
369 |
+
self.conf_toggle = not self.conf_toggle
|
370 |
+
elif key == ord("d"):
|
371 |
+
self.undist_toggle = not self.undist_toggle
|
372 |
+
elif key == ord("g"):
|
373 |
+
self.grid_toggle = not self.grid_toggle
|
374 |
+
elif key == ord("b"):
|
375 |
+
self.box_toggle = not self.box_toggle
|
376 |
+
|
377 |
+
elif key == ord("1"):
|
378 |
+
self.camera_model = "pinhole"
|
379 |
+
elif key == ord("2"):
|
380 |
+
self.camera_model = "simple_radial"
|
381 |
+
elif key == ord("3"):
|
382 |
+
self.camera_model = "simple_divisional"
|
383 |
+
|
384 |
+
elif key == ord("q"):
|
385 |
+
return True
|
386 |
+
|
387 |
+
return False
|
388 |
+
|
389 |
+
def run(self):
|
390 |
+
"""Run the interactive demo."""
|
391 |
+
while True:
|
392 |
+
start = time()
|
393 |
+
ret, frame = self.cap.read()
|
394 |
+
|
395 |
+
if not ret:
|
396 |
+
print("Error: Failed to retrieve frame.")
|
397 |
+
break
|
398 |
+
|
399 |
+
# create tensor from frame
|
400 |
+
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
401 |
+
img = torch.tensor(img).permute(2, 0, 1) / 255.0
|
402 |
+
|
403 |
+
calibration = self.model.calibrate(img.to(self.device), camera_model=self.camera_model)
|
404 |
+
|
405 |
+
# render results to the frame
|
406 |
+
frame = self.render_frame(frame, calibration)
|
407 |
+
frame = add_text(frame, self.format_results(calibration))
|
408 |
+
|
409 |
+
end = time()
|
410 |
+
frame = add_text(
|
411 |
+
frame, f"FPS: {1 / (end - start):04.1f}", align_left=False, align_top=False
|
412 |
+
)
|
413 |
+
|
414 |
+
cv2.imshow("GeoCalib Demo", frame)
|
415 |
+
|
416 |
+
if self.update_toggles():
|
417 |
+
break
|
418 |
+
|
419 |
+
|
420 |
+
def main():
|
421 |
+
parser = argparse.ArgumentParser()
|
422 |
+
parser.add_argument(
|
423 |
+
"--camera_id",
|
424 |
+
type=int,
|
425 |
+
default=None,
|
426 |
+
help="Camera ID to use. If none, will ask for ip of droidcam.",
|
427 |
+
)
|
428 |
+
args = parser.parse_args()
|
429 |
+
|
430 |
+
print(description)
|
431 |
+
|
432 |
+
device = get_device()
|
433 |
+
print(f"Running on: {device}")
|
434 |
+
|
435 |
+
# setup video capture
|
436 |
+
if args.camera_id is not None:
|
437 |
+
cap = VideoCapture(args.camera_id)
|
438 |
+
else:
|
439 |
+
ip = input("Enter the IP address of the camera: ")
|
440 |
+
cap = VideoCapture(f"http://{ip}:4747/video/force/1920x1080")
|
441 |
+
|
442 |
+
if not cap.isOpened():
|
443 |
+
raise ValueError("Error: Could not open camera.")
|
444 |
+
|
445 |
+
demo = InteractiveDemo(cap, device)
|
446 |
+
demo.run()
|
447 |
+
|
448 |
+
|
449 |
+
if __name__ == "__main__":
|
450 |
+
main()
|
geocalib/lm_optimizer.py
ADDED
@@ -0,0 +1,642 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Implementation of the Levenberg-Marquardt optimizer for camera calibration."""
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import time
|
5 |
+
from types import SimpleNamespace
|
6 |
+
from typing import Any, Callable, Dict, Tuple
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from geocalib.camera import BaseCamera, camera_models
|
12 |
+
from geocalib.gravity import Gravity
|
13 |
+
from geocalib.misc import J_focal2fov
|
14 |
+
from geocalib.perspective_fields import J_perspective_field, get_perspective_field
|
15 |
+
from geocalib.utils import focal2fov, rad2deg
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
def get_trivial_estimation(data: Dict[str, torch.Tensor], camera_model: BaseCamera) -> BaseCamera:
|
21 |
+
"""Get initial camera for optimization with roll=0, pitch=0, vfov=0.7 * max(h, w).
|
22 |
+
|
23 |
+
Args:
|
24 |
+
data (Dict[str, torch.Tensor]): Input data dictionary.
|
25 |
+
camera_model (BaseCamera): Camera model to use.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
BaseCamera: Initial camera for optimization.
|
29 |
+
"""
|
30 |
+
"""Get initial camera for optimization with roll=0, pitch=0, vfov=0.7 * max(h, w)."""
|
31 |
+
ref = data.get("up_field", data["latitude_field"])
|
32 |
+
ref = ref.detach()
|
33 |
+
|
34 |
+
h, w = ref.shape[-2:]
|
35 |
+
batch_h, batch_w = (
|
36 |
+
ref.new_ones((ref.shape[0],)) * h,
|
37 |
+
ref.new_ones((ref.shape[0],)) * w,
|
38 |
+
)
|
39 |
+
|
40 |
+
init_r = ref.new_zeros((ref.shape[0],))
|
41 |
+
init_p = ref.new_zeros((ref.shape[0],))
|
42 |
+
|
43 |
+
focal = data.get("prior_focal", 0.7 * torch.max(batch_h, batch_w))
|
44 |
+
init_vfov = focal2fov(focal, h)
|
45 |
+
|
46 |
+
params = {"width": batch_w, "height": batch_h, "vfov": init_vfov}
|
47 |
+
params |= {"scales": data["scales"]} if "scales" in data else {}
|
48 |
+
params |= {"k1": data["prior_k1"]} if "prior_k1" in data else {}
|
49 |
+
camera = camera_model.from_dict(params)
|
50 |
+
camera = camera.float().to(ref.device)
|
51 |
+
|
52 |
+
gravity = Gravity.from_rp(init_r, init_p).float().to(ref.device)
|
53 |
+
|
54 |
+
if "prior_gravity" in data:
|
55 |
+
gravity = data["prior_gravity"].float().to(ref.device)
|
56 |
+
gravity = Gravity(gravity) if isinstance(gravity, torch.Tensor) else gravity
|
57 |
+
|
58 |
+
return camera, gravity
|
59 |
+
|
60 |
+
|
61 |
+
def scaled_loss(
|
62 |
+
x: torch.Tensor, fn: Callable, a: float
|
63 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
64 |
+
"""Apply a loss function to a tensor and pre- and post-scale it.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
x: the data tensor, should already be squared: `x = y**2`.
|
68 |
+
fn: the loss function, with signature `fn(x) -> y`.
|
69 |
+
a: the scale parameter.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
The value of the loss, and its first and second derivatives.
|
73 |
+
"""
|
74 |
+
a2 = a**2
|
75 |
+
loss, loss_d1, loss_d2 = fn(x / a2)
|
76 |
+
return loss * a2, loss_d1, loss_d2 / a2
|
77 |
+
|
78 |
+
|
79 |
+
def huber_loss(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
80 |
+
"""The classical robust Huber loss, with first and second derivatives."""
|
81 |
+
mask = x <= 1
|
82 |
+
sx = torch.sqrt(x + 1e-8) # avoid nan in backward pass
|
83 |
+
isx = torch.max(sx.new_tensor(torch.finfo(torch.float).eps), 1 / sx)
|
84 |
+
loss = torch.where(mask, x, 2 * sx - 1)
|
85 |
+
loss_d1 = torch.where(mask, torch.ones_like(x), isx)
|
86 |
+
loss_d2 = torch.where(mask, torch.zeros_like(x), -isx / (2 * x))
|
87 |
+
return loss, loss_d1, loss_d2
|
88 |
+
|
89 |
+
|
90 |
+
def early_stop(new_cost: torch.Tensor, prev_cost: torch.Tensor, atol: float, rtol: float) -> bool:
|
91 |
+
"""Early stopping criterion based on cost convergence."""
|
92 |
+
return torch.allclose(new_cost, prev_cost, atol=atol, rtol=rtol)
|
93 |
+
|
94 |
+
|
95 |
+
def update_lambda(
|
96 |
+
lamb: torch.Tensor,
|
97 |
+
prev_cost: torch.Tensor,
|
98 |
+
new_cost: torch.Tensor,
|
99 |
+
lambda_min: float = 1e-6,
|
100 |
+
lambda_max: float = 1e2,
|
101 |
+
) -> torch.Tensor:
|
102 |
+
"""Update damping factor for Levenberg-Marquardt optimization."""
|
103 |
+
new_lamb = lamb.new_zeros(lamb.shape)
|
104 |
+
new_lamb = lamb * torch.where(new_cost > prev_cost, 10, 0.1)
|
105 |
+
lamb = torch.clamp(new_lamb, lambda_min, lambda_max)
|
106 |
+
return lamb
|
107 |
+
|
108 |
+
|
109 |
+
def optimizer_step(
|
110 |
+
G: torch.Tensor, H: torch.Tensor, lambda_: torch.Tensor, eps: float = 1e-6
|
111 |
+
) -> torch.Tensor:
|
112 |
+
"""One optimization step with Gauss-Newton or Levenberg-Marquardt.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
G (torch.Tensor): Batched gradient tensor of size (..., N).
|
116 |
+
H (torch.Tensor): Batched hessian tensor of size (..., N, N).
|
117 |
+
lambda_ (torch.Tensor): Damping factor for LM (use GN if lambda_=0) with shape (B,).
|
118 |
+
eps (float, optional): Epsilon for damping. Defaults to 1e-6.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
torch.Tensor: Batched update tensor of size (..., N).
|
122 |
+
"""
|
123 |
+
diag = H.diagonal(dim1=-2, dim2=-1)
|
124 |
+
diag = diag * lambda_.unsqueeze(-1) # (B, 3)
|
125 |
+
|
126 |
+
H = H + diag.clamp(min=eps).diag_embed()
|
127 |
+
|
128 |
+
H_, G_ = H.cpu(), G.cpu()
|
129 |
+
try:
|
130 |
+
U = torch.linalg.cholesky(H_)
|
131 |
+
except RuntimeError:
|
132 |
+
logger.warning("Cholesky decomposition failed. Stopping.")
|
133 |
+
delta = H.new_zeros((H.shape[0], H.shape[-1])) # (B, 3)
|
134 |
+
else:
|
135 |
+
delta = torch.cholesky_solve(G_[..., None], U)[..., 0]
|
136 |
+
|
137 |
+
return delta.to(H.device)
|
138 |
+
|
139 |
+
|
140 |
+
# mypy: ignore-errors
|
141 |
+
class LMOptimizer(nn.Module):
|
142 |
+
"""Levenberg-Marquardt optimizer for camera calibration."""
|
143 |
+
|
144 |
+
default_conf = {
|
145 |
+
# Camera model parameters
|
146 |
+
"camera_model": "pinhole", # {"pinhole", "simple_radial", "simple_spherical"}
|
147 |
+
"shared_intrinsics": False, # share focal length across all images in batch
|
148 |
+
# LM optimizer parameters
|
149 |
+
"num_steps": 30,
|
150 |
+
"lambda_": 0.1,
|
151 |
+
"fix_lambda": False,
|
152 |
+
"early_stop": True,
|
153 |
+
"atol": 1e-8,
|
154 |
+
"rtol": 1e-8,
|
155 |
+
"use_spherical_manifold": True, # use spherical manifold for gravity optimization
|
156 |
+
"use_log_focal": True, # use log focal length for optimization
|
157 |
+
# Loss function parameters
|
158 |
+
"up_loss_fn_scale": 1e-2,
|
159 |
+
"lat_loss_fn_scale": 1e-2,
|
160 |
+
# Misc
|
161 |
+
"verbose": False,
|
162 |
+
}
|
163 |
+
|
164 |
+
def __init__(self, conf: Dict[str, Any]):
|
165 |
+
"""Initialize the LM optimizer."""
|
166 |
+
super().__init__()
|
167 |
+
self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
|
168 |
+
self.num_steps = conf.num_steps
|
169 |
+
|
170 |
+
self.set_camera_model(conf.camera_model)
|
171 |
+
self.setup_optimization_and_priors(shared_intrinsics=conf.shared_intrinsics)
|
172 |
+
|
173 |
+
def set_camera_model(self, camera_model: str) -> None:
|
174 |
+
"""Set the camera model to use for the optimization.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
camera_model (str): Camera model to use.
|
178 |
+
"""
|
179 |
+
assert (
|
180 |
+
camera_model in camera_models.keys()
|
181 |
+
), f"Unknown camera model: {camera_model} not in {camera_models.keys()}"
|
182 |
+
self.camera_model = camera_models[camera_model]
|
183 |
+
self.camera_has_distortion = hasattr(self.camera_model, "dist")
|
184 |
+
|
185 |
+
logger.debug(
|
186 |
+
f"Using camera model: {camera_model} (with distortion: {self.camera_has_distortion})"
|
187 |
+
)
|
188 |
+
|
189 |
+
def setup_optimization_and_priors(
|
190 |
+
self, data: Dict[str, torch.Tensor] = None, shared_intrinsics: bool = False
|
191 |
+
) -> None:
|
192 |
+
"""Setup the optimization and priors for the LM optimizer.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
data (Dict[str, torch.Tensor], optional): Dict potentially containing priors. Defaults
|
196 |
+
to None.
|
197 |
+
shared_intrinsics (bool, optional): Whether to share the intrinsics across the batch.
|
198 |
+
Defaults to False.
|
199 |
+
"""
|
200 |
+
if data is None:
|
201 |
+
data = {}
|
202 |
+
self.shared_intrinsics = shared_intrinsics
|
203 |
+
|
204 |
+
if shared_intrinsics: # si => must use pinhole
|
205 |
+
assert (
|
206 |
+
self.camera_model == camera_models["pinhole"]
|
207 |
+
), f"Shared intrinsics only supported with pinhole camera model: {self.camera_model}"
|
208 |
+
|
209 |
+
self.estimate_gravity = True
|
210 |
+
if "prior_gravity" in data:
|
211 |
+
self.estimate_gravity = False
|
212 |
+
logger.debug("Using provided gravity as prior.")
|
213 |
+
|
214 |
+
self.estimate_focal = True
|
215 |
+
if "prior_focal" in data:
|
216 |
+
self.estimate_focal = False
|
217 |
+
logger.debug("Using provided focal as prior.")
|
218 |
+
|
219 |
+
self.estimate_k1 = True
|
220 |
+
if "prior_k1" in data:
|
221 |
+
self.estimate_k1 = False
|
222 |
+
logger.debug("Using provided k1 as prior.")
|
223 |
+
|
224 |
+
self.gravity_delta_dims = (0, 1) if self.estimate_gravity else (-1,)
|
225 |
+
self.focal_delta_dims = (
|
226 |
+
(max(self.gravity_delta_dims) + 1,) if self.estimate_focal else (-1,)
|
227 |
+
)
|
228 |
+
self.k1_delta_dims = (max(self.focal_delta_dims) + 1,) if self.estimate_k1 else (-1,)
|
229 |
+
|
230 |
+
logger.debug(f"Camera Model: {self.camera_model}")
|
231 |
+
logger.debug(f"Optimizing gravity: {self.estimate_gravity} ({self.gravity_delta_dims})")
|
232 |
+
logger.debug(f"Optimizing focal: {self.estimate_focal} ({self.focal_delta_dims})")
|
233 |
+
logger.debug(f"Optimizing k1: {self.estimate_k1} ({self.k1_delta_dims})")
|
234 |
+
|
235 |
+
logger.debug(f"Shared intrinsics: {self.shared_intrinsics}")
|
236 |
+
|
237 |
+
def calculate_residuals(
|
238 |
+
self, camera: BaseCamera, gravity: Gravity, data: Dict[str, torch.Tensor]
|
239 |
+
) -> Dict[str, torch.Tensor]:
|
240 |
+
"""Calculate the residuals for the optimization.
|
241 |
+
|
242 |
+
Args:
|
243 |
+
camera (BaseCamera): Optimized camera.
|
244 |
+
gravity (Gravity): Optimized gravity.
|
245 |
+
data (Dict[str, torch.Tensor]): Input data containing the up and latitude fields.
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
Dict[str, torch.Tensor]: Residuals for the optimization.
|
249 |
+
"""
|
250 |
+
perspective_up, perspective_lat = get_perspective_field(camera, gravity)
|
251 |
+
perspective_lat = torch.sin(perspective_lat)
|
252 |
+
|
253 |
+
residuals = {}
|
254 |
+
if "up_field" in data:
|
255 |
+
up_residual = (data["up_field"] - perspective_up).permute(0, 2, 3, 1)
|
256 |
+
residuals["up_residual"] = up_residual.reshape(up_residual.shape[0], -1, 2)
|
257 |
+
|
258 |
+
if "latitude_field" in data:
|
259 |
+
target_lat = torch.sin(data["latitude_field"])
|
260 |
+
lat_residual = (target_lat - perspective_lat).permute(0, 2, 3, 1)
|
261 |
+
residuals["latitude_residual"] = lat_residual.reshape(lat_residual.shape[0], -1, 1)
|
262 |
+
|
263 |
+
return residuals
|
264 |
+
|
265 |
+
def calculate_costs(
|
266 |
+
self, residuals: torch.Tensor, data: Dict[str, torch.Tensor]
|
267 |
+
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
|
268 |
+
"""Calculate the costs and weights for the optimization.
|
269 |
+
|
270 |
+
Args:
|
271 |
+
residuals (torch.Tensor): Residuals for the optimization.
|
272 |
+
data (Dict[str, torch.Tensor]): Input data containing the up and latitude confidence.
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: Costs and weights for the
|
276 |
+
optimization.
|
277 |
+
"""
|
278 |
+
costs, weights = {}, {}
|
279 |
+
|
280 |
+
if "up_residual" in residuals:
|
281 |
+
up_cost = (residuals["up_residual"] ** 2).sum(dim=-1)
|
282 |
+
up_cost, up_weight, _ = scaled_loss(up_cost, huber_loss, self.conf.up_loss_fn_scale)
|
283 |
+
|
284 |
+
if "up_confidence" in data:
|
285 |
+
up_conf = data["up_confidence"].reshape(up_weight.shape[0], -1)
|
286 |
+
up_weight = up_weight * up_conf
|
287 |
+
up_cost = up_cost * up_conf
|
288 |
+
|
289 |
+
costs["up_cost"] = up_cost
|
290 |
+
weights["up_weights"] = up_weight
|
291 |
+
|
292 |
+
if "latitude_residual" in residuals:
|
293 |
+
lat_cost = (residuals["latitude_residual"] ** 2).sum(dim=-1)
|
294 |
+
lat_cost, lat_weight, _ = scaled_loss(lat_cost, huber_loss, self.conf.lat_loss_fn_scale)
|
295 |
+
|
296 |
+
if "latitude_confidence" in data:
|
297 |
+
lat_conf = data["latitude_confidence"].reshape(lat_weight.shape[0], -1)
|
298 |
+
lat_weight = lat_weight * lat_conf
|
299 |
+
lat_cost = lat_cost * lat_conf
|
300 |
+
|
301 |
+
costs["latitude_cost"] = lat_cost
|
302 |
+
weights["latitude_weights"] = lat_weight
|
303 |
+
|
304 |
+
return costs, weights
|
305 |
+
|
306 |
+
def calculate_gradient_and_hessian(
|
307 |
+
self,
|
308 |
+
J: torch.Tensor,
|
309 |
+
residuals: torch.Tensor,
|
310 |
+
weights: torch.Tensor,
|
311 |
+
shared_intrinsics: bool,
|
312 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
313 |
+
"""Calculate the gradient and Hessian for given the Jacobian, residuals, and weights.
|
314 |
+
|
315 |
+
Args:
|
316 |
+
J (torch.Tensor): Jacobian.
|
317 |
+
residuals (torch.Tensor): Residuals.
|
318 |
+
weights (torch.Tensor): Weights.
|
319 |
+
shared_intrinsics (bool): Whether to share the intrinsics across the batch.
|
320 |
+
|
321 |
+
Returns:
|
322 |
+
Tuple[torch.Tensor, torch.Tensor]: Gradient and Hessian.
|
323 |
+
"""
|
324 |
+
dims = ()
|
325 |
+
if self.estimate_gravity:
|
326 |
+
dims = (0, 1)
|
327 |
+
if self.estimate_focal:
|
328 |
+
dims += (2,)
|
329 |
+
if self.camera_has_distortion and self.estimate_k1:
|
330 |
+
dims += (3,)
|
331 |
+
assert dims, "No parameters to optimize"
|
332 |
+
|
333 |
+
J = J[..., dims]
|
334 |
+
|
335 |
+
Grad = torch.einsum("...Njk,...Nj->...Nk", J, residuals)
|
336 |
+
Grad = weights[..., None] * Grad
|
337 |
+
Grad = Grad.sum(-2) # (B, N_params)
|
338 |
+
|
339 |
+
if shared_intrinsics:
|
340 |
+
# reshape to (1, B * (N_params-1) + 1)
|
341 |
+
Grad_g = Grad[..., :2].reshape(1, -1)
|
342 |
+
Grad_f = Grad[..., 2].reshape(1, -1).sum(-1, keepdim=True)
|
343 |
+
Grad = torch.cat([Grad_g, Grad_f], dim=-1)
|
344 |
+
|
345 |
+
Hess = torch.einsum("...Njk,...Njl->...Nkl", J, J)
|
346 |
+
Hess = weights[..., None, None] * Hess
|
347 |
+
Hess = Hess.sum(-3)
|
348 |
+
|
349 |
+
if shared_intrinsics:
|
350 |
+
H_g = torch.block_diag(*list(Hess[..., :2, :2]))
|
351 |
+
J_fg = Hess[..., :2, 2].flatten()
|
352 |
+
J_gf = Hess[..., 2, :2].flatten()
|
353 |
+
J_f = Hess[..., 2, 2].sum()
|
354 |
+
dims = H_g.shape[-1] + 1
|
355 |
+
Hess = Hess.new_zeros((dims, dims), dtype=torch.float32)
|
356 |
+
Hess[:-1, :-1] = H_g
|
357 |
+
Hess[-1, :-1] = J_gf
|
358 |
+
Hess[:-1, -1] = J_fg
|
359 |
+
Hess[-1, -1] = J_f
|
360 |
+
Hess = Hess.unsqueeze(0)
|
361 |
+
|
362 |
+
return Grad, Hess
|
363 |
+
|
364 |
+
def setup_system(
|
365 |
+
self,
|
366 |
+
camera: BaseCamera,
|
367 |
+
gravity: Gravity,
|
368 |
+
residuals: Dict[str, torch.Tensor],
|
369 |
+
weights: Dict[str, torch.Tensor],
|
370 |
+
as_rpf: bool = False,
|
371 |
+
shared_intrinsics: bool = False,
|
372 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
373 |
+
"""Calculate the gradient and Hessian for the optimization.
|
374 |
+
|
375 |
+
Args:
|
376 |
+
camera (BaseCamera): Optimized camera.
|
377 |
+
gravity (Gravity): Optimized gravity.
|
378 |
+
residuals (Dict[str, torch.Tensor]): Residuals for the optimization.
|
379 |
+
weights (Dict[str, torch.Tensor]): Weights for the optimization.
|
380 |
+
as_rpf (bool, optional): Wether to calculate the gradient and Hessian with respect to
|
381 |
+
roll, pitch, and focal length. Defaults to False.
|
382 |
+
shared_intrinsics (bool, optional): Whether to share the intrinsics across the batch.
|
383 |
+
Defaults to False.
|
384 |
+
|
385 |
+
Returns:
|
386 |
+
Tuple[torch.Tensor, torch.Tensor]: Gradient and Hessian for the optimization.
|
387 |
+
"""
|
388 |
+
J_up, J_lat = J_perspective_field(
|
389 |
+
camera,
|
390 |
+
gravity,
|
391 |
+
spherical=self.conf.use_spherical_manifold and not as_rpf,
|
392 |
+
log_focal=self.conf.use_log_focal and not as_rpf,
|
393 |
+
)
|
394 |
+
|
395 |
+
J_up = J_up.reshape(J_up.shape[0], -1, J_up.shape[-2], J_up.shape[-1]) # (B, N, 2, 3)
|
396 |
+
J_lat = J_lat.reshape(J_lat.shape[0], -1, J_lat.shape[-2], J_lat.shape[-1]) # (B, N, 1, 3)
|
397 |
+
|
398 |
+
n_params = (
|
399 |
+
2 * self.estimate_gravity
|
400 |
+
+ self.estimate_focal
|
401 |
+
+ (self.camera_has_distortion and self.estimate_k1)
|
402 |
+
)
|
403 |
+
Grad = J_up.new_zeros(J_up.shape[0], n_params)
|
404 |
+
Hess = J_up.new_zeros(J_up.shape[0], n_params, n_params)
|
405 |
+
|
406 |
+
if shared_intrinsics:
|
407 |
+
N_params = Grad.shape[0] * (n_params - 1) + 1
|
408 |
+
Grad = Grad.new_zeros(1, N_params)
|
409 |
+
Hess = Hess.new_zeros(1, N_params, N_params)
|
410 |
+
|
411 |
+
if "up_residual" in residuals:
|
412 |
+
Up_Grad, Up_Hess = self.calculate_gradient_and_hessian(
|
413 |
+
J_up, residuals["up_residual"], weights["up_weights"], shared_intrinsics
|
414 |
+
)
|
415 |
+
|
416 |
+
if self.conf.verbose:
|
417 |
+
logger.info(f"Up J:\n{Up_Grad.mean(0)}")
|
418 |
+
|
419 |
+
Grad = Grad + Up_Grad
|
420 |
+
Hess = Hess + Up_Hess
|
421 |
+
|
422 |
+
if "latitude_residual" in residuals:
|
423 |
+
Lat_Grad, Lat_Hess = self.calculate_gradient_and_hessian(
|
424 |
+
J_lat,
|
425 |
+
residuals["latitude_residual"],
|
426 |
+
weights["latitude_weights"],
|
427 |
+
shared_intrinsics,
|
428 |
+
)
|
429 |
+
|
430 |
+
if self.conf.verbose:
|
431 |
+
logger.info(f"Lat J:\n{Lat_Grad.mean(0)}")
|
432 |
+
|
433 |
+
Grad = Grad + Lat_Grad
|
434 |
+
Hess = Hess + Lat_Hess
|
435 |
+
|
436 |
+
return Grad, Hess
|
437 |
+
|
438 |
+
def estimate_uncertainty(
|
439 |
+
self,
|
440 |
+
camera_opt: BaseCamera,
|
441 |
+
gravity_opt: Gravity,
|
442 |
+
errors: Dict[str, torch.Tensor],
|
443 |
+
weights: Dict[str, torch.Tensor],
|
444 |
+
) -> Dict[str, torch.Tensor]:
|
445 |
+
"""Estimate the uncertainty of the optimized camera and gravity at the final step.
|
446 |
+
|
447 |
+
Args:
|
448 |
+
camera_opt (BaseCamera): Final optimized camera.
|
449 |
+
gravity_opt (Gravity): Final optimized gravity.
|
450 |
+
errors (Dict[str, torch.Tensor]): Costs for the optimization.
|
451 |
+
weights (Dict[str, torch.Tensor]): Weights for the optimization.
|
452 |
+
|
453 |
+
Returns:
|
454 |
+
Dict[str, torch.Tensor]: Uncertainty estimates for the optimized camera and gravity.
|
455 |
+
"""
|
456 |
+
_, Hess = self.setup_system(
|
457 |
+
camera_opt, gravity_opt, errors, weights, as_rpf=True, shared_intrinsics=False
|
458 |
+
)
|
459 |
+
Cov = torch.inverse(Hess)
|
460 |
+
|
461 |
+
roll_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
|
462 |
+
pitch_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
|
463 |
+
gravity_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
|
464 |
+
if self.estimate_gravity:
|
465 |
+
roll_uncertainty = Cov[..., 0, 0]
|
466 |
+
pitch_uncertainty = Cov[..., 1, 1]
|
467 |
+
|
468 |
+
try:
|
469 |
+
delta_uncertainty = Cov[..., :2, :2]
|
470 |
+
eigenvalues = torch.linalg.eigvalsh(delta_uncertainty.cpu())
|
471 |
+
gravity_uncertainty = torch.max(eigenvalues, dim=-1).values.to(Cov.device)
|
472 |
+
except RuntimeError:
|
473 |
+
logger.warning("Could not calculate gravity uncertainty")
|
474 |
+
gravity_uncertainty = Cov.new_zeros(Cov.shape[0])
|
475 |
+
|
476 |
+
focal_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
|
477 |
+
fov_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
|
478 |
+
if self.estimate_focal:
|
479 |
+
focal_uncertainty = Cov[..., self.focal_delta_dims[0], self.focal_delta_dims[0]]
|
480 |
+
fov_uncertainty = (
|
481 |
+
J_focal2fov(camera_opt.f[..., 1], camera_opt.size[..., 1]) ** 2 * focal_uncertainty
|
482 |
+
)
|
483 |
+
|
484 |
+
return {
|
485 |
+
"covariance": Cov,
|
486 |
+
"roll_uncertainty": torch.sqrt(roll_uncertainty),
|
487 |
+
"pitch_uncertainty": torch.sqrt(pitch_uncertainty),
|
488 |
+
"gravity_uncertainty": torch.sqrt(gravity_uncertainty),
|
489 |
+
"focal_uncertainty": torch.sqrt(focal_uncertainty) / 2,
|
490 |
+
"vfov_uncertainty": torch.sqrt(fov_uncertainty / 2),
|
491 |
+
}
|
492 |
+
|
493 |
+
def update_estimate(
|
494 |
+
self, camera: BaseCamera, gravity: Gravity, delta: torch.Tensor
|
495 |
+
) -> Tuple[BaseCamera, Gravity]:
|
496 |
+
"""Update the camera and gravity estimates with the given delta.
|
497 |
+
|
498 |
+
Args:
|
499 |
+
camera (BaseCamera): Optimized camera.
|
500 |
+
gravity (Gravity): Optimized gravity.
|
501 |
+
delta (torch.Tensor): Delta to update the camera and gravity estimates.
|
502 |
+
|
503 |
+
Returns:
|
504 |
+
Tuple[BaseCamera, Gravity]: Updated camera and gravity estimates.
|
505 |
+
"""
|
506 |
+
delta_gravity = (
|
507 |
+
delta[..., self.gravity_delta_dims]
|
508 |
+
if self.estimate_gravity
|
509 |
+
else delta.new_zeros(delta.shape[:-1] + (2,))
|
510 |
+
)
|
511 |
+
new_gravity = gravity.update(delta_gravity, spherical=self.conf.use_spherical_manifold)
|
512 |
+
|
513 |
+
delta_f = (
|
514 |
+
delta[..., self.focal_delta_dims]
|
515 |
+
if self.estimate_focal
|
516 |
+
else delta.new_zeros(delta.shape[:-1] + (1,))
|
517 |
+
)
|
518 |
+
new_camera = camera.update_focal(delta_f, as_log=self.conf.use_log_focal)
|
519 |
+
|
520 |
+
delta_dist = (
|
521 |
+
delta[..., self.k1_delta_dims]
|
522 |
+
if self.camera_has_distortion and self.estimate_k1
|
523 |
+
else delta.new_zeros(delta.shape[:-1] + (1,))
|
524 |
+
)
|
525 |
+
if self.camera_has_distortion:
|
526 |
+
new_camera = new_camera.update_dist(delta_dist)
|
527 |
+
|
528 |
+
return new_camera, new_gravity
|
529 |
+
|
530 |
+
def optimize(
|
531 |
+
self,
|
532 |
+
data: Dict[str, torch.Tensor],
|
533 |
+
camera_opt: BaseCamera,
|
534 |
+
gravity_opt: Gravity,
|
535 |
+
) -> Tuple[BaseCamera, Gravity, Dict[str, torch.Tensor]]:
|
536 |
+
"""Optimize the camera and gravity estimates.
|
537 |
+
|
538 |
+
Args:
|
539 |
+
data (Dict[str, torch.Tensor]): Input data.
|
540 |
+
camera_opt (BaseCamera): Optimized camera.
|
541 |
+
gravity_opt (Gravity): Optimized gravity.
|
542 |
+
|
543 |
+
Returns:
|
544 |
+
Tuple[BaseCamera, Gravity, Dict[str, torch.Tensor]]: Optimized camera, gravity
|
545 |
+
estimates and optimization information.
|
546 |
+
"""
|
547 |
+
key = list(data.keys())[0]
|
548 |
+
B = data[key].shape[0]
|
549 |
+
|
550 |
+
lamb = data[key].new_ones(B) * self.conf.lambda_
|
551 |
+
if self.shared_intrinsics:
|
552 |
+
lamb = data[key].new_ones(1) * self.conf.lambda_
|
553 |
+
|
554 |
+
infos = {"stop_at": self.num_steps}
|
555 |
+
for i in range(self.num_steps):
|
556 |
+
if self.conf.verbose:
|
557 |
+
logger.info(f"Step {i+1}/{self.num_steps}")
|
558 |
+
|
559 |
+
errors = self.calculate_residuals(camera_opt, gravity_opt, data)
|
560 |
+
costs, weights = self.calculate_costs(errors, data)
|
561 |
+
|
562 |
+
if i == 0:
|
563 |
+
prev_cost = sum(c.mean(-1) for c in costs.values())
|
564 |
+
for k, c in costs.items():
|
565 |
+
infos[f"initial_{k}"] = c.mean(-1)
|
566 |
+
|
567 |
+
infos["initial_cost"] = prev_cost
|
568 |
+
|
569 |
+
Grad, Hess = self.setup_system(
|
570 |
+
camera_opt,
|
571 |
+
gravity_opt,
|
572 |
+
errors,
|
573 |
+
weights,
|
574 |
+
shared_intrinsics=self.shared_intrinsics,
|
575 |
+
)
|
576 |
+
delta = optimizer_step(Grad, Hess, lamb) # (B, N_params)
|
577 |
+
|
578 |
+
if self.shared_intrinsics:
|
579 |
+
delta_g = delta[..., :-1].reshape(B, 2)
|
580 |
+
delta_f = delta[..., -1].expand(B, 1)
|
581 |
+
delta = torch.cat([delta_g, delta_f], dim=-1)
|
582 |
+
|
583 |
+
# calculate new cost
|
584 |
+
camera_opt, gravity_opt = self.update_estimate(camera_opt, gravity_opt, delta)
|
585 |
+
new_cost, _ = self.calculate_costs(
|
586 |
+
self.calculate_residuals(camera_opt, gravity_opt, data), data
|
587 |
+
)
|
588 |
+
new_cost = sum(c.mean(-1) for c in new_cost.values())
|
589 |
+
|
590 |
+
if not self.conf.fix_lambda and not self.shared_intrinsics:
|
591 |
+
lamb = update_lambda(lamb, prev_cost, new_cost)
|
592 |
+
|
593 |
+
if self.conf.verbose:
|
594 |
+
logger.info(f"Cost:\nPrev: {prev_cost}\nNew: {new_cost}")
|
595 |
+
logger.info(f"Camera:\n{camera_opt._data}")
|
596 |
+
|
597 |
+
if early_stop(new_cost, prev_cost, atol=self.conf.atol, rtol=self.conf.rtol):
|
598 |
+
infos["stop_at"] = min(i + 1, infos["stop_at"])
|
599 |
+
|
600 |
+
if self.conf.early_stop:
|
601 |
+
if self.conf.verbose:
|
602 |
+
logger.info(f"Early stopping at step {i+1}")
|
603 |
+
break
|
604 |
+
|
605 |
+
prev_cost = new_cost
|
606 |
+
|
607 |
+
if i == self.num_steps - 1 and self.conf.early_stop:
|
608 |
+
logger.warning("Reached maximum number of steps without convergence.")
|
609 |
+
|
610 |
+
final_errors = self.calculate_residuals(camera_opt, gravity_opt, data) # (B, N, 3)
|
611 |
+
final_cost, weights = self.calculate_costs(final_errors, data) # (B, N)
|
612 |
+
|
613 |
+
if not self.training:
|
614 |
+
infos |= self.estimate_uncertainty(camera_opt, gravity_opt, final_errors, weights)
|
615 |
+
|
616 |
+
infos["stop_at"] = camera_opt.new_ones(camera_opt.shape[0]) * infos["stop_at"]
|
617 |
+
for k, c in final_cost.items():
|
618 |
+
infos[f"final_{k}"] = c.mean(-1)
|
619 |
+
|
620 |
+
infos["final_cost"] = sum(c.mean(-1) for c in final_cost.values())
|
621 |
+
|
622 |
+
return camera_opt, gravity_opt, infos
|
623 |
+
|
624 |
+
def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
625 |
+
"""Run the LM optimization."""
|
626 |
+
camera_init, gravity_init = get_trivial_estimation(data, self.camera_model)
|
627 |
+
|
628 |
+
self.setup_optimization_and_priors(data, shared_intrinsics=self.shared_intrinsics)
|
629 |
+
|
630 |
+
start = time.time()
|
631 |
+
camera_opt, gravity_opt, infos = self.optimize(data, camera_init, gravity_init)
|
632 |
+
|
633 |
+
if self.conf.verbose:
|
634 |
+
logger.info(f"Optimization took {(time.time() - start)*1000:.2f} ms")
|
635 |
+
|
636 |
+
logger.info(f"Initial camera:\n{rad2deg(camera_init.vfov)}")
|
637 |
+
logger.info(f"Optimized camera:\n{rad2deg(camera_opt.vfov)}")
|
638 |
+
|
639 |
+
logger.info(f"Initial gravity:\n{rad2deg(gravity_init.rp)}")
|
640 |
+
logger.info(f"Optimized gravity:\n{rad2deg(gravity_opt.rp)}")
|
641 |
+
|
642 |
+
return {"camera": camera_opt, "gravity": gravity_opt, **infos}
|
geocalib/misc.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Miscellaneous functions and classes for the geocalib_inference package."""
|
2 |
+
|
3 |
+
import functools
|
4 |
+
import inspect
|
5 |
+
import logging
|
6 |
+
from typing import Callable, List
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
# mypy: ignore-errors
|
14 |
+
|
15 |
+
|
16 |
+
def autocast(func: Callable) -> Callable:
|
17 |
+
"""Cast the inputs of a TensorWrapper method to PyTorch tensors if they are numpy arrays.
|
18 |
+
|
19 |
+
Use the device and dtype of the wrapper.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
func (Callable): Method of a TensorWrapper class.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
Callable: Wrapped method.
|
26 |
+
"""
|
27 |
+
|
28 |
+
@functools.wraps(func)
|
29 |
+
def wrap(self, *args):
|
30 |
+
device = torch.device("cpu")
|
31 |
+
dtype = None
|
32 |
+
if isinstance(self, TensorWrapper):
|
33 |
+
if self._data is not None:
|
34 |
+
device = self.device
|
35 |
+
dtype = self.dtype
|
36 |
+
elif not inspect.isclass(self) or not issubclass(self, TensorWrapper):
|
37 |
+
raise ValueError(self)
|
38 |
+
|
39 |
+
cast_args = []
|
40 |
+
for arg in args:
|
41 |
+
if isinstance(arg, np.ndarray):
|
42 |
+
arg = torch.from_numpy(arg)
|
43 |
+
arg = arg.to(device=device, dtype=dtype)
|
44 |
+
cast_args.append(arg)
|
45 |
+
return func(self, *cast_args)
|
46 |
+
|
47 |
+
return wrap
|
48 |
+
|
49 |
+
|
50 |
+
class TensorWrapper:
|
51 |
+
"""Wrapper for PyTorch tensors."""
|
52 |
+
|
53 |
+
_data = None
|
54 |
+
|
55 |
+
@autocast
|
56 |
+
def __init__(self, data: torch.Tensor):
|
57 |
+
"""Wrapper for PyTorch tensors."""
|
58 |
+
self._data = data
|
59 |
+
|
60 |
+
@property
|
61 |
+
def shape(self) -> torch.Size:
|
62 |
+
"""Shape of the underlying tensor."""
|
63 |
+
return self._data.shape[:-1]
|
64 |
+
|
65 |
+
@property
|
66 |
+
def device(self) -> torch.device:
|
67 |
+
"""Get the device of the underlying tensor."""
|
68 |
+
return self._data.device
|
69 |
+
|
70 |
+
@property
|
71 |
+
def dtype(self) -> torch.dtype:
|
72 |
+
"""Get the dtype of the underlying tensor."""
|
73 |
+
return self._data.dtype
|
74 |
+
|
75 |
+
def __getitem__(self, index) -> torch.Tensor:
|
76 |
+
"""Get the underlying tensor."""
|
77 |
+
return self.__class__(self._data[index])
|
78 |
+
|
79 |
+
def __setitem__(self, index, item):
|
80 |
+
"""Set the underlying tensor."""
|
81 |
+
self._data[index] = item.data
|
82 |
+
|
83 |
+
def to(self, *args, **kwargs):
|
84 |
+
"""Move the underlying tensor to a new device."""
|
85 |
+
return self.__class__(self._data.to(*args, **kwargs))
|
86 |
+
|
87 |
+
def cpu(self):
|
88 |
+
"""Move the underlying tensor to the CPU."""
|
89 |
+
return self.__class__(self._data.cpu())
|
90 |
+
|
91 |
+
def cuda(self):
|
92 |
+
"""Move the underlying tensor to the GPU."""
|
93 |
+
return self.__class__(self._data.cuda())
|
94 |
+
|
95 |
+
def pin_memory(self):
|
96 |
+
"""Pin the underlying tensor to memory."""
|
97 |
+
return self.__class__(self._data.pin_memory())
|
98 |
+
|
99 |
+
def float(self):
|
100 |
+
"""Cast the underlying tensor to float."""
|
101 |
+
return self.__class__(self._data.float())
|
102 |
+
|
103 |
+
def double(self):
|
104 |
+
"""Cast the underlying tensor to double."""
|
105 |
+
return self.__class__(self._data.double())
|
106 |
+
|
107 |
+
def detach(self):
|
108 |
+
"""Detach the underlying tensor."""
|
109 |
+
return self.__class__(self._data.detach())
|
110 |
+
|
111 |
+
def numpy(self):
|
112 |
+
"""Convert the underlying tensor to a numpy array."""
|
113 |
+
return self._data.detach().cpu().numpy()
|
114 |
+
|
115 |
+
def new_tensor(self, *args, **kwargs):
|
116 |
+
"""Create a new tensor of the same type and device."""
|
117 |
+
return self._data.new_tensor(*args, **kwargs)
|
118 |
+
|
119 |
+
def new_zeros(self, *args, **kwargs):
|
120 |
+
"""Create a new tensor of the same type and device."""
|
121 |
+
return self._data.new_zeros(*args, **kwargs)
|
122 |
+
|
123 |
+
def new_ones(self, *args, **kwargs):
|
124 |
+
"""Create a new tensor of the same type and device."""
|
125 |
+
return self._data.new_ones(*args, **kwargs)
|
126 |
+
|
127 |
+
def new_full(self, *args, **kwargs):
|
128 |
+
"""Create a new tensor of the same type and device."""
|
129 |
+
return self._data.new_full(*args, **kwargs)
|
130 |
+
|
131 |
+
def new_empty(self, *args, **kwargs):
|
132 |
+
"""Create a new tensor of the same type and device."""
|
133 |
+
return self._data.new_empty(*args, **kwargs)
|
134 |
+
|
135 |
+
def unsqueeze(self, *args, **kwargs):
|
136 |
+
"""Create a new tensor of the same type and device."""
|
137 |
+
return self.__class__(self._data.unsqueeze(*args, **kwargs))
|
138 |
+
|
139 |
+
def squeeze(self, *args, **kwargs):
|
140 |
+
"""Create a new tensor of the same type and device."""
|
141 |
+
return self.__class__(self._data.squeeze(*args, **kwargs))
|
142 |
+
|
143 |
+
@classmethod
|
144 |
+
def stack(cls, objects: List, dim=0, *, out=None):
|
145 |
+
"""Stack a list of objects with the same type and shape."""
|
146 |
+
data = torch.stack([obj._data for obj in objects], dim=dim, out=out)
|
147 |
+
return cls(data)
|
148 |
+
|
149 |
+
@classmethod
|
150 |
+
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
151 |
+
"""Support torch functions."""
|
152 |
+
if kwargs is None:
|
153 |
+
kwargs = {}
|
154 |
+
return cls.stack(*args, **kwargs) if func is torch.stack else NotImplemented
|
155 |
+
|
156 |
+
|
157 |
+
class EuclideanManifold:
|
158 |
+
"""Simple euclidean manifold."""
|
159 |
+
|
160 |
+
@staticmethod
|
161 |
+
def J_plus(x: torch.Tensor) -> torch.Tensor:
|
162 |
+
"""Plus operator Jacobian."""
|
163 |
+
return torch.eye(x.shape[-1]).to(x)
|
164 |
+
|
165 |
+
@staticmethod
|
166 |
+
def plus(x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
|
167 |
+
"""Plus operator."""
|
168 |
+
return x + delta
|
169 |
+
|
170 |
+
|
171 |
+
class SphericalManifold:
|
172 |
+
"""Implementation of the spherical manifold.
|
173 |
+
|
174 |
+
Following the derivation from 'Integrating Generic Sensor Fusion Algorithms with Sound State
|
175 |
+
Representations through Encapsulation of Manifolds' by Hertzberg et al. (B.2, p. 25).
|
176 |
+
|
177 |
+
Householder transformation following Algorithm 5.1.1 (p. 210) from 'Matrix Computations' by
|
178 |
+
Golub et al.
|
179 |
+
"""
|
180 |
+
|
181 |
+
@staticmethod
|
182 |
+
def householder_vector(x: torch.Tensor) -> torch.Tensor:
|
183 |
+
"""Return the Householder vector and beta.
|
184 |
+
|
185 |
+
Algorithm 5.1.1 (p. 210) from 'Matrix Computations' by Golub et al. (Johns Hopkins Studies
|
186 |
+
in Mathematical Sciences) but using the nth element of the input vector as pivot instead of
|
187 |
+
first.
|
188 |
+
|
189 |
+
This computes the vector v with v(n) = 1 and beta such that H = I - beta * v * v^T is
|
190 |
+
orthogonal and H * x = ||x||_2 * e_n.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
x (torch.Tensor): [..., n] tensor.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
torch.Tensor: v of shape [..., n]
|
197 |
+
torch.Tensor: beta of shape [...]
|
198 |
+
"""
|
199 |
+
sigma = torch.sum(x[..., :-1] ** 2, -1)
|
200 |
+
xpiv = x[..., -1]
|
201 |
+
norm = torch.norm(x, dim=-1)
|
202 |
+
if torch.any(sigma < 1e-7):
|
203 |
+
sigma = torch.where(sigma < 1e-7, sigma + 1e-7, sigma)
|
204 |
+
logger.warning("sigma < 1e-7")
|
205 |
+
|
206 |
+
vpiv = torch.where(xpiv < 0, xpiv - norm, -sigma / (xpiv + norm))
|
207 |
+
beta = 2 * vpiv**2 / (sigma + vpiv**2)
|
208 |
+
v = torch.cat([x[..., :-1] / vpiv[..., None], torch.ones_like(vpiv)[..., None]], -1)
|
209 |
+
return v, beta
|
210 |
+
|
211 |
+
@staticmethod
|
212 |
+
def apply_householder(y: torch.Tensor, v: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:
|
213 |
+
"""Apply Householder transformation.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
y (torch.Tensor): Vector to transform of shape [..., n].
|
217 |
+
v (torch.Tensor): Householder vector of shape [..., n].
|
218 |
+
beta (torch.Tensor): Householder beta of shape [...].
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
torch.Tensor: Transformed vector of shape [..., n].
|
222 |
+
"""
|
223 |
+
return y - v * (beta * torch.einsum("...i,...i->...", v, y))[..., None]
|
224 |
+
|
225 |
+
@classmethod
|
226 |
+
def J_plus(cls, x: torch.Tensor) -> torch.Tensor:
|
227 |
+
"""Plus operator Jacobian."""
|
228 |
+
v, beta = cls.householder_vector(x)
|
229 |
+
H = -torch.einsum("..., ...k, ...l->...kl", beta, v, v)
|
230 |
+
H = H + torch.eye(H.shape[-1]).to(H)
|
231 |
+
return H[..., :-1] # J
|
232 |
+
|
233 |
+
@classmethod
|
234 |
+
def plus(cls, x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
|
235 |
+
"""Plus operator.
|
236 |
+
|
237 |
+
Equation 109 (p. 25) from 'Integrating Generic Sensor Fusion Algorithms with Sound State
|
238 |
+
Representations through Encapsulation of Manifolds' by Hertzberg et al. but using the nth
|
239 |
+
element of the input vector as pivot instead of first.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
x: point on the manifold
|
243 |
+
delta: tangent vector
|
244 |
+
"""
|
245 |
+
eps = 1e-7
|
246 |
+
# keep norm is not equal to 1
|
247 |
+
nx = torch.norm(x, dim=-1, keepdim=True)
|
248 |
+
nd = torch.norm(delta, dim=-1, keepdim=True)
|
249 |
+
|
250 |
+
# make sure we don't divide by zero in backward as torch.where computes grad for both
|
251 |
+
# branches
|
252 |
+
nd_ = torch.where(nd < eps, nd + eps, nd)
|
253 |
+
sinc = torch.where(nd < eps, nd.new_ones(nd.shape), torch.sin(nd_) / nd_)
|
254 |
+
|
255 |
+
# cos is applied to last dim instead of first
|
256 |
+
exp_delta = torch.cat([sinc * delta, torch.cos(nd)], -1)
|
257 |
+
|
258 |
+
v, beta = cls.householder_vector(x)
|
259 |
+
return nx * cls.apply_householder(exp_delta, v, beta)
|
260 |
+
|
261 |
+
|
262 |
+
@torch.jit.script
|
263 |
+
def J_vecnorm(vec: torch.Tensor) -> torch.Tensor:
|
264 |
+
"""Compute the jacobian of vec / norm2(vec).
|
265 |
+
|
266 |
+
Args:
|
267 |
+
vec (torch.Tensor): [..., D] tensor.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
torch.Tensor: [..., D, D] Jacobian.
|
271 |
+
"""
|
272 |
+
D = vec.shape[-1]
|
273 |
+
norm_x = torch.norm(vec, dim=-1, keepdim=True).unsqueeze(-1) # (..., 1, 1)
|
274 |
+
|
275 |
+
if (norm_x == 0).any():
|
276 |
+
norm_x = norm_x + 1e-6
|
277 |
+
|
278 |
+
xxT = torch.einsum("...i,...j->...ij", vec, vec) # (..., D, D)
|
279 |
+
identity = torch.eye(D, device=vec.device, dtype=vec.dtype) # (D, D)
|
280 |
+
|
281 |
+
return identity / norm_x - (xxT / norm_x**3) # (..., D, D)
|
282 |
+
|
283 |
+
|
284 |
+
@torch.jit.script
|
285 |
+
def J_focal2fov(focal: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
|
286 |
+
"""Compute the jacobian of the focal2fov function."""
|
287 |
+
return -4 * h / (4 * focal**2 + h**2)
|
288 |
+
|
289 |
+
|
290 |
+
@torch.jit.script
|
291 |
+
def J_up_projection(uv: torch.Tensor, abc: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
|
292 |
+
"""Compute the jacobian of the up-vector projection.
|
293 |
+
|
294 |
+
Args:
|
295 |
+
uv (torch.Tensor): Normalized image coordinates of shape (..., 2).
|
296 |
+
abc (torch.Tensor): Gravity vector of shape (..., 3).
|
297 |
+
wrt (str, optional): Parameter to differentiate with respect to. Defaults to "uv".
|
298 |
+
|
299 |
+
Raises:
|
300 |
+
ValueError: If the wrt parameter is unknown.
|
301 |
+
|
302 |
+
Returns:
|
303 |
+
torch.Tensor: Jacobian with respect to the parameter.
|
304 |
+
"""
|
305 |
+
if wrt == "uv":
|
306 |
+
c = abc[..., 2][..., None, None, None]
|
307 |
+
return -c * torch.eye(2, device=uv.device, dtype=uv.dtype).expand(uv.shape[:-1] + (2, 2))
|
308 |
+
|
309 |
+
elif wrt == "abc":
|
310 |
+
J = uv.new_zeros(uv.shape[:-1] + (2, 3))
|
311 |
+
J[..., 0, 0] = 1
|
312 |
+
J[..., 1, 1] = 1
|
313 |
+
J[..., 0, 2] = -uv[..., 0]
|
314 |
+
J[..., 1, 2] = -uv[..., 1]
|
315 |
+
return J
|
316 |
+
|
317 |
+
else:
|
318 |
+
raise ValueError(f"Unknown wrt: {wrt}")
|
geocalib/modules.py
ADDED
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Implementation of MSCAN from SegNeXt: Rethinking Convolutional Attention Design for Semantic
|
2 |
+
Segmentation (NeurIPS 2022) adapted from
|
3 |
+
|
4 |
+
https://github.com/Visual-Attention-Network/SegNeXt/blob/main/mmseg/models/backbones/mscan.py
|
5 |
+
|
6 |
+
|
7 |
+
Light Hamburger Decoder adapted from:
|
8 |
+
|
9 |
+
https://github.com/Visual-Attention-Network/SegNeXt/blob/main/mmseg/models/decode_heads/ham_head.py
|
10 |
+
"""
|
11 |
+
|
12 |
+
from typing import Dict, Tuple
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from torch.nn.modules.utils import _pair as to_2tuple
|
18 |
+
|
19 |
+
# flake8: noqa: E266
|
20 |
+
# mypy: ignore-errors
|
21 |
+
|
22 |
+
|
23 |
+
class ConvModule(nn.Module):
|
24 |
+
"""Replacement for mmcv.cnn.ConvModule to avoid mmcv dependency."""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
in_channels: int,
|
29 |
+
out_channels: int,
|
30 |
+
kernel_size: int,
|
31 |
+
padding: int = 0,
|
32 |
+
use_norm: bool = False,
|
33 |
+
bias: bool = True,
|
34 |
+
):
|
35 |
+
"""Simple convolution block.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
in_channels (int): Input channels.
|
39 |
+
out_channels (int): Output channels.
|
40 |
+
kernel_size (int): Kernel size.
|
41 |
+
padding (int, optional): Padding. Defaults to 0.
|
42 |
+
use_norm (bool, optional): Whether to use normalization. Defaults to False.
|
43 |
+
bias (bool, optional): Whether to use bias. Defaults to True.
|
44 |
+
"""
|
45 |
+
super().__init__()
|
46 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)
|
47 |
+
self.bn = nn.BatchNorm2d(out_channels) if use_norm else nn.Identity()
|
48 |
+
self.activate = nn.ReLU(inplace=True)
|
49 |
+
|
50 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
51 |
+
"""Forward pass."""
|
52 |
+
x = self.conv(x)
|
53 |
+
x = self.bn(x)
|
54 |
+
return self.activate(x)
|
55 |
+
|
56 |
+
|
57 |
+
class ResidualConvUnit(nn.Module):
|
58 |
+
"""Residual convolution module."""
|
59 |
+
|
60 |
+
def __init__(self, features):
|
61 |
+
"""Simple residual convolution block.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
features (int): number of features
|
65 |
+
"""
|
66 |
+
super().__init__()
|
67 |
+
|
68 |
+
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True)
|
69 |
+
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True)
|
70 |
+
|
71 |
+
self.relu = torch.nn.ReLU(inplace=True)
|
72 |
+
|
73 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
74 |
+
"""Forward pass."""
|
75 |
+
out = self.relu(x)
|
76 |
+
out = self.conv1(out)
|
77 |
+
out = self.relu(out)
|
78 |
+
out = self.conv2(out)
|
79 |
+
return out + x
|
80 |
+
|
81 |
+
|
82 |
+
class FeatureFusionBlock(nn.Module):
|
83 |
+
"""Feature fusion block."""
|
84 |
+
|
85 |
+
def __init__(self, features: int, unit2only=False, upsample=True):
|
86 |
+
"""Feature fusion block.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
features (int): Number of features.
|
90 |
+
unit2only (bool, optional): Whether to use only the second unit. Defaults to False.
|
91 |
+
upsample (bool, optional): Whether to upsample. Defaults to True.
|
92 |
+
"""
|
93 |
+
super().__init__()
|
94 |
+
self.upsample = upsample
|
95 |
+
|
96 |
+
if not unit2only:
|
97 |
+
self.resConfUnit1 = ResidualConvUnit(features)
|
98 |
+
self.resConfUnit2 = ResidualConvUnit(features)
|
99 |
+
|
100 |
+
def forward(self, *xs: torch.Tensor) -> torch.Tensor:
|
101 |
+
"""Forward pass."""
|
102 |
+
output = xs[0]
|
103 |
+
|
104 |
+
if len(xs) == 2:
|
105 |
+
output = output + self.resConfUnit1(xs[1])
|
106 |
+
|
107 |
+
output = self.resConfUnit2(output)
|
108 |
+
|
109 |
+
if self.upsample:
|
110 |
+
output = F.interpolate(output, scale_factor=2, mode="bilinear", align_corners=False)
|
111 |
+
|
112 |
+
return output
|
113 |
+
|
114 |
+
|
115 |
+
###################################################
|
116 |
+
########### Light Hamburger Decoder ###############
|
117 |
+
###################################################
|
118 |
+
|
119 |
+
|
120 |
+
class NMF2D(nn.Module):
|
121 |
+
"""Non-negative Matrix Factorization (NMF) for 2D data."""
|
122 |
+
|
123 |
+
def __init__(self):
|
124 |
+
"""Non-negative Matrix Factorization (NMF) for 2D data."""
|
125 |
+
super().__init__()
|
126 |
+
self.S, self.D, self.R = 1, 512, 64
|
127 |
+
self.train_steps = 6
|
128 |
+
self.eval_steps = 7
|
129 |
+
self.inv_t = 1
|
130 |
+
|
131 |
+
def _build_bases(self, B: int, S: int, D: int, R: int, device: str = "cpu") -> torch.Tensor:
|
132 |
+
bases = torch.rand((B * S, D, R)).to(device)
|
133 |
+
return F.normalize(bases, dim=1)
|
134 |
+
|
135 |
+
def local_step(
|
136 |
+
self, x: torch.Tensor, bases: torch.Tensor, coef: torch.Tensor
|
137 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
138 |
+
"""Update bases and coefficient."""
|
139 |
+
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
140 |
+
numerator = torch.bmm(x.transpose(1, 2), bases)
|
141 |
+
# (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
|
142 |
+
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
|
143 |
+
# Multiplicative Update
|
144 |
+
coef = coef * numerator / (denominator + 1e-6)
|
145 |
+
# (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
|
146 |
+
numerator = torch.bmm(x, coef)
|
147 |
+
# (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
|
148 |
+
denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
|
149 |
+
# Multiplicative Update
|
150 |
+
bases = bases * numerator / (denominator + 1e-6)
|
151 |
+
return bases, coef
|
152 |
+
|
153 |
+
def compute_coef(
|
154 |
+
self, x: torch.Tensor, bases: torch.Tensor, coef: torch.Tensor
|
155 |
+
) -> torch.Tensor:
|
156 |
+
"""Compute coefficient."""
|
157 |
+
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
158 |
+
numerator = torch.bmm(x.transpose(1, 2), bases)
|
159 |
+
# (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
|
160 |
+
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
|
161 |
+
# multiplication update
|
162 |
+
return coef * numerator / (denominator + 1e-6)
|
163 |
+
|
164 |
+
def local_inference(
|
165 |
+
self, x: torch.Tensor, bases: torch.Tensor
|
166 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
167 |
+
"""Local inference."""
|
168 |
+
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
169 |
+
coef = torch.bmm(x.transpose(1, 2), bases)
|
170 |
+
coef = F.softmax(self.inv_t * coef, dim=-1)
|
171 |
+
|
172 |
+
steps = self.train_steps if self.training else self.eval_steps
|
173 |
+
for _ in range(steps):
|
174 |
+
bases, coef = self.local_step(x, bases, coef)
|
175 |
+
|
176 |
+
return bases, coef
|
177 |
+
|
178 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
179 |
+
"""Forward pass."""
|
180 |
+
B, C, H, W = x.shape
|
181 |
+
|
182 |
+
# (B, C, H, W) -> (B * S, D, N)
|
183 |
+
D = C // self.S
|
184 |
+
N = H * W
|
185 |
+
x = x.view(B * self.S, D, N)
|
186 |
+
|
187 |
+
# (S, D, R) -> (B * S, D, R)
|
188 |
+
bases = self._build_bases(B, self.S, D, self.R, device=x.device)
|
189 |
+
bases, coef = self.local_inference(x, bases)
|
190 |
+
# (B * S, N, R)
|
191 |
+
coef = self.compute_coef(x, bases, coef)
|
192 |
+
# (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
|
193 |
+
x = torch.bmm(bases, coef.transpose(1, 2))
|
194 |
+
# (B * S, D, N) -> (B, C, H, W)
|
195 |
+
x = x.view(B, C, H, W)
|
196 |
+
# (B * H, D, R) -> (B, H, N, D)
|
197 |
+
bases = bases.view(B, self.S, D, self.R)
|
198 |
+
|
199 |
+
return x
|
200 |
+
|
201 |
+
|
202 |
+
class Hamburger(nn.Module):
|
203 |
+
"""Hamburger Module."""
|
204 |
+
|
205 |
+
def __init__(self, ham_channels: int = 512):
|
206 |
+
"""Hambuger Module.
|
207 |
+
|
208 |
+
Args:
|
209 |
+
ham_channels (int, optional): Number of channels in the hamburger module. Defaults to
|
210 |
+
512.
|
211 |
+
"""
|
212 |
+
super().__init__()
|
213 |
+
self.ham_in = ConvModule(ham_channels, ham_channels, 1)
|
214 |
+
self.ham = NMF2D()
|
215 |
+
self.ham_out = ConvModule(ham_channels, ham_channels, 1)
|
216 |
+
|
217 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
218 |
+
"""Forward pass."""
|
219 |
+
enjoy = self.ham_in(x)
|
220 |
+
enjoy = F.relu(enjoy, inplace=False)
|
221 |
+
enjoy = self.ham(enjoy)
|
222 |
+
enjoy = self.ham_out(enjoy)
|
223 |
+
ham = F.relu(x + enjoy, inplace=False)
|
224 |
+
return ham
|
225 |
+
|
226 |
+
|
227 |
+
class LightHamHead(nn.Module):
|
228 |
+
"""Is Attention Better Than Matrix Decomposition?
|
229 |
+
|
230 |
+
This head is the implementation of `HamNet <https://arxiv.org/abs/2109.04553>`.
|
231 |
+
"""
|
232 |
+
|
233 |
+
def __init__(self):
|
234 |
+
"""Light hamburger decoder head."""
|
235 |
+
super().__init__()
|
236 |
+
self.in_index = [0, 1, 2, 3]
|
237 |
+
self.in_channels = [64, 128, 320, 512]
|
238 |
+
self.out_channels = 64
|
239 |
+
self.ham_channels = 512
|
240 |
+
self.align_corners = False
|
241 |
+
|
242 |
+
self.squeeze = ConvModule(sum(self.in_channels), self.ham_channels, 1)
|
243 |
+
|
244 |
+
self.hamburger = Hamburger(self.ham_channels)
|
245 |
+
|
246 |
+
self.align = ConvModule(self.ham_channels, self.out_channels, 1)
|
247 |
+
|
248 |
+
self.linear_pred_uncertainty = nn.Sequential(
|
249 |
+
ConvModule(
|
250 |
+
in_channels=self.out_channels,
|
251 |
+
out_channels=self.out_channels,
|
252 |
+
kernel_size=3,
|
253 |
+
padding=1,
|
254 |
+
bias=False,
|
255 |
+
),
|
256 |
+
nn.Conv2d(in_channels=self.out_channels, out_channels=1, kernel_size=1),
|
257 |
+
)
|
258 |
+
|
259 |
+
self.out_conv = ConvModule(self.out_channels, self.out_channels, 3, padding=1, bias=False)
|
260 |
+
self.ll_fusion = FeatureFusionBlock(self.out_channels, upsample=False)
|
261 |
+
|
262 |
+
def forward(self, features: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
263 |
+
"""Forward pass."""
|
264 |
+
inputs = [features["hl"][i] for i in self.in_index]
|
265 |
+
|
266 |
+
inputs = [
|
267 |
+
F.interpolate(
|
268 |
+
level, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
|
269 |
+
)
|
270 |
+
for level in inputs
|
271 |
+
]
|
272 |
+
|
273 |
+
inputs = torch.cat(inputs, dim=1)
|
274 |
+
x = self.squeeze(inputs)
|
275 |
+
|
276 |
+
x = self.hamburger(x)
|
277 |
+
|
278 |
+
feats = self.align(x)
|
279 |
+
|
280 |
+
assert "ll" in features, "Low-level features are required for this model"
|
281 |
+
feats = F.interpolate(feats, scale_factor=2, mode="bilinear", align_corners=False)
|
282 |
+
feats = self.out_conv(feats)
|
283 |
+
feats = F.interpolate(feats, scale_factor=2, mode="bilinear", align_corners=False)
|
284 |
+
feats = self.ll_fusion(feats, features["ll"].clone())
|
285 |
+
|
286 |
+
uncertainty = self.linear_pred_uncertainty(feats).squeeze(1)
|
287 |
+
|
288 |
+
return feats, uncertainty
|
289 |
+
|
290 |
+
|
291 |
+
###################################################
|
292 |
+
########### MSCAN ################
|
293 |
+
###################################################
|
294 |
+
|
295 |
+
|
296 |
+
class DWConv(nn.Module):
|
297 |
+
"""Depthwise convolution."""
|
298 |
+
|
299 |
+
def __init__(self, dim: int = 768):
|
300 |
+
"""Depthwise convolution.
|
301 |
+
|
302 |
+
Args:
|
303 |
+
dim (int, optional): Number of features. Defaults to 768.
|
304 |
+
"""
|
305 |
+
super().__init__()
|
306 |
+
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
307 |
+
|
308 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
309 |
+
"""Forward pass."""
|
310 |
+
return self.dwconv(x)
|
311 |
+
|
312 |
+
|
313 |
+
class Mlp(nn.Module):
|
314 |
+
"""MLP module."""
|
315 |
+
|
316 |
+
def __init__(
|
317 |
+
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
|
318 |
+
):
|
319 |
+
"""Initialize the MLP."""
|
320 |
+
super().__init__()
|
321 |
+
out_features = out_features or in_features
|
322 |
+
hidden_features = hidden_features or in_features
|
323 |
+
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
|
324 |
+
self.dwconv = DWConv(hidden_features)
|
325 |
+
self.act = act_layer()
|
326 |
+
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
|
327 |
+
self.drop = nn.Dropout(drop)
|
328 |
+
|
329 |
+
def forward(self, x):
|
330 |
+
"""Forward pass."""
|
331 |
+
x = self.fc1(x)
|
332 |
+
|
333 |
+
x = self.dwconv(x)
|
334 |
+
x = self.act(x)
|
335 |
+
x = self.drop(x)
|
336 |
+
x = self.fc2(x)
|
337 |
+
x = self.drop(x)
|
338 |
+
|
339 |
+
return x
|
340 |
+
|
341 |
+
|
342 |
+
class StemConv(nn.Module):
|
343 |
+
"""Simple stem convolution module."""
|
344 |
+
|
345 |
+
def __init__(self, in_channels: int, out_channels: int):
|
346 |
+
"""Simple stem convolution module.
|
347 |
+
|
348 |
+
Args:
|
349 |
+
in_channels (int): Input channels.
|
350 |
+
out_channels (int): Output channels.
|
351 |
+
"""
|
352 |
+
super().__init__()
|
353 |
+
|
354 |
+
self.proj = nn.Sequential(
|
355 |
+
nn.Conv2d(
|
356 |
+
in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
|
357 |
+
),
|
358 |
+
nn.BatchNorm2d(out_channels // 2),
|
359 |
+
nn.GELU(),
|
360 |
+
nn.Conv2d(
|
361 |
+
out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
|
362 |
+
),
|
363 |
+
nn.BatchNorm2d(out_channels),
|
364 |
+
)
|
365 |
+
|
366 |
+
def forward(self, x):
|
367 |
+
"""Forward pass."""
|
368 |
+
x = self.proj(x)
|
369 |
+
_, _, H, W = x.size()
|
370 |
+
x = x.flatten(2).transpose(1, 2)
|
371 |
+
return x, H, W
|
372 |
+
|
373 |
+
|
374 |
+
class AttentionModule(nn.Module):
|
375 |
+
"""Attention module."""
|
376 |
+
|
377 |
+
def __init__(self, dim: int):
|
378 |
+
"""Attention module.
|
379 |
+
|
380 |
+
Args:
|
381 |
+
dim (int): Number of features.
|
382 |
+
"""
|
383 |
+
super().__init__()
|
384 |
+
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
|
385 |
+
self.conv0_1 = nn.Conv2d(dim, dim, (1, 7), padding=(0, 3), groups=dim)
|
386 |
+
self.conv0_2 = nn.Conv2d(dim, dim, (7, 1), padding=(3, 0), groups=dim)
|
387 |
+
|
388 |
+
self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)
|
389 |
+
self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)
|
390 |
+
|
391 |
+
self.conv2_1 = nn.Conv2d(dim, dim, (1, 21), padding=(0, 10), groups=dim)
|
392 |
+
self.conv2_2 = nn.Conv2d(dim, dim, (21, 1), padding=(10, 0), groups=dim)
|
393 |
+
self.conv3 = nn.Conv2d(dim, dim, 1)
|
394 |
+
|
395 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
396 |
+
"""Forward pass."""
|
397 |
+
u = x.clone()
|
398 |
+
attn = self.conv0(x)
|
399 |
+
|
400 |
+
attn_0 = self.conv0_1(attn)
|
401 |
+
attn_0 = self.conv0_2(attn_0)
|
402 |
+
|
403 |
+
attn_1 = self.conv1_1(attn)
|
404 |
+
attn_1 = self.conv1_2(attn_1)
|
405 |
+
|
406 |
+
attn_2 = self.conv2_1(attn)
|
407 |
+
attn_2 = self.conv2_2(attn_2)
|
408 |
+
attn = attn + attn_0 + attn_1 + attn_2
|
409 |
+
|
410 |
+
attn = self.conv3(attn)
|
411 |
+
return attn * u
|
412 |
+
|
413 |
+
|
414 |
+
class SpatialAttention(nn.Module):
|
415 |
+
"""Spatial attention module."""
|
416 |
+
|
417 |
+
def __init__(self, dim: int):
|
418 |
+
"""Spatial attention module.
|
419 |
+
|
420 |
+
Args:
|
421 |
+
dim (int): Number of features.
|
422 |
+
"""
|
423 |
+
super().__init__()
|
424 |
+
self.d_model = dim
|
425 |
+
self.proj_1 = nn.Conv2d(dim, dim, 1)
|
426 |
+
self.activation = nn.GELU()
|
427 |
+
self.spatial_gating_unit = AttentionModule(dim)
|
428 |
+
self.proj_2 = nn.Conv2d(dim, dim, 1)
|
429 |
+
|
430 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
431 |
+
"""Forward pass."""
|
432 |
+
shorcut = x.clone()
|
433 |
+
x = self.proj_1(x)
|
434 |
+
x = self.activation(x)
|
435 |
+
x = self.spatial_gating_unit(x)
|
436 |
+
x = self.proj_2(x)
|
437 |
+
x = x + shorcut
|
438 |
+
return x
|
439 |
+
|
440 |
+
|
441 |
+
class Block(nn.Module):
|
442 |
+
"""MSCAN block."""
|
443 |
+
|
444 |
+
def __init__(
|
445 |
+
self, dim: int, mlp_ratio: float = 4.0, drop: float = 0.0, act_layer: nn.Module = nn.GELU
|
446 |
+
):
|
447 |
+
"""MSCAN block.
|
448 |
+
|
449 |
+
Args:
|
450 |
+
dim (int): Number of features.
|
451 |
+
mlp_ratio (float, optional): Ratio of the hidden features in the MLP. Defaults to 4.0.
|
452 |
+
drop (float, optional): Dropout rate. Defaults to 0.0.
|
453 |
+
act_layer (nn.Module, optional): Activation layer. Defaults to nn.GELU.
|
454 |
+
"""
|
455 |
+
super().__init__()
|
456 |
+
self.norm1 = nn.BatchNorm2d(dim)
|
457 |
+
self.attn = SpatialAttention(dim)
|
458 |
+
self.drop_path = nn.Identity() # only used in training
|
459 |
+
self.norm2 = nn.BatchNorm2d(dim)
|
460 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
461 |
+
self.mlp = Mlp(
|
462 |
+
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
|
463 |
+
)
|
464 |
+
layer_scale_init_value = 1e-2
|
465 |
+
self.layer_scale_1 = nn.Parameter(
|
466 |
+
layer_scale_init_value * torch.ones((dim)), requires_grad=True
|
467 |
+
)
|
468 |
+
self.layer_scale_2 = nn.Parameter(
|
469 |
+
layer_scale_init_value * torch.ones((dim)), requires_grad=True
|
470 |
+
)
|
471 |
+
|
472 |
+
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
473 |
+
"""Forward pass."""
|
474 |
+
B, N, C = x.shape
|
475 |
+
x = x.permute(0, 2, 1).view(B, C, H, W)
|
476 |
+
x = x + self.drop_path(self.layer_scale_1[..., None, None] * self.attn(self.norm1(x)))
|
477 |
+
x = x + self.drop_path(self.layer_scale_2[..., None, None] * self.mlp(self.norm2(x)))
|
478 |
+
return x.view(B, C, N).permute(0, 2, 1)
|
479 |
+
|
480 |
+
|
481 |
+
class OverlapPatchEmbed(nn.Module):
|
482 |
+
"""Image to Patch Embedding"""
|
483 |
+
|
484 |
+
def __init__(
|
485 |
+
self, patch_size: int = 7, stride: int = 4, in_chans: int = 3, embed_dim: int = 768
|
486 |
+
):
|
487 |
+
"""Image to Patch Embedding.
|
488 |
+
|
489 |
+
Args:
|
490 |
+
patch_size (int, optional): Image patch size. Defaults to 7.
|
491 |
+
stride (int, optional): Stride. Defaults to 4.
|
492 |
+
in_chans (int, optional): Number of input channels. Defaults to 3.
|
493 |
+
embed_dim (int, optional): Embedding dimension. Defaults to 768.
|
494 |
+
"""
|
495 |
+
super().__init__()
|
496 |
+
patch_size = to_2tuple(patch_size)
|
497 |
+
|
498 |
+
self.proj = nn.Conv2d(
|
499 |
+
in_chans,
|
500 |
+
embed_dim,
|
501 |
+
kernel_size=patch_size,
|
502 |
+
stride=stride,
|
503 |
+
padding=(patch_size[0] // 2, patch_size[1] // 2),
|
504 |
+
)
|
505 |
+
self.norm = nn.BatchNorm2d(embed_dim)
|
506 |
+
|
507 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
508 |
+
"""Forward pass."""
|
509 |
+
x = self.proj(x)
|
510 |
+
_, _, H, W = x.shape
|
511 |
+
x = self.norm(x)
|
512 |
+
x = x.flatten(2).transpose(1, 2)
|
513 |
+
return x, H, W
|
514 |
+
|
515 |
+
|
516 |
+
class MSCAN(nn.Module):
|
517 |
+
"""Multi-scale convolutional attention network."""
|
518 |
+
|
519 |
+
def __init__(self):
|
520 |
+
"""Multi-scale convolutional attention network."""
|
521 |
+
super().__init__()
|
522 |
+
self.in_channels = 3
|
523 |
+
self.embed_dims = [64, 128, 320, 512]
|
524 |
+
self.mlp_ratios = [8, 8, 4, 4]
|
525 |
+
self.drop_rate = 0.0
|
526 |
+
self.drop_path_rate = 0.1
|
527 |
+
self.depths = [3, 3, 12, 3]
|
528 |
+
self.num_stages = 4
|
529 |
+
|
530 |
+
for i in range(self.num_stages):
|
531 |
+
if i == 0:
|
532 |
+
patch_embed = StemConv(3, self.embed_dims[0])
|
533 |
+
else:
|
534 |
+
patch_embed = OverlapPatchEmbed(
|
535 |
+
patch_size=7 if i == 0 else 3,
|
536 |
+
stride=4 if i == 0 else 2,
|
537 |
+
in_chans=self.in_chans if i == 0 else self.embed_dims[i - 1],
|
538 |
+
embed_dim=self.embed_dims[i],
|
539 |
+
)
|
540 |
+
|
541 |
+
block = nn.ModuleList(
|
542 |
+
[
|
543 |
+
Block(
|
544 |
+
dim=self.embed_dims[i],
|
545 |
+
mlp_ratio=self.mlp_ratios[i],
|
546 |
+
drop=self.drop_rate,
|
547 |
+
)
|
548 |
+
for _ in range(self.depths[i])
|
549 |
+
]
|
550 |
+
)
|
551 |
+
norm = nn.LayerNorm(self.embed_dims[i])
|
552 |
+
|
553 |
+
setattr(self, f"patch_embed{i + 1}", patch_embed)
|
554 |
+
setattr(self, f"block{i + 1}", block)
|
555 |
+
setattr(self, f"norm{i + 1}", norm)
|
556 |
+
|
557 |
+
def forward(self, data):
|
558 |
+
"""Forward pass."""
|
559 |
+
# rgb -> bgr and from [0, 1] to [0, 255]
|
560 |
+
x = data["image"][:, [2, 1, 0], :, :] * 255.0
|
561 |
+
B = x.shape[0]
|
562 |
+
|
563 |
+
outs = []
|
564 |
+
for i in range(self.num_stages):
|
565 |
+
patch_embed = getattr(self, f"patch_embed{i + 1}")
|
566 |
+
block = getattr(self, f"block{i + 1}")
|
567 |
+
norm = getattr(self, f"norm{i + 1}")
|
568 |
+
x, H, W = patch_embed(x)
|
569 |
+
for blk in block:
|
570 |
+
x = blk(x, H, W)
|
571 |
+
x = norm(x)
|
572 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
573 |
+
outs.append(x)
|
574 |
+
|
575 |
+
return {"features": outs}
|
geocalib/perspective_fields.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Implementation of perspective fields.
|
2 |
+
|
3 |
+
Adapted from https://github.com/jinlinyi/PerspectiveFields/blob/main/perspective2d/utils/panocam.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
from typing import Tuple
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
from geocalib.camera import BaseCamera
|
12 |
+
from geocalib.gravity import Gravity
|
13 |
+
from geocalib.misc import J_up_projection, J_vecnorm, SphericalManifold
|
14 |
+
|
15 |
+
# flake8: noqa: E266
|
16 |
+
|
17 |
+
|
18 |
+
def get_horizon_line(camera: BaseCamera, gravity: Gravity, relative: bool = True) -> torch.Tensor:
|
19 |
+
"""Get the horizon line from the camera parameters.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
camera (Camera): Camera parameters.
|
23 |
+
gravity (Gravity): Gravity vector.
|
24 |
+
relative (bool, optional): Whether to normalize horizon line by img_h. Defaults to True.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
torch.Tensor: In image frame, fraction of image left/right border intersection with
|
28 |
+
respect to image height.
|
29 |
+
"""
|
30 |
+
camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
|
31 |
+
gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
|
32 |
+
|
33 |
+
# project horizon midpoint to image plane
|
34 |
+
horizon_midpoint = camera.new_tensor([0, 0, 1])
|
35 |
+
horizon_midpoint = camera.K @ gravity.R @ horizon_midpoint
|
36 |
+
midpoint = horizon_midpoint[:2] / horizon_midpoint[2]
|
37 |
+
|
38 |
+
# compute left and right offset to borders
|
39 |
+
left_offset = midpoint[0] * torch.tan(gravity.roll)
|
40 |
+
right_offset = (camera.size[0] - midpoint[0]) * torch.tan(gravity.roll)
|
41 |
+
left, right = midpoint[1] + left_offset, midpoint[1] - right_offset
|
42 |
+
|
43 |
+
horizon = camera.new_tensor([left, right])
|
44 |
+
return horizon / camera.size[1] if relative else horizon
|
45 |
+
|
46 |
+
|
47 |
+
def get_up_field(camera: BaseCamera, gravity: Gravity, normalize: bool = True) -> torch.Tensor:
|
48 |
+
"""Get the up vector field from the camera parameters.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
camera (Camera): Camera parameters.
|
52 |
+
normalize (bool, optional): Whether to normalize the up vector. Defaults to True.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
torch.Tensor: up vector field as tensor of shape (..., h, w, 2).
|
56 |
+
"""
|
57 |
+
camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
|
58 |
+
gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
|
59 |
+
|
60 |
+
w, h = camera.size[0].unbind(-1)
|
61 |
+
h, w = h.round().to(int), w.round().to(int)
|
62 |
+
|
63 |
+
uv = camera.normalize(camera.pixel_coordinates())
|
64 |
+
|
65 |
+
# projected up is (a, b) - c * (u, v)
|
66 |
+
abc = gravity.vec3d
|
67 |
+
projected_up2d = abc[..., None, :2] - abc[..., 2, None, None] * uv # (..., N, 2)
|
68 |
+
|
69 |
+
if hasattr(camera, "dist"):
|
70 |
+
d_uv = camera.distort(uv, return_scale=True)[0] # (..., N, 1)
|
71 |
+
d_uv = torch.diag_embed(d_uv.expand(d_uv.shape[:-1] + (2,))) # (..., N, 2, 2)
|
72 |
+
offset = camera.up_projection_offset(uv) # (..., N, 2)
|
73 |
+
offset = torch.einsum("...i,...j->...ij", offset, uv) # (..., N, 2, 2)
|
74 |
+
|
75 |
+
# (..., N, 2)
|
76 |
+
projected_up2d = torch.einsum("...Nij,...Nj->...Ni", d_uv + offset, projected_up2d)
|
77 |
+
|
78 |
+
if normalize:
|
79 |
+
projected_up2d = F.normalize(projected_up2d, dim=-1) # (..., N, 2)
|
80 |
+
|
81 |
+
return projected_up2d.reshape(camera.shape[0], h, w, 2)
|
82 |
+
|
83 |
+
|
84 |
+
def J_up_field(
|
85 |
+
camera: BaseCamera, gravity: Gravity, spherical: bool = False, log_focal: bool = False
|
86 |
+
) -> torch.Tensor:
|
87 |
+
"""Get the jacobian of the up field.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
camera (Camera): Camera parameters.
|
91 |
+
gravity (Gravity): Gravity vector.
|
92 |
+
spherical (bool, optional): Whether to use spherical coordinates. Defaults to False.
|
93 |
+
log_focal (bool, optional): Whether to use log-focal length. Defaults to False.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
torch.Tensor: Jacobian of the up field as a tensor of shape (..., h, w, 2, 2, 3).
|
97 |
+
"""
|
98 |
+
camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
|
99 |
+
gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
|
100 |
+
|
101 |
+
w, h = camera.size[0].unbind(-1)
|
102 |
+
h, w = h.round().to(int), w.round().to(int)
|
103 |
+
|
104 |
+
# Forward
|
105 |
+
xy = camera.pixel_coordinates()
|
106 |
+
uv = camera.normalize(xy)
|
107 |
+
|
108 |
+
projected_up2d = gravity.vec3d[..., None, :2] - gravity.vec3d[..., 2, None, None] * uv
|
109 |
+
|
110 |
+
# Backward
|
111 |
+
J = []
|
112 |
+
|
113 |
+
# (..., N, 2, 2)
|
114 |
+
J_norm2proj = J_vecnorm(
|
115 |
+
get_up_field(camera, gravity, normalize=False).reshape(camera.shape[0], -1, 2)
|
116 |
+
)
|
117 |
+
|
118 |
+
# distortion values
|
119 |
+
if hasattr(camera, "dist"):
|
120 |
+
d_uv = camera.distort(uv, return_scale=True)[0] # (..., N, 1)
|
121 |
+
d_uv = torch.diag_embed(d_uv.expand(d_uv.shape[:-1] + (2,))) # (..., N, 2, 2)
|
122 |
+
offset = camera.up_projection_offset(uv) # (..., N, 2)
|
123 |
+
offset_uv = torch.einsum("...i,...j->...ij", offset, uv) # (..., N, 2, 2)
|
124 |
+
|
125 |
+
######################
|
126 |
+
## Gravity Jacobian ##
|
127 |
+
######################
|
128 |
+
|
129 |
+
J_proj2abc = J_up_projection(uv, gravity.vec3d, wrt="abc") # (..., N, 2, 3)
|
130 |
+
|
131 |
+
if hasattr(camera, "dist"):
|
132 |
+
# (..., N, 2, 3)
|
133 |
+
J_proj2abc = torch.einsum("...Nij,...Njk->...Nik", d_uv + offset_uv, J_proj2abc)
|
134 |
+
|
135 |
+
J_abc2delta = SphericalManifold.J_plus(gravity.vec3d) if spherical else gravity.J_rp()
|
136 |
+
J_proj2delta = torch.einsum("...Nij,...jk->...Nik", J_proj2abc, J_abc2delta)
|
137 |
+
J_up2delta = torch.einsum("...Nij,...Njk->...Nik", J_norm2proj, J_proj2delta)
|
138 |
+
J.append(J_up2delta)
|
139 |
+
|
140 |
+
######################
|
141 |
+
### Focal Jacobian ###
|
142 |
+
######################
|
143 |
+
|
144 |
+
J_proj2uv = J_up_projection(uv, gravity.vec3d, wrt="uv") # (..., N, 2, 2)
|
145 |
+
|
146 |
+
if hasattr(camera, "dist"):
|
147 |
+
J_proj2up = torch.einsum("...Nij,...Njk->...Nik", d_uv + offset_uv, J_proj2uv)
|
148 |
+
J_proj2duv = torch.einsum("...i,...j->...ji", offset, projected_up2d)
|
149 |
+
|
150 |
+
inner = (uv * projected_up2d).sum(-1)[..., None, None]
|
151 |
+
J_proj2offset1 = inner * camera.J_up_projection_offset(uv, wrt="uv")
|
152 |
+
J_proj2offset2 = torch.einsum("...i,...j->...ij", offset, projected_up2d) # (..., N, 2, 2)
|
153 |
+
J_proj2uv = (J_proj2duv + J_proj2offset1 + J_proj2offset2) + J_proj2up
|
154 |
+
|
155 |
+
J_uv2f = camera.J_normalize(xy) # (..., N, 2, 2)
|
156 |
+
|
157 |
+
if log_focal:
|
158 |
+
J_uv2f = J_uv2f * camera.f[..., None, None, :] # (..., N, 2, 2)
|
159 |
+
|
160 |
+
J_uv2f = J_uv2f.sum(-1) # (..., N, 2)
|
161 |
+
|
162 |
+
J_proj2f = torch.einsum("...ij,...j->...i", J_proj2uv, J_uv2f) # (..., N, 2)
|
163 |
+
J_up2f = torch.einsum("...Nij,...Nj->...Ni", J_norm2proj, J_proj2f)[..., None] # (..., N, 2, 1)
|
164 |
+
J.append(J_up2f)
|
165 |
+
|
166 |
+
######################
|
167 |
+
##### K1 Jacobian ####
|
168 |
+
######################
|
169 |
+
|
170 |
+
if hasattr(camera, "dist"):
|
171 |
+
J_duv = camera.J_distort(uv, wrt="scale2dist")
|
172 |
+
J_duv = torch.diag_embed(J_duv.expand(J_duv.shape[:-1] + (2,))) # (..., N, 2, 2)
|
173 |
+
J_offset = torch.einsum(
|
174 |
+
"...i,...j->...ij", camera.J_up_projection_offset(uv, wrt="dist"), uv
|
175 |
+
)
|
176 |
+
J_proj2k1 = torch.einsum("...Nij,...Nj->...Ni", J_duv + J_offset, projected_up2d)
|
177 |
+
J_k1 = torch.einsum("...Nij,...Nj->...Ni", J_norm2proj, J_proj2k1)[..., None]
|
178 |
+
J.append(J_k1)
|
179 |
+
|
180 |
+
n_params = sum(j.shape[-1] for j in J)
|
181 |
+
return torch.cat(J, axis=-1).reshape(camera.shape[0], h, w, 2, n_params)
|
182 |
+
|
183 |
+
|
184 |
+
def get_latitude_field(camera: BaseCamera, gravity: Gravity) -> torch.Tensor:
|
185 |
+
"""Get the latitudes of the camera pixels in radians.
|
186 |
+
|
187 |
+
Latitudes are defined as the angle between the ray and the up vector.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
camera (Camera): Camera parameters.
|
191 |
+
gravity (Gravity): Gravity vector.
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
torch.Tensor: Latitudes in radians as a tensor of shape (..., h, w, 1).
|
195 |
+
"""
|
196 |
+
camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
|
197 |
+
gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
|
198 |
+
|
199 |
+
w, h = camera.size[0].unbind(-1)
|
200 |
+
h, w = h.round().to(int), w.round().to(int)
|
201 |
+
|
202 |
+
uv1, _ = camera.image2world(camera.pixel_coordinates())
|
203 |
+
rays = camera.pixel_bearing_many(uv1)
|
204 |
+
|
205 |
+
lat = torch.einsum("...Nj,...j->...N", rays, gravity.vec3d)
|
206 |
+
|
207 |
+
eps = 1e-6
|
208 |
+
lat_asin = torch.asin(lat.clamp(min=-1 + eps, max=1 - eps))
|
209 |
+
|
210 |
+
return lat_asin.reshape(camera.shape[0], h, w, 1)
|
211 |
+
|
212 |
+
|
213 |
+
def J_latitude_field(
|
214 |
+
camera: BaseCamera, gravity: Gravity, spherical: bool = False, log_focal: bool = False
|
215 |
+
) -> torch.Tensor:
|
216 |
+
"""Get the jacobian of the latitude field.
|
217 |
+
|
218 |
+
Args:
|
219 |
+
camera (Camera): Camera parameters.
|
220 |
+
gravity (Gravity): Gravity vector.
|
221 |
+
spherical (bool, optional): Whether to use spherical coordinates. Defaults to False.
|
222 |
+
log_focal (bool, optional): Whether to use log-focal length. Defaults to False.
|
223 |
+
|
224 |
+
Returns:
|
225 |
+
torch.Tensor: Jacobian of the latitude field as a tensor of shape (..., h, w, 1, 3).
|
226 |
+
"""
|
227 |
+
camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
|
228 |
+
gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
|
229 |
+
|
230 |
+
w, h = camera.size[0].unbind(-1)
|
231 |
+
h, w = h.round().to(int), w.round().to(int)
|
232 |
+
|
233 |
+
# Forward
|
234 |
+
xy = camera.pixel_coordinates()
|
235 |
+
uv1, _ = camera.image2world(xy)
|
236 |
+
uv1_norm = camera.pixel_bearing_many(uv1) # (..., N, 3)
|
237 |
+
|
238 |
+
# Backward
|
239 |
+
J = []
|
240 |
+
J_norm2w_to_img = J_vecnorm(uv1)[..., :2] # (..., N, 2)
|
241 |
+
|
242 |
+
######################
|
243 |
+
## Gravity Jacobian ##
|
244 |
+
######################
|
245 |
+
|
246 |
+
J_delta = SphericalManifold.J_plus(gravity.vec3d) if spherical else gravity.J_rp()
|
247 |
+
J_delta = torch.einsum("...Ni,...ij->...Nj", uv1_norm, J_delta) # (..., N, 2)
|
248 |
+
J.append(J_delta)
|
249 |
+
|
250 |
+
######################
|
251 |
+
### Focal Jacobian ###
|
252 |
+
######################
|
253 |
+
|
254 |
+
J_w_to_img2f = camera.J_image2world(xy, "f") # (..., N, 2, 2)
|
255 |
+
if log_focal:
|
256 |
+
J_w_to_img2f = J_w_to_img2f * camera.f[..., None, None, :]
|
257 |
+
J_w_to_img2f = J_w_to_img2f.sum(-1) # (..., N, 2)
|
258 |
+
|
259 |
+
J_norm2f = torch.einsum("...Nij,...Nj->...Ni", J_norm2w_to_img, J_w_to_img2f) # (..., N, 3)
|
260 |
+
J_f = torch.einsum("...Ni,...i->...N", J_norm2f, gravity.vec3d).unsqueeze(-1) # (..., N, 1)
|
261 |
+
J.append(J_f)
|
262 |
+
|
263 |
+
######################
|
264 |
+
##### K1 Jacobian ####
|
265 |
+
######################
|
266 |
+
|
267 |
+
if hasattr(camera, "dist"):
|
268 |
+
J_w_to_img2k1 = camera.J_image2world(xy, "dist") # (..., N, 2)
|
269 |
+
# (..., N, 2)
|
270 |
+
J_norm2k1 = torch.einsum("...Nij,...Nj->...Ni", J_norm2w_to_img, J_w_to_img2k1)
|
271 |
+
# (..., N, 1)
|
272 |
+
J_k1 = torch.einsum("...Ni,...i->...N", J_norm2k1, gravity.vec3d).unsqueeze(-1)
|
273 |
+
J.append(J_k1)
|
274 |
+
|
275 |
+
n_params = sum(j.shape[-1] for j in J)
|
276 |
+
return torch.cat(J, axis=-1).reshape(camera.shape[0], h, w, 1, n_params)
|
277 |
+
|
278 |
+
|
279 |
+
def get_perspective_field(
|
280 |
+
camera: BaseCamera,
|
281 |
+
gravity: Gravity,
|
282 |
+
use_up: bool = True,
|
283 |
+
use_latitude: bool = True,
|
284 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
285 |
+
"""Get the perspective field from the camera parameters.
|
286 |
+
|
287 |
+
Args:
|
288 |
+
camera (Camera): Camera parameters.
|
289 |
+
gravity (Gravity): Gravity vector.
|
290 |
+
use_up (bool, optional): Whether to include the up vector field. Defaults to True.
|
291 |
+
use_latitude (bool, optional): Whether to include the latitude field. Defaults to True.
|
292 |
+
|
293 |
+
Returns:
|
294 |
+
Tuple[torch.Tensor, torch.Tensor]: Up and latitude fields as tensors of shape
|
295 |
+
(..., 2, h, w) and (..., 1, h, w).
|
296 |
+
"""
|
297 |
+
assert use_up or use_latitude, "At least one of use_up or use_latitude must be True."
|
298 |
+
|
299 |
+
camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
|
300 |
+
gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
|
301 |
+
|
302 |
+
w, h = camera.size[0].unbind(-1)
|
303 |
+
h, w = h.round().to(int), w.round().to(int)
|
304 |
+
|
305 |
+
if use_up:
|
306 |
+
permute = (0, 3, 1, 2)
|
307 |
+
# (..., 2, h, w)
|
308 |
+
up = get_up_field(camera, gravity).permute(permute)
|
309 |
+
else:
|
310 |
+
shape = (camera.shape[0], 2, h, w)
|
311 |
+
up = camera.new_zeros(shape)
|
312 |
+
|
313 |
+
if use_latitude:
|
314 |
+
permute = (0, 3, 1, 2)
|
315 |
+
# (..., 1, h, w)
|
316 |
+
lat = get_latitude_field(camera, gravity).permute(permute)
|
317 |
+
else:
|
318 |
+
shape = (camera.shape[0], 1, h, w)
|
319 |
+
lat = camera.new_zeros(shape)
|
320 |
+
|
321 |
+
return up, lat
|
322 |
+
|
323 |
+
|
324 |
+
def J_perspective_field(
|
325 |
+
camera: BaseCamera,
|
326 |
+
gravity: Gravity,
|
327 |
+
use_up: bool = True,
|
328 |
+
use_latitude: bool = True,
|
329 |
+
spherical: bool = False,
|
330 |
+
log_focal: bool = False,
|
331 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
332 |
+
"""Get the jacobian of the perspective field.
|
333 |
+
|
334 |
+
Args:
|
335 |
+
camera (Camera): Camera parameters.
|
336 |
+
gravity (Gravity): Gravity vector.
|
337 |
+
use_up (bool, optional): Whether to include the up vector field. Defaults to True.
|
338 |
+
use_latitude (bool, optional): Whether to include the latitude field. Defaults to True.
|
339 |
+
spherical (bool, optional): Whether to use spherical coordinates. Defaults to False.
|
340 |
+
log_focal (bool, optional): Whether to use log-focal length. Defaults to False.
|
341 |
+
|
342 |
+
Returns:
|
343 |
+
Tuple[torch.Tensor, torch.Tensor]: Up and latitude jacobians as tensors of shape
|
344 |
+
(..., h, w, 2, 4) and (..., h, w, 1, 4).
|
345 |
+
"""
|
346 |
+
assert use_up or use_latitude, "At least one of use_up or use_latitude must be True."
|
347 |
+
|
348 |
+
camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
|
349 |
+
gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
|
350 |
+
|
351 |
+
w, h = camera.size[0].unbind(-1)
|
352 |
+
h, w = h.round().to(int), w.round().to(int)
|
353 |
+
|
354 |
+
if use_up:
|
355 |
+
J_up = J_up_field(camera, gravity, spherical, log_focal) # (..., h, w, 2, 4)
|
356 |
+
else:
|
357 |
+
shape = (camera.shape[0], h, w, 2, 4)
|
358 |
+
J_up = camera.new_zeros(shape)
|
359 |
+
|
360 |
+
if use_latitude:
|
361 |
+
J_lat = J_latitude_field(camera, gravity, spherical, log_focal) # (..., h, w, 1, 4)
|
362 |
+
else:
|
363 |
+
shape = (camera.shape[0], h, w, 1, 4)
|
364 |
+
J_lat = camera.new_zeros(shape)
|
365 |
+
|
366 |
+
return J_up, J_lat
|
geocalib/utils.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Image loading and general conversion utilities."""
|
2 |
+
|
3 |
+
import collections.abc as collections
|
4 |
+
from pathlib import Path
|
5 |
+
from types import SimpleNamespace
|
6 |
+
from typing import Dict, Optional, Tuple
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import kornia
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torchvision
|
13 |
+
|
14 |
+
# mypy: ignore-errors
|
15 |
+
|
16 |
+
|
17 |
+
def fit_to_multiple(x: torch.Tensor, multiple: int, mode: str = "center", crop: bool = False):
|
18 |
+
"""Get padding to make the image size a multiple of the given number.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
x (torch.Tensor): Input tensor.
|
22 |
+
multiple (int, optional): Multiple to fit to.
|
23 |
+
crop (bool, optional): Whether to crop or pad. Defaults to False.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
torch.Tensor: Padding.
|
27 |
+
"""
|
28 |
+
h, w = x.shape[-2:]
|
29 |
+
|
30 |
+
if crop:
|
31 |
+
pad_w = (w // multiple) * multiple - w
|
32 |
+
pad_h = (h // multiple) * multiple - h
|
33 |
+
else:
|
34 |
+
pad_w = (multiple - w % multiple) % multiple
|
35 |
+
pad_h = (multiple - h % multiple) % multiple
|
36 |
+
|
37 |
+
if mode == "center":
|
38 |
+
pad_l = pad_w // 2
|
39 |
+
pad_r = pad_w - pad_l
|
40 |
+
pad_t = pad_h // 2
|
41 |
+
pad_b = pad_h - pad_t
|
42 |
+
elif mode == "left":
|
43 |
+
pad_l, pad_r = 0, pad_w
|
44 |
+
pad_t, pad_b = 0, pad_h
|
45 |
+
else:
|
46 |
+
raise ValueError(f"Unknown mode {mode}")
|
47 |
+
|
48 |
+
return (pad_l, pad_r, pad_t, pad_b)
|
49 |
+
|
50 |
+
|
51 |
+
def fit_features_to_multiple(
|
52 |
+
features: torch.Tensor, multiple: int = 32, crop: bool = False
|
53 |
+
) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
54 |
+
"""Pad or crop image to a multiple of the given number.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
features (torch.Tensor): Input features.
|
58 |
+
multiple (int, optional): Multiple. Defaults to 32.
|
59 |
+
crop (bool, optional): Whether to crop or pad. Defaults to False.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
Tuple[torch.Tensor, Tuple[int, int]]: Padded features and padding.
|
63 |
+
"""
|
64 |
+
pad = fit_to_multiple(features, multiple, crop=crop)
|
65 |
+
return torch.nn.functional.pad(features, pad, mode="reflect"), pad
|
66 |
+
|
67 |
+
|
68 |
+
class ImagePreprocessor:
|
69 |
+
"""Preprocess images for calibration."""
|
70 |
+
|
71 |
+
default_conf = {
|
72 |
+
"resize": 320, # target edge length, None for no resizing
|
73 |
+
"edge_divisible_by": None,
|
74 |
+
"side": "short",
|
75 |
+
"interpolation": "bilinear",
|
76 |
+
"align_corners": None,
|
77 |
+
"antialias": True,
|
78 |
+
"square_crop": False,
|
79 |
+
"add_padding_mask": False,
|
80 |
+
"resize_backend": "kornia", # torchvision, kornia
|
81 |
+
}
|
82 |
+
|
83 |
+
def __init__(self, conf) -> None:
|
84 |
+
"""Initialize the image preprocessor."""
|
85 |
+
self.conf = {**self.default_conf, **conf}
|
86 |
+
self.conf = SimpleNamespace(**self.conf)
|
87 |
+
|
88 |
+
def __call__(self, img: torch.Tensor, interpolation: Optional[str] = None) -> dict:
|
89 |
+
"""Resize and preprocess an image, return image and resize scale."""
|
90 |
+
h, w = img.shape[-2:]
|
91 |
+
size = h, w
|
92 |
+
|
93 |
+
if self.conf.square_crop:
|
94 |
+
min_size = min(h, w)
|
95 |
+
offset = (h - min_size) // 2, (w - min_size) // 2
|
96 |
+
img = img[:, offset[0] : offset[0] + min_size, offset[1] : offset[1] + min_size]
|
97 |
+
size = img.shape[-2:]
|
98 |
+
|
99 |
+
if self.conf.resize is not None:
|
100 |
+
if interpolation is None:
|
101 |
+
interpolation = self.conf.interpolation
|
102 |
+
size = self.get_new_image_size(h, w)
|
103 |
+
img = self.resize(img, size, interpolation)
|
104 |
+
|
105 |
+
scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
|
106 |
+
T = np.diag([scale[0].cpu(), scale[1].cpu(), 1])
|
107 |
+
|
108 |
+
data = {
|
109 |
+
"scales": scale,
|
110 |
+
"image_size": np.array(size[::-1]),
|
111 |
+
"transform": T,
|
112 |
+
"original_image_size": np.array([w, h]),
|
113 |
+
}
|
114 |
+
|
115 |
+
if self.conf.edge_divisible_by is not None:
|
116 |
+
# crop to make the edge divisible by a number
|
117 |
+
w_, h_ = img.shape[-1], img.shape[-2]
|
118 |
+
img, _ = fit_features_to_multiple(img, self.conf.edge_divisible_by, crop=True)
|
119 |
+
crop_pad = torch.Tensor([img.shape[-1] - w_, img.shape[-2] - h_]).to(img)
|
120 |
+
data["crop_pad"] = crop_pad
|
121 |
+
data["image_size"] = np.array([img.shape[-1], img.shape[-2]])
|
122 |
+
|
123 |
+
data["image"] = img
|
124 |
+
return data
|
125 |
+
|
126 |
+
def resize(self, img: torch.Tensor, size: Tuple[int, int], interpolation: str) -> torch.Tensor:
|
127 |
+
"""Resize an image using the specified backend."""
|
128 |
+
if self.conf.resize_backend == "kornia":
|
129 |
+
return kornia.geometry.transform.resize(
|
130 |
+
img,
|
131 |
+
size,
|
132 |
+
side=self.conf.side,
|
133 |
+
antialias=self.conf.antialias,
|
134 |
+
align_corners=self.conf.align_corners,
|
135 |
+
interpolation=interpolation,
|
136 |
+
)
|
137 |
+
elif self.conf.resize_backend == "torchvision":
|
138 |
+
return torchvision.transforms.Resize(size, antialias=self.conf.antialias)(img)
|
139 |
+
else:
|
140 |
+
raise ValueError(f"{self.conf.resize_backend} not implemented.")
|
141 |
+
|
142 |
+
def load_image(self, image_path: Path) -> dict:
|
143 |
+
"""Load an image from a path and preprocess it."""
|
144 |
+
return self(load_image(image_path))
|
145 |
+
|
146 |
+
def get_new_image_size(self, h: int, w: int) -> Tuple[int, int]:
|
147 |
+
"""Get the new image size after resizing."""
|
148 |
+
side = self.conf.side
|
149 |
+
if isinstance(self.conf.resize, collections.Iterable):
|
150 |
+
assert len(self.conf.resize) == 2
|
151 |
+
return tuple(self.conf.resize)
|
152 |
+
side_size = self.conf.resize
|
153 |
+
aspect_ratio = w / h
|
154 |
+
if side not in ("short", "long", "vert", "horz"):
|
155 |
+
raise ValueError(
|
156 |
+
f"side can be one of 'short', 'long', 'vert', and 'horz'. Got '{side}'"
|
157 |
+
)
|
158 |
+
return (
|
159 |
+
(side_size, int(side_size * aspect_ratio))
|
160 |
+
if side == "vert" or (side != "horz" and (side == "short") ^ (aspect_ratio < 1.0))
|
161 |
+
else (int(side_size / aspect_ratio), side_size)
|
162 |
+
)
|
163 |
+
|
164 |
+
|
165 |
+
def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
|
166 |
+
"""Normalize the image tensor and reorder the dimensions."""
|
167 |
+
if image.ndim == 3:
|
168 |
+
image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
|
169 |
+
elif image.ndim == 2:
|
170 |
+
image = image[None] # add channel axis
|
171 |
+
else:
|
172 |
+
raise ValueError(f"Not an image: {image.shape}")
|
173 |
+
return torch.tensor(image / 255.0, dtype=torch.float)
|
174 |
+
|
175 |
+
|
176 |
+
def torch_image_to_numpy(image: torch.Tensor) -> np.ndarray:
|
177 |
+
"""Normalize and reorder the dimensions of an image tensor."""
|
178 |
+
if image.ndim == 3:
|
179 |
+
image = image.permute((1, 2, 0)) # CxHxW to HxWxC
|
180 |
+
elif image.ndim == 2:
|
181 |
+
image = image[None] # add channel axis
|
182 |
+
else:
|
183 |
+
raise ValueError(f"Not an image: {image.shape}")
|
184 |
+
return (image.cpu().detach().numpy() * 255).astype(np.uint8)
|
185 |
+
|
186 |
+
|
187 |
+
def read_image(path: Path, grayscale: bool = False) -> np.ndarray:
|
188 |
+
"""Read an image from path as RGB or grayscale."""
|
189 |
+
if not Path(path).exists():
|
190 |
+
raise FileNotFoundError(f"No image at path {path}.")
|
191 |
+
mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
|
192 |
+
image = cv2.imread(str(path), mode)
|
193 |
+
if image is None:
|
194 |
+
raise IOError(f"Could not read image at {path}.")
|
195 |
+
if not grayscale:
|
196 |
+
image = image[..., ::-1]
|
197 |
+
return image
|
198 |
+
|
199 |
+
|
200 |
+
def write_image(img: torch.Tensor, path: Path):
|
201 |
+
"""Write an image tensor to a file."""
|
202 |
+
img = torch_image_to_numpy(img) if isinstance(img, torch.Tensor) else img
|
203 |
+
cv2.imwrite(str(path), img[..., ::-1])
|
204 |
+
|
205 |
+
|
206 |
+
def load_image(path: Path, grayscale: bool = False, return_tensor: bool = True) -> torch.Tensor:
|
207 |
+
"""Load an image from a path and return as a tensor."""
|
208 |
+
image = read_image(path, grayscale=grayscale)
|
209 |
+
if return_tensor:
|
210 |
+
return numpy_image_to_torch(image)
|
211 |
+
|
212 |
+
assert image.ndim in [2, 3], f"Not an image: {image.shape}"
|
213 |
+
image = image[None] if image.ndim == 2 else image
|
214 |
+
return torch.tensor(image.copy(), dtype=torch.uint8)
|
215 |
+
|
216 |
+
|
217 |
+
def skew_symmetric(v: torch.Tensor) -> torch.Tensor:
|
218 |
+
"""Create a skew-symmetric matrix from a (batched) vector of size (..., 3).
|
219 |
+
|
220 |
+
Args:
|
221 |
+
(torch.Tensor): Vector of size (..., 3).
|
222 |
+
|
223 |
+
Returns:
|
224 |
+
(torch.Tensor): Skew-symmetric matrix of size (..., 3, 3).
|
225 |
+
"""
|
226 |
+
z = torch.zeros_like(v[..., 0])
|
227 |
+
return torch.stack(
|
228 |
+
[z, -v[..., 2], v[..., 1], v[..., 2], z, -v[..., 0], -v[..., 1], v[..., 0], z], dim=-1
|
229 |
+
).reshape(v.shape[:-1] + (3, 3))
|
230 |
+
|
231 |
+
|
232 |
+
def rad2rotmat(
|
233 |
+
roll: torch.Tensor, pitch: torch.Tensor, yaw: Optional[torch.Tensor] = None
|
234 |
+
) -> torch.Tensor:
|
235 |
+
"""Convert (batched) roll, pitch, yaw angles (in radians) to rotation matrix.
|
236 |
+
|
237 |
+
Args:
|
238 |
+
roll (torch.Tensor): Roll angle in radians.
|
239 |
+
pitch (torch.Tensor): Pitch angle in radians.
|
240 |
+
yaw (torch.Tensor, optional): Yaw angle in radians. Defaults to None.
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
torch.Tensor: Rotation matrix of shape (..., 3, 3).
|
244 |
+
"""
|
245 |
+
if yaw is None:
|
246 |
+
yaw = roll.new_zeros(roll.shape)
|
247 |
+
|
248 |
+
Rx = pitch.new_zeros(pitch.shape + (3, 3))
|
249 |
+
Rx[..., 0, 0] = 1
|
250 |
+
Rx[..., 1, 1] = torch.cos(pitch)
|
251 |
+
Rx[..., 1, 2] = torch.sin(pitch)
|
252 |
+
Rx[..., 2, 1] = -torch.sin(pitch)
|
253 |
+
Rx[..., 2, 2] = torch.cos(pitch)
|
254 |
+
|
255 |
+
Ry = yaw.new_zeros(yaw.shape + (3, 3))
|
256 |
+
Ry[..., 0, 0] = torch.cos(yaw)
|
257 |
+
Ry[..., 0, 2] = -torch.sin(yaw)
|
258 |
+
Ry[..., 1, 1] = 1
|
259 |
+
Ry[..., 2, 0] = torch.sin(yaw)
|
260 |
+
Ry[..., 2, 2] = torch.cos(yaw)
|
261 |
+
|
262 |
+
Rz = roll.new_zeros(roll.shape + (3, 3))
|
263 |
+
Rz[..., 0, 0] = torch.cos(roll)
|
264 |
+
Rz[..., 0, 1] = torch.sin(roll)
|
265 |
+
Rz[..., 1, 0] = -torch.sin(roll)
|
266 |
+
Rz[..., 1, 1] = torch.cos(roll)
|
267 |
+
Rz[..., 2, 2] = 1
|
268 |
+
|
269 |
+
return Rz @ Rx @ Ry
|
270 |
+
|
271 |
+
|
272 |
+
def fov2focal(fov: torch.Tensor, size: torch.Tensor) -> torch.Tensor:
|
273 |
+
"""Compute focal length from (vertical/horizontal) field of view."""
|
274 |
+
return size / 2 / torch.tan(fov / 2)
|
275 |
+
|
276 |
+
|
277 |
+
def focal2fov(focal: torch.Tensor, size: torch.Tensor) -> torch.Tensor:
|
278 |
+
"""Compute (vertical/horizontal) field of view from focal length."""
|
279 |
+
return 2 * torch.arctan(size / (2 * focal))
|
280 |
+
|
281 |
+
|
282 |
+
def pitch2rho(pitch: torch.Tensor, f: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
|
283 |
+
"""Compute the distance from principal point to the horizon."""
|
284 |
+
return torch.tan(pitch) * f / h
|
285 |
+
|
286 |
+
|
287 |
+
def rho2pitch(rho: torch.Tensor, f: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
|
288 |
+
"""Compute the pitch angle from the distance to the horizon."""
|
289 |
+
return torch.atan(rho * h / f)
|
290 |
+
|
291 |
+
|
292 |
+
def rad2deg(rad: torch.Tensor) -> torch.Tensor:
|
293 |
+
"""Convert radians to degrees."""
|
294 |
+
return rad / torch.pi * 180
|
295 |
+
|
296 |
+
|
297 |
+
def deg2rad(deg: torch.Tensor) -> torch.Tensor:
|
298 |
+
"""Convert degrees to radians."""
|
299 |
+
return deg / 180 * torch.pi
|
300 |
+
|
301 |
+
|
302 |
+
def get_device() -> str:
|
303 |
+
"""Get the device (cpu, cuda, mps) available."""
|
304 |
+
device = "cpu"
|
305 |
+
if torch.cuda.is_available():
|
306 |
+
device = "cuda"
|
307 |
+
elif torch.backends.mps.is_available():
|
308 |
+
device = "mps"
|
309 |
+
return device
|
310 |
+
|
311 |
+
|
312 |
+
def print_calibration(results: Dict[str, torch.Tensor]) -> None:
|
313 |
+
"""Print the calibration results."""
|
314 |
+
camera, gravity = results["camera"], results["gravity"]
|
315 |
+
vfov = rad2deg(camera.vfov)
|
316 |
+
roll, pitch = rad2deg(gravity.rp).unbind(-1)
|
317 |
+
|
318 |
+
print("\nEstimated parameters (Pred):")
|
319 |
+
print(f"Roll: {roll.item():.1f}° (± {rad2deg(results['roll_uncertainty']).item():.1f})°")
|
320 |
+
print(f"Pitch: {pitch.item():.1f}° (± {rad2deg(results['pitch_uncertainty']).item():.1f})°")
|
321 |
+
print(f"vFoV: {vfov.item():.1f}° (± {rad2deg(results['vfov_uncertainty']).item():.1f})°")
|
322 |
+
print(f"Focal: {camera.f[0, 1].item():.1f} px (± {results['focal_uncertainty'].item():.1f} px)")
|
323 |
+
|
324 |
+
if hasattr(camera, "k1"):
|
325 |
+
print(f"K1: {camera.k1.item():.1f}")
|
geocalib/viz2d.py
ADDED
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""2D visualization primitives based on Matplotlib.
|
2 |
+
|
3 |
+
1) Plot images with `plot_images`.
|
4 |
+
2) Call functions to plot heatmaps, vector fields, and horizon lines.
|
5 |
+
3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import matplotlib.patheffects as path_effects
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from geocalib.perspective_fields import get_perspective_field
|
14 |
+
from geocalib.utils import rad2deg
|
15 |
+
|
16 |
+
# mypy: ignore-errors
|
17 |
+
|
18 |
+
|
19 |
+
def plot_images(imgs, titles=None, cmaps="gray", dpi=200, pad=0.5, adaptive=True):
|
20 |
+
"""Plot a list of images.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
imgs (List[np.ndarray]): List of images to plot.
|
24 |
+
titles (List[str], optional): Titles. Defaults to None.
|
25 |
+
cmaps (str, optional): Colormaps. Defaults to "gray".
|
26 |
+
dpi (int, optional): Dots per inch. Defaults to 200.
|
27 |
+
pad (float, optional): Padding. Defaults to 0.5.
|
28 |
+
adaptive (bool, optional): Whether to adapt the aspect ratio. Defaults to True.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
plt.Figure: Figure of the images.
|
32 |
+
"""
|
33 |
+
n = len(imgs)
|
34 |
+
if not isinstance(cmaps, (list, tuple)):
|
35 |
+
cmaps = [cmaps] * n
|
36 |
+
|
37 |
+
ratios = [i.shape[1] / i.shape[0] for i in imgs] if adaptive else [4 / 3] * n
|
38 |
+
figsize = [sum(ratios) * 4.5, 4.5]
|
39 |
+
fig, axs = plt.subplots(1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios})
|
40 |
+
if n == 1:
|
41 |
+
axs = [axs]
|
42 |
+
for i, (img, ax) in enumerate(zip(imgs, axs)):
|
43 |
+
ax.imshow(img, cmap=plt.get_cmap(cmaps[i]))
|
44 |
+
ax.set_axis_off()
|
45 |
+
if titles:
|
46 |
+
ax.set_title(titles[i])
|
47 |
+
fig.tight_layout(pad=pad)
|
48 |
+
|
49 |
+
return fig
|
50 |
+
|
51 |
+
|
52 |
+
def plot_image_grid(
|
53 |
+
imgs,
|
54 |
+
titles=None,
|
55 |
+
cmaps="gray",
|
56 |
+
dpi=100,
|
57 |
+
pad=0.5,
|
58 |
+
fig=None,
|
59 |
+
adaptive=True,
|
60 |
+
figs=3.0,
|
61 |
+
return_fig=False,
|
62 |
+
set_lim=False,
|
63 |
+
) -> plt.Figure:
|
64 |
+
"""Plot a grid of images.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
imgs (List[np.ndarray]): List of images to plot.
|
68 |
+
titles (List[str], optional): Titles. Defaults to None.
|
69 |
+
cmaps (str, optional): Colormaps. Defaults to "gray".
|
70 |
+
dpi (int, optional): Dots per inch. Defaults to 100.
|
71 |
+
pad (float, optional): Padding. Defaults to 0.5.
|
72 |
+
fig (_type_, optional): Figure to plot on. Defaults to None.
|
73 |
+
adaptive (bool, optional): Whether to adapt the aspect ratio. Defaults to True.
|
74 |
+
figs (float, optional): Figure size. Defaults to 3.0.
|
75 |
+
return_fig (bool, optional): Whether to return the figure. Defaults to False.
|
76 |
+
set_lim (bool, optional): Whether to set the limits. Defaults to False.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
plt.Figure: Figure and axes or just axes.
|
80 |
+
"""
|
81 |
+
nr, n = len(imgs), len(imgs[0])
|
82 |
+
if not isinstance(cmaps, (list, tuple)):
|
83 |
+
cmaps = [cmaps] * n
|
84 |
+
|
85 |
+
if adaptive:
|
86 |
+
ratios = [i.shape[1] / i.shape[0] for i in imgs[0]] # W / H
|
87 |
+
else:
|
88 |
+
ratios = [4 / 3] * n
|
89 |
+
|
90 |
+
figsize = [sum(ratios) * figs, nr * figs]
|
91 |
+
if fig is None:
|
92 |
+
fig, axs = plt.subplots(
|
93 |
+
nr, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
|
94 |
+
)
|
95 |
+
else:
|
96 |
+
axs = fig.subplots(nr, n, gridspec_kw={"width_ratios": ratios})
|
97 |
+
fig.figure.set_size_inches(figsize)
|
98 |
+
|
99 |
+
if nr == 1 and n == 1:
|
100 |
+
axs = [[axs]]
|
101 |
+
elif n == 1:
|
102 |
+
axs = axs[:, None]
|
103 |
+
elif nr == 1:
|
104 |
+
axs = [axs]
|
105 |
+
|
106 |
+
for j in range(nr):
|
107 |
+
for i in range(n):
|
108 |
+
ax = axs[j][i]
|
109 |
+
ax.imshow(imgs[j][i], cmap=plt.get_cmap(cmaps[i]))
|
110 |
+
ax.set_axis_off()
|
111 |
+
if set_lim:
|
112 |
+
ax.set_xlim([0, imgs[j][i].shape[1]])
|
113 |
+
ax.set_ylim([imgs[j][i].shape[0], 0])
|
114 |
+
if titles:
|
115 |
+
ax.set_title(titles[j][i])
|
116 |
+
if isinstance(fig, plt.Figure):
|
117 |
+
fig.tight_layout(pad=pad)
|
118 |
+
return (fig, axs) if return_fig else axs
|
119 |
+
|
120 |
+
|
121 |
+
def add_text(
|
122 |
+
idx,
|
123 |
+
text,
|
124 |
+
pos=(0.01, 0.99),
|
125 |
+
fs=15,
|
126 |
+
color="w",
|
127 |
+
lcolor="k",
|
128 |
+
lwidth=4,
|
129 |
+
ha="left",
|
130 |
+
va="top",
|
131 |
+
axes=None,
|
132 |
+
**kwargs,
|
133 |
+
):
|
134 |
+
"""Add text to a plot.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
idx (int): Index of the axes.
|
138 |
+
text (str): Text to add.
|
139 |
+
pos (tuple, optional): Text position. Defaults to (0.01, 0.99).
|
140 |
+
fs (int, optional): Font size. Defaults to 15.
|
141 |
+
color (str, optional): Text color. Defaults to "w".
|
142 |
+
lcolor (str, optional): Line color. Defaults to "k".
|
143 |
+
lwidth (int, optional): Line width. Defaults to 4.
|
144 |
+
ha (str, optional): Horizontal alignment. Defaults to "left".
|
145 |
+
va (str, optional): Vertical alignment. Defaults to "top".
|
146 |
+
axes (List[plt.Axes], optional): Axes to put text on. Defaults to None.
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
plt.Text: Text object.
|
150 |
+
"""
|
151 |
+
if axes is None:
|
152 |
+
axes = plt.gcf().axes
|
153 |
+
|
154 |
+
ax = axes[idx]
|
155 |
+
|
156 |
+
t = ax.text(
|
157 |
+
*pos,
|
158 |
+
text,
|
159 |
+
fontsize=fs,
|
160 |
+
ha=ha,
|
161 |
+
va=va,
|
162 |
+
color=color,
|
163 |
+
transform=ax.transAxes,
|
164 |
+
zorder=5,
|
165 |
+
**kwargs,
|
166 |
+
)
|
167 |
+
if lcolor is not None:
|
168 |
+
t.set_path_effects(
|
169 |
+
[
|
170 |
+
path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
|
171 |
+
path_effects.Normal(),
|
172 |
+
]
|
173 |
+
)
|
174 |
+
return t
|
175 |
+
|
176 |
+
|
177 |
+
def plot_heatmaps(
|
178 |
+
heatmaps,
|
179 |
+
vmin=-1e-6, # include negative zero
|
180 |
+
vmax=None,
|
181 |
+
cmap="Spectral",
|
182 |
+
a=0.5,
|
183 |
+
axes=None,
|
184 |
+
contours_every=None,
|
185 |
+
contour_style="solid",
|
186 |
+
colorbar=False,
|
187 |
+
):
|
188 |
+
"""Plot heatmaps with optional contours.
|
189 |
+
|
190 |
+
To plot latitude field, set vmin=-90, vmax=90 and contours_every=15.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
heatmaps (List[np.ndarray | torch.Tensor]): List of 2D heatmaps.
|
194 |
+
vmin (float, optional): Min Value. Defaults to -1e-6.
|
195 |
+
vmax (float, optional): Max Value. Defaults to None.
|
196 |
+
cmap (str, optional): Colormap. Defaults to "Spectral".
|
197 |
+
a (float, optional): Alpha value. Defaults to 0.5.
|
198 |
+
axes (List[plt.Axes], optional): Axes to plot on. Defaults to None.
|
199 |
+
contours_every (int, optional): If not none, will draw contours. Defaults to None.
|
200 |
+
contour_style (str, optional): Style of the contours. Defaults to "solid".
|
201 |
+
colorbar (bool, optional): Whether to show colorbar. Defaults to False.
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
List[plt.Artist]: List of artists.
|
205 |
+
"""
|
206 |
+
if axes is None:
|
207 |
+
axes = plt.gcf().axes
|
208 |
+
artists = []
|
209 |
+
|
210 |
+
for i in range(len(axes)):
|
211 |
+
a_ = a if isinstance(a, float) else a[i]
|
212 |
+
|
213 |
+
if isinstance(heatmaps[i], torch.Tensor):
|
214 |
+
heatmaps[i] = heatmaps[i].cpu().numpy()
|
215 |
+
|
216 |
+
alpha = a_
|
217 |
+
# Plot the heatmap
|
218 |
+
art = axes[i].imshow(
|
219 |
+
heatmaps[i],
|
220 |
+
alpha=alpha,
|
221 |
+
vmin=vmin,
|
222 |
+
vmax=vmax,
|
223 |
+
cmap=cmap,
|
224 |
+
)
|
225 |
+
if colorbar:
|
226 |
+
cmax = vmax or np.percentile(heatmaps[i], 99)
|
227 |
+
art.set_clim(vmin, cmax)
|
228 |
+
cbar = plt.colorbar(art, ax=axes[i])
|
229 |
+
artists.append(cbar)
|
230 |
+
|
231 |
+
artists.append(art)
|
232 |
+
|
233 |
+
if contours_every is not None:
|
234 |
+
# Add contour lines to the heatmap
|
235 |
+
contour_data = np.arange(vmin, vmax + contours_every, contours_every)
|
236 |
+
|
237 |
+
# Get the colormap colors for contour lines
|
238 |
+
contour_colors = [
|
239 |
+
plt.colormaps.get_cmap(cmap)(plt.Normalize(vmin=vmin, vmax=vmax)(level))
|
240 |
+
for level in contour_data
|
241 |
+
]
|
242 |
+
contours = axes[i].contour(
|
243 |
+
heatmaps[i],
|
244 |
+
levels=contour_data,
|
245 |
+
linewidths=2,
|
246 |
+
colors=contour_colors,
|
247 |
+
linestyles=contour_style,
|
248 |
+
)
|
249 |
+
|
250 |
+
contours.set_clim(vmin, vmax)
|
251 |
+
|
252 |
+
fmt = {
|
253 |
+
level: f"{label}°"
|
254 |
+
for level, label in zip(contour_data, contour_data.astype(int).astype(str))
|
255 |
+
}
|
256 |
+
t = axes[i].clabel(contours, inline=True, fmt=fmt, fontsize=16, colors="white")
|
257 |
+
|
258 |
+
for label in t:
|
259 |
+
label.set_path_effects(
|
260 |
+
[
|
261 |
+
path_effects.Stroke(linewidth=1, foreground="k"),
|
262 |
+
path_effects.Normal(),
|
263 |
+
]
|
264 |
+
)
|
265 |
+
artists.append(contours)
|
266 |
+
|
267 |
+
return artists
|
268 |
+
|
269 |
+
|
270 |
+
def plot_horizon_lines(
|
271 |
+
cameras, gravities, line_colors="orange", lw=2, styles="solid", alpha=1.0, ax=None
|
272 |
+
):
|
273 |
+
"""Plot horizon lines on the perspective field.
|
274 |
+
|
275 |
+
Args:
|
276 |
+
cameras (List[Camera]): List of cameras.
|
277 |
+
gravities (List[Gravity]): Gravities.
|
278 |
+
line_colors (str, optional): Line Colors. Defaults to "orange".
|
279 |
+
lw (int, optional): Line width. Defaults to 2.
|
280 |
+
styles (str, optional): Line styles. Defaults to "solid".
|
281 |
+
alpha (float, optional): Alphas. Defaults to 1.0.
|
282 |
+
ax (List[plt.Axes], optional): Axes to draw horizon line on. Defaults to None.
|
283 |
+
"""
|
284 |
+
if not isinstance(line_colors, list):
|
285 |
+
line_colors = [line_colors] * len(cameras)
|
286 |
+
|
287 |
+
if not isinstance(styles, list):
|
288 |
+
styles = [styles] * len(cameras)
|
289 |
+
|
290 |
+
fig = plt.gcf()
|
291 |
+
ax = fig.gca() if ax is None else ax
|
292 |
+
|
293 |
+
if isinstance(ax, plt.Axes):
|
294 |
+
ax = [ax] * len(cameras)
|
295 |
+
|
296 |
+
assert len(ax) == len(cameras), f"{len(ax)}, {len(cameras)}"
|
297 |
+
|
298 |
+
for i in range(len(cameras)):
|
299 |
+
_, lat = get_perspective_field(cameras[i], gravities[i])
|
300 |
+
# horizon line is zero level of the latitude field
|
301 |
+
lat = lat[0, 0].cpu().numpy()
|
302 |
+
contours = ax[i].contour(lat, levels=[0], linewidths=lw, colors=line_colors[i])
|
303 |
+
for contour_line in contours.collections:
|
304 |
+
contour_line.set_linestyle(styles[i])
|
305 |
+
|
306 |
+
|
307 |
+
def plot_vector_fields(
|
308 |
+
vector_fields,
|
309 |
+
cmap="lime",
|
310 |
+
subsample=15,
|
311 |
+
scale=None,
|
312 |
+
lw=None,
|
313 |
+
alphas=0.8,
|
314 |
+
axes=None,
|
315 |
+
):
|
316 |
+
"""Plot vector fields.
|
317 |
+
|
318 |
+
Args:
|
319 |
+
vector_fields (List[torch.Tensor]): List of vector fields of shape (2, H, W).
|
320 |
+
cmap (str, optional): Color of the vectors. Defaults to "lime".
|
321 |
+
subsample (int, optional): Subsample the vector field. Defaults to 15.
|
322 |
+
scale (float, optional): Scale of the vectors. Defaults to None.
|
323 |
+
lw (float, optional): Line width of the vectors. Defaults to None.
|
324 |
+
alphas (float | np.ndarray, optional): Alpha per vector or global. Defaults to 0.8.
|
325 |
+
axes (List[plt.Axes], optional): List of axes to draw on. Defaults to None.
|
326 |
+
|
327 |
+
Returns:
|
328 |
+
List[plt.Artist]: List of artists.
|
329 |
+
"""
|
330 |
+
if axes is None:
|
331 |
+
axes = plt.gcf().axes
|
332 |
+
|
333 |
+
vector_fields = [v.cpu().numpy() if isinstance(v, torch.Tensor) else v for v in vector_fields]
|
334 |
+
|
335 |
+
artists = []
|
336 |
+
|
337 |
+
H, W = vector_fields[0].shape[-2:]
|
338 |
+
if scale is None:
|
339 |
+
scale = subsample / min(H, W)
|
340 |
+
|
341 |
+
if lw is None:
|
342 |
+
lw = 0.1 / subsample
|
343 |
+
|
344 |
+
if alphas is None:
|
345 |
+
alphas = np.ones_like(vector_fields[0][0])
|
346 |
+
alphas = np.stack([alphas] * len(vector_fields), 0)
|
347 |
+
elif isinstance(alphas, float):
|
348 |
+
alphas = np.ones_like(vector_fields[0][0]) * alphas
|
349 |
+
alphas = np.stack([alphas] * len(vector_fields), 0)
|
350 |
+
else:
|
351 |
+
alphas = np.array(alphas)
|
352 |
+
|
353 |
+
subsample = min(W, H) // subsample
|
354 |
+
offset_x = ((W % subsample) + subsample) // 2
|
355 |
+
|
356 |
+
samples_x = np.arange(offset_x, W, subsample)
|
357 |
+
samples_y = np.arange(int(subsample * 0.9), H, subsample)
|
358 |
+
|
359 |
+
x_grid, y_grid = np.meshgrid(samples_x, samples_y)
|
360 |
+
|
361 |
+
for i in range(len(axes)):
|
362 |
+
# vector field of shape (2, H, W) with vectors of norm == 1
|
363 |
+
vector_field = vector_fields[i]
|
364 |
+
|
365 |
+
a = alphas[i][samples_y][:, samples_x]
|
366 |
+
x, y = vector_field[:, samples_y][:, :, samples_x]
|
367 |
+
|
368 |
+
c = cmap
|
369 |
+
if not isinstance(cmap, str):
|
370 |
+
c = cmap[i][samples_y][:, samples_x].reshape(-1, 3)
|
371 |
+
|
372 |
+
s = scale * min(H, W)
|
373 |
+
arrows = axes[i].quiver(
|
374 |
+
x_grid,
|
375 |
+
y_grid,
|
376 |
+
x,
|
377 |
+
y,
|
378 |
+
scale=s,
|
379 |
+
scale_units="width" if H > W else "height",
|
380 |
+
units="width" if H > W else "height",
|
381 |
+
alpha=a,
|
382 |
+
color=c,
|
383 |
+
angles="xy",
|
384 |
+
antialiased=True,
|
385 |
+
width=lw,
|
386 |
+
headaxislength=3.5,
|
387 |
+
zorder=5,
|
388 |
+
)
|
389 |
+
|
390 |
+
artists.append(arrows)
|
391 |
+
|
392 |
+
return artists
|
393 |
+
|
394 |
+
|
395 |
+
def plot_latitudes(
|
396 |
+
latitude,
|
397 |
+
is_radians=True,
|
398 |
+
vmin=-90,
|
399 |
+
vmax=90,
|
400 |
+
cmap="seismic",
|
401 |
+
contours_every=15,
|
402 |
+
alpha=0.4,
|
403 |
+
axes=None,
|
404 |
+
**kwargs,
|
405 |
+
):
|
406 |
+
"""Plot latitudes.
|
407 |
+
|
408 |
+
Args:
|
409 |
+
latitude (List[torch.Tensor]): List of latitudes.
|
410 |
+
is_radians (bool, optional): Whether the latitudes are in radians. Defaults to True.
|
411 |
+
vmin (int, optional): Min value to clip to. Defaults to -90.
|
412 |
+
vmax (int, optional): Max value to clip to. Defaults to 90.
|
413 |
+
cmap (str, optional): Colormap. Defaults to "seismic".
|
414 |
+
contours_every (int, optional): Contours every. Defaults to 15.
|
415 |
+
alpha (float, optional): Alpha value. Defaults to 0.4.
|
416 |
+
axes (List[plt.Axes], optional): Axes to plot on. Defaults to None.
|
417 |
+
|
418 |
+
Returns:
|
419 |
+
List[plt.Artist]: List of artists.
|
420 |
+
"""
|
421 |
+
if axes is None:
|
422 |
+
axes = plt.gcf().axes
|
423 |
+
|
424 |
+
assert len(axes) == len(latitude), f"{len(axes)}, {len(latitude)}"
|
425 |
+
lat = [rad2deg(lat) for lat in latitude] if is_radians else latitude
|
426 |
+
return plot_heatmaps(
|
427 |
+
lat,
|
428 |
+
vmin=vmin,
|
429 |
+
vmax=vmax,
|
430 |
+
cmap=cmap,
|
431 |
+
a=alpha,
|
432 |
+
axes=axes,
|
433 |
+
contours_every=contours_every,
|
434 |
+
**kwargs,
|
435 |
+
)
|
436 |
+
|
437 |
+
|
438 |
+
def plot_perspective_fields(cameras, gravities, axes=None, **kwargs):
|
439 |
+
"""Plot perspective fields.
|
440 |
+
|
441 |
+
Args:
|
442 |
+
cameras (List[Camera]): List of cameras.
|
443 |
+
gravities (List[Gravity]): List of gravities.
|
444 |
+
axes (List[plt.Axes], optional): Axes to plot on. Defaults to None.
|
445 |
+
|
446 |
+
Returns:
|
447 |
+
List[plt.Artist]: List of artists.
|
448 |
+
"""
|
449 |
+
if axes is None:
|
450 |
+
axes = plt.gcf().axes
|
451 |
+
|
452 |
+
assert len(axes) == len(cameras), f"{len(axes)}, {len(cameras)}"
|
453 |
+
|
454 |
+
artists = []
|
455 |
+
for i in range(len(axes)):
|
456 |
+
up, lat = get_perspective_field(cameras[i], gravities[i])
|
457 |
+
artists += plot_vector_fields([up[0]], axes=[axes[i]], **kwargs)
|
458 |
+
artists += plot_latitudes([lat[0, 0]], axes=[axes[i]], **kwargs)
|
459 |
+
|
460 |
+
return artists
|
461 |
+
|
462 |
+
|
463 |
+
def plot_confidences(
|
464 |
+
confidence,
|
465 |
+
as_log=True,
|
466 |
+
vmin=-4,
|
467 |
+
vmax=0,
|
468 |
+
cmap="turbo",
|
469 |
+
alpha=0.4,
|
470 |
+
axes=None,
|
471 |
+
**kwargs,
|
472 |
+
):
|
473 |
+
"""Plot confidences.
|
474 |
+
|
475 |
+
Args:
|
476 |
+
confidence (List[torch.Tensor]): Confidence maps.
|
477 |
+
as_log (bool, optional): Whether to plot in log scale. Defaults to True.
|
478 |
+
vmin (int, optional): Min value to clip to. Defaults to -4.
|
479 |
+
vmax (int, optional): Max value to clip to. Defaults to 0.
|
480 |
+
cmap (str, optional): Colormap. Defaults to "turbo".
|
481 |
+
alpha (float, optional): Alpha value. Defaults to 0.4.
|
482 |
+
axes (List[plt.Axes], optional): Axes to plot on. Defaults to None.
|
483 |
+
|
484 |
+
Returns:
|
485 |
+
List[plt.Artist]: List of artists.
|
486 |
+
"""
|
487 |
+
if axes is None:
|
488 |
+
axes = plt.gcf().axes
|
489 |
+
|
490 |
+
assert len(axes) == len(confidence), f"{len(axes)}, {len(confidence)}"
|
491 |
+
|
492 |
+
if as_log:
|
493 |
+
confidence = [torch.log10(c.clip(1e-5)).clip(vmin, vmax) for c in confidence]
|
494 |
+
|
495 |
+
# normalize to [0, 1]
|
496 |
+
confidence = [(c - c.min()) / (c.max() - c.min()) for c in confidence]
|
497 |
+
return plot_heatmaps(confidence, vmin=0, vmax=1, cmap=cmap, a=alpha, axes=axes, **kwargs)
|
498 |
+
|
499 |
+
|
500 |
+
def save_plot(path, **kw):
|
501 |
+
"""Save the current figure without any white margin."""
|
502 |
+
plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)
|
gradio_app.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Gradio app for GeoCalib inference."""
|
2 |
+
|
3 |
+
from copy import deepcopy
|
4 |
+
from time import time
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import numpy as np
|
8 |
+
import spaces
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from geocalib import viz2d
|
12 |
+
from geocalib.camera import camera_models
|
13 |
+
from geocalib.extractor import GeoCalib
|
14 |
+
from geocalib.perspective_fields import get_perspective_field
|
15 |
+
from geocalib.utils import rad2deg
|
16 |
+
|
17 |
+
# flake8: noqa
|
18 |
+
# mypy: ignore-errors
|
19 |
+
|
20 |
+
description = """
|
21 |
+
<p align="center">
|
22 |
+
<h1 align="center"><ins>GeoCalib</ins> 📸<br>Single-image Calibration with Geometric Optimization</h1>
|
23 |
+
<p align="center">
|
24 |
+
<a href="https://www.linkedin.com/in/alexander-veicht/">Alexander Veicht</a>
|
25 |
+
·
|
26 |
+
<a href="https://psarlin.com/">Paul-Edouard Sarlin</a>
|
27 |
+
·
|
28 |
+
<a href="https://www.linkedin.com/in/philipplindenberger/">Philipp Lindenberger</a>
|
29 |
+
·
|
30 |
+
<a href="https://www.microsoft.com/en-us/research/people/mapoll/">Marc Pollefeys</a>
|
31 |
+
</p>
|
32 |
+
<h2 align="center">
|
33 |
+
<p>ECCV 2024</p>
|
34 |
+
<a href="" align="center">Paper</a> | <!--TODO: update link-->
|
35 |
+
<a href="https://github.com/cvg/GeoCalib" align="center">Code</a> |
|
36 |
+
<a href="https://colab.research.google.com/drive/1oMzgPGppAPAIQxe-s7SRd_q8r7dVfnqo#scrollTo=etdzQZQzoo-K" align="center">Colab</a>
|
37 |
+
</h2>
|
38 |
+
</p>
|
39 |
+
|
40 |
+
## Getting Started
|
41 |
+
GeoCalib accurately estimates the camera intrinsics and gravity direction from a single image by
|
42 |
+
combining geometric optimization with deep learning.
|
43 |
+
|
44 |
+
To get started, upload an image or select one of the examples below.
|
45 |
+
You can choose between different camera models and visualize the calibration results.
|
46 |
+
|
47 |
+
"""
|
48 |
+
|
49 |
+
example_images = [
|
50 |
+
["assets/pinhole-church.jpg"],
|
51 |
+
["assets/pinhole-garden.jpg"],
|
52 |
+
["assets/fisheye-skyline.jpg"],
|
53 |
+
["assets/fisheye-dog-pool.jpg"],
|
54 |
+
]
|
55 |
+
|
56 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
57 |
+
model = GeoCalib().to(device)
|
58 |
+
|
59 |
+
|
60 |
+
def format_output(results):
|
61 |
+
camera, gravity = results["camera"], results["gravity"]
|
62 |
+
vfov = rad2deg(camera.vfov)
|
63 |
+
roll, pitch = rad2deg(gravity.rp).unbind(-1)
|
64 |
+
|
65 |
+
txt = "Estimated parameters:\n"
|
66 |
+
txt += f"Roll: {roll.item():.2f}° (± {rad2deg(results['roll_uncertainty']).item():.2f})°\n"
|
67 |
+
txt += f"Pitch: {pitch.item():.2f}° (± {rad2deg(results['pitch_uncertainty']).item():.2f})°\n"
|
68 |
+
txt += f"vFoV: {vfov.item():.2f}° (± {rad2deg(results['vfov_uncertainty']).item():.2f})°\n"
|
69 |
+
txt += (
|
70 |
+
f"Focal: {camera.f[0, 1].item():.2f} px (± {results['focal_uncertainty'].item():.2f} px)\n"
|
71 |
+
)
|
72 |
+
if hasattr(camera, "k1"):
|
73 |
+
txt += f"K1: {camera.k1[0].item():.2f}\n"
|
74 |
+
return txt
|
75 |
+
|
76 |
+
|
77 |
+
@spaces.GPU(duration=10)
|
78 |
+
def inference(img, camera_model):
|
79 |
+
out = model.calibrate(img.to(device), camera_model=camera_model)
|
80 |
+
save_keys = ["camera", "gravity"] + [f"{k}_uncertainty" for k in ["roll", "pitch", "vfov"]]
|
81 |
+
res = {k: v.cpu() for k, v in out.items() if k in save_keys}
|
82 |
+
# not converting to numpy results in gpu abort
|
83 |
+
res["up_confidence"] = out["up_confidence"].cpu().numpy()
|
84 |
+
res["latitude_confidence"] = out["latitude_confidence"].cpu().numpy()
|
85 |
+
return res
|
86 |
+
|
87 |
+
|
88 |
+
def process_results(
|
89 |
+
image_path,
|
90 |
+
camera_model,
|
91 |
+
plot_up,
|
92 |
+
plot_up_confidence,
|
93 |
+
plot_latitude,
|
94 |
+
plot_latitude_confidence,
|
95 |
+
plot_undistort,
|
96 |
+
):
|
97 |
+
"""Process the image and return the calibration results."""
|
98 |
+
|
99 |
+
if image_path is None:
|
100 |
+
raise gr.Error("Please upload an image first.")
|
101 |
+
|
102 |
+
img = model.load_image(image_path)
|
103 |
+
print("Running inference...")
|
104 |
+
start = time()
|
105 |
+
inference_result = inference(img, camera_model)
|
106 |
+
print(f"Done ({time() - start:.2f}s)")
|
107 |
+
inference_result["image"] = img.cpu()
|
108 |
+
|
109 |
+
if inference_result is None:
|
110 |
+
return ("", np.ones((128, 256, 3)), None)
|
111 |
+
|
112 |
+
plot_img = update_plot(
|
113 |
+
inference_result,
|
114 |
+
plot_up,
|
115 |
+
plot_up_confidence,
|
116 |
+
plot_latitude,
|
117 |
+
plot_latitude_confidence,
|
118 |
+
plot_undistort,
|
119 |
+
)
|
120 |
+
|
121 |
+
return format_output(inference_result), plot_img, inference_result
|
122 |
+
|
123 |
+
|
124 |
+
def update_plot(
|
125 |
+
inference_result,
|
126 |
+
plot_up,
|
127 |
+
plot_up_confidence,
|
128 |
+
plot_latitude,
|
129 |
+
plot_latitude_confidence,
|
130 |
+
plot_undistort,
|
131 |
+
):
|
132 |
+
"""Update the plot based on the selected options."""
|
133 |
+
if inference_result is None:
|
134 |
+
gr.Error("Please calibrate an image first.")
|
135 |
+
return np.ones((128, 256, 3))
|
136 |
+
|
137 |
+
camera, gravity = inference_result["camera"], inference_result["gravity"]
|
138 |
+
img = inference_result["image"].permute(1, 2, 0).numpy()
|
139 |
+
|
140 |
+
if plot_undistort:
|
141 |
+
if not hasattr(camera, "k1"):
|
142 |
+
return img
|
143 |
+
|
144 |
+
return camera.undistort_image(inference_result["image"][None])[0].permute(1, 2, 0).numpy()
|
145 |
+
|
146 |
+
up, lat = get_perspective_field(camera, gravity)
|
147 |
+
|
148 |
+
fig = viz2d.plot_images([img], pad=0)
|
149 |
+
ax = fig.get_axes()
|
150 |
+
|
151 |
+
if plot_up:
|
152 |
+
viz2d.plot_vector_fields([up[0]], axes=[ax[0]])
|
153 |
+
|
154 |
+
if plot_latitude:
|
155 |
+
viz2d.plot_latitudes([lat[0, 0]], axes=[ax[0]])
|
156 |
+
|
157 |
+
if plot_up_confidence:
|
158 |
+
viz2d.plot_confidences([inference_result["up_confidence"][0]], axes=[ax[0]])
|
159 |
+
|
160 |
+
if plot_latitude_confidence:
|
161 |
+
viz2d.plot_confidences([inference_result["latitude_confidence"][0]], axes=[ax[0]])
|
162 |
+
|
163 |
+
fig.canvas.draw()
|
164 |
+
img = np.array(fig.canvas.renderer.buffer_rgba())
|
165 |
+
|
166 |
+
return img
|
167 |
+
|
168 |
+
|
169 |
+
# Create the Gradio interface
|
170 |
+
with gr.Blocks() as demo:
|
171 |
+
gr.Markdown(description)
|
172 |
+
with gr.Row():
|
173 |
+
with gr.Column():
|
174 |
+
gr.Markdown("""## Input Image""")
|
175 |
+
image_path = gr.Image(label="Upload image to calibrate", type="filepath")
|
176 |
+
choice_input = gr.Dropdown(
|
177 |
+
choices=list(camera_models.keys()), label="Choose a camera model.", value="pinhole"
|
178 |
+
)
|
179 |
+
submit_btn = gr.Button("Calibrate 📸")
|
180 |
+
gr.Examples(examples=example_images, inputs=[image_path, choice_input])
|
181 |
+
|
182 |
+
with gr.Column():
|
183 |
+
gr.Markdown("""## Results""")
|
184 |
+
image_output = gr.Image(label="Calibration Results")
|
185 |
+
gr.Markdown("### Plot Options")
|
186 |
+
plot_undistort = gr.Checkbox(
|
187 |
+
label="undistort",
|
188 |
+
value=False,
|
189 |
+
info="Undistorted image "
|
190 |
+
+ "(this is only available for models with distortion "
|
191 |
+
+ "parameters and will overwrite other options).",
|
192 |
+
)
|
193 |
+
|
194 |
+
with gr.Row():
|
195 |
+
plot_up = gr.Checkbox(label="up-vectors", value=True)
|
196 |
+
plot_up_confidence = gr.Checkbox(label="up confidence", value=False)
|
197 |
+
plot_latitude = gr.Checkbox(label="latitude", value=True)
|
198 |
+
plot_latitude_confidence = gr.Checkbox(label="latitude confidence", value=False)
|
199 |
+
|
200 |
+
gr.Markdown("### Calibration Results")
|
201 |
+
text_output = gr.Textbox(label="Estimated parameters", type="text", lines=5)
|
202 |
+
|
203 |
+
# Define the action when the button is clicked
|
204 |
+
inference_state = gr.State()
|
205 |
+
plot_inputs = [
|
206 |
+
inference_state,
|
207 |
+
plot_up,
|
208 |
+
plot_up_confidence,
|
209 |
+
plot_latitude,
|
210 |
+
plot_latitude_confidence,
|
211 |
+
plot_undistort,
|
212 |
+
]
|
213 |
+
submit_btn.click(
|
214 |
+
fn=process_results,
|
215 |
+
inputs=[image_path, choice_input] + plot_inputs[1:],
|
216 |
+
outputs=[text_output, image_output, inference_state],
|
217 |
+
)
|
218 |
+
|
219 |
+
# Define the action when the plot checkboxes are clicked
|
220 |
+
plot_up.change(fn=update_plot, inputs=plot_inputs, outputs=image_output)
|
221 |
+
plot_up_confidence.change(fn=update_plot, inputs=plot_inputs, outputs=image_output)
|
222 |
+
plot_latitude.change(fn=update_plot, inputs=plot_inputs, outputs=image_output)
|
223 |
+
plot_latitude_confidence.change(fn=update_plot, inputs=plot_inputs, outputs=image_output)
|
224 |
+
plot_undistort.change(fn=update_plot, inputs=plot_inputs, outputs=image_output)
|
225 |
+
|
226 |
+
|
227 |
+
# Launch the app
|
228 |
+
demo.launch()
|
hubconf.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Entrypoint for torch hub."""
|
2 |
+
|
3 |
+
dependencies = ["torch", "torchvision", "opencv-python", "kornia", "matplotlib"]
|
4 |
+
|
5 |
+
from geocalib import GeoCalib
|
6 |
+
|
7 |
+
|
8 |
+
def model(*args, **kwargs):
|
9 |
+
"""Pre-trained Geocalib model.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
weights (str): trained variant, "pinhole" (default) or "distorted".
|
13 |
+
"""
|
14 |
+
return GeoCalib(*args, **kwargs)
|
pyproject.toml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools", "wheel"]
|
3 |
+
build-backend = "setuptools.build_meta"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "geocalib"
|
7 |
+
version = "1.0"
|
8 |
+
description = "GeoCalib Inference Package"
|
9 |
+
authors = [
|
10 |
+
{ name = "Alexander Veicht" },
|
11 |
+
{ name = "Paul-Edouard Sarlin" },
|
12 |
+
{ name = "Philipp Lindenberger" },
|
13 |
+
]
|
14 |
+
readme = "README.md"
|
15 |
+
requires-python = ">=3.9"
|
16 |
+
license = { file = "LICENSE" }
|
17 |
+
classifiers = [
|
18 |
+
"Programming Language :: Python :: 3",
|
19 |
+
"License :: OSI Approved :: Apache Software License",
|
20 |
+
"Operating System :: OS Independent",
|
21 |
+
]
|
22 |
+
urls = { Repository = "https://github.com/cvg/GeoCalib" }
|
23 |
+
|
24 |
+
dynamic = ["dependencies"]
|
25 |
+
|
26 |
+
[project.optional-dependencies]
|
27 |
+
dev = ["black==23.9.1", "flake8", "isort==5.12.0"]
|
28 |
+
|
29 |
+
[tool.setuptools]
|
30 |
+
packages = ["geocalib"]
|
31 |
+
|
32 |
+
[tool.setuptools.dynamic]
|
33 |
+
dependencies = { file = ["requirements.txt"] }
|
34 |
+
|
35 |
+
|
36 |
+
[tool.black]
|
37 |
+
line-length = 100
|
38 |
+
exclude = "(venv/|docs/|third_party/)"
|
39 |
+
|
40 |
+
[tool.isort]
|
41 |
+
profile = "black"
|
42 |
+
line_length = 100
|
43 |
+
atomic = true
|
44 |
+
|
45 |
+
[tool.flake8]
|
46 |
+
max-line-length = 100
|
47 |
+
docstring-convention = "google"
|
48 |
+
ignore = ["E203", "W503", "E402"]
|
49 |
+
exclude = [".git", "__pycache__", "venv", "docs", "third_party", "scripts"]
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
opencv-python
|
4 |
+
kornia
|
5 |
+
matplotlib
|
siclib/LICENSE
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
Copyright 2024 ETH Zurich
|
179 |
+
|
180 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
181 |
+
you may not use this file except in compliance with the License.
|
182 |
+
You may obtain a copy of the License at
|
183 |
+
|
184 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
185 |
+
|
186 |
+
Unless required by applicable law or agreed to in writing, software
|
187 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
188 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
189 |
+
See the License for the specific language governing permissions and
|
190 |
+
limitations under the License.
|
siclib/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
formatter = logging.Formatter(
|
4 |
+
fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
5 |
+
)
|
6 |
+
handler = logging.StreamHandler()
|
7 |
+
handler.setFormatter(formatter)
|
8 |
+
handler.setLevel(logging.INFO)
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
logger.setLevel(logging.INFO)
|
12 |
+
logger.addHandler(handler)
|
13 |
+
logger.propagate = False
|
14 |
+
|
15 |
+
__module_name__ = __name__
|
siclib/configs/deepcalib.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- data: openpano-radial
|
3 |
+
- train: deepcalib
|
4 |
+
- model: deepcalib
|
5 |
+
- _self_
|
6 |
+
|
7 |
+
data:
|
8 |
+
train_batch_size: 32
|
9 |
+
val_batch_size: 32
|
10 |
+
test_batch_size: 32
|
11 |
+
augmentations:
|
12 |
+
name: "deepcalib"
|
siclib/configs/geocalib-radial.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- data: openpano-radial
|
3 |
+
- train: geocalib
|
4 |
+
- model: geocalib
|
5 |
+
- _self_
|
6 |
+
|
7 |
+
data:
|
8 |
+
# smaller batch size since lm takes more memory
|
9 |
+
train_batch_size: 18
|
10 |
+
val_batch_size: 18
|
11 |
+
test_batch_size: 18
|
12 |
+
|
13 |
+
model:
|
14 |
+
optimizer:
|
15 |
+
camera_model: simple_radial
|
16 |
+
|
17 |
+
weights: weights/geocalib.tar
|
18 |
+
|
19 |
+
train:
|
20 |
+
lr: 1e-5 # smaller lr since we are fine-tuning
|
21 |
+
num_steps: 200_000 # adapt to see same number of samples as previous training
|
22 |
+
|
23 |
+
lr_schedule:
|
24 |
+
type: SequentialLR
|
25 |
+
on_epoch: false
|
26 |
+
options:
|
27 |
+
# adapt to see same number of samples as previous training
|
28 |
+
milestones: [5_000]
|
29 |
+
schedulers:
|
30 |
+
- type: LinearLR
|
31 |
+
options:
|
32 |
+
start_factor: 1e-3
|
33 |
+
total_iters: 5_000
|
34 |
+
- type: MultiStepLR
|
35 |
+
options:
|
36 |
+
gamma: 0.1
|
37 |
+
# adapt to see same number of samples as previous training
|
38 |
+
milestones: [110_000, 170_000]
|
siclib/configs/geocalib.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- data: openpano
|
3 |
+
- train: geocalib
|
4 |
+
- model: geocalib
|
5 |
+
- _self_
|
6 |
+
|
7 |
+
data:
|
8 |
+
train_batch_size: 24
|
9 |
+
val_batch_size: 24
|
10 |
+
test_batch_size: 24
|
siclib/configs/model/deepcalib.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: networks.deepcalib
|
2 |
+
bounds:
|
3 |
+
roll: [-45, 45]
|
4 |
+
# rho = torch.tan(pitch) / torch.tan(vfov / 2) / 2 -> rho in [-1/0.3526, 1/0.0872]
|
5 |
+
rho: [-2.83607487, 2.83607487]
|
6 |
+
vfov: [20, 105]
|
7 |
+
k1_hat: [-0.7, 0.7]
|
siclib/configs/model/geocalib.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: networks.geocalib
|
2 |
+
|
3 |
+
ll_enc:
|
4 |
+
name: encoders.low_level_encoder
|
5 |
+
|
6 |
+
backbone:
|
7 |
+
name: encoders.mscan
|
8 |
+
weights: weights/mscan_b.pth
|
9 |
+
|
10 |
+
perspective_decoder:
|
11 |
+
name: decoders.perspective_decoder
|
12 |
+
|
13 |
+
up_decoder:
|
14 |
+
name: decoders.up_decoder
|
15 |
+
loss_type: l1
|
16 |
+
use_uncertainty_loss: true
|
17 |
+
decoder:
|
18 |
+
name: decoders.light_hamburger
|
19 |
+
predict_uncertainty: true
|
20 |
+
|
21 |
+
latitude_decoder:
|
22 |
+
name: decoders.latitude_decoder
|
23 |
+
loss_type: l1
|
24 |
+
use_uncertainty_loss: true
|
25 |
+
decoder:
|
26 |
+
name: decoders.light_hamburger
|
27 |
+
predict_uncertainty: true
|
28 |
+
|
29 |
+
optimizer:
|
30 |
+
name: optimization.lm_optimizer
|
31 |
+
camera_model: pinhole
|
siclib/configs/train/deepcalib.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
seed: 0
|
2 |
+
num_steps: 20_000
|
3 |
+
log_every_iter: 500
|
4 |
+
eval_every_iter: 3000
|
5 |
+
test_every_epoch: 1
|
6 |
+
writer: null
|
7 |
+
lr: 1.0e-4
|
8 |
+
clip_grad: 1.0
|
9 |
+
lr_schedule:
|
10 |
+
type: null
|
11 |
+
optimizer: adam
|
12 |
+
submodules: []
|
13 |
+
median_metrics:
|
14 |
+
- roll_error
|
15 |
+
- pitch_error
|
16 |
+
- vfov_error
|
17 |
+
recall_metrics:
|
18 |
+
roll_error: [1, 5, 10]
|
19 |
+
pitch_error: [1, 5, 10]
|
20 |
+
vfov_error: [1, 5, 10]
|
21 |
+
|
22 |
+
plot: [3, "siclib.visualization.visualize_batch.make_perspective_figures"]
|
siclib/configs/train/geocalib.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
seed: 0
|
2 |
+
num_steps: 150_000
|
3 |
+
|
4 |
+
writer: null
|
5 |
+
log_every_iter: 500
|
6 |
+
eval_every_iter: 1000
|
7 |
+
|
8 |
+
lr: 1e-4
|
9 |
+
optimizer: adamw
|
10 |
+
clip_grad: 1.0
|
11 |
+
best_key: loss/param_total
|
12 |
+
|
13 |
+
lr_schedule:
|
14 |
+
type: SequentialLR
|
15 |
+
on_epoch: false
|
16 |
+
options:
|
17 |
+
milestones: [4_000]
|
18 |
+
schedulers:
|
19 |
+
- type: LinearLR
|
20 |
+
options:
|
21 |
+
start_factor: 1e-3
|
22 |
+
total_iters: 4_000
|
23 |
+
- type: MultiStepLR
|
24 |
+
options:
|
25 |
+
gamma: 0.1
|
26 |
+
milestones: [80_000, 130_000]
|
27 |
+
|
28 |
+
submodules: []
|
29 |
+
|
30 |
+
median_metrics:
|
31 |
+
- roll_error
|
32 |
+
- pitch_error
|
33 |
+
- gravity_error
|
34 |
+
- vfov_error
|
35 |
+
- up_angle_error
|
36 |
+
- latitude_angle_error
|
37 |
+
- up_angle_recall@1
|
38 |
+
- up_angle_recall@5
|
39 |
+
- up_angle_recall@10
|
40 |
+
- latitude_angle_recall@1
|
41 |
+
- latitude_angle_recall@5
|
42 |
+
- latitude_angle_recall@10
|
43 |
+
|
44 |
+
recall_metrics:
|
45 |
+
roll_error: [1, 3, 5, 10]
|
46 |
+
pitch_error: [1, 3, 5, 10]
|
47 |
+
gravity_error: [1, 3, 5, 10]
|
48 |
+
vfov_error: [1, 3, 5, 10]
|
49 |
+
|
50 |
+
plot: [3, "siclib.visualization.visualize_batch.make_perspective_figures"]
|
siclib/datasets/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.util
|
2 |
+
|
3 |
+
from siclib.datasets.base_dataset import BaseDataset
|
4 |
+
from siclib.utils.tools import get_class
|
5 |
+
|
6 |
+
|
7 |
+
def get_dataset(name):
|
8 |
+
import_paths = [name, f"{__name__}.{name}"]
|
9 |
+
for path in import_paths:
|
10 |
+
try:
|
11 |
+
spec = importlib.util.find_spec(path)
|
12 |
+
except ModuleNotFoundError:
|
13 |
+
spec = None
|
14 |
+
if spec is not None:
|
15 |
+
try:
|
16 |
+
return get_class(path, BaseDataset)
|
17 |
+
except AssertionError:
|
18 |
+
mod = __import__(path, fromlist=[""])
|
19 |
+
try:
|
20 |
+
return mod.__main_dataset__
|
21 |
+
except AttributeError as exc:
|
22 |
+
print(exc)
|
23 |
+
continue
|
24 |
+
|
25 |
+
raise RuntimeError(f'Dataset {name} not found in any of [{" ".join(import_paths)}]')
|
siclib/datasets/augmentations.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import albumentations as A
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from albumentations.pytorch.transforms import ToTensorV2
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
|
10 |
+
|
11 |
+
# flake8: noqa
|
12 |
+
# mypy: ignore-errors
|
13 |
+
class IdentityTransform(A.ImageOnlyTransform):
|
14 |
+
def apply(self, img, **params):
|
15 |
+
return img
|
16 |
+
|
17 |
+
def get_transform_init_args_names(self):
|
18 |
+
return ()
|
19 |
+
|
20 |
+
|
21 |
+
class RandomAdditiveShade(A.ImageOnlyTransform):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
nb_ellipses=10,
|
25 |
+
transparency_limit=[-0.5, 0.8],
|
26 |
+
kernel_size_limit=[150, 350],
|
27 |
+
always_apply=False,
|
28 |
+
p=0.5,
|
29 |
+
):
|
30 |
+
super().__init__(always_apply, p)
|
31 |
+
self.nb_ellipses = nb_ellipses
|
32 |
+
self.transparency_limit = transparency_limit
|
33 |
+
self.kernel_size_limit = kernel_size_limit
|
34 |
+
|
35 |
+
def apply(self, img, **params):
|
36 |
+
if img.dtype == np.float32:
|
37 |
+
shaded = self._py_additive_shade(img * 255.0)
|
38 |
+
shaded /= 255.0
|
39 |
+
elif img.dtype == np.uint8:
|
40 |
+
shaded = self._py_additive_shade(img.astype(np.float32))
|
41 |
+
shaded = shaded.astype(np.uint8)
|
42 |
+
else:
|
43 |
+
raise NotImplementedError(f"Data augmentation not available for type: {img.dtype}")
|
44 |
+
return shaded
|
45 |
+
|
46 |
+
def _py_additive_shade(self, img):
|
47 |
+
grayscale = len(img.shape) == 2
|
48 |
+
if grayscale:
|
49 |
+
img = img[None]
|
50 |
+
min_dim = min(img.shape[:2]) / 4
|
51 |
+
mask = np.zeros(img.shape[:2], img.dtype)
|
52 |
+
for i in range(self.nb_ellipses):
|
53 |
+
ax = int(max(np.random.rand() * min_dim, min_dim / 5))
|
54 |
+
ay = int(max(np.random.rand() * min_dim, min_dim / 5))
|
55 |
+
max_rad = max(ax, ay)
|
56 |
+
x = np.random.randint(max_rad, img.shape[1] - max_rad) # center
|
57 |
+
y = np.random.randint(max_rad, img.shape[0] - max_rad)
|
58 |
+
angle = np.random.rand() * 90
|
59 |
+
cv2.ellipse(mask, (x, y), (ax, ay), angle, 0, 360, 255, -1)
|
60 |
+
|
61 |
+
transparency = np.random.uniform(*self.transparency_limit)
|
62 |
+
ks = np.random.randint(*self.kernel_size_limit)
|
63 |
+
if (ks % 2) == 0: # kernel_size has to be odd
|
64 |
+
ks += 1
|
65 |
+
mask = cv2.GaussianBlur(mask.astype(np.float32), (ks, ks), 0)
|
66 |
+
shaded = img * (1 - transparency * mask[..., np.newaxis] / 255.0)
|
67 |
+
out = np.clip(shaded, 0, 255)
|
68 |
+
if grayscale:
|
69 |
+
out = out.squeeze(0)
|
70 |
+
return out
|
71 |
+
|
72 |
+
def get_transform_init_args_names(self):
|
73 |
+
return "transparency_limit", "kernel_size_limit", "nb_ellipses"
|
74 |
+
|
75 |
+
|
76 |
+
def kw(entry: Union[float, dict], n=None, **default):
|
77 |
+
if not isinstance(entry, dict):
|
78 |
+
entry = {"p": entry}
|
79 |
+
entry = OmegaConf.create(entry)
|
80 |
+
if n is not None:
|
81 |
+
entry = default.get(n, entry)
|
82 |
+
return OmegaConf.merge(default, entry)
|
83 |
+
|
84 |
+
|
85 |
+
def kwi(entry: Union[float, dict], n=None, **default):
|
86 |
+
conf = kw(entry, n=n, **default)
|
87 |
+
return {k: conf[k] for k in set(default.keys()).union(set(["p"]))}
|
88 |
+
|
89 |
+
|
90 |
+
def replay_str(transforms, s="Replay:\n", log_inactive=True):
|
91 |
+
for t in transforms:
|
92 |
+
if "transforms" in t.keys():
|
93 |
+
s = replay_str(t["transforms"], s=s)
|
94 |
+
elif t["applied"] or log_inactive:
|
95 |
+
s += t["__class_fullname__"] + " " + str(t["applied"]) + "\n"
|
96 |
+
return s
|
97 |
+
|
98 |
+
|
99 |
+
class BaseAugmentation(object):
|
100 |
+
base_default_conf = {
|
101 |
+
"name": "???",
|
102 |
+
"shuffle": False,
|
103 |
+
"p": 1.0,
|
104 |
+
"verbose": False,
|
105 |
+
"dtype": "uint8", # (byte, float)
|
106 |
+
}
|
107 |
+
|
108 |
+
default_conf = {}
|
109 |
+
|
110 |
+
def __init__(self, conf={}):
|
111 |
+
"""Perform some logic and call the _init method of the child model."""
|
112 |
+
default_conf = OmegaConf.merge(
|
113 |
+
OmegaConf.create(self.base_default_conf),
|
114 |
+
OmegaConf.create(self.default_conf),
|
115 |
+
)
|
116 |
+
OmegaConf.set_struct(default_conf, True)
|
117 |
+
if isinstance(conf, dict):
|
118 |
+
conf = OmegaConf.create(conf)
|
119 |
+
self.conf = OmegaConf.merge(default_conf, conf)
|
120 |
+
OmegaConf.set_readonly(self.conf, True)
|
121 |
+
self._init(self.conf)
|
122 |
+
|
123 |
+
self.conf = OmegaConf.merge(self.conf, conf)
|
124 |
+
if self.conf.verbose:
|
125 |
+
self.compose = A.ReplayCompose
|
126 |
+
else:
|
127 |
+
self.compose = A.Compose
|
128 |
+
if self.conf.dtype == "uint8":
|
129 |
+
self.dtype = np.uint8
|
130 |
+
self.preprocess = A.FromFloat(always_apply=True, dtype="uint8")
|
131 |
+
self.postprocess = A.ToFloat(always_apply=True)
|
132 |
+
elif self.conf.dtype == "float32":
|
133 |
+
self.dtype = np.float32
|
134 |
+
self.preprocess = A.ToFloat(always_apply=True)
|
135 |
+
self.postprocess = IdentityTransform()
|
136 |
+
else:
|
137 |
+
raise ValueError(f"Unsupported dtype {self.conf.dtype}")
|
138 |
+
self.to_tensor = ToTensorV2()
|
139 |
+
|
140 |
+
def _init(self, conf):
|
141 |
+
"""Child class overwrites this, setting up a list of transforms"""
|
142 |
+
self.transforms = []
|
143 |
+
|
144 |
+
def __call__(self, image, return_tensor=False):
|
145 |
+
"""image as HW or HWC"""
|
146 |
+
if isinstance(image, torch.Tensor):
|
147 |
+
image = image.cpu().numpy()
|
148 |
+
data = {"image": image}
|
149 |
+
if image.dtype != self.dtype:
|
150 |
+
data = self.preprocess(**data)
|
151 |
+
transforms = self.transforms
|
152 |
+
if self.conf.shuffle:
|
153 |
+
order = [i for i, _ in enumerate(transforms)]
|
154 |
+
np.random.shuffle(order)
|
155 |
+
transforms = [transforms[i] for i in order]
|
156 |
+
transformed = self.compose(transforms, p=self.conf.p)(**data)
|
157 |
+
if self.conf.verbose:
|
158 |
+
print(replay_str(transformed["replay"]["transforms"]))
|
159 |
+
transformed = self.postprocess(**transformed)
|
160 |
+
if return_tensor:
|
161 |
+
return self.to_tensor(**transformed)["image"]
|
162 |
+
else:
|
163 |
+
return transformed["image"]
|
164 |
+
|
165 |
+
|
166 |
+
class IdentityAugmentation(BaseAugmentation):
|
167 |
+
default_conf = {}
|
168 |
+
|
169 |
+
def _init(self, conf):
|
170 |
+
self.transforms = [IdentityTransform(p=1.0)]
|
171 |
+
|
172 |
+
|
173 |
+
class DarkAugmentation(BaseAugmentation):
|
174 |
+
default_conf = {"p": 0.75}
|
175 |
+
|
176 |
+
def _init(self, conf):
|
177 |
+
bright_contr = 0.5
|
178 |
+
blur = 0.1
|
179 |
+
random_gamma = 0.1
|
180 |
+
hue = 0.1
|
181 |
+
self.transforms = [
|
182 |
+
A.RandomRain(p=0.2),
|
183 |
+
A.RandomBrightnessContrast(
|
184 |
+
**kw(
|
185 |
+
bright_contr,
|
186 |
+
brightness_limit=(-0.4, 0.0),
|
187 |
+
contrast_limit=(-0.3, 0.0),
|
188 |
+
)
|
189 |
+
),
|
190 |
+
A.OneOf(
|
191 |
+
[
|
192 |
+
A.Blur(**kwi(blur, p=0.1, blur_limit=(3, 9), n="blur")),
|
193 |
+
A.MotionBlur(**kwi(blur, p=0.2, blur_limit=(3, 25), n="motion_blur")),
|
194 |
+
A.ISONoise(),
|
195 |
+
A.ImageCompression(),
|
196 |
+
],
|
197 |
+
**kwi(blur, p=0.1),
|
198 |
+
),
|
199 |
+
A.RandomGamma(**kw(random_gamma, gamma_limit=(15, 65))),
|
200 |
+
A.OneOf(
|
201 |
+
[
|
202 |
+
A.Equalize(),
|
203 |
+
A.CLAHE(p=0.2),
|
204 |
+
A.ToGray(),
|
205 |
+
A.ToSepia(p=0.1),
|
206 |
+
A.HueSaturationValue(**kw(hue, val_shift_limit=(-100, -40))),
|
207 |
+
],
|
208 |
+
p=0.5,
|
209 |
+
),
|
210 |
+
]
|
211 |
+
|
212 |
+
|
213 |
+
class DefaultAugmentation(BaseAugmentation):
|
214 |
+
default_conf = {"p": 1.0}
|
215 |
+
|
216 |
+
def _init(self, conf):
|
217 |
+
self.transforms = [
|
218 |
+
A.RandomBrightnessContrast(p=0.2),
|
219 |
+
A.HueSaturationValue(p=0.2),
|
220 |
+
A.ToGray(p=0.2),
|
221 |
+
A.ImageCompression(quality_lower=30, quality_upper=100, p=0.5),
|
222 |
+
A.OneOf(
|
223 |
+
[
|
224 |
+
A.MotionBlur(p=0.2),
|
225 |
+
A.MedianBlur(blur_limit=3, p=0.1),
|
226 |
+
A.Blur(blur_limit=3, p=0.1),
|
227 |
+
],
|
228 |
+
p=0.2,
|
229 |
+
),
|
230 |
+
]
|
231 |
+
|
232 |
+
|
233 |
+
class PerspectiveAugmentation(BaseAugmentation):
|
234 |
+
default_conf = {"p": 1.0}
|
235 |
+
|
236 |
+
def _init(self, conf):
|
237 |
+
self.transforms = [
|
238 |
+
A.RandomBrightnessContrast(p=0.2),
|
239 |
+
A.HueSaturationValue(p=0.2),
|
240 |
+
A.ToGray(p=0.2),
|
241 |
+
A.ImageCompression(quality_lower=30, quality_upper=100, p=0.5),
|
242 |
+
A.OneOf(
|
243 |
+
[
|
244 |
+
A.MotionBlur(p=0.2),
|
245 |
+
A.MedianBlur(blur_limit=3, p=0.1),
|
246 |
+
A.Blur(blur_limit=3, p=0.1),
|
247 |
+
],
|
248 |
+
p=0.2,
|
249 |
+
),
|
250 |
+
]
|
251 |
+
|
252 |
+
|
253 |
+
class DeepCalibAugmentations(BaseAugmentation):
|
254 |
+
default_conf = {"p": 1.0}
|
255 |
+
|
256 |
+
def _init(self, conf):
|
257 |
+
self.transforms = [
|
258 |
+
A.RandomBrightnessContrast(p=0.5),
|
259 |
+
A.GaussNoise(var_limit=(5.0, 112.0), mean=0, per_channel=True, p=0.75),
|
260 |
+
A.Downscale(
|
261 |
+
scale_min=0.5,
|
262 |
+
scale_max=0.95,
|
263 |
+
interpolation=dict(downscale=cv2.INTER_AREA, upscale=cv2.INTER_LINEAR),
|
264 |
+
p=0.5,
|
265 |
+
),
|
266 |
+
A.Downscale(scale_min=0.5, scale_max=0.95, interpolation=cv2.INTER_LINEAR, p=0.5),
|
267 |
+
A.ImageCompression(quality_lower=20, quality_upper=85, p=1, always_apply=True),
|
268 |
+
A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.4),
|
269 |
+
A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=0.5),
|
270 |
+
A.ToGray(always_apply=False, p=0.2),
|
271 |
+
A.GaussianBlur(blur_limit=(3, 5), sigma_limit=0, p=0.25),
|
272 |
+
A.MotionBlur(blur_limit=5, allow_shifted=True, p=0.25),
|
273 |
+
A.MultiplicativeNoise(multiplier=[0.85, 1.15], elementwise=True, p=0.5),
|
274 |
+
]
|
275 |
+
|
276 |
+
|
277 |
+
class GeoCalibAugmentations(BaseAugmentation):
|
278 |
+
default_conf = {"p": 1.0}
|
279 |
+
|
280 |
+
def _init(self, conf):
|
281 |
+
self.color_transforms = [
|
282 |
+
A.RandomGamma(gamma_limit=(80, 180), p=0.8),
|
283 |
+
A.RandomToneCurve(scale=0.1, p=0.5),
|
284 |
+
A.RandomBrightnessContrast(p=0.5),
|
285 |
+
A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.4),
|
286 |
+
A.OneOf([A.ToGray(p=0.1), A.ToSepia(p=0.1), IdentityTransform(p=0.8)], p=1),
|
287 |
+
]
|
288 |
+
|
289 |
+
self.noise_transforms = [
|
290 |
+
A.GaussNoise(var_limit=(5.0, 112.0), mean=0, per_channel=True, p=0.75),
|
291 |
+
A.ImageCompression(quality_lower=20, quality_upper=100, p=1),
|
292 |
+
A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), p=0.5),
|
293 |
+
A.OneOrOther(
|
294 |
+
first=A.Compose(
|
295 |
+
[
|
296 |
+
A.AdvancedBlur(
|
297 |
+
p=1,
|
298 |
+
blur_limit=(3, 7),
|
299 |
+
sigmaX_limit=(0.2, 1.0),
|
300 |
+
sigmaY_limit=(0.2, 1.0),
|
301 |
+
rotate_limit=(-90, 90),
|
302 |
+
beta_limit=(0.5, 8.0),
|
303 |
+
noise_limit=(0.9, 1.1),
|
304 |
+
),
|
305 |
+
A.Sharpen(p=0.5, alpha=(0.2, 0.5), lightness=(0.5, 1.0)),
|
306 |
+
]
|
307 |
+
),
|
308 |
+
second=A.Compose(
|
309 |
+
[
|
310 |
+
A.Sharpen(p=0.5, alpha=(0.2, 0.5), lightness=(0.5, 1.0)),
|
311 |
+
A.AdvancedBlur(
|
312 |
+
p=1,
|
313 |
+
blur_limit=(3, 7),
|
314 |
+
sigmaX_limit=(0.2, 1.0),
|
315 |
+
sigmaY_limit=(0.2, 1.0),
|
316 |
+
rotate_limit=(-90, 90),
|
317 |
+
beta_limit=(0.5, 8.0),
|
318 |
+
noise_limit=(0.9, 1.1),
|
319 |
+
),
|
320 |
+
]
|
321 |
+
),
|
322 |
+
),
|
323 |
+
]
|
324 |
+
|
325 |
+
self.image_transforms = [
|
326 |
+
A.OneOf(
|
327 |
+
[
|
328 |
+
A.Downscale(
|
329 |
+
scale_min=0.5,
|
330 |
+
scale_max=0.99,
|
331 |
+
interpolation=dict(downscale=down, upscale=up),
|
332 |
+
p=1,
|
333 |
+
)
|
334 |
+
for down, up in [
|
335 |
+
(cv2.INTER_AREA, cv2.INTER_LINEAR),
|
336 |
+
(cv2.INTER_LINEAR, cv2.INTER_CUBIC),
|
337 |
+
(cv2.INTER_CUBIC, cv2.INTER_LINEAR),
|
338 |
+
(cv2.INTER_LINEAR, cv2.INTER_AREA),
|
339 |
+
]
|
340 |
+
],
|
341 |
+
p=1,
|
342 |
+
)
|
343 |
+
]
|
344 |
+
|
345 |
+
self.transforms = [
|
346 |
+
*self.color_transforms,
|
347 |
+
*self.noise_transforms,
|
348 |
+
*self.image_transforms,
|
349 |
+
]
|
350 |
+
|
351 |
+
|
352 |
+
augmentations = {
|
353 |
+
"default": DefaultAugmentation,
|
354 |
+
"dark": DarkAugmentation,
|
355 |
+
"perspective": PerspectiveAugmentation,
|
356 |
+
"deepcalib": DeepCalibAugmentations,
|
357 |
+
"geocalib": GeoCalibAugmentations,
|
358 |
+
"identity": IdentityAugmentation,
|
359 |
+
}
|
siclib/datasets/base_dataset.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Base class for dataset.
|
2 |
+
|
3 |
+
See mnist.py for an example of dataset.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import collections
|
7 |
+
import logging
|
8 |
+
from abc import ABCMeta, abstractmethod
|
9 |
+
|
10 |
+
import omegaconf
|
11 |
+
import torch
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
from torch.utils.data import DataLoader, Sampler, get_worker_info
|
14 |
+
from torch.utils.data._utils.collate import default_collate_err_msg_format, np_str_obj_array_pattern
|
15 |
+
|
16 |
+
from siclib.utils.tensor import string_classes
|
17 |
+
from siclib.utils.tools import set_num_threads, set_seed
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
# mypy: ignore-errors
|
22 |
+
|
23 |
+
|
24 |
+
class LoopSampler(Sampler):
|
25 |
+
"""Infinite sampler that loops over a given number of elements."""
|
26 |
+
|
27 |
+
def __init__(self, loop_size: int, total_size: int = None):
|
28 |
+
"""Initialize the sampler.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
loop_size (int): Number of elements to loop over.
|
32 |
+
total_size (int, optional): Total number of elements. Defaults to None.
|
33 |
+
"""
|
34 |
+
self.loop_size = loop_size
|
35 |
+
self.total_size = total_size - (total_size % loop_size)
|
36 |
+
|
37 |
+
def __iter__(self):
|
38 |
+
"""Return an iterator over the elements."""
|
39 |
+
return (i % self.loop_size for i in range(self.total_size))
|
40 |
+
|
41 |
+
def __len__(self):
|
42 |
+
"""Return the number of elements."""
|
43 |
+
return self.total_size
|
44 |
+
|
45 |
+
|
46 |
+
def worker_init_fn(i):
|
47 |
+
"""Initialize the workers with a different seed."""
|
48 |
+
info = get_worker_info()
|
49 |
+
if hasattr(info.dataset, "conf"):
|
50 |
+
conf = info.dataset.conf
|
51 |
+
set_seed(info.id + conf.seed)
|
52 |
+
set_num_threads(conf.num_threads)
|
53 |
+
else:
|
54 |
+
set_num_threads(1)
|
55 |
+
|
56 |
+
|
57 |
+
def collate(batch):
|
58 |
+
"""Difference with PyTorch default_collate: it can stack of other objects."""
|
59 |
+
if not isinstance(batch, list): # no batching
|
60 |
+
return batch
|
61 |
+
elem = batch[0]
|
62 |
+
elem_type = type(elem)
|
63 |
+
if isinstance(elem, torch.Tensor):
|
64 |
+
# out = None
|
65 |
+
if torch.utils.data.get_worker_info() is not None:
|
66 |
+
# If we're in a background process, concatenate directly into a
|
67 |
+
# shared memory tensor to avoid an extra copy
|
68 |
+
numel = sum([x.numel() for x in batch])
|
69 |
+
try:
|
70 |
+
_ = elem.untyped_storage()._new_shared(numel)
|
71 |
+
except AttributeError:
|
72 |
+
_ = elem.storage()._new_shared(numel)
|
73 |
+
return torch.stack(batch, dim=0)
|
74 |
+
elif (
|
75 |
+
elem_type.__module__ == "numpy"
|
76 |
+
and elem_type.__name__ != "str_"
|
77 |
+
and elem_type.__name__ != "string_"
|
78 |
+
):
|
79 |
+
if elem_type.__name__ in ["ndarray", "memmap"]:
|
80 |
+
# array of string classes and object
|
81 |
+
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
82 |
+
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
83 |
+
return collate([torch.as_tensor(b) for b in batch])
|
84 |
+
elif elem.shape == (): # scalars
|
85 |
+
return torch.as_tensor(batch)
|
86 |
+
elif isinstance(elem, float):
|
87 |
+
return torch.tensor(batch, dtype=torch.float64)
|
88 |
+
elif isinstance(elem, int):
|
89 |
+
return torch.tensor(batch)
|
90 |
+
elif isinstance(elem, string_classes):
|
91 |
+
return batch
|
92 |
+
elif isinstance(elem, collections.abc.Mapping):
|
93 |
+
return {key: collate([d[key] for d in batch]) for key in elem}
|
94 |
+
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
|
95 |
+
return elem_type(*(collate(samples) for samples in zip(*batch)))
|
96 |
+
elif isinstance(elem, collections.abc.Sequence):
|
97 |
+
# check to make sure that the elements in batch have consistent size
|
98 |
+
it = iter(batch)
|
99 |
+
elem_size = len(next(it))
|
100 |
+
if any(len(elem) != elem_size for elem in it):
|
101 |
+
raise RuntimeError("each element in list of batch should be of equal size")
|
102 |
+
transposed = zip(*batch)
|
103 |
+
return [collate(samples) for samples in transposed]
|
104 |
+
elif elem is None:
|
105 |
+
return elem
|
106 |
+
else:
|
107 |
+
# try to stack anyway in case the object implements stacking.
|
108 |
+
return torch.stack(batch, 0)
|
109 |
+
|
110 |
+
|
111 |
+
class BaseDataset(metaclass=ABCMeta):
|
112 |
+
"""Base class for dataset.
|
113 |
+
|
114 |
+
What the dataset model is expect to declare:
|
115 |
+
default_conf: dictionary of the default configuration of the dataset.
|
116 |
+
It overwrites base_default_conf in BaseModel, and it is overwritten by
|
117 |
+
the user-provided configuration passed to __init__.
|
118 |
+
Configurations can be nested.
|
119 |
+
|
120 |
+
_init(self, conf): initialization method, where conf is the final
|
121 |
+
configuration object (also accessible with `self.conf`). Accessing
|
122 |
+
unknown configuration entries will raise an error.
|
123 |
+
|
124 |
+
get_dataset(self, split): method that returns an instance of
|
125 |
+
torch.utils.data.Dataset corresponding to the requested split string,
|
126 |
+
which can be `'train'`, `'val'`, or `'test'`.
|
127 |
+
"""
|
128 |
+
|
129 |
+
base_default_conf = {
|
130 |
+
"name": "???",
|
131 |
+
"num_workers": "???",
|
132 |
+
"train_batch_size": "???",
|
133 |
+
"val_batch_size": "???",
|
134 |
+
"test_batch_size": "???",
|
135 |
+
"shuffle_training": True,
|
136 |
+
"batch_size": 1,
|
137 |
+
"num_threads": 1,
|
138 |
+
"seed": 0,
|
139 |
+
"prefetch_factor": 2,
|
140 |
+
}
|
141 |
+
default_conf = {}
|
142 |
+
|
143 |
+
def __init__(self, conf):
|
144 |
+
"""Perform some logic and call the _init method of the child model."""
|
145 |
+
default_conf = OmegaConf.merge(
|
146 |
+
OmegaConf.create(self.base_default_conf),
|
147 |
+
OmegaConf.create(self.default_conf),
|
148 |
+
)
|
149 |
+
OmegaConf.set_struct(default_conf, True)
|
150 |
+
if isinstance(conf, dict):
|
151 |
+
conf = OmegaConf.create(conf)
|
152 |
+
self.conf = OmegaConf.merge(default_conf, conf)
|
153 |
+
OmegaConf.set_readonly(self.conf, True)
|
154 |
+
logger.info(f"Creating dataset {self.__class__.__name__}")
|
155 |
+
self._init(self.conf)
|
156 |
+
|
157 |
+
@abstractmethod
|
158 |
+
def _init(self, conf):
|
159 |
+
"""To be implemented by the child class."""
|
160 |
+
raise NotImplementedError
|
161 |
+
|
162 |
+
@abstractmethod
|
163 |
+
def get_dataset(self, split):
|
164 |
+
"""To be implemented by the child class."""
|
165 |
+
raise NotImplementedError
|
166 |
+
|
167 |
+
def get_data_loader(self, split, shuffle=None, pinned=False, distributed=False):
|
168 |
+
"""Return a data loader for a given split."""
|
169 |
+
assert split in ["train", "val", "test"]
|
170 |
+
dataset = self.get_dataset(split)
|
171 |
+
try:
|
172 |
+
batch_size = self.conf[f"{split}_batch_size"]
|
173 |
+
except omegaconf.MissingMandatoryValue:
|
174 |
+
batch_size = self.conf.batch_size
|
175 |
+
num_workers = self.conf.get("num_workers", batch_size)
|
176 |
+
if distributed:
|
177 |
+
shuffle = False
|
178 |
+
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
179 |
+
else:
|
180 |
+
sampler = None
|
181 |
+
if shuffle is None:
|
182 |
+
shuffle = split == "train" and self.conf.shuffle_training
|
183 |
+
return DataLoader(
|
184 |
+
dataset,
|
185 |
+
batch_size=batch_size,
|
186 |
+
shuffle=shuffle,
|
187 |
+
sampler=sampler,
|
188 |
+
pin_memory=pinned,
|
189 |
+
collate_fn=collate,
|
190 |
+
num_workers=num_workers,
|
191 |
+
worker_init_fn=worker_init_fn,
|
192 |
+
prefetch_factor=self.conf.prefetch_factor,
|
193 |
+
)
|
194 |
+
|
195 |
+
def get_overfit_loader(self, split: str):
|
196 |
+
"""Return an overfit data loader.
|
197 |
+
|
198 |
+
The training set is composed of a single duplicated batch, while
|
199 |
+
the validation and test sets contain a single copy of this same batch.
|
200 |
+
This is useful to debug a model and make sure that losses and metrics
|
201 |
+
correlate well.
|
202 |
+
"""
|
203 |
+
assert split in {"train", "val", "test"}
|
204 |
+
dataset = self.get_dataset("train")
|
205 |
+
sampler = LoopSampler(
|
206 |
+
self.conf.batch_size,
|
207 |
+
len(dataset) if split == "train" else self.conf.batch_size,
|
208 |
+
)
|
209 |
+
num_workers = self.conf.get("num_workers", self.conf.batch_size)
|
210 |
+
return DataLoader(
|
211 |
+
dataset,
|
212 |
+
batch_size=self.conf.batch_size,
|
213 |
+
pin_memory=True,
|
214 |
+
num_workers=num_workers,
|
215 |
+
sampler=sampler,
|
216 |
+
worker_init_fn=worker_init_fn,
|
217 |
+
collate_fn=collate,
|
218 |
+
)
|
siclib/datasets/configs/openpano-radial.yaml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: openpano_radial
|
2 |
+
base_dir: data/openpano
|
3 |
+
pano_dir: "${.base_dir}/panoramas"
|
4 |
+
images_per_pano: 16
|
5 |
+
resize_factor: null
|
6 |
+
n_workers: 1
|
7 |
+
device: cpu
|
8 |
+
overwrite: true
|
9 |
+
parameter_dists:
|
10 |
+
roll:
|
11 |
+
type: uniform # uni[-45, 45]
|
12 |
+
options:
|
13 |
+
loc: -0.7853981633974483 # -45 degrees
|
14 |
+
scale: 1.5707963267948966 # 90 degrees
|
15 |
+
pitch:
|
16 |
+
type: uniform # uni[-45, 45]
|
17 |
+
options:
|
18 |
+
loc: -0.7853981633974483 # -45 degrees
|
19 |
+
scale: 1.5707963267948966 # 90 degrees
|
20 |
+
vfov:
|
21 |
+
type: uniform # uni[20, 105]
|
22 |
+
options:
|
23 |
+
loc: 0.3490658503988659 # 20 degrees
|
24 |
+
scale: 1.48352986419518 # 85 degrees
|
25 |
+
k1_hat:
|
26 |
+
type: truncnorm
|
27 |
+
options:
|
28 |
+
a: -4.285714285714286 # corresponds to -0.3
|
29 |
+
b: 4.285714285714286 # corresponds to 0.3
|
30 |
+
loc: 0
|
31 |
+
scale: 0.07
|
32 |
+
resize_factor:
|
33 |
+
type: uniform
|
34 |
+
options:
|
35 |
+
loc: 1.2
|
36 |
+
scale: 0.5
|
37 |
+
shape:
|
38 |
+
type: fix
|
39 |
+
value:
|
40 |
+
- 640
|
41 |
+
- 640
|
siclib/datasets/configs/openpano.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: openpano
|
2 |
+
base_dir: data/openpano
|
3 |
+
pano_dir: "${.base_dir}/panoramas"
|
4 |
+
images_per_pano: 16
|
5 |
+
resize_factor: null
|
6 |
+
n_workers: 1
|
7 |
+
device: cpu
|
8 |
+
overwrite: true
|
9 |
+
parameter_dists:
|
10 |
+
roll:
|
11 |
+
type: uniform # uni[-45, 45]
|
12 |
+
options:
|
13 |
+
loc: -0.7853981633974483 # -45 degrees
|
14 |
+
scale: 1.5707963267948966 # 90 degrees
|
15 |
+
pitch:
|
16 |
+
type: uniform # uni[-45, 45]
|
17 |
+
options:
|
18 |
+
loc: -0.7853981633974483 # -45 degrees
|
19 |
+
scale: 1.5707963267948966 # 90 degrees
|
20 |
+
vfov:
|
21 |
+
type: uniform # uni[20, 105]
|
22 |
+
options:
|
23 |
+
loc: 0.3490658503988659 # 20 degrees
|
24 |
+
scale: 1.48352986419518 # 85 degrees
|
25 |
+
resize_factor:
|
26 |
+
type: uniform
|
27 |
+
options:
|
28 |
+
loc: 1.2
|
29 |
+
scale: 0.5
|
30 |
+
shape:
|
31 |
+
type: fix
|
32 |
+
value:
|
33 |
+
- 640
|
34 |
+
- 640
|
siclib/datasets/create_dataset_from_pano.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Script to create a dataset from panorama images."""
|
2 |
+
|
3 |
+
import hashlib
|
4 |
+
import logging
|
5 |
+
from concurrent import futures
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import hydra
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import numpy as np
|
11 |
+
import pandas as pd
|
12 |
+
import scipy
|
13 |
+
import torch
|
14 |
+
from omegaconf import DictConfig, OmegaConf
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
from siclib.geometry.camera import camera_models
|
18 |
+
from siclib.geometry.gravity import Gravity
|
19 |
+
from siclib.utils.conversions import deg2rad, focal2fov, fov2focal, rad2deg
|
20 |
+
from siclib.utils.image import load_image, write_image
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
# mypy: ignore-errors
|
26 |
+
|
27 |
+
|
28 |
+
def max_radius(a, b):
|
29 |
+
"""Compute the maximum radius of a Brown distortion model."""
|
30 |
+
discrim = a * a - 4 * b
|
31 |
+
# if torch.isfinite(discrim) and discrim >= 0.0:
|
32 |
+
# discrim = np.sqrt(discrim) - a
|
33 |
+
# if discrim > 0.0:
|
34 |
+
# return 2.0 / discrim
|
35 |
+
|
36 |
+
valid = torch.isfinite(discrim) & (discrim >= 0.0)
|
37 |
+
discrim = torch.sqrt(discrim) - a
|
38 |
+
valid &= discrim > 0.0
|
39 |
+
return 2.0 / torch.where(valid, discrim, 0)
|
40 |
+
|
41 |
+
|
42 |
+
def brown_max_radius(k1, k2):
|
43 |
+
"""Compute the maximum radius of a Brown distortion model."""
|
44 |
+
# fold the constants from the derivative into a and b
|
45 |
+
a = k1 * 3
|
46 |
+
b = k2 * 5
|
47 |
+
return torch.sqrt(max_radius(a, b))
|
48 |
+
|
49 |
+
|
50 |
+
class ParallelProcessor:
|
51 |
+
"""Generic parallel processor class."""
|
52 |
+
|
53 |
+
def __init__(self, max_workers):
|
54 |
+
"""Init processor and pbars."""
|
55 |
+
self.max_workers = max_workers
|
56 |
+
self.executor = futures.ProcessPoolExecutor(max_workers=self.max_workers)
|
57 |
+
self.pbars = {}
|
58 |
+
|
59 |
+
def update_pbar(self, pbar_key):
|
60 |
+
"""Update progressbar."""
|
61 |
+
pbar = self.pbars.get(pbar_key)
|
62 |
+
pbar.update(1)
|
63 |
+
|
64 |
+
def submit_tasks(self, task_func, task_args, pbar_key):
|
65 |
+
"""Submit tasks."""
|
66 |
+
pbar = tqdm(total=len(task_args), desc=f"Processing {pbar_key}", ncols=80)
|
67 |
+
self.pbars[pbar_key] = pbar
|
68 |
+
|
69 |
+
def update_pbar(future):
|
70 |
+
self.update_pbar(pbar_key)
|
71 |
+
|
72 |
+
futures = []
|
73 |
+
for args in task_args:
|
74 |
+
future = self.executor.submit(task_func, *args)
|
75 |
+
future.add_done_callback(update_pbar)
|
76 |
+
futures.append(future)
|
77 |
+
|
78 |
+
return futures
|
79 |
+
|
80 |
+
def wait_for_completion(self, futures):
|
81 |
+
"""Wait for completion and return results."""
|
82 |
+
results = []
|
83 |
+
for f in futures:
|
84 |
+
results += f.result()
|
85 |
+
|
86 |
+
for key in self.pbars.keys():
|
87 |
+
self.pbars[key].close()
|
88 |
+
|
89 |
+
return results
|
90 |
+
|
91 |
+
def shutdown(self):
|
92 |
+
"""Close the executer."""
|
93 |
+
self.executor.shutdown()
|
94 |
+
|
95 |
+
|
96 |
+
class DatasetGenerator:
|
97 |
+
"""Dataset generator class to create perspective datasets from panoramas."""
|
98 |
+
|
99 |
+
default_conf = {
|
100 |
+
"name": "???",
|
101 |
+
# paths
|
102 |
+
"base_dir": "???",
|
103 |
+
"pano_dir": "${.base_dir}/panoramas",
|
104 |
+
"pano_train": "${.pano_dir}/train",
|
105 |
+
"pano_val": "${.pano_dir}/val",
|
106 |
+
"pano_test": "${.pano_dir}/test",
|
107 |
+
"perspective_dir": "${.base_dir}/${.name}",
|
108 |
+
"perspective_train": "${.perspective_dir}/train",
|
109 |
+
"perspective_val": "${.perspective_dir}/val",
|
110 |
+
"perspective_test": "${.perspective_dir}/test",
|
111 |
+
"train_csv": "${.perspective_dir}/train.csv",
|
112 |
+
"val_csv": "${.perspective_dir}/val.csv",
|
113 |
+
"test_csv": "${.perspective_dir}/test.csv",
|
114 |
+
# data options
|
115 |
+
"camera_model": "pinhole",
|
116 |
+
"parameter_dists": {
|
117 |
+
"roll": {
|
118 |
+
"type": "uniform",
|
119 |
+
"options": {"loc": deg2rad(-45), "scale": deg2rad(90)}, # in [-45, 45]
|
120 |
+
},
|
121 |
+
"pitch": {
|
122 |
+
"type": "uniform",
|
123 |
+
"options": {"loc": deg2rad(-45), "scale": deg2rad(90)}, # in [-45, 45]
|
124 |
+
},
|
125 |
+
"vfov": {
|
126 |
+
"type": "uniform",
|
127 |
+
"options": {"loc": deg2rad(20), "scale": deg2rad(85)}, # in [20, 105]
|
128 |
+
},
|
129 |
+
"resize_factor": {
|
130 |
+
"type": "uniform",
|
131 |
+
"options": {"loc": 1.0, "scale": 1.0}, # factor in [1.0, 2.0]
|
132 |
+
},
|
133 |
+
"shape": {"type": "fix", "value": (640, 640)},
|
134 |
+
},
|
135 |
+
"images_per_pano": 16,
|
136 |
+
"n_workers": 10,
|
137 |
+
"device": "cpu",
|
138 |
+
"overwrite": False,
|
139 |
+
}
|
140 |
+
|
141 |
+
def __init__(self, conf):
|
142 |
+
"""Init the class by merging and storing the config."""
|
143 |
+
self.conf = OmegaConf.merge(
|
144 |
+
OmegaConf.create(self.default_conf),
|
145 |
+
OmegaConf.create(conf),
|
146 |
+
)
|
147 |
+
logger.info(f"Config:\n{OmegaConf.to_yaml(self.conf)}")
|
148 |
+
|
149 |
+
self.infos = {}
|
150 |
+
self.device = self.conf.device
|
151 |
+
|
152 |
+
self.camera_model = camera_models[self.conf.camera_model]
|
153 |
+
|
154 |
+
def sample_value(self, parameter_name, seed=None):
|
155 |
+
"""Sample a value from the specified distribution."""
|
156 |
+
param_conf = self.conf["parameter_dists"][parameter_name]
|
157 |
+
|
158 |
+
if param_conf.type == "fix":
|
159 |
+
return torch.tensor(param_conf.value)
|
160 |
+
|
161 |
+
# fix seed for reproducibility
|
162 |
+
generator = None
|
163 |
+
if seed:
|
164 |
+
if not isinstance(seed, (int, float)):
|
165 |
+
seed = int(hashlib.sha256(seed.encode()).hexdigest(), 16) % (2**32)
|
166 |
+
generator = np.random.default_rng(seed)
|
167 |
+
|
168 |
+
sampler = getattr(scipy.stats, param_conf.type)
|
169 |
+
return torch.tensor(sampler.rvs(random_state=generator, **param_conf.options))
|
170 |
+
|
171 |
+
def plot_distributions(self):
|
172 |
+
"""Plot parameter distributions."""
|
173 |
+
fig, ax = plt.subplots(3, 3, figsize=(15, 10))
|
174 |
+
for i, split in enumerate(["train", "val", "test"]):
|
175 |
+
roll_vals = [rad2deg(row["roll"]) for row in self.infos[split]]
|
176 |
+
ax[i, 0].hist(roll_vals, bins=100)
|
177 |
+
ax[i, 0].set_xlabel("Roll (°)")
|
178 |
+
ax[i, 0].set_ylabel(f"Count {split}")
|
179 |
+
|
180 |
+
pitch_vals = [rad2deg(row["pitch"]) for row in self.infos[split]]
|
181 |
+
ax[i, 1].hist(pitch_vals, bins=100)
|
182 |
+
ax[i, 1].set_xlabel("Pitch (°)")
|
183 |
+
ax[i, 1].set_ylabel(f"Count {split}")
|
184 |
+
|
185 |
+
vfov_vals = [rad2deg(row["vfov"]) for row in self.infos[split]]
|
186 |
+
ax[i, 2].hist(vfov_vals, bins=100)
|
187 |
+
ax[i, 2].set_xlabel("vFoV (°)")
|
188 |
+
ax[i, 2].set_ylabel(f"Count {split}")
|
189 |
+
|
190 |
+
plt.tight_layout()
|
191 |
+
plt.savefig(Path(self.conf.perspective_dir) / "distributions.pdf")
|
192 |
+
|
193 |
+
fig, ax = plt.subplots(3, 3, figsize=(15, 10))
|
194 |
+
for i, k1 in enumerate(["roll", "pitch", "vfov"]):
|
195 |
+
for j, k2 in enumerate(["roll", "pitch", "vfov"]):
|
196 |
+
ax[i, j].scatter(
|
197 |
+
[rad2deg(row[k1]) for row in self.infos["train"]],
|
198 |
+
[rad2deg(row[k2]) for row in self.infos["train"]],
|
199 |
+
s=1,
|
200 |
+
label="train",
|
201 |
+
)
|
202 |
+
|
203 |
+
ax[i, j].scatter(
|
204 |
+
[rad2deg(row[k1]) for row in self.infos["val"]],
|
205 |
+
[rad2deg(row[k2]) for row in self.infos["val"]],
|
206 |
+
s=1,
|
207 |
+
label="val",
|
208 |
+
)
|
209 |
+
|
210 |
+
ax[i, j].scatter(
|
211 |
+
[rad2deg(row[k1]) for row in self.infos["test"]],
|
212 |
+
[rad2deg(row[k2]) for row in self.infos["test"]],
|
213 |
+
s=1,
|
214 |
+
label="test",
|
215 |
+
)
|
216 |
+
|
217 |
+
ax[i, j].set_xlabel(k1)
|
218 |
+
ax[i, j].set_ylabel(k2)
|
219 |
+
ax[i, j].legend()
|
220 |
+
|
221 |
+
plt.tight_layout()
|
222 |
+
plt.savefig(Path(self.conf.perspective_dir) / "distributions_scatter.pdf")
|
223 |
+
|
224 |
+
def generate_images_from_pano(self, pano_path: Path, out_dir: Path):
|
225 |
+
"""Generate perspective images from a single panorama."""
|
226 |
+
infos = []
|
227 |
+
|
228 |
+
pano = load_image(pano_path).to(self.device)
|
229 |
+
|
230 |
+
yaws = np.linspace(0, 2 * np.pi, self.conf.images_per_pano, endpoint=False)
|
231 |
+
params = {
|
232 |
+
k: [self.sample_value(k, pano_path.stem + k + str(i)) for i in yaws]
|
233 |
+
for k in self.conf.parameter_dists
|
234 |
+
if k != "shape"
|
235 |
+
}
|
236 |
+
shapes = [self.sample_value("shape", pano_path.stem + "shape") for _ in yaws]
|
237 |
+
params |= {
|
238 |
+
"height": [shape[0] for shape in shapes],
|
239 |
+
"width": [shape[1] for shape in shapes],
|
240 |
+
}
|
241 |
+
|
242 |
+
if "k1_hat" in params:
|
243 |
+
height = torch.tensor(params["height"])
|
244 |
+
width = torch.tensor(params["width"])
|
245 |
+
k1_hat = torch.tensor(params["k1_hat"])
|
246 |
+
vfov = torch.tensor(params["vfov"])
|
247 |
+
focal = fov2focal(vfov, height)
|
248 |
+
focal = focal
|
249 |
+
rel_focal = focal / height
|
250 |
+
k1 = k1_hat * rel_focal
|
251 |
+
|
252 |
+
# distance to image corner
|
253 |
+
# r_max_im = f_px * r_max * (1 + k1*r_max**2)
|
254 |
+
# function of r_max_im: f_px = r_max_im / (r_max * (1 + k1*r_max**2))
|
255 |
+
min_permissible_rmax = torch.sqrt((height / 2) ** 2 + (width / 2) ** 2)
|
256 |
+
r_max = brown_max_radius(k1=k1, k2=0)
|
257 |
+
lowest_possible_f_px = min_permissible_rmax / (r_max * (1 + k1 * r_max**2))
|
258 |
+
valid = lowest_possible_f_px <= focal
|
259 |
+
|
260 |
+
f = torch.where(valid, focal, lowest_possible_f_px)
|
261 |
+
vfov = focal2fov(f, height)
|
262 |
+
|
263 |
+
params["vfov"] = vfov
|
264 |
+
params |= {"k1": k1}
|
265 |
+
|
266 |
+
cam = self.camera_model.from_dict(params).float().to(self.device)
|
267 |
+
gravity = Gravity.from_rp(params["roll"], params["pitch"]).float().to(self.device)
|
268 |
+
|
269 |
+
if (out_dir / f"{pano_path.stem}_0.jpg").exists() and not self.conf.overwrite:
|
270 |
+
for i in range(self.conf.images_per_pano):
|
271 |
+
perspective_name = f"{pano_path.stem}_{i}.jpg"
|
272 |
+
info = {"fname": perspective_name} | {k: v[i].item() for k, v in params.items()}
|
273 |
+
infos.append(info)
|
274 |
+
|
275 |
+
logger.info(f"Perspectives for {pano_path.stem} already exist.")
|
276 |
+
|
277 |
+
return infos
|
278 |
+
|
279 |
+
perspective_images = cam.get_img_from_pano(
|
280 |
+
pano_img=pano, gravity=gravity, yaws=yaws, resize_factor=params["resize_factor"]
|
281 |
+
)
|
282 |
+
|
283 |
+
for i, perspective_image in enumerate(perspective_images):
|
284 |
+
perspective_name = f"{pano_path.stem}_{i}.jpg"
|
285 |
+
|
286 |
+
n_pixels = perspective_image.shape[-2] * perspective_image.shape[-1]
|
287 |
+
valid = (torch.sum(perspective_image.sum(0) == 0) / n_pixels) < 0.01
|
288 |
+
if not valid:
|
289 |
+
logger.debug(f"Perspective {perspective_name} has too many black pixels.")
|
290 |
+
continue
|
291 |
+
|
292 |
+
write_image(perspective_image, out_dir / perspective_name)
|
293 |
+
|
294 |
+
info = {"fname": perspective_name} | {k: v[i].item() for k, v in params.items()}
|
295 |
+
infos.append(info)
|
296 |
+
|
297 |
+
return infos
|
298 |
+
|
299 |
+
def generate_split(self, split: str, parallel_processor: ParallelProcessor):
|
300 |
+
"""Generate a single split of a dataset."""
|
301 |
+
self.infos[split] = []
|
302 |
+
panorama_paths = [
|
303 |
+
path
|
304 |
+
for path in Path(self.conf[f"pano_{split}"]).glob("*")
|
305 |
+
if not path.name.startswith(".")
|
306 |
+
]
|
307 |
+
|
308 |
+
out_dir = Path(self.conf[f"perspective_{split}"])
|
309 |
+
logger.info(f"Writing perspective images to {str(out_dir)}")
|
310 |
+
if not out_dir.exists():
|
311 |
+
out_dir.mkdir(parents=True)
|
312 |
+
|
313 |
+
futures = parallel_processor.submit_tasks(
|
314 |
+
self.generate_images_from_pano, [(f, out_dir) for f in panorama_paths], split
|
315 |
+
)
|
316 |
+
self.infos[split] = parallel_processor.wait_for_completion(futures)
|
317 |
+
# parallel_processor.shutdown()
|
318 |
+
|
319 |
+
metadata = pd.DataFrame(data=self.infos[split])
|
320 |
+
metadata.to_csv(self.conf[f"{split}_csv"])
|
321 |
+
|
322 |
+
def generate_dataset(self):
|
323 |
+
"""Generate all splits of a dataset."""
|
324 |
+
out_dir = Path(self.conf.perspective_dir)
|
325 |
+
if not out_dir.exists():
|
326 |
+
out_dir.mkdir(parents=True)
|
327 |
+
|
328 |
+
OmegaConf.save(self.conf, out_dir / "config.yaml")
|
329 |
+
|
330 |
+
processor = ParallelProcessor(self.conf.n_workers)
|
331 |
+
for split in ["train", "val", "test"]:
|
332 |
+
self.generate_split(split=split, parallel_processor=processor)
|
333 |
+
|
334 |
+
processor.shutdown()
|
335 |
+
|
336 |
+
for split in ["train", "val", "test"]:
|
337 |
+
logger.info(f"Generated {len(self.infos[split])} {split} images.")
|
338 |
+
|
339 |
+
self.plot_distributions()
|
340 |
+
|
341 |
+
|
342 |
+
@hydra.main(version_base=None, config_path="configs", config_name="SUN360")
|
343 |
+
def main(cfg: DictConfig) -> None:
|
344 |
+
"""Run dataset generation."""
|
345 |
+
generator = DatasetGenerator(conf=cfg)
|
346 |
+
generator.generate_dataset()
|
347 |
+
|
348 |
+
|
349 |
+
if __name__ == "__main__":
|
350 |
+
main()
|
siclib/datasets/simple_dataset.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Dataset for images created with 'create_dataset_from_pano.py'."""
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Any, Dict, List, Tuple
|
6 |
+
|
7 |
+
import pandas as pd
|
8 |
+
import torch
|
9 |
+
from omegaconf import DictConfig
|
10 |
+
|
11 |
+
from siclib.datasets.augmentations import IdentityAugmentation, augmentations
|
12 |
+
from siclib.datasets.base_dataset import BaseDataset
|
13 |
+
from siclib.geometry.camera import SimpleRadial
|
14 |
+
from siclib.geometry.gravity import Gravity
|
15 |
+
from siclib.geometry.perspective_fields import get_perspective_field
|
16 |
+
from siclib.utils.conversions import fov2focal
|
17 |
+
from siclib.utils.image import ImagePreprocessor, load_image
|
18 |
+
from siclib.utils.tools import fork_rng
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
# mypy: ignore-errors
|
23 |
+
|
24 |
+
|
25 |
+
def load_csv(
|
26 |
+
csv_file: Path, img_root: Path
|
27 |
+
) -> Tuple[List[Dict[str, Any]], torch.Tensor, torch.Tensor]:
|
28 |
+
"""Load a CSV file containing image information.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
csv_file (str): Path to the CSV file.
|
32 |
+
img_root (str): Path to the root directory containing the images.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
list: List of dictionaries containing the image paths and camera parameters.
|
36 |
+
"""
|
37 |
+
df = pd.read_csv(csv_file)
|
38 |
+
|
39 |
+
infos, params, gravity = [], [], []
|
40 |
+
for _, row in df.iterrows():
|
41 |
+
h = row["height"]
|
42 |
+
w = row["width"]
|
43 |
+
px = row.get("px", w / 2)
|
44 |
+
py = row.get("py", h / 2)
|
45 |
+
vfov = row["vfov"]
|
46 |
+
f = fov2focal(torch.tensor(vfov), h)
|
47 |
+
k1 = row.get("k1", 0)
|
48 |
+
k2 = row.get("k2", 0)
|
49 |
+
params.append(torch.tensor([w, h, f, f, px, py, k1, k2]))
|
50 |
+
|
51 |
+
roll = row["roll"]
|
52 |
+
pitch = row["pitch"]
|
53 |
+
gravity.append(torch.tensor([roll, pitch]))
|
54 |
+
|
55 |
+
infos.append({"name": row["fname"], "file_name": str(img_root / row["fname"])})
|
56 |
+
|
57 |
+
params = torch.stack(params).float()
|
58 |
+
gravity = torch.stack(gravity).float()
|
59 |
+
return infos, params, gravity
|
60 |
+
|
61 |
+
|
62 |
+
class SimpleDataset(BaseDataset):
|
63 |
+
"""Dataset for images created with 'create_dataset_from_pano.py'."""
|
64 |
+
|
65 |
+
default_conf = {
|
66 |
+
# paths
|
67 |
+
"dataset_dir": "???",
|
68 |
+
"train_img_dir": "${.dataset_dir}/train",
|
69 |
+
"val_img_dir": "${.dataset_dir}/val",
|
70 |
+
"test_img_dir": "${.dataset_dir}/test",
|
71 |
+
"train_csv": "${.dataset_dir}/train.csv",
|
72 |
+
"val_csv": "${.dataset_dir}/val.csv",
|
73 |
+
"test_csv": "${.dataset_dir}/test.csv",
|
74 |
+
# data options
|
75 |
+
"use_up": True,
|
76 |
+
"use_latitude": True,
|
77 |
+
"use_prior_focal": False,
|
78 |
+
"use_prior_gravity": False,
|
79 |
+
"use_prior_k1": False,
|
80 |
+
# image options
|
81 |
+
"grayscale": False,
|
82 |
+
"preprocessing": ImagePreprocessor.default_conf,
|
83 |
+
"augmentations": {"name": "geocalib", "verbose": False},
|
84 |
+
"p_rotate": 0.0, # probability to rotate image by +/- 90°
|
85 |
+
"reseed": False,
|
86 |
+
"seed": 0,
|
87 |
+
# data loader options
|
88 |
+
"num_workers": 8,
|
89 |
+
"prefetch_factor": 2,
|
90 |
+
"train_batch_size": 32,
|
91 |
+
"val_batch_size": 32,
|
92 |
+
"test_batch_size": 32,
|
93 |
+
}
|
94 |
+
|
95 |
+
def _init(self, conf):
|
96 |
+
pass
|
97 |
+
|
98 |
+
def get_dataset(self, split: str) -> torch.utils.data.Dataset:
|
99 |
+
"""Return a dataset for a given split."""
|
100 |
+
return _SimpleDataset(self.conf, split)
|
101 |
+
|
102 |
+
|
103 |
+
class _SimpleDataset(torch.utils.data.Dataset):
|
104 |
+
"""Dataset for dataset for images created with 'create_dataset_from_pano.py'."""
|
105 |
+
|
106 |
+
def __init__(self, conf: DictConfig, split: str):
|
107 |
+
"""Initialize the dataset."""
|
108 |
+
self.conf = conf
|
109 |
+
self.split = split
|
110 |
+
self.img_dir = Path(conf.get(f"{split}_img_dir"))
|
111 |
+
|
112 |
+
self.preprocessor = ImagePreprocessor(conf.preprocessing)
|
113 |
+
|
114 |
+
# load image information
|
115 |
+
assert f"{split}_csv" in conf, f"Missing {split}_csv in conf"
|
116 |
+
infos_path = self.conf.get(f"{split}_csv")
|
117 |
+
self.infos, self.parameters, self.gravity = load_csv(infos_path, self.img_dir)
|
118 |
+
|
119 |
+
# define augmentations
|
120 |
+
aug_name = conf.augmentations.name
|
121 |
+
assert (
|
122 |
+
aug_name in augmentations.keys()
|
123 |
+
), f'{aug_name} not in {" ".join(augmentations.keys())}'
|
124 |
+
|
125 |
+
if self.split == "train":
|
126 |
+
self.augmentation = augmentations[aug_name](conf.augmentations)
|
127 |
+
else:
|
128 |
+
self.augmentation = IdentityAugmentation()
|
129 |
+
|
130 |
+
def __len__(self):
|
131 |
+
return len(self.infos)
|
132 |
+
|
133 |
+
def __getitem__(self, idx):
|
134 |
+
if not self.conf.reseed:
|
135 |
+
return self.getitem(idx)
|
136 |
+
with fork_rng(self.conf.seed + idx, False):
|
137 |
+
return self.getitem(idx)
|
138 |
+
|
139 |
+
def _read_image(
|
140 |
+
self, infos: Dict[str, Any], parameters: torch.Tensor, gravity: torch.Tensor
|
141 |
+
) -> Dict[str, Any]:
|
142 |
+
path = Path(str(infos["file_name"]))
|
143 |
+
|
144 |
+
# load image as uint8 and HWC for augmentation
|
145 |
+
image = load_image(path, self.conf.grayscale, return_tensor=False)
|
146 |
+
image = self.augmentation(image, return_tensor=True)
|
147 |
+
|
148 |
+
# create radial camera -> same as pinhole if k1 = 0
|
149 |
+
camera = SimpleRadial(parameters[None]).float()
|
150 |
+
|
151 |
+
roll, pitch = gravity[None].unbind(-1)
|
152 |
+
gravity = Gravity.from_rp(roll, pitch)
|
153 |
+
|
154 |
+
# preprocess
|
155 |
+
data = self.preprocessor(image)
|
156 |
+
camera = camera.scale(data["scales"])
|
157 |
+
camera = camera.crop(data["crop_pad"]) if "crop_pad" in data else camera
|
158 |
+
|
159 |
+
priors = {"prior_gravity": gravity} if self.conf.use_prior_gravity else {}
|
160 |
+
priors |= {"prior_focal": camera.f[..., 1]} if self.conf.use_prior_focal else {}
|
161 |
+
priors |= {"prior_k1": camera.k1} if self.conf.use_prior_k1 else {}
|
162 |
+
return {
|
163 |
+
"name": infos["name"],
|
164 |
+
"path": str(path),
|
165 |
+
"camera": camera[0],
|
166 |
+
"gravity": gravity[0],
|
167 |
+
**priors,
|
168 |
+
**data,
|
169 |
+
}
|
170 |
+
|
171 |
+
def _get_perspective(self, data):
|
172 |
+
"""Get perspective field."""
|
173 |
+
camera = data["camera"]
|
174 |
+
gravity = data["gravity"]
|
175 |
+
|
176 |
+
up_field, lat_field = get_perspective_field(
|
177 |
+
camera, gravity, use_up=self.conf.use_up, use_latitude=self.conf.use_latitude
|
178 |
+
)
|
179 |
+
|
180 |
+
out = {}
|
181 |
+
if self.conf.use_up:
|
182 |
+
out["up_field"] = up_field[0]
|
183 |
+
if self.conf.use_latitude:
|
184 |
+
out["latitude_field"] = lat_field[0]
|
185 |
+
|
186 |
+
return out
|
187 |
+
|
188 |
+
def getitem(self, idx: int):
|
189 |
+
"""Return a sample from the dataset."""
|
190 |
+
infos = self.infos[idx]
|
191 |
+
parameters = self.parameters[idx]
|
192 |
+
gravity = self.gravity[idx]
|
193 |
+
data = self._read_image(infos, parameters, gravity)
|
194 |
+
|
195 |
+
if self.conf.use_up or self.conf.use_latitude:
|
196 |
+
data |= self._get_perspective(data)
|
197 |
+
|
198 |
+
return data
|
199 |
+
|
200 |
+
|
201 |
+
if __name__ == "__main__":
|
202 |
+
# Create a dump of the dataset
|
203 |
+
import argparse
|
204 |
+
|
205 |
+
import matplotlib.pyplot as plt
|
206 |
+
|
207 |
+
from siclib.visualization.visualize_batch import make_perspective_figures
|
208 |
+
|
209 |
+
parser = argparse.ArgumentParser()
|
210 |
+
parser.add_argument("--name", type=str, required=True)
|
211 |
+
parser.add_argument("--data_dir", type=str)
|
212 |
+
parser.add_argument("--split", type=str, default="train")
|
213 |
+
parser.add_argument("--shuffle", action="store_true")
|
214 |
+
parser.add_argument("--n_rows", type=int, default=4)
|
215 |
+
parser.add_argument("--dpi", type=int, default=100)
|
216 |
+
args = parser.parse_intermixed_args()
|
217 |
+
|
218 |
+
dconf = SimpleDataset.default_conf
|
219 |
+
dconf["name"] = args.name
|
220 |
+
dconf["num_workers"] = 0
|
221 |
+
dconf["prefetch_factor"] = None
|
222 |
+
|
223 |
+
dconf["dataset_dir"] = args.data_dir
|
224 |
+
dconf[f"{args.split}_batch_size"] = args.n_rows
|
225 |
+
|
226 |
+
torch.set_grad_enabled(False)
|
227 |
+
|
228 |
+
dataset = SimpleDataset(dconf)
|
229 |
+
loader = dataset.get_data_loader(args.split, args.shuffle)
|
230 |
+
|
231 |
+
with fork_rng(seed=42):
|
232 |
+
for data in loader:
|
233 |
+
pred = data
|
234 |
+
break
|
235 |
+
fig = make_perspective_figures(pred, data, n_pairs=args.n_rows)
|
236 |
+
|
237 |
+
plt.show()
|
siclib/datasets/utils/__init__.py
ADDED
File without changes
|
siclib/datasets/utils/align_megadepth.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import subprocess
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
# flake8: noqa
|
6 |
+
# mypy: ignore-errors
|
7 |
+
|
8 |
+
parser = argparse.ArgumentParser(description="Aligns a COLMAP model and plots the horizon lines")
|
9 |
+
parser.add_argument(
|
10 |
+
"--base_dir", type=str, help="Path to the base directory of the MegaDepth dataset"
|
11 |
+
)
|
12 |
+
parser.add_argument("--out_dir", type=str, help="Path to the output directory")
|
13 |
+
args = parser.parse_args()
|
14 |
+
|
15 |
+
base_dir = Path(args.base_dir)
|
16 |
+
out_dir = Path(args.out_dir)
|
17 |
+
|
18 |
+
scenes = [d.name for d in base_dir.iterdir() if d.is_dir()]
|
19 |
+
print(scenes[:3], len(scenes))
|
20 |
+
|
21 |
+
# exit()
|
22 |
+
|
23 |
+
for scene in scenes:
|
24 |
+
image_dir = base_dir / scene / "images"
|
25 |
+
sfm_dir = base_dir / scene / "sparse" / "manhattan" / "0"
|
26 |
+
|
27 |
+
# Align model
|
28 |
+
align_dir = out_dir / scene / "sparse" / "align"
|
29 |
+
align_dir.mkdir(exist_ok=True, parents=True)
|
30 |
+
|
31 |
+
print(f"image_dir ({image_dir.exists()}): {image_dir}")
|
32 |
+
print(f"sfm_dir ({sfm_dir.exists()}): {sfm_dir}")
|
33 |
+
print(f"align_dir ({align_dir.exists()}): {align_dir}")
|
34 |
+
|
35 |
+
cmd = (
|
36 |
+
"colmap model_orientation_aligner "
|
37 |
+
+ f"--image_path {image_dir} "
|
38 |
+
+ f"--input_path {sfm_dir} "
|
39 |
+
+ f"--output_path {str(align_dir)}"
|
40 |
+
)
|
41 |
+
subprocess.run(cmd, shell=True)
|
siclib/datasets/utils/download_openpano.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Helper script to download and extract OpenPano dataset."""
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import shutil
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from siclib import logger
|
11 |
+
|
12 |
+
PANO_URL = "https://cvg-data.inf.ethz.ch/GeoCalib_ECCV2024/openpano.zip"
|
13 |
+
|
14 |
+
|
15 |
+
def download_and_extract_dataset(name: str, url: Path, output: Path) -> None:
|
16 |
+
"""Download and extract a dataset from a URL."""
|
17 |
+
dataset_dir = output / name
|
18 |
+
if not output.exists():
|
19 |
+
output.mkdir(parents=True)
|
20 |
+
|
21 |
+
if dataset_dir.exists():
|
22 |
+
logger.info(f"Dataset {name} already exists at {dataset_dir}, skipping download.")
|
23 |
+
return
|
24 |
+
|
25 |
+
zip_file = output / f"{name}.zip"
|
26 |
+
|
27 |
+
if not zip_file.exists():
|
28 |
+
logger.info(f"Downloading dataset {name} to {zip_file} from {url}.")
|
29 |
+
torch.hub.download_url_to_file(url, zip_file)
|
30 |
+
|
31 |
+
logger.info(f"Extracting dataset {name} in {output}.")
|
32 |
+
shutil.unpack_archive(zip_file, output, format="zip")
|
33 |
+
zip_file.unlink()
|
34 |
+
|
35 |
+
|
36 |
+
def main():
|
37 |
+
"""Prepare the OpenPano dataset."""
|
38 |
+
parser = argparse.ArgumentParser(description="Download and extract OpenPano dataset.")
|
39 |
+
parser.add_argument("--name", type=str, default="openpano", help="Name of the dataset.")
|
40 |
+
parser.add_argument(
|
41 |
+
"--laval_dir", type=str, default="data/laval-tonemap", help="Path the Laval dataset."
|
42 |
+
)
|
43 |
+
|
44 |
+
args = parser.parse_args()
|
45 |
+
|
46 |
+
out_dir = Path("data")
|
47 |
+
download_and_extract_dataset(args.name, PANO_URL, out_dir)
|
48 |
+
|
49 |
+
pano_dir = out_dir / args.name / "panoramas"
|
50 |
+
for split in ["train", "test", "val"]:
|
51 |
+
with open(pano_dir / f"{split}_panos.txt", "r") as f:
|
52 |
+
pano_list = f.readlines()
|
53 |
+
pano_list = [fname.strip() for fname in pano_list]
|
54 |
+
|
55 |
+
for fname in tqdm(pano_list, ncols=80, desc=f"Copying {split} panoramas"):
|
56 |
+
laval_path = Path(args.laval_dir) / fname
|
57 |
+
target_path = pano_dir / split / fname
|
58 |
+
|
59 |
+
# pano either exists in laval or is in split
|
60 |
+
if target_path.exists():
|
61 |
+
continue
|
62 |
+
|
63 |
+
if laval_path.exists():
|
64 |
+
shutil.copy(laval_path, target_path)
|
65 |
+
else: # not in laval and not in split
|
66 |
+
logger.warning(f"Panorama {fname} not found in {args.laval_dir} or {split} split.")
|
67 |
+
|
68 |
+
n_train = len(list(pano_dir.glob("train/*.jpg")))
|
69 |
+
n_test = len(list(pano_dir.glob("test/*.jpg")))
|
70 |
+
n_val = len(list(pano_dir.glob("val/*.jpg")))
|
71 |
+
logger.info(f"{args.name} contains {n_train}/{n_test}/{n_val} train/test/val panoramas.")
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
main()
|
siclib/datasets/utils/tonemapping.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from matplotlib import colors
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
# flake8: noqa
|
14 |
+
# mypy: ignore-errors
|
15 |
+
|
16 |
+
|
17 |
+
class tonemap:
|
18 |
+
def __init__(self):
|
19 |
+
pass
|
20 |
+
|
21 |
+
def process(self, img):
|
22 |
+
return img
|
23 |
+
|
24 |
+
def inv_process(self, img):
|
25 |
+
return img
|
26 |
+
|
27 |
+
|
28 |
+
# Log correction
|
29 |
+
class log_tonemap(tonemap):
|
30 |
+
# Constructor
|
31 |
+
# Base of log
|
32 |
+
# Scale of tonemapped
|
33 |
+
# Offset
|
34 |
+
def __init__(self, base, scale=1, offset=1):
|
35 |
+
self.base = base
|
36 |
+
self.scale = scale
|
37 |
+
self.offset = offset
|
38 |
+
|
39 |
+
def process(self, img):
|
40 |
+
tonemapped = (np.log(img + self.offset) / np.log(self.base)) * self.scale
|
41 |
+
return tonemapped
|
42 |
+
|
43 |
+
def inv_process(self, img):
|
44 |
+
inverse_tonemapped = np.power(self.base, (img) / self.scale) - self.offset
|
45 |
+
return inverse_tonemapped
|
46 |
+
|
47 |
+
|
48 |
+
class log_tonemap_clip(tonemap):
|
49 |
+
# Constructor
|
50 |
+
# Base of log
|
51 |
+
# Scale of tonemapped
|
52 |
+
# Offset
|
53 |
+
def __init__(self, base, scale=1, offset=1):
|
54 |
+
self.base = base
|
55 |
+
self.scale = scale
|
56 |
+
self.offset = offset
|
57 |
+
|
58 |
+
def process(self, img):
|
59 |
+
tonemapped = np.clip((np.log(img * self.scale + self.offset) / np.log(self.base)), 0, 2) - 1
|
60 |
+
return tonemapped
|
61 |
+
|
62 |
+
def inv_process(self, img):
|
63 |
+
inverse_tonemapped = (np.power(self.base, (img + 1)) - self.offset) / self.scale
|
64 |
+
return inverse_tonemapped
|
65 |
+
|
66 |
+
|
67 |
+
# Gamma Tonemap
|
68 |
+
class gamma_tonemap(tonemap):
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
gamma,
|
72 |
+
):
|
73 |
+
self.gamma = gamma
|
74 |
+
|
75 |
+
def process(self, img):
|
76 |
+
tonemapped = np.power(img, 1 / self.gamma)
|
77 |
+
return tonemapped
|
78 |
+
|
79 |
+
def inv_process(self, img):
|
80 |
+
inverse_tonemapped = np.power(img, self.gamma)
|
81 |
+
return inverse_tonemapped
|
82 |
+
|
83 |
+
|
84 |
+
class linear_clip(tonemap):
|
85 |
+
def __init__(self, scale, mean):
|
86 |
+
self.scale = scale
|
87 |
+
self.mean = mean
|
88 |
+
|
89 |
+
def process(self, img):
|
90 |
+
tonemapped = np.clip((img - self.mean) / self.scale, -1, 1)
|
91 |
+
return tonemapped
|
92 |
+
|
93 |
+
def inv_process(self, img):
|
94 |
+
inverse_tonemapped = img * self.scale + self.mean
|
95 |
+
return inverse_tonemapped
|
96 |
+
|
97 |
+
|
98 |
+
def make_tonemap_HDR(opt):
|
99 |
+
if opt.mode == "luminance":
|
100 |
+
res_tonemap = log_tonemap_clip(10, 1.0, 1.0)
|
101 |
+
else: # temperature
|
102 |
+
res_tonemap = linear_clip(5000.0, 5000.0)
|
103 |
+
return res_tonemap
|
104 |
+
|
105 |
+
|
106 |
+
class LDRfromHDR:
|
107 |
+
def __init__(
|
108 |
+
self, tonemap="none", orig_scale=False, clip=True, quantization=0, color_jitter=0, noise=0
|
109 |
+
):
|
110 |
+
self.tonemap_str, val = tonemap
|
111 |
+
if tonemap[0] == "gamma":
|
112 |
+
self.tonemap = gamma_tonemap(val)
|
113 |
+
elif tonemap[0] == "log10":
|
114 |
+
self.tonemap = log_tonemap(val)
|
115 |
+
else:
|
116 |
+
print("Warning: No tonemap specified, using linear")
|
117 |
+
|
118 |
+
self.clip = clip
|
119 |
+
self.orig_scale = orig_scale
|
120 |
+
self.bits = quantization
|
121 |
+
self.jitter = color_jitter
|
122 |
+
self.noise = noise
|
123 |
+
|
124 |
+
self.wbModel = None
|
125 |
+
|
126 |
+
def process(self, HDR):
|
127 |
+
LDR, normalized_scale = self.rescale(HDR)
|
128 |
+
LDR = self.apply_clip(LDR)
|
129 |
+
LDR = self.apply_scale(LDR, normalized_scale)
|
130 |
+
LDR = self.apply_tonemap(LDR)
|
131 |
+
LDR = self.colorJitter(LDR)
|
132 |
+
LDR = self.gaussianNoise(LDR)
|
133 |
+
LDR = self.quantize(LDR)
|
134 |
+
LDR = self.apply_white_balance(LDR)
|
135 |
+
return LDR, normalized_scale
|
136 |
+
|
137 |
+
def rescale(self, img, percentile=90, max_mapping=0.8):
|
138 |
+
r_percentile = np.percentile(img, percentile)
|
139 |
+
alpha = max_mapping / (r_percentile + 1e-10)
|
140 |
+
|
141 |
+
img_reexposed = img * alpha
|
142 |
+
|
143 |
+
normalized_scale = normalizeScale(1 / alpha)
|
144 |
+
|
145 |
+
return img_reexposed, normalized_scale
|
146 |
+
|
147 |
+
def rescaleAlpha(self, img, percentile=90, max_mapping=0.8):
|
148 |
+
r_percentile = np.percentile(img, percentile)
|
149 |
+
alpha = max_mapping / (r_percentile + 1e-10)
|
150 |
+
|
151 |
+
return alpha
|
152 |
+
|
153 |
+
def apply_clip(self, img):
|
154 |
+
if self.clip:
|
155 |
+
img = np.clip(img, 0, 1)
|
156 |
+
return img
|
157 |
+
|
158 |
+
def apply_scale(self, img, scale):
|
159 |
+
if self.orig_scale:
|
160 |
+
scale = unNormalizeScale(scale)
|
161 |
+
img = img * scale
|
162 |
+
return img
|
163 |
+
|
164 |
+
def apply_tonemap(self, img):
|
165 |
+
if self.tonemap_str == "none":
|
166 |
+
return img
|
167 |
+
gammaed = self.tonemap.process(img)
|
168 |
+
return gammaed
|
169 |
+
|
170 |
+
def quantize(self, img):
|
171 |
+
if self.bits == 0:
|
172 |
+
return img
|
173 |
+
max_val = np.power(2, self.bits)
|
174 |
+
img = img * max_val
|
175 |
+
img = np.floor(img)
|
176 |
+
img = img / max_val
|
177 |
+
return img
|
178 |
+
|
179 |
+
def colorJitter(self, img):
|
180 |
+
if self.jitter == 0:
|
181 |
+
return img
|
182 |
+
hsv = colors.rgb_to_hsv(img)
|
183 |
+
hue_offset = np.random.normal(0, self.jitter, 1)
|
184 |
+
hsv[:, :, 0] = (hsv[:, :, 0] + hue_offset) % 1.0
|
185 |
+
rgb = colors.hsv_to_rgb(hsv)
|
186 |
+
return rgb
|
187 |
+
|
188 |
+
def gaussianNoise(self, img):
|
189 |
+
if self.noise == 0:
|
190 |
+
return img
|
191 |
+
noise_amount = np.random.uniform(0, self.noise, 1)
|
192 |
+
noise_img = np.random.normal(0, noise_amount, img.shape)
|
193 |
+
img = img + noise_img
|
194 |
+
img = np.clip(img, 0, 1).astype(np.float32)
|
195 |
+
return img
|
196 |
+
|
197 |
+
def apply_white_balance(self, img):
|
198 |
+
if self.wbModel is None:
|
199 |
+
return img
|
200 |
+
img = self.wbModel.correctImage(img)
|
201 |
+
return img.copy()
|
202 |
+
|
203 |
+
|
204 |
+
def make_LDRfromHDR(opt):
|
205 |
+
LDR_from_HDR = LDRfromHDR(
|
206 |
+
opt.tonemap_LDR, opt.orig_scale, opt.clip, opt.quantization, opt.color_jitter, opt.noise
|
207 |
+
)
|
208 |
+
return LDR_from_HDR
|
209 |
+
|
210 |
+
|
211 |
+
def torchnormalizeEV(EV, mean=5.12, scale=6, clip=True):
|
212 |
+
# Normalize based on the computed distribution between -1 1
|
213 |
+
EV -= mean
|
214 |
+
EV = EV / scale
|
215 |
+
|
216 |
+
if clip:
|
217 |
+
EV = torch.clip(EV, min=-1, max=1)
|
218 |
+
|
219 |
+
return EV
|
220 |
+
|
221 |
+
|
222 |
+
def torchnormalizeEV0(EV, mean=5.12, scale=6, clip=True):
|
223 |
+
# Normalize based on the computed distribution between 0 1
|
224 |
+
EV -= mean
|
225 |
+
EV = EV / scale
|
226 |
+
|
227 |
+
if clip:
|
228 |
+
EV = torch.clip(EV, min=-1, max=1)
|
229 |
+
|
230 |
+
EV += 0.5
|
231 |
+
EV = EV / 2
|
232 |
+
|
233 |
+
return EV
|
234 |
+
|
235 |
+
|
236 |
+
def normalizeScale(x, scale=4):
|
237 |
+
x = np.log10(x + 1)
|
238 |
+
|
239 |
+
x = x / (scale / 2)
|
240 |
+
x = x - 1
|
241 |
+
|
242 |
+
return x
|
243 |
+
|
244 |
+
|
245 |
+
def unNormalizeScale(x, scale=4):
|
246 |
+
x = x + 1
|
247 |
+
x = x * (scale / 2)
|
248 |
+
|
249 |
+
x = np.power(10, x) - 1
|
250 |
+
|
251 |
+
return x
|
252 |
+
|
253 |
+
|
254 |
+
def normalizeIlluminance(x, scale=5):
|
255 |
+
x = np.log10(x + 1)
|
256 |
+
|
257 |
+
x = x / (scale / 2)
|
258 |
+
x = x - 1
|
259 |
+
|
260 |
+
return x
|
261 |
+
|
262 |
+
|
263 |
+
def unNormalizeIlluminance(x, scale=5):
|
264 |
+
x = x + 1
|
265 |
+
x = x * (scale / 2)
|
266 |
+
|
267 |
+
x = np.power(10, x) - 1
|
268 |
+
|
269 |
+
return x
|
270 |
+
|
271 |
+
|
272 |
+
def main(args):
|
273 |
+
processor = LDRfromHDR(
|
274 |
+
# tonemap=("log10", 10),
|
275 |
+
tonemap=("gamma", args.gamma),
|
276 |
+
orig_scale=False,
|
277 |
+
clip=True,
|
278 |
+
quantization=0,
|
279 |
+
color_jitter=0,
|
280 |
+
noise=0,
|
281 |
+
)
|
282 |
+
|
283 |
+
img_list = list(os.listdir(args.hdr_dir))
|
284 |
+
img_list = [f for f in img_list if f.endswith(args.extension)]
|
285 |
+
img_list = [f for f in img_list if not f.startswith("._")]
|
286 |
+
|
287 |
+
if not os.path.exists(args.out_dir):
|
288 |
+
os.makedirs(args.out_dir)
|
289 |
+
|
290 |
+
for fname in tqdm(img_list):
|
291 |
+
fname_out = ".".join(fname.split(".")[:-1])
|
292 |
+
out = os.path.join(args.out_dir, f"{fname_out}.jpg")
|
293 |
+
if os.path.exists(out) and not args.overwrite:
|
294 |
+
continue
|
295 |
+
|
296 |
+
fpath = os.path.join(args.hdr_dir, fname)
|
297 |
+
img = cv2.imread(fpath, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
|
298 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
299 |
+
|
300 |
+
ldr, scale = processor.process(img)
|
301 |
+
|
302 |
+
ldr = (ldr * 255).astype(np.uint8)
|
303 |
+
ldr = cv2.cvtColor(ldr, cv2.COLOR_RGB2BGR)
|
304 |
+
cv2.imwrite(out, ldr)
|
305 |
+
|
306 |
+
|
307 |
+
if __name__ == "__main__":
|
308 |
+
parser = argparse.ArgumentParser()
|
309 |
+
parser.add_argument("--hdr_dir", type=str, default="hdr")
|
310 |
+
parser.add_argument("--out_dir", type=str, default="ldr")
|
311 |
+
parser.add_argument("--extension", type=str, default=".exr")
|
312 |
+
parser.add_argument("--overwrite", action="store_true")
|
313 |
+
parser.add_argument("--gamma", type=float, default=2)
|
314 |
+
args = parser.parse_args()
|
315 |
+
|
316 |
+
main(args)
|
siclib/eval/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from siclib.eval.eval_pipeline import EvalPipeline
|
4 |
+
from siclib.utils.tools import get_class
|
5 |
+
|
6 |
+
|
7 |
+
def get_benchmark(benchmark):
|
8 |
+
return get_class(f"{__name__}.{benchmark}", EvalPipeline)
|
9 |
+
|
10 |
+
|
11 |
+
@torch.no_grad()
|
12 |
+
def run_benchmark(benchmark, eval_conf, experiment_dir, model=None):
|
13 |
+
"""This overwrites existing benchmarks"""
|
14 |
+
experiment_dir.mkdir(exist_ok=True, parents=True)
|
15 |
+
bm = get_benchmark(benchmark)
|
16 |
+
|
17 |
+
pipeline = bm(eval_conf)
|
18 |
+
return pipeline.run(experiment_dir, model=model, overwrite=True, overwrite_eval=True)
|
siclib/eval/configs/deepcalib.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
name: networks.deepcalib
|
3 |
+
weights: weights/deepcalib.tar
|
siclib/eval/configs/geocalib-pinhole.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
name: networks.geocalib_pretrained
|