diff --git a/vggsfm/.gitignore b/vggsfm/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..1c67c2a5bdadb8f9ea0f85dbb52c19ffee1928e4 --- /dev/null +++ b/vggsfm/.gitignore @@ -0,0 +1,143 @@ +.hydra/ +output/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Profiling data +.prof + +# Folder specific to your needs +**/tmp/ +**/outputs/ diff --git a/vggsfm/CHANGELOG.txt b/vggsfm/CHANGELOG.txt new file mode 100644 index 0000000000000000000000000000000000000000..ae6325250373a998ea59fa5725828f6064469fcc --- /dev/null +++ b/vggsfm/CHANGELOG.txt @@ -0,0 +1,19 @@ +VGGSfM 2.0 + +* More powerful camera and track predictor +* Save the GPU memory usage by around 50% +* Add Poselib as an option for Fundamental Matrix Estimation +* Support COLMAP-tyle output +* Provide focal length in pixel +* Normalize the scene after each BA +* Remove preliminary_cameras for simplicity +* Switch to lightglue instead of gluefactory +* Upgrade pycolmap from 0.5.0 to 0.6.1 + + + +TODO: +1. Make precision consistent +2. Provide a step-by-step instruction for using VGGSfM for 3D Gaussian +3. Support shared cameras + diff --git a/vggsfm/CODE_OF_CONDUCT.md b/vggsfm/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..08b500a221857ec3f451338e80b4a9ab1173a1af --- /dev/null +++ b/vggsfm/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/vggsfm/CONTRIBUTING.md b/vggsfm/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..c88cc4f734d301267f3e7c00f6cfe4baf9a8222c --- /dev/null +++ b/vggsfm/CONTRIBUTING.md @@ -0,0 +1,31 @@ +# Contributing to PoseDiffusion +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to PoseDiffusion, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/vggsfm/LICENSE.txt b/vggsfm/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..108b5f002fc31efe11d881de2cd05329ebe8cc37 --- /dev/null +++ b/vggsfm/LICENSE.txt @@ -0,0 +1,399 @@ +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/vggsfm/README.md b/vggsfm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d5a4b49ed9f4a796a101223f1e0ff1630ff53fcc --- /dev/null +++ b/vggsfm/README.md @@ -0,0 +1,117 @@ +# VGGSfM: Visual Geometry Grounded Deep Structure From Motion + + +![Teaser](https://raw.githubusercontent.com/vggsfm/vggsfm.github.io/main/resources/vggsfm_teaser.gif) + +**[Meta AI Research, GenAI](https://ai.facebook.com/research/)**; **[University of Oxford, VGG](https://www.robots.ox.ac.uk/~vgg/)** + + +[Jianyuan Wang](https://jytime.github.io/), [Nikita Karaev](https://nikitakaraevv.github.io/), [Christian Rupprecht](https://chrirupp.github.io/), [David Novotny](https://d-novotny.github.io/) + + + +

[Paper] +[Project Page] +[Version 2.0] +

+ + +**Updates:** +- [Jun 25, 2024] Upgrade to VGGSfM 2.0! More memory efficient, more robust, more powerful, and easier to start! + + +- [Apr 23, 2024] Release the code and model weight for VGGSfM v1.1. + + + + +## Installation +We provide a simple installation script that, by default, sets up a conda environment with Python 3.10, PyTorch 2.1, and CUDA 12.1. + +```.bash +source install.sh +``` + +This script installs official pytorch3d, accelerate, lightglue, pycolmap, and visdom. Besides, it will also (optionally) install [poselib](https://github.com/PoseLib/PoseLib) using the python wheel under the folder ```wheels```, which is compiled by us instead of the official poselib team. + +## Demo + +### 1. Download Model +To get started, you need to first download the checkpoint. We provide the checkpoint for v2.0 model by [Hugging Face](https://huggingface.co./facebook/VGGSfM/blob/main/vggsfm_v2_0_0.bin) and [Google Drive](https://drive.google.com/file/d/163bHiqeTJhQ2_UnihRNPRA4Y9X8-gZ1-/view?usp=sharing). + +### 2. Run the Demo + +Now time to enjoy your 3D reconstruction! You can start by our provided examples, such as: + +```bash +python demo.py SCENE_DIR=examples/cake resume_ckpt=/PATH/YOUR/CKPT + +python demo.py SCENE_DIR=examples/british_museum query_frame_num=2 resume_ckpt=/PATH/YOUR/CKPT + +python demo.py SCENE_DIR=examples/apple query_frame_num=5 max_query_pts=1600 resume_ckpt=/PATH/YOUR/CKPT +``` + +All default settings for the flags are specified in `cfgs/demo.yaml`. For example, we have modified the values of `query_frame_num` and `max_query_pts` from the default settings of `3` and `4096` to `5` and `1600`, respectively, to ensure a 32 GB GPU can work for ```examples/apple```. + + +The reconstruction result (camera parameters and 3D points) will be automatically saved in the COLMAP format at ```output/seq_name```. You can use the [COLMAP GUI](https://colmap.github.io/gui.html) to view them. + +If you want to visualize it more easily, we provide an approach supported by [visdom](https://github.com/fossasia/visdom). To begin using Visdom, start the server by entering ```visdom``` in the command line. Once the server is running, access Visdom by navigating to ```http://localhost:8097``` in your web browser. Now every reconstruction will be visualized and saved to the visdom server by enabling ```visualize=True```: + +```bash +python demo.py visualize=True ...(other flags) +``` + +By doing so, you should see an interface such as: + +![UI](assets/ui.png) + + + +### 3. Try your own data + +You only need to specify the address of your data, such as: + +```bash +python demo.py SCENE_DIR=examples/YOUR_FOLDER ...(other flags) +``` + +Please ensure that the images are stored in ```YOUR_FOLDER/images```. This folder should contain only the images. Check the ```examples``` folder for the desired data structure. + + +Have fun and feel free to create an issue if you meet any problem. SfM is always about corner/hard cases. I am happy to help. If you prefer not to share your images publicly, please send them to me by email. + +### FAQ + +* What should I do if I encounter an out-of-memory error? + +To resolve an out-of-memory error, you can start by reducing the number of ```max_query_pts``` from the default ```4096``` to a lower value. If necessary, consider decreasing the ```query_frame_num```. Be aware that these adjustments may result in a sparser point cloud and could potentially impact the accuracy of the reconstruction. + + + +## Testing + +We are still preparing the testing script for VGGSfM v2. However, you can use our code for VGGSfM v1.1 to reproduce our benchmark results in the paper. Please refer to the branch ```v1.1```. + + +## Acknowledgement + +We are highly inspired by [colmap](https://github.com/colmap/colmap), [pycolmap](https://github.com/colmap/pycolmap), [posediffusion](https://github.com/facebookresearch/PoseDiffusion), [cotracker](https://github.com/facebookresearch/co-tracker), and [kornia](https://github.com/kornia/kornia). + + +## License +See the [LICENSE](./LICENSE) file for details about the license under which this code is made available. + + +## Citing VGGSfM + +If you find our repository useful, please consider giving it a star ⭐ and citing our paper in your work: + +```bibtex +@article{wang2023vggsfm, + title={VGGSfM: Visual Geometry Grounded Deep Structure From Motion}, + author={Wang, Jianyuan and Karaev, Nikita and Rupprecht, Christian and Novotny, David}, + journal={arXiv preprint arXiv:2312.04563}, + year={2023} +} diff --git a/vggsfm/assets/ui.png b/vggsfm/assets/ui.png new file mode 100644 index 0000000000000000000000000000000000000000..ed62366664b5827df911421eda20480fa1594d99 Binary files /dev/null and b/vggsfm/assets/ui.png differ diff --git a/vggsfm/cfgs/demo.yaml b/vggsfm/cfgs/demo.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5811653e330ff867048b7722e2186a17d195019a --- /dev/null +++ b/vggsfm/cfgs/demo.yaml @@ -0,0 +1,80 @@ +hydra: + run: + dir: . + +seed: 0 +img_size: 1024 + +viz_ip: 127.0.0.1 + +debug: False + +center_order: True +mixed_precision: fp16 +extract_color: True +filter_invalid_frame: True + +comple_nonvis: True +query_frame_num: 3 +robust_refine: 2 +BA_iters: 2 + + +load_gt: False +visualize: False +fmat_thres: 4.0 +max_reproj_error: 4.0 +init_max_reproj_error: 4.0 +max_query_pts: 4096 + + + +SCENE_DIR: examples/cake + +resume_ckpt: ckpt/vggsfm_v2_0_0.bin + + +query_method: "sp+sift" + +use_poselib: True + +MODEL: + _target_: vggsfm.models.VGGSfM + + TRACK: + _target_: vggsfm.models.TrackerPredictor + + efficient_corr: False + + COARSE: + stride: 4 + down_ratio: 2 + FEATURENET: + _target_: vggsfm.models.BasicEncoder + + PREDICTOR: + _target_: vggsfm.models.BaseTrackerPredictor + + FINE: + FEATURENET: + _target_: vggsfm.models.ShallowEncoder + + + PREDICTOR: + _target_: vggsfm.models.BaseTrackerPredictor + depth: 4 + corr_levels: 3 + corr_radius: 3 + latent_dim: 32 + hidden_size: 256 + fine: True + use_spaceatt: False + + CAMERA: + _target_: vggsfm.models.CameraPredictor + + + TRIANGULAE: + _target_: vggsfm.models.Triangulator + + diff --git a/vggsfm/demo.py b/vggsfm/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..d6fe79367b7fc7fecac8665a9d6cb5b7799b47a2 --- /dev/null +++ b/vggsfm/demo.py @@ -0,0 +1,489 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import time +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch.cuda.amp import autocast +import hydra + +from omegaconf import DictConfig, OmegaConf +from hydra.utils import instantiate + +from lightglue import LightGlue, SuperPoint, SIFT, ALIKED + +import pycolmap + +from visdom import Visdom + + +from vggsfm.datasets.demo_loader import DemoLoader +from vggsfm.two_view_geo.estimate_preliminary import estimate_preliminary_cameras + +try: + import poselib + from vggsfm.two_view_geo.estimate_preliminary import estimate_preliminary_cameras_poselib + + print("Poselib is available") +except: + print("Poselib is not installed. Please disable use_poselib") + +from vggsfm.utils.utils import ( + set_seed_and_print, + farthest_point_sampling, + calculate_index_mappings, + switch_tensor_order, +) + + +@hydra.main(config_path="cfgs/", config_name="demo") +def demo_fn(cfg: DictConfig): + OmegaConf.set_struct(cfg, False) + + # Print configuration + print("Model Config:", OmegaConf.to_yaml(cfg)) + + torch.backends.cudnn.enabled = False + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = True + + # Set seed + seed_all_random_engines(cfg.seed) + + # Model instantiation + model = instantiate(cfg.MODEL, _recursive_=False, cfg=cfg) + + device = "cuda" if torch.cuda.is_available() else "cpu" + + model = model.to(device) + + # Prepare test dataset + test_dataset = DemoLoader( + SCENE_DIR=cfg.SCENE_DIR, img_size=cfg.img_size, normalize_cameras=False, load_gt=cfg.load_gt, cfg=cfg + ) + + if cfg.resume_ckpt: + # Reload model + checkpoint = torch.load(cfg.resume_ckpt) + model.load_state_dict(checkpoint, strict=True) + print(f"Successfully resumed from {cfg.resume_ckpt}") + + if cfg.visualize: + from pytorch3d.structures import Pointclouds + from pytorch3d.vis.plotly_vis import plot_scene + from pytorch3d.renderer.cameras import PerspectiveCameras as PerspectiveCamerasVisual + + viz = Visdom() + + + sequence_list = test_dataset.sequence_list + + for seq_name in sequence_list: + print("*" * 50 + f" Testing on Scene {seq_name} " + "*" * 50) + + # Load the data + batch, image_paths = test_dataset.get_data(sequence_name=seq_name, return_path=True) + + # Send to GPU + images = batch["image"].to(device) + crop_params = batch["crop_params"].to(device) + + + # Unsqueeze to have batch size = 1 + images = images.unsqueeze(0) + crop_params = crop_params.unsqueeze(0) + + batch_size = len(images) + + with torch.no_grad(): + # Run the model + assert cfg.mixed_precision in ("None", "bf16", "fp16") + if cfg.mixed_precision == "None": + dtype = torch.float32 + elif cfg.mixed_precision == "bf16": + dtype = torch.bfloat16 + elif cfg.mixed_precision == "fp16": + dtype = torch.float16 + else: + raise NotImplementedError(f"dtype {cfg.mixed_precision} is not supported now") + + predictions = run_one_scene( + model, + images, + crop_params=crop_params, + query_frame_num=cfg.query_frame_num, + image_paths=image_paths, + dtype=dtype, + cfg=cfg, + ) + + # Export prediction as colmap format + reconstruction_pycolmap = predictions["reconstruction"] + output_path = os.path.join("output", seq_name) + print("-" * 50) + print(f"The output has been saved in COLMAP style at: {output_path} ") + os.makedirs(output_path, exist_ok=True) + reconstruction_pycolmap.write(output_path) + + pred_cameras_PT3D = predictions["pred_cameras_PT3D"] + + if cfg.visualize: + if "points3D_rgb" in predictions: + pcl = Pointclouds(points=predictions["points3D"][None], features=predictions["points3D_rgb"][None]) + else: + pcl = Pointclouds(points=predictions["points3D"][None]) + + visual_cameras = PerspectiveCamerasVisual( + R=pred_cameras_PT3D.R, + T=pred_cameras_PT3D.T, + device=pred_cameras_PT3D.device, + ) + + visual_dict = {"scenes": {"points": pcl, "cameras": visual_cameras}} + + fig = plot_scene(visual_dict, camera_scale=0.05) + + env_name = f"demo_visual_{seq_name}" + print(f"Visualizing the scene by visdom at env: {env_name}") + viz.plotlyplot(fig, env=env_name, win="3D") + + return True + + +def run_one_scene(model, images, crop_params=None, query_frame_num=3, image_paths=None, dtype=None, cfg=None): + """ + images have been normalized to the range [0, 1] instead of [0, 255] + """ + batch_num, frame_num, image_dim, height, width = images.shape + device = images.device + reshaped_image = images.reshape(batch_num * frame_num, image_dim, height, width) + + predictions = {} + extra_dict = {} + + camera_predictor = model.camera_predictor + track_predictor = model.track_predictor + triangulator = model.triangulator + + # Find the query frames + # First use DINO to find the most common frame among all the input frames + # i.e., the one has highest (average) cosine similarity to all others + # Then use farthest_point_sampling to find the next ones + # The number of query frames is determined by query_frame_num + + with autocast(dtype=dtype): + query_frame_indexes = find_query_frame_indexes(reshaped_image, camera_predictor, frame_num) + + image_paths = [os.path.basename(imgpath) for imgpath in image_paths] + + if cfg.center_order: + # The code below switchs the first frame (frame 0) to the most common frame + center_frame_index = query_frame_indexes[0] + center_order = calculate_index_mappings(center_frame_index, frame_num, device=device) + + images, crop_params = switch_tensor_order([images, crop_params], center_order, dim=1) + reshaped_image = switch_tensor_order([reshaped_image], center_order, dim=0)[0] + + image_paths = [image_paths[i] for i in center_order.cpu().numpy().tolist()] + + # Also update query_frame_indexes: + query_frame_indexes = [center_frame_index if x == 0 else x for x in query_frame_indexes] + query_frame_indexes[0] = 0 + + # only pick query_frame_num + query_frame_indexes = query_frame_indexes[:query_frame_num] + + # Prepare image feature maps for tracker + fmaps_for_tracker = track_predictor.process_images_to_fmaps(images) + + # Predict tracks + with autocast(dtype=dtype): + pred_track, pred_vis, pred_score = predict_tracks( + cfg.query_method, + cfg.max_query_pts, + track_predictor, + images, + fmaps_for_tracker, + query_frame_indexes, + frame_num, + device, + cfg, + ) + + if cfg.comple_nonvis: + pred_track, pred_vis, pred_score = comple_nonvis_frames( + track_predictor, + images, + fmaps_for_tracker, + frame_num, + device, + pred_track, + pred_vis, + pred_score, + 500, + cfg=cfg, + ) + + torch.cuda.empty_cache() + + # If necessary, force all the predictions at the padding areas as non-visible + if crop_params is not None: + boundaries = crop_params[:, :, -4:-2].abs().to(device) + boundaries = torch.cat([boundaries, reshaped_image.shape[-1] - boundaries], dim=-1) + hvis = torch.logical_and( + pred_track[..., 1] >= boundaries[:, :, 1:2], pred_track[..., 1] <= boundaries[:, :, 3:4] + ) + wvis = torch.logical_and( + pred_track[..., 0] >= boundaries[:, :, 0:1], pred_track[..., 0] <= boundaries[:, :, 2:3] + ) + force_vis = torch.logical_and(hvis, wvis) + pred_vis = pred_vis * force_vis.float() + + # TODO: plot 2D matches + if cfg.use_poselib: + estimate_preliminary_cameras_fn = estimate_preliminary_cameras_poselib + else: + estimate_preliminary_cameras_fn = estimate_preliminary_cameras + + # Estimate preliminary_cameras by recovering fundamental/essential/homography matrix from 2D matches + # By default, we use fundamental matrix estimation with 7p/8p+LORANSAC + # All the operations are batched and differentiable (if necessary) + # except when you enable use_poselib to save GPU memory + _, preliminary_dict = estimate_preliminary_cameras_fn( + pred_track, + pred_vis, + width, + height, + tracks_score=pred_score, + max_error=cfg.fmat_thres, + loopresidual=True, + # max_ransac_iters=cfg.max_ransac_iters, + ) + + pose_predictions = camera_predictor(reshaped_image, batch_size=batch_num) + + pred_cameras = pose_predictions["pred_cameras"] + + # Conduct Triangulation and Bundle Adjustment + ( + BA_cameras_PT3D, + extrinsics_opencv, + intrinsics_opencv, + points3D, + points3D_rgb, + reconstruction, + valid_frame_mask, + ) = triangulator( + pred_cameras, + pred_track, + pred_vis, + images, + preliminary_dict, + image_paths=image_paths, + crop_params=crop_params, + pred_score=pred_score, + fmat_thres=cfg.fmat_thres, + BA_iters=cfg.BA_iters, + max_reproj_error = cfg.max_reproj_error, + init_max_reproj_error=cfg.init_max_reproj_error, + cfg=cfg, + ) + + if cfg.center_order: + # NOTE we changed the image order previously, now we need to switch it back + BA_cameras_PT3D = BA_cameras_PT3D[center_order] + extrinsics_opencv = extrinsics_opencv[center_order] + intrinsics_opencv = intrinsics_opencv[center_order] + + predictions["pred_cameras_PT3D"] = BA_cameras_PT3D + predictions["extrinsics_opencv"] = extrinsics_opencv + predictions["intrinsics_opencv"] = intrinsics_opencv + predictions["points3D"] = points3D + predictions["points3D_rgb"] = points3D_rgb + predictions["reconstruction"] = reconstruction + return predictions + + +def predict_tracks( + query_method, + max_query_pts, + track_predictor, + images, + fmaps_for_tracker, + query_frame_indexes, + frame_num, + device, + cfg=None, +): + pred_track_list = [] + pred_vis_list = [] + pred_score_list = [] + + for query_index in query_frame_indexes: + print(f"Predicting tracks with query_index = {query_index}") + + # Find query_points at the query frame + query_points = get_query_points(images[:, query_index], query_method, max_query_pts) + + # Switch so that query_index frame stays at the first frame + # This largely simplifies the code structure of tracker + new_order = calculate_index_mappings(query_index, frame_num, device=device) + images_feed, fmaps_feed = switch_tensor_order([images, fmaps_for_tracker], new_order) + + # Feed into track predictor + fine_pred_track, _, pred_vis, pred_score = track_predictor(images_feed, query_points, fmaps=fmaps_feed) + + # Switch back the predictions + fine_pred_track, pred_vis, pred_score = switch_tensor_order([fine_pred_track, pred_vis, pred_score], new_order) + + # Append predictions for different queries + pred_track_list.append(fine_pred_track) + pred_vis_list.append(pred_vis) + pred_score_list.append(pred_score) + + pred_track = torch.cat(pred_track_list, dim=2) + pred_vis = torch.cat(pred_vis_list, dim=2) + pred_score = torch.cat(pred_score_list, dim=2) + + return pred_track, pred_vis, pred_score + + +def comple_nonvis_frames( + track_predictor, + images, + fmaps_for_tracker, + frame_num, + device, + pred_track, + pred_vis, + pred_score, + min_vis=500, + cfg=None, +): + # if a frame has too few visible inlier, use it as a query + non_vis_frames = torch.nonzero((pred_vis.squeeze(0) > 0.05).sum(-1) < min_vis).squeeze(-1).tolist() + last_query = -1 + while len(non_vis_frames) > 0: + print("Processing non visible frames") + print(non_vis_frames) + if non_vis_frames[0] == last_query: + print("The non vis frame still does not has enough 2D matches") + pred_track_comple, pred_vis_comple, pred_score_comple = predict_tracks( + "sp+sift+aliked", + cfg.max_query_pts // 2, + track_predictor, + images, + fmaps_for_tracker, + non_vis_frames, + frame_num, + device, + cfg, + ) + # concat predictions + pred_track = torch.cat([pred_track, pred_track_comple], dim=2) + pred_vis = torch.cat([pred_vis, pred_vis_comple], dim=2) + pred_score = torch.cat([pred_score, pred_score_comple], dim=2) + break + + non_vis_query_list = [non_vis_frames[0]] + last_query = non_vis_frames[0] + pred_track_comple, pred_vis_comple, pred_score_comple = predict_tracks( + cfg.query_method, + cfg.max_query_pts, + track_predictor, + images, + fmaps_for_tracker, + non_vis_query_list, + frame_num, + device, + cfg, + ) + + # concat predictions + pred_track = torch.cat([pred_track, pred_track_comple], dim=2) + pred_vis = torch.cat([pred_vis, pred_vis_comple], dim=2) + pred_score = torch.cat([pred_score, pred_score_comple], dim=2) + non_vis_frames = torch.nonzero((pred_vis.squeeze(0) > 0.05).sum(-1) < min_vis).squeeze(-1).tolist() + return pred_track, pred_vis, pred_score + + +def find_query_frame_indexes(reshaped_image, camera_predictor, query_frame_num, image_size=336): + # Downsample image to image_size x image_size + # because we found it is unnecessary to use high resolution + rgbs = F.interpolate(reshaped_image, (image_size, image_size), mode="bilinear", align_corners=True) + rgbs = camera_predictor._resnet_normalize_image(rgbs) + + # Get the image features (patch level) + frame_feat = camera_predictor.backbone(rgbs, is_training=True) + frame_feat = frame_feat["x_norm_patchtokens"] + frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) + + # Compute the similiarty matrix + frame_feat_norm = frame_feat_norm.permute(1, 0, 2) + similarity_matrix = torch.bmm(frame_feat_norm, frame_feat_norm.transpose(-1, -2)) + similarity_matrix = similarity_matrix.mean(dim=0) + distance_matrix = 1 - similarity_matrix.clone() + + # Ignore self-pairing + similarity_matrix.fill_diagonal_(0) + + similarity_sum = similarity_matrix.sum(dim=1) + + # Find the most common frame + most_common_frame_index = torch.argmax(similarity_sum).item() + + # Conduct FPS sampling + # Starting from the most_common_frame_index, + # try to find the farthest frame, + # then the farthest to the last found frame + # (frames are not allowed to be found twice) + fps_idx = farthest_point_sampling(distance_matrix, query_frame_num, most_common_frame_index) + + return fps_idx + + +def get_query_points(query_image, query_method, max_query_num=4096, det_thres=0.005): + # Run superpoint and sift on the target frame + # Feel free to modify for your own + + methods = query_method.split("+") + pred_points = [] + + for method in methods: + if "sp" in method: + extractor = SuperPoint(max_num_keypoints=max_query_num, detection_threshold=det_thres).cuda().eval() + elif "sift" in method: + extractor = SIFT(max_num_keypoints=max_query_num).cuda().eval() + elif "aliked" in method: + extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres).cuda().eval() + else: + raise NotImplementedError(f"query method {method} is not supprted now") + + query_points = extractor.extract(query_image)["keypoints"] + pred_points.append(query_points) + + query_points = torch.cat(pred_points, dim=1) + + if query_points.shape[1] > max_query_num: + random_point_indices = torch.randperm(query_points.shape[1])[:max_query_num] + query_points = query_points[:, random_point_indices, :] + + return query_points + + +def seed_all_random_engines(seed: int) -> None: + np.random.seed(seed) + torch.manual_seed(seed) + random.seed(seed) + + +if __name__ == "__main__": + with torch.no_grad(): + demo_fn() diff --git a/vggsfm/examples/apple/images/frame000007.jpg b/vggsfm/examples/apple/images/frame000007.jpg new file mode 100644 index 0000000000000000000000000000000000000000..154d162cbdf8d818dac50a87bbdbb7c57b82608c Binary files /dev/null and b/vggsfm/examples/apple/images/frame000007.jpg differ diff --git a/vggsfm/examples/apple/images/frame000012.jpg b/vggsfm/examples/apple/images/frame000012.jpg new file mode 100644 index 0000000000000000000000000000000000000000..21b713cfe3076c1abc1ca5fed1a6dcd7c9180570 Binary files /dev/null and b/vggsfm/examples/apple/images/frame000012.jpg differ diff --git a/vggsfm/examples/apple/images/frame000017.jpg b/vggsfm/examples/apple/images/frame000017.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f9335e6b02e059fc582d2625076d39739eca57e2 Binary files /dev/null and b/vggsfm/examples/apple/images/frame000017.jpg differ diff --git a/vggsfm/examples/apple/images/frame000019.jpg b/vggsfm/examples/apple/images/frame000019.jpg new file mode 100644 index 0000000000000000000000000000000000000000..13f4d6e5288cb13aef26d0acc8267b8416310600 Binary files /dev/null and b/vggsfm/examples/apple/images/frame000019.jpg differ diff --git a/vggsfm/examples/apple/images/frame000024.jpg b/vggsfm/examples/apple/images/frame000024.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d3f038320f8414e1b67c42630005be54e45fd497 Binary files /dev/null and b/vggsfm/examples/apple/images/frame000024.jpg differ diff --git a/vggsfm/examples/apple/images/frame000025.jpg b/vggsfm/examples/apple/images/frame000025.jpg new file mode 100644 index 0000000000000000000000000000000000000000..30126ef69f4afe7d40e626415350a3e44479b5a5 Binary files /dev/null and b/vggsfm/examples/apple/images/frame000025.jpg differ diff --git a/vggsfm/examples/apple/images/frame000043.jpg b/vggsfm/examples/apple/images/frame000043.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d175b48ff17c8136b846f1c5add1f8d3b2a1218a Binary files /dev/null and b/vggsfm/examples/apple/images/frame000043.jpg differ diff --git a/vggsfm/examples/apple/images/frame000052.jpg b/vggsfm/examples/apple/images/frame000052.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e709e4996fbe49e6b71f279089dba15836b6d0e2 Binary files /dev/null and b/vggsfm/examples/apple/images/frame000052.jpg differ diff --git a/vggsfm/examples/apple/images/frame000070.jpg b/vggsfm/examples/apple/images/frame000070.jpg new file mode 100644 index 0000000000000000000000000000000000000000..793297b4fa9ba79d74e725290d68526ecaab20d8 Binary files /dev/null and b/vggsfm/examples/apple/images/frame000070.jpg differ diff --git a/vggsfm/examples/apple/images/frame000077.jpg b/vggsfm/examples/apple/images/frame000077.jpg new file mode 100644 index 0000000000000000000000000000000000000000..13d70f8643d1554bc02cd0a2e260d4be52214928 Binary files /dev/null and b/vggsfm/examples/apple/images/frame000077.jpg differ diff --git a/vggsfm/examples/apple/images/frame000085.jpg b/vggsfm/examples/apple/images/frame000085.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2988869ef0b88f1e900bcf1dedab5afdada85bc1 Binary files /dev/null and b/vggsfm/examples/apple/images/frame000085.jpg differ diff --git a/vggsfm/examples/apple/images/frame000096.jpg b/vggsfm/examples/apple/images/frame000096.jpg new file mode 100644 index 0000000000000000000000000000000000000000..515852c7cc303a9bfdd3df5880f01d10fb633b07 Binary files /dev/null and b/vggsfm/examples/apple/images/frame000096.jpg differ diff --git a/vggsfm/examples/apple/images/frame000128.jpg b/vggsfm/examples/apple/images/frame000128.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bbed5e0e78569e2c6a93a0b1e0e3a39d87c9de4f Binary files /dev/null and b/vggsfm/examples/apple/images/frame000128.jpg differ diff --git a/vggsfm/examples/apple/images/frame000145.jpg b/vggsfm/examples/apple/images/frame000145.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b4e917619741bcd45bc72c35bf579ebe5b5fd618 Binary files /dev/null and b/vggsfm/examples/apple/images/frame000145.jpg differ diff --git a/vggsfm/examples/apple/images/frame000160.jpg b/vggsfm/examples/apple/images/frame000160.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bf50a01c0ef06ba2043ba9c0afdfe068e67c6683 Binary files /dev/null and b/vggsfm/examples/apple/images/frame000160.jpg differ diff --git a/vggsfm/examples/apple/images/frame000162.jpg b/vggsfm/examples/apple/images/frame000162.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6764171b9584e544b7a2609c75d2589315d77c9b Binary files /dev/null and b/vggsfm/examples/apple/images/frame000162.jpg differ diff --git a/vggsfm/examples/apple/images/frame000168.jpg b/vggsfm/examples/apple/images/frame000168.jpg new file mode 100644 index 0000000000000000000000000000000000000000..727e5b395ee0ea145799bc3268a1628089ee0c0e Binary files /dev/null and b/vggsfm/examples/apple/images/frame000168.jpg differ diff --git a/vggsfm/examples/apple/images/frame000172.jpg b/vggsfm/examples/apple/images/frame000172.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0e96c742636192831f61aeabfbc260a3da16158d Binary files /dev/null and b/vggsfm/examples/apple/images/frame000172.jpg differ diff --git a/vggsfm/examples/apple/images/frame000191.jpg b/vggsfm/examples/apple/images/frame000191.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d91efdf1687a6fbbd1e7e63dc7be3796970b0498 Binary files /dev/null and b/vggsfm/examples/apple/images/frame000191.jpg differ diff --git a/vggsfm/examples/apple/images/frame000200.jpg b/vggsfm/examples/apple/images/frame000200.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b4adaa2094f0616998d4eee449d26a38f7361d91 Binary files /dev/null and b/vggsfm/examples/apple/images/frame000200.jpg differ diff --git a/vggsfm/examples/british_museum/images/29057984_287139632.jpg b/vggsfm/examples/british_museum/images/29057984_287139632.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5b8d4103af3ae3ed7d5db6ee38cfdc9f411ef64a Binary files /dev/null and b/vggsfm/examples/british_museum/images/29057984_287139632.jpg differ diff --git a/vggsfm/examples/british_museum/images/32630292_7166579210.jpg b/vggsfm/examples/british_museum/images/32630292_7166579210.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d6905a3c1e4a9f98bd33c679f2d649935d511d6f Binary files /dev/null and b/vggsfm/examples/british_museum/images/32630292_7166579210.jpg differ diff --git a/vggsfm/examples/british_museum/images/45839934_4117745134.jpg b/vggsfm/examples/british_museum/images/45839934_4117745134.jpg new file mode 100644 index 0000000000000000000000000000000000000000..db390cbd2b5710a43cb51a04a28e2953ffcf5495 Binary files /dev/null and b/vggsfm/examples/british_museum/images/45839934_4117745134.jpg differ diff --git a/vggsfm/examples/british_museum/images/51004432_567773767.jpg b/vggsfm/examples/british_museum/images/51004432_567773767.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fadfc0df07cb422d1ac7ea44e0ec78a8bb82f5d4 Binary files /dev/null and b/vggsfm/examples/british_museum/images/51004432_567773767.jpg differ diff --git a/vggsfm/examples/british_museum/images/62620282_3728576515.jpg b/vggsfm/examples/british_museum/images/62620282_3728576515.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0f71489713d21fb212b3c994cf866aedf728acca Binary files /dev/null and b/vggsfm/examples/british_museum/images/62620282_3728576515.jpg differ diff --git a/vggsfm/examples/british_museum/images/71931631_7212707886.jpg b/vggsfm/examples/british_museum/images/71931631_7212707886.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a42de7b5eac0fc3a151297e161460a3a42176fb6 Binary files /dev/null and b/vggsfm/examples/british_museum/images/71931631_7212707886.jpg differ diff --git a/vggsfm/examples/british_museum/images/78600497_407639599.jpg b/vggsfm/examples/british_museum/images/78600497_407639599.jpg new file mode 100644 index 0000000000000000000000000000000000000000..af39d2a6dd4111492001f92f0c0a8b2a1c0e86c5 Binary files /dev/null and b/vggsfm/examples/british_museum/images/78600497_407639599.jpg differ diff --git a/vggsfm/examples/british_museum/images/80340357_5029510336.jpg b/vggsfm/examples/british_museum/images/80340357_5029510336.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a6d81c675fa7cc861db656c39758951899e2d172 Binary files /dev/null and b/vggsfm/examples/british_museum/images/80340357_5029510336.jpg differ diff --git a/vggsfm/examples/british_museum/images/81272348_2712949069.jpg b/vggsfm/examples/british_museum/images/81272348_2712949069.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1279c226a4b775402b9d2dd3b0483d2b3384f14d Binary files /dev/null and b/vggsfm/examples/british_museum/images/81272348_2712949069.jpg differ diff --git a/vggsfm/examples/british_museum/images/93266801_2335569192.jpg b/vggsfm/examples/british_museum/images/93266801_2335569192.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6d83eb80ac1fe6ed458c06d63ee616c0bb178dce Binary files /dev/null and b/vggsfm/examples/british_museum/images/93266801_2335569192.jpg differ diff --git a/vggsfm/examples/cake/images/frame000020.jpg b/vggsfm/examples/cake/images/frame000020.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0cfabaf20c5eafde87d189d2ca775019821cccd6 Binary files /dev/null and b/vggsfm/examples/cake/images/frame000020.jpg differ diff --git a/vggsfm/examples/cake/images/frame000069.jpg b/vggsfm/examples/cake/images/frame000069.jpg new file mode 100644 index 0000000000000000000000000000000000000000..eeb477fe6deb02f529c8e5388c85ebdf97196098 Binary files /dev/null and b/vggsfm/examples/cake/images/frame000069.jpg differ diff --git a/vggsfm/examples/cake/images/frame000096.jpg b/vggsfm/examples/cake/images/frame000096.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c897f37aa13cdfe0830bee5912b310ad84d70c1b Binary files /dev/null and b/vggsfm/examples/cake/images/frame000096.jpg differ diff --git a/vggsfm/examples/cake/images/frame000112.jpg b/vggsfm/examples/cake/images/frame000112.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6d25c4a7b655fd0526daaa899da3252b40be19c9 Binary files /dev/null and b/vggsfm/examples/cake/images/frame000112.jpg differ diff --git a/vggsfm/examples/cake/images/frame000146.jpg b/vggsfm/examples/cake/images/frame000146.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f6d84bdf0511b82be21dbeab4ae7a2b65e1fe050 Binary files /dev/null and b/vggsfm/examples/cake/images/frame000146.jpg differ diff --git a/vggsfm/examples/cake/images/frame000149.jpg b/vggsfm/examples/cake/images/frame000149.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ef4304b6494e265f186653a65bbae911a4f19c40 Binary files /dev/null and b/vggsfm/examples/cake/images/frame000149.jpg differ diff --git a/vggsfm/examples/cake/images/frame000166.jpg b/vggsfm/examples/cake/images/frame000166.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c3516f3e0bf14594c04247bf57f6ca01fc28db0f Binary files /dev/null and b/vggsfm/examples/cake/images/frame000166.jpg differ diff --git a/vggsfm/examples/cake/images/frame000169.jpg b/vggsfm/examples/cake/images/frame000169.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6531dca027fe4e20889c93db42f3b4dd28cf5173 Binary files /dev/null and b/vggsfm/examples/cake/images/frame000169.jpg differ diff --git a/vggsfm/install.sh b/vggsfm/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..b1c5000b6851bbe35583254c3a598c11b5b6ac58 --- /dev/null +++ b/vggsfm/install.sh @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# This Script Assumes Python 3.10, CUDA 12.1 + +conda deactivate + +# Set environment variables +export ENV_NAME=vggsfm +export PYTHON_VERSION=3.10 +export PYTORCH_VERSION=2.1.0 +export CUDA_VERSION=12.1 + +# Create a new conda environment and activate it +conda create -n $ENV_NAME python=$PYTHON_VERSION +conda activate $ENV_NAME + +# Install PyTorch, torchvision, and PyTorch3D using conda +conda install pytorch=$PYTORCH_VERSION torchvision pytorch-cuda=$CUDA_VERSION -c pytorch -c nvidia +conda install -c fvcore -c iopath -c conda-forge fvcore iopath +conda install pytorch3d -c pytorch3d + +# Install pip packages +pip install hydra-core --upgrade +pip install omegaconf opencv-python einops visdom +pip install accelerate==0.24.0 + +# Install LightGlue +git clone https://github.com/cvg/LightGlue.git dependency/LightGlue + +cd dependency/LightGlue/ +python -m pip install -e . # editable mode +cd ../../ + +# Force numpy <2 +pip install numpy==1.26.3 + +# Ensure the version of pycolmap is 0.6.1 +pip install pycolmap==0.6.1 + +# (Optional) Install poselib +pip install https://huggingface.co./facebook/VGGSfM/resolve/main/poselib-2.0.2-cp310-cp310-linux_x86_64.whl + + diff --git a/vggsfm/minipytorch3d/__init__.py b/vggsfm/minipytorch3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vggsfm/minipytorch3d/cameras.py b/vggsfm/minipytorch3d/cameras.py new file mode 100644 index 0000000000000000000000000000000000000000..58416e21760c968406d2fce8cec1d517ca20f5f4 --- /dev/null +++ b/vggsfm/minipytorch3d/cameras.py @@ -0,0 +1,1722 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import math +import warnings +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F + +# from pytorch3d.common.datatypes import Device + +from .device_utils import Device, get_device, make_device +from .transform3d import Rotate, Transform3d, Translate +from .renderer_utils import convert_to_tensors_and_broadcast, TensorProperties + + +# Default values for rotation and translation matrices. +_R = torch.eye(3)[None] # (1, 3, 3) +_T = torch.zeros(1, 3) # (1, 3) + +# An input which is a float per batch element +_BatchFloatType = Union[float, Sequence[float], torch.Tensor] + +# one or two floats per batch element +_FocalLengthType = Union[float, Sequence[Tuple[float]], Sequence[Tuple[float, float]], torch.Tensor] + + +class CamerasBase(TensorProperties): + """ + `CamerasBase` implements a base class for all cameras. + + For cameras, there are four different coordinate systems (or spaces) + - World coordinate system: This is the system the object lives - the world. + - Camera view coordinate system: This is the system that has its origin on + the camera and the Z-axis perpendicular to the image plane. + In PyTorch3D, we assume that +X points left, and +Y points up and + +Z points out from the image plane. + The transformation from world --> view happens after applying a rotation (R) + and translation (T) + - NDC coordinate system: This is the normalized coordinate system that confines + points in a volume the rendered part of the object or scene, also known as + view volume. For square images, given the PyTorch3D convention, (+1, +1, znear) + is the top left near corner, and (-1, -1, zfar) is the bottom right far + corner of the volume. + The transformation from view --> NDC happens after applying the camera + projection matrix (P) if defined in NDC space. + For non square images, we scale the points such that smallest side + has range [-1, 1] and the largest side has range [-u, u], with u > 1. + - Screen coordinate system: This is another representation of the view volume with + the XY coordinates defined in image space instead of a normalized space. + + An illustration of the coordinate systems can be found in pytorch3d/docs/notes/cameras.md. + + CameraBase defines methods that are common to all camera models: + - `get_camera_center` that returns the optical center of the camera in + world coordinates + - `get_world_to_view_transform` which returns a 3D transform from + world coordinates to the camera view coordinates (R, T) + - `get_full_projection_transform` which composes the projection + transform (P) with the world-to-view transform (R, T) + - `transform_points` which takes a set of input points in world coordinates and + projects to the space the camera is defined in (NDC or screen) + - `get_ndc_camera_transform` which defines the transform from screen/NDC to + PyTorch3D's NDC space + - `transform_points_ndc` which takes a set of points in world coordinates and + projects them to PyTorch3D's NDC space + - `transform_points_screen` which takes a set of points in world coordinates and + projects them to screen space + + For each new camera, one should implement the `get_projection_transform` + routine that returns the mapping from camera view coordinates to camera + coordinates (NDC or screen). + + Another useful function that is specific to each camera model is + `unproject_points` which sends points from camera coordinates (NDC or screen) + back to camera view or world coordinates depending on the `world_coordinates` + boolean argument of the function. + """ + + # Used in __getitem__ to index the relevant fields + # When creating a new camera, this should be set in the __init__ + _FIELDS: Tuple[str, ...] = () + + # Names of fields which are a constant property of the whole batch, rather + # than themselves a batch of data. + # When joining objects into a batch, they will have to agree. + _SHARED_FIELDS: Tuple[str, ...] = () + + def get_projection_transform(self, **kwargs): + """ + Calculate the projective transformation matrix. + + Args: + **kwargs: parameters for the projection can be passed in as keyword + arguments to override the default values set in `__init__`. + + Return: + a `Transform3d` object which represents a batch of projection + matrices of shape (N, 3, 3) + """ + raise NotImplementedError() + + def unproject_points(self, xy_depth: torch.Tensor, **kwargs): + """ + Transform input points from camera coordinates (NDC or screen) + to the world / camera coordinates. + + Each of the input points `xy_depth` of shape (..., 3) is + a concatenation of the x, y location and its depth. + + For instance, for an input 2D tensor of shape `(num_points, 3)` + `xy_depth` takes the following form: + `xy_depth[i] = [x[i], y[i], depth[i]]`, + for a each point at an index `i`. + + The following example demonstrates the relationship between + `transform_points` and `unproject_points`: + + .. code-block:: python + + cameras = # camera object derived from CamerasBase + xyz = # 3D points of shape (batch_size, num_points, 3) + # transform xyz to the camera view coordinates + xyz_cam = cameras.get_world_to_view_transform().transform_points(xyz) + # extract the depth of each point as the 3rd coord of xyz_cam + depth = xyz_cam[:, :, 2:] + # project the points xyz to the camera + xy = cameras.transform_points(xyz)[:, :, :2] + # append depth to xy + xy_depth = torch.cat((xy, depth), dim=2) + # unproject to the world coordinates + xyz_unproj_world = cameras.unproject_points(xy_depth, world_coordinates=True) + print(torch.allclose(xyz, xyz_unproj_world)) # True + # unproject to the camera coordinates + xyz_unproj = cameras.unproject_points(xy_depth, world_coordinates=False) + print(torch.allclose(xyz_cam, xyz_unproj)) # True + + Args: + xy_depth: torch tensor of shape (..., 3). + world_coordinates: If `True`, unprojects the points back to world + coordinates using the camera extrinsics `R` and `T`. + `False` ignores `R` and `T` and unprojects to + the camera view coordinates. + from_ndc: If `False` (default), assumes xy part of input is in + NDC space if self.in_ndc(), otherwise in screen space. If + `True`, assumes xy is in NDC space even if the camera + is defined in screen space. + + Returns + new_points: unprojected points with the same shape as `xy_depth`. + """ + raise NotImplementedError() + + def get_camera_center(self, **kwargs) -> torch.Tensor: + """ + Return the 3D location of the camera optical center + in the world coordinates. + + Args: + **kwargs: parameters for the camera extrinsics can be passed in + as keyword arguments to override the default values + set in __init__. + + Setting R or T here will update the values set in init as these + values may be needed later on in the rendering pipeline e.g. for + lighting calculations. + + Returns: + C: a batch of 3D locations of shape (N, 3) denoting + the locations of the center of each camera in the batch. + """ + w2v_trans = self.get_world_to_view_transform(**kwargs) + P = w2v_trans.inverse().get_matrix() + # the camera center is the translation component (the first 3 elements + # of the last row) of the inverted world-to-view + # transform (4x4 RT matrix) + C = P[:, 3, :3] + return C + + def get_world_to_view_transform(self, **kwargs) -> Transform3d: + """ + Return the world-to-view transform. + + Args: + **kwargs: parameters for the camera extrinsics can be passed in + as keyword arguments to override the default values + set in __init__. + + Setting R and T here will update the values set in init as these + values may be needed later on in the rendering pipeline e.g. for + lighting calculations. + + Returns: + A Transform3d object which represents a batch of transforms + of shape (N, 3, 3) + """ + R: torch.Tensor = kwargs.get("R", self.R) + T: torch.Tensor = kwargs.get("T", self.T) + self.R = R + self.T = T + world_to_view_transform = get_world_to_view_transform(R=R, T=T) + return world_to_view_transform + + def get_full_projection_transform(self, **kwargs) -> Transform3d: + """ + Return the full world-to-camera transform composing the + world-to-view and view-to-camera transforms. + If camera is defined in NDC space, the projected points are in NDC space. + If camera is defined in screen space, the projected points are in screen space. + + Args: + **kwargs: parameters for the projection transforms can be passed in + as keyword arguments to override the default values + set in __init__. + + Setting R and T here will update the values set in init as these + values may be needed later on in the rendering pipeline e.g. for + lighting calculations. + + Returns: + a Transform3d object which represents a batch of transforms + of shape (N, 3, 3) + """ + self.R: torch.Tensor = kwargs.get("R", self.R) + self.T: torch.Tensor = kwargs.get("T", self.T) + world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T) + view_to_proj_transform = self.get_projection_transform(**kwargs) + return world_to_view_transform.compose(view_to_proj_transform) + + def transform_points(self, points, eps: Optional[float] = None, **kwargs) -> torch.Tensor: + """ + Transform input points from world to camera space. + If camera is defined in NDC space, the projected points are in NDC space. + If camera is defined in screen space, the projected points are in screen space. + + For `CamerasBase.transform_points`, setting `eps > 0` + stabilizes gradients since it leads to avoiding division + by excessively low numbers for points close to the camera plane. + + Args: + points: torch tensor of shape (..., 3). + eps: If eps!=None, the argument is used to clamp the + divisor in the homogeneous normalization of the points + transformed to the ndc space. Please see + `transforms.Transform3d.transform_points` for details. + + For `CamerasBase.transform_points`, setting `eps > 0` + stabilizes gradients since it leads to avoiding division + by excessively low numbers for points close to the + camera plane. + + Returns + new_points: transformed points with the same shape as the input. + """ + world_to_proj_transform = self.get_full_projection_transform(**kwargs) + return world_to_proj_transform.transform_points(points, eps=eps) + + def get_ndc_camera_transform(self, **kwargs) -> Transform3d: + """ + Returns the transform from camera projection space (screen or NDC) to NDC space. + For cameras that can be specified in screen space, this transform + allows points to be converted from screen to NDC space. + The default transform scales the points from [0, W]x[0, H] + to [-1, 1]x[-u, u] or [-u, u]x[-1, 1] where u > 1 is the aspect ratio of the image. + This function should be modified per camera definitions if need be, + e.g. for Perspective/Orthographic cameras we provide a custom implementation. + This transform assumes PyTorch3D coordinate system conventions for + both the NDC space and the input points. + + This transform interfaces with the PyTorch3D renderer which assumes + input points to the renderer to be in NDC space. + """ + if self.in_ndc(): + return Transform3d(device=self.device, dtype=torch.float32) + else: + # For custom cameras which can be defined in screen space, + # users might might have to implement the screen to NDC transform based + # on the definition of the camera parameters. + # See PerspectiveCameras/OrthographicCameras for an example. + # We don't flip xy because we assume that world points are in + # PyTorch3D coordinates, and thus conversion from screen to ndc + # is a mere scaling from image to [-1, 1] scale. + image_size = kwargs.get("image_size", self.get_image_size()) + return get_screen_to_ndc_transform(self, with_xyflip=False, image_size=image_size) + + def transform_points_ndc(self, points, eps: Optional[float] = None, **kwargs) -> torch.Tensor: + """ + Transforms points from PyTorch3D world/camera space to NDC space. + Input points follow the PyTorch3D coordinate system conventions: +X left, +Y up. + Output points are in NDC space: +X left, +Y up, origin at image center. + + Args: + points: torch tensor of shape (..., 3). + eps: If eps!=None, the argument is used to clamp the + divisor in the homogeneous normalization of the points + transformed to the ndc space. Please see + `transforms.Transform3d.transform_points` for details. + + For `CamerasBase.transform_points`, setting `eps > 0` + stabilizes gradients since it leads to avoiding division + by excessively low numbers for points close to the + camera plane. + + Returns + new_points: transformed points with the same shape as the input. + """ + world_to_ndc_transform = self.get_full_projection_transform(**kwargs) + if not self.in_ndc(): + to_ndc_transform = self.get_ndc_camera_transform(**kwargs) + world_to_ndc_transform = world_to_ndc_transform.compose(to_ndc_transform) + + return world_to_ndc_transform.transform_points(points, eps=eps) + + def transform_points_screen( + self, points, eps: Optional[float] = None, with_xyflip: bool = True, **kwargs + ) -> torch.Tensor: + """ + Transforms points from PyTorch3D world/camera space to screen space. + Input points follow the PyTorch3D coordinate system conventions: +X left, +Y up. + Output points are in screen space: +X right, +Y down, origin at top left corner. + + Args: + points: torch tensor of shape (..., 3). + eps: If eps!=None, the argument is used to clamp the + divisor in the homogeneous normalization of the points + transformed to the ndc space. Please see + `transforms.Transform3d.transform_points` for details. + + For `CamerasBase.transform_points`, setting `eps > 0` + stabilizes gradients since it leads to avoiding division + by excessively low numbers for points close to the + camera plane. + with_xyflip: If True, flip x and y directions. In world/camera/ndc coords, + +x points to the left and +y up. If with_xyflip is true, in screen + coords +x points right, and +y down, following the usual RGB image + convention. Warning: do not set to False unless you know what you're + doing! + + Returns + new_points: transformed points with the same shape as the input. + """ + points_ndc = self.transform_points_ndc(points, eps=eps, **kwargs) + image_size = kwargs.get("image_size", self.get_image_size()) + return get_ndc_to_screen_transform(self, with_xyflip=with_xyflip, image_size=image_size).transform_points( + points_ndc, eps=eps + ) + + def clone(self): + """ + Returns a copy of `self`. + """ + cam_type = type(self) + other = cam_type(device=self.device) + return super().clone(other) + + def is_perspective(self): + raise NotImplementedError() + + def in_ndc(self): + """ + Specifies whether the camera is defined in NDC space + or in screen (image) space + """ + raise NotImplementedError() + + def get_znear(self): + return getattr(self, "znear", None) + + def get_image_size(self): + """ + Returns the image size, if provided, expected in the form of (height, width) + The image size is used for conversion of projected points to screen coordinates. + """ + return getattr(self, "image_size", None) + + def __getitem__(self, index: Union[int, List[int], torch.BoolTensor, torch.LongTensor]) -> "CamerasBase": + """ + Override for the __getitem__ method in TensorProperties which needs to be + refactored. + + Args: + index: an integer index, list/tensor of integer indices, or tensor of boolean + indicators used to filter all the fields in the cameras given by self._FIELDS. + Returns: + an instance of the current cameras class with only the values at the selected index. + """ + + kwargs = {} + + tensor_types = { + # pyre-fixme[16]: Module `cuda` has no attribute `BoolTensor`. + "bool": (torch.BoolTensor, torch.cuda.BoolTensor), + # pyre-fixme[16]: Module `cuda` has no attribute `LongTensor`. + "long": (torch.LongTensor, torch.cuda.LongTensor), + } + if not isinstance(index, (int, list, *tensor_types["bool"], *tensor_types["long"])) or ( + isinstance(index, list) and not all(isinstance(i, int) and not isinstance(i, bool) for i in index) + ): + msg = "Invalid index type, expected int, List[int] or Bool/LongTensor; got %r" + raise ValueError(msg % type(index)) + + if isinstance(index, int): + index = [index] + + if isinstance(index, tensor_types["bool"]): + # pyre-fixme[16]: Item `List` of `Union[List[int], BoolTensor, + # LongTensor]` has no attribute `ndim`. + # pyre-fixme[16]: Item `List` of `Union[List[int], BoolTensor, + # LongTensor]` has no attribute `shape`. + if index.ndim != 1 or index.shape[0] != len(self): + raise ValueError( + # pyre-fixme[16]: Item `List` of `Union[List[int], BoolTensor, + # LongTensor]` has no attribute `shape`. + f"Boolean index of shape {index.shape} does not match cameras" + ) + elif max(index) >= len(self): + raise IndexError(f"Index {max(index)} is out of bounds for select cameras") + + for field in self._FIELDS: + val = getattr(self, field, None) + if val is None: + continue + + # e.g. "in_ndc" is set as attribute "_in_ndc" on the class + # but provided as "in_ndc" on initialization + if field.startswith("_"): + field = field[1:] + + if isinstance(val, (str, bool)): + kwargs[field] = val + elif isinstance(val, torch.Tensor): + # In the init, all inputs will be converted to + # tensors before setting as attributes + kwargs[field] = val[index] + else: + raise ValueError(f"Field {field} type is not supported for indexing") + + kwargs["device"] = self.device + return self.__class__(**kwargs) + + +############################################################ +# Field of View Camera Classes # +############################################################ + + +def OpenGLPerspectiveCameras( + znear: _BatchFloatType = 1.0, + zfar: _BatchFloatType = 100.0, + aspect_ratio: _BatchFloatType = 1.0, + fov: _BatchFloatType = 60.0, + degrees: bool = True, + R: torch.Tensor = _R, + T: torch.Tensor = _T, + device: Device = "cpu", +) -> "FoVPerspectiveCameras": + """ + OpenGLPerspectiveCameras has been DEPRECATED. Use FoVPerspectiveCameras instead. + Preserving OpenGLPerspectiveCameras for backward compatibility. + """ + + warnings.warn( + """OpenGLPerspectiveCameras is deprecated, + Use FoVPerspectiveCameras instead. + OpenGLPerspectiveCameras will be removed in future releases.""", + PendingDeprecationWarning, + ) + + return FoVPerspectiveCameras( + znear=znear, zfar=zfar, aspect_ratio=aspect_ratio, fov=fov, degrees=degrees, R=R, T=T, device=device + ) + + +class FoVPerspectiveCameras(CamerasBase): + """ + A class which stores a batch of parameters to generate a batch of + projection matrices by specifying the field of view. + The definitions of the parameters follow the OpenGL perspective camera. + + The extrinsics of the camera (R and T matrices) can also be set in the + initializer or passed in to `get_full_projection_transform` to get + the full transformation from world -> ndc. + + The `transform_points` method calculates the full world -> ndc transform + and then applies it to the input points. + + The transforms can also be returned separately as Transform3d objects. + + * Setting the Aspect Ratio for Non Square Images * + + If the desired output image size is non square (i.e. a tuple of (H, W) where H != W) + the aspect ratio needs special consideration: There are two aspect ratios + to be aware of: + - the aspect ratio of each pixel + - the aspect ratio of the output image + The `aspect_ratio` setting in the FoVPerspectiveCameras sets the + pixel aspect ratio. When using this camera with the differentiable rasterizer + be aware that in the rasterizer we assume square pixels, but allow + variable image aspect ratio (i.e rectangle images). + + In most cases you will want to set the camera `aspect_ratio=1.0` + (i.e. square pixels) and only vary the output image dimensions in pixels + for rasterization. + """ + + # For __getitem__ + _FIELDS = ("K", "znear", "zfar", "aspect_ratio", "fov", "R", "T", "degrees") + + _SHARED_FIELDS = ("degrees",) + + def __init__( + self, + znear: _BatchFloatType = 1.0, + zfar: _BatchFloatType = 100.0, + aspect_ratio: _BatchFloatType = 1.0, + fov: _BatchFloatType = 60.0, + degrees: bool = True, + R: torch.Tensor = _R, + T: torch.Tensor = _T, + K: Optional[torch.Tensor] = None, + device: Device = "cpu", + ) -> None: + """ + + Args: + znear: near clipping plane of the view frustrum. + zfar: far clipping plane of the view frustrum. + aspect_ratio: aspect ratio of the image pixels. + 1.0 indicates square pixels. + fov: field of view angle of the camera. + degrees: bool, set to True if fov is specified in degrees. + R: Rotation matrix of shape (N, 3, 3) + T: Translation matrix of shape (N, 3) + K: (optional) A calibration matrix of shape (N, 4, 4) + If provided, don't need znear, zfar, fov, aspect_ratio, degrees + device: Device (as str or torch.device) + """ + # The initializer formats all inputs to torch tensors and broadcasts + # all the inputs to have the same batch dimension where necessary. + super().__init__(device=device, znear=znear, zfar=zfar, aspect_ratio=aspect_ratio, fov=fov, R=R, T=T, K=K) + + # No need to convert to tensor or broadcast. + self.degrees = degrees + + def compute_projection_matrix(self, znear, zfar, fov, aspect_ratio, degrees: bool) -> torch.Tensor: + """ + Compute the calibration matrix K of shape (N, 4, 4) + + Args: + znear: near clipping plane of the view frustrum. + zfar: far clipping plane of the view frustrum. + fov: field of view angle of the camera. + aspect_ratio: aspect ratio of the image pixels. + 1.0 indicates square pixels. + degrees: bool, set to True if fov is specified in degrees. + + Returns: + torch.FloatTensor of the calibration matrix with shape (N, 4, 4) + """ + K = torch.zeros((self._N, 4, 4), device=self.device, dtype=torch.float32) + ones = torch.ones((self._N), dtype=torch.float32, device=self.device) + if degrees: + fov = (np.pi / 180) * fov + + if not torch.is_tensor(fov): + fov = torch.tensor(fov, device=self.device) + tanHalfFov = torch.tan((fov / 2)) + max_y = tanHalfFov * znear + min_y = -max_y + max_x = max_y * aspect_ratio + min_x = -max_x + + # NOTE: In OpenGL the projection matrix changes the handedness of the + # coordinate frame. i.e the NDC space positive z direction is the + # camera space negative z direction. This is because the sign of the z + # in the projection matrix is set to -1.0. + # In pytorch3d we maintain a right handed coordinate system throughout + # so the so the z sign is 1.0. + z_sign = 1.0 + + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + K[:, 0, 0] = 2.0 * znear / (max_x - min_x) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + K[:, 1, 1] = 2.0 * znear / (max_y - min_y) + K[:, 0, 2] = (max_x + min_x) / (max_x - min_x) + K[:, 1, 2] = (max_y + min_y) / (max_y - min_y) + K[:, 3, 2] = z_sign * ones + + # NOTE: This maps the z coordinate from [0, 1] where z = 0 if the point + # is at the near clipping plane and z = 1 when the point is at the far + # clipping plane. + K[:, 2, 2] = z_sign * zfar / (zfar - znear) + K[:, 2, 3] = -(zfar * znear) / (zfar - znear) + + return K + + def get_projection_transform(self, **kwargs) -> Transform3d: + """ + Calculate the perspective projection matrix with a symmetric + viewing frustrum. Use column major order. + The viewing frustrum will be projected into ndc, s.t. + (max_x, max_y) -> (+1, +1) + (min_x, min_y) -> (-1, -1) + + Args: + **kwargs: parameters for the projection can be passed in as keyword + arguments to override the default values set in `__init__`. + + Return: + a Transform3d object which represents a batch of projection + matrices of shape (N, 4, 4) + + .. code-block:: python + + h1 = (max_y + min_y)/(max_y - min_y) + w1 = (max_x + min_x)/(max_x - min_x) + tanhalffov = tan((fov/2)) + s1 = 1/tanhalffov + s2 = 1/(tanhalffov * (aspect_ratio)) + + # To map z to the range [0, 1] use: + f1 = far / (far - near) + f2 = -(far * near) / (far - near) + + # Projection matrix + K = [ + [s1, 0, w1, 0], + [0, s2, h1, 0], + [0, 0, f1, f2], + [0, 0, 1, 0], + ] + """ + K = kwargs.get("K", self.K) + if K is not None: + if K.shape != (self._N, 4, 4): + msg = "Expected K to have shape of (%r, 4, 4)" + raise ValueError(msg % (self._N)) + else: + K = self.compute_projection_matrix( + kwargs.get("znear", self.znear), + kwargs.get("zfar", self.zfar), + kwargs.get("fov", self.fov), + kwargs.get("aspect_ratio", self.aspect_ratio), + kwargs.get("degrees", self.degrees), + ) + + # Transpose the projection matrix as PyTorch3D transforms use row vectors. + transform = Transform3d(matrix=K.transpose(1, 2).contiguous(), device=self.device) + return transform + + def unproject_points( + self, xy_depth: torch.Tensor, world_coordinates: bool = True, scaled_depth_input: bool = False, **kwargs + ) -> torch.Tensor: + """>! + FoV cameras further allow for passing depth in world units + (`scaled_depth_input=False`) or in the [0, 1]-normalized units + (`scaled_depth_input=True`) + + Args: + scaled_depth_input: If `True`, assumes the input depth is in + the [0, 1]-normalized units. If `False` the input depth is in + the world units. + """ + + # obtain the relevant transformation to ndc + if world_coordinates: + to_ndc_transform = self.get_full_projection_transform() + else: + to_ndc_transform = self.get_projection_transform() + + if scaled_depth_input: + # the input is scaled depth, so we don't have to do anything + xy_sdepth = xy_depth + else: + # parse out important values from the projection matrix + K_matrix = self.get_projection_transform(**kwargs.copy()).get_matrix() + # parse out f1, f2 from K_matrix + unsqueeze_shape = [1] * xy_depth.dim() + unsqueeze_shape[0] = K_matrix.shape[0] + f1 = K_matrix[:, 2, 2].reshape(unsqueeze_shape) + f2 = K_matrix[:, 3, 2].reshape(unsqueeze_shape) + # get the scaled depth + sdepth = (f1 * xy_depth[..., 2:3] + f2) / xy_depth[..., 2:3] + # concatenate xy + scaled depth + xy_sdepth = torch.cat((xy_depth[..., 0:2], sdepth), dim=-1) + + # unproject with inverse of the projection + unprojection_transform = to_ndc_transform.inverse() + return unprojection_transform.transform_points(xy_sdepth) + + def is_perspective(self): + return True + + def in_ndc(self): + return True + + +def OpenGLOrthographicCameras( + znear: _BatchFloatType = 1.0, + zfar: _BatchFloatType = 100.0, + top: _BatchFloatType = 1.0, + bottom: _BatchFloatType = -1.0, + left: _BatchFloatType = -1.0, + right: _BatchFloatType = 1.0, + scale_xyz=((1.0, 1.0, 1.0),), # (1, 3) + R: torch.Tensor = _R, + T: torch.Tensor = _T, + device: Device = "cpu", +) -> "FoVOrthographicCameras": + """ + OpenGLOrthographicCameras has been DEPRECATED. Use FoVOrthographicCameras instead. + Preserving OpenGLOrthographicCameras for backward compatibility. + """ + + warnings.warn( + """OpenGLOrthographicCameras is deprecated, + Use FoVOrthographicCameras instead. + OpenGLOrthographicCameras will be removed in future releases.""", + PendingDeprecationWarning, + ) + + return FoVOrthographicCameras( + znear=znear, + zfar=zfar, + max_y=top, + min_y=bottom, + max_x=right, + min_x=left, + scale_xyz=scale_xyz, + R=R, + T=T, + device=device, + ) + + +class FoVOrthographicCameras(CamerasBase): + """ + A class which stores a batch of parameters to generate a batch of + projection matrices by specifying the field of view. + The definitions of the parameters follow the OpenGL orthographic camera. + """ + + # For __getitem__ + _FIELDS = ("K", "znear", "zfar", "R", "T", "max_y", "min_y", "max_x", "min_x", "scale_xyz") + + def __init__( + self, + znear: _BatchFloatType = 1.0, + zfar: _BatchFloatType = 100.0, + max_y: _BatchFloatType = 1.0, + min_y: _BatchFloatType = -1.0, + max_x: _BatchFloatType = 1.0, + min_x: _BatchFloatType = -1.0, + scale_xyz=((1.0, 1.0, 1.0),), # (1, 3) + R: torch.Tensor = _R, + T: torch.Tensor = _T, + K: Optional[torch.Tensor] = None, + device: Device = "cpu", + ): + """ + + Args: + znear: near clipping plane of the view frustrum. + zfar: far clipping plane of the view frustrum. + max_y: maximum y coordinate of the frustrum. + min_y: minimum y coordinate of the frustrum. + max_x: maximum x coordinate of the frustrum. + min_x: minimum x coordinate of the frustrum + scale_xyz: scale factors for each axis of shape (N, 3). + R: Rotation matrix of shape (N, 3, 3). + T: Translation of shape (N, 3). + K: (optional) A calibration matrix of shape (N, 4, 4) + If provided, don't need znear, zfar, max_y, min_y, max_x, min_x, scale_xyz + device: torch.device or string. + + Only need to set min_x, max_x, min_y, max_y for viewing frustrums + which are non symmetric about the origin. + """ + # The initializer formats all inputs to torch tensors and broadcasts + # all the inputs to have the same batch dimension where necessary. + super().__init__( + device=device, + znear=znear, + zfar=zfar, + max_y=max_y, + min_y=min_y, + max_x=max_x, + min_x=min_x, + scale_xyz=scale_xyz, + R=R, + T=T, + K=K, + ) + + def compute_projection_matrix(self, znear, zfar, max_x, min_x, max_y, min_y, scale_xyz) -> torch.Tensor: + """ + Compute the calibration matrix K of shape (N, 4, 4) + + Args: + znear: near clipping plane of the view frustrum. + zfar: far clipping plane of the view frustrum. + max_x: maximum x coordinate of the frustrum. + min_x: minimum x coordinate of the frustrum + max_y: maximum y coordinate of the frustrum. + min_y: minimum y coordinate of the frustrum. + scale_xyz: scale factors for each axis of shape (N, 3). + """ + K = torch.zeros((self._N, 4, 4), dtype=torch.float32, device=self.device) + ones = torch.ones((self._N), dtype=torch.float32, device=self.device) + # NOTE: OpenGL flips handedness of coordinate system between camera + # space and NDC space so z sign is -ve. In PyTorch3D we maintain a + # right handed coordinate system throughout. + z_sign = +1.0 + + K[:, 0, 0] = (2.0 / (max_x - min_x)) * scale_xyz[:, 0] + K[:, 1, 1] = (2.0 / (max_y - min_y)) * scale_xyz[:, 1] + K[:, 0, 3] = -(max_x + min_x) / (max_x - min_x) + K[:, 1, 3] = -(max_y + min_y) / (max_y - min_y) + K[:, 3, 3] = ones + + # NOTE: This maps the z coordinate to the range [0, 1] and replaces the + # the OpenGL z normalization to [-1, 1] + K[:, 2, 2] = z_sign * (1.0 / (zfar - znear)) * scale_xyz[:, 2] + K[:, 2, 3] = -znear / (zfar - znear) + + return K + + def get_projection_transform(self, **kwargs) -> Transform3d: + """ + Calculate the orthographic projection matrix. + Use column major order. + + Args: + **kwargs: parameters for the projection can be passed in to + override the default values set in __init__. + Return: + a Transform3d object which represents a batch of projection + matrices of shape (N, 4, 4) + + .. code-block:: python + + scale_x = 2 / (max_x - min_x) + scale_y = 2 / (max_y - min_y) + scale_z = 2 / (far-near) + mid_x = (max_x + min_x) / (max_x - min_x) + mix_y = (max_y + min_y) / (max_y - min_y) + mid_z = (far + near) / (far - near) + + K = [ + [scale_x, 0, 0, -mid_x], + [0, scale_y, 0, -mix_y], + [0, 0, -scale_z, -mid_z], + [0, 0, 0, 1], + ] + """ + K = kwargs.get("K", self.K) + if K is not None: + if K.shape != (self._N, 4, 4): + msg = "Expected K to have shape of (%r, 4, 4)" + raise ValueError(msg % (self._N)) + else: + K = self.compute_projection_matrix( + kwargs.get("znear", self.znear), + kwargs.get("zfar", self.zfar), + kwargs.get("max_x", self.max_x), + kwargs.get("min_x", self.min_x), + kwargs.get("max_y", self.max_y), + kwargs.get("min_y", self.min_y), + kwargs.get("scale_xyz", self.scale_xyz), + ) + + transform = Transform3d(matrix=K.transpose(1, 2).contiguous(), device=self.device) + return transform + + def unproject_points( + self, xy_depth: torch.Tensor, world_coordinates: bool = True, scaled_depth_input: bool = False, **kwargs + ) -> torch.Tensor: + """>! + FoV cameras further allow for passing depth in world units + (`scaled_depth_input=False`) or in the [0, 1]-normalized units + (`scaled_depth_input=True`) + + Args: + scaled_depth_input: If `True`, assumes the input depth is in + the [0, 1]-normalized units. If `False` the input depth is in + the world units. + """ + + if world_coordinates: + to_ndc_transform = self.get_full_projection_transform(**kwargs.copy()) + else: + to_ndc_transform = self.get_projection_transform(**kwargs.copy()) + + if scaled_depth_input: + # the input depth is already scaled + xy_sdepth = xy_depth + else: + # we have to obtain the scaled depth first + K = self.get_projection_transform(**kwargs).get_matrix() + unsqueeze_shape = [1] * K.dim() + unsqueeze_shape[0] = K.shape[0] + mid_z = K[:, 3, 2].reshape(unsqueeze_shape) + scale_z = K[:, 2, 2].reshape(unsqueeze_shape) + scaled_depth = scale_z * xy_depth[..., 2:3] + mid_z + # cat xy and scaled depth + xy_sdepth = torch.cat((xy_depth[..., :2], scaled_depth), dim=-1) + # finally invert the transform + unprojection_transform = to_ndc_transform.inverse() + return unprojection_transform.transform_points(xy_sdepth) + + def is_perspective(self): + return False + + def in_ndc(self): + return True + + +############################################################ +# MultiView Camera Classes # +############################################################ +""" +Note that the MultiView Cameras accept parameters in NDC space. +""" + + +def SfMPerspectiveCameras( + focal_length: _FocalLengthType = 1.0, + principal_point=((0.0, 0.0),), + R: torch.Tensor = _R, + T: torch.Tensor = _T, + device: Device = "cpu", +) -> "PerspectiveCameras": + """ + SfMPerspectiveCameras has been DEPRECATED. Use PerspectiveCameras instead. + Preserving SfMPerspectiveCameras for backward compatibility. + """ + + warnings.warn( + """SfMPerspectiveCameras is deprecated, + Use PerspectiveCameras instead. + SfMPerspectiveCameras will be removed in future releases.""", + PendingDeprecationWarning, + ) + + return PerspectiveCameras(focal_length=focal_length, principal_point=principal_point, R=R, T=T, device=device) + + +class PerspectiveCameras(CamerasBase): + """ + A class which stores a batch of parameters to generate a batch of + transformation matrices using the multi-view geometry convention for + perspective camera. + + Parameters for this camera are specified in NDC if `in_ndc` is set to True. + If parameters are specified in screen space, `in_ndc` must be set to False. + """ + + # For __getitem__ + _FIELDS = ( + "K", + "R", + "T", + "focal_length", + "principal_point", + "_in_ndc", # arg is in_ndc but attribute set as _in_ndc + "image_size", + ) + + _SHARED_FIELDS = ("_in_ndc",) + + def __init__( + self, + focal_length: _FocalLengthType = 1.0, + principal_point=((0.0, 0.0),), + R: torch.Tensor = _R, + T: torch.Tensor = _T, + K: Optional[torch.Tensor] = None, + device: Device = "cpu", + in_ndc: bool = True, + image_size: Optional[Union[List, Tuple, torch.Tensor]] = None, + ) -> None: + """ + + Args: + focal_length: Focal length of the camera in world units. + A tensor of shape (N, 1) or (N, 2) for + square and non-square pixels respectively. + principal_point: xy coordinates of the center of + the principal point of the camera in pixels. + A tensor of shape (N, 2). + in_ndc: True if camera parameters are specified in NDC. + If camera parameters are in screen space, it must + be set to False. + R: Rotation matrix of shape (N, 3, 3) + T: Translation matrix of shape (N, 3) + K: (optional) A calibration matrix of shape (N, 4, 4) + If provided, don't need focal_length, principal_point + image_size: (height, width) of image size. + A tensor of shape (N, 2) or a list/tuple. Required for screen cameras. + device: torch.device or string + """ + # The initializer formats all inputs to torch tensors and broadcasts + # all the inputs to have the same batch dimension where necessary. + kwargs = {"image_size": image_size} if image_size is not None else {} + super().__init__( + device=device, + focal_length=focal_length, + principal_point=principal_point, + R=R, + T=T, + K=K, + _in_ndc=in_ndc, + **kwargs, # pyre-ignore + ) + if image_size is not None: + if (self.image_size < 1).any(): # pyre-ignore + raise ValueError("Image_size provided has invalid values") + else: + self.image_size = None + + # When focal length is provided as one value, expand to + # create (N, 2) shape tensor + if self.focal_length.ndim == 1: # (N,) + self.focal_length = self.focal_length[:, None] # (N, 1) + self.focal_length = self.focal_length.expand(-1, 2) # (N, 2) + + def get_projection_transform(self, **kwargs) -> Transform3d: + """ + Calculate the projection matrix using the + multi-view geometry convention. + + Args: + **kwargs: parameters for the projection can be passed in as keyword + arguments to override the default values set in __init__. + + Returns: + A `Transform3d` object with a batch of `N` projection transforms. + + .. code-block:: python + + fx = focal_length[:, 0] + fy = focal_length[:, 1] + px = principal_point[:, 0] + py = principal_point[:, 1] + + K = [ + [fx, 0, px, 0], + [0, fy, py, 0], + [0, 0, 0, 1], + [0, 0, 1, 0], + ] + """ + K = kwargs.get("K", self.K) + if K is not None: + if K.shape != (self._N, 4, 4): + msg = "Expected K to have shape of (%r, 4, 4)" + raise ValueError(msg % (self._N)) + else: + K = _get_sfm_calibration_matrix( + self._N, + self.device, + kwargs.get("focal_length", self.focal_length), + kwargs.get("principal_point", self.principal_point), + orthographic=False, + ) + + transform = Transform3d(matrix=K.transpose(1, 2).contiguous(), device=self.device) + return transform + + def unproject_points( + self, xy_depth: torch.Tensor, world_coordinates: bool = True, from_ndc: bool = False, **kwargs + ) -> torch.Tensor: + """ + Args: + from_ndc: If `False` (default), assumes xy part of input is in + NDC space if self.in_ndc(), otherwise in screen space. If + `True`, assumes xy is in NDC space even if the camera + is defined in screen space. + """ + if world_coordinates: + to_camera_transform = self.get_full_projection_transform(**kwargs) + else: + to_camera_transform = self.get_projection_transform(**kwargs) + if from_ndc: + to_camera_transform = to_camera_transform.compose(self.get_ndc_camera_transform()) + + unprojection_transform = to_camera_transform.inverse() + xy_inv_depth = torch.cat((xy_depth[..., :2], 1.0 / xy_depth[..., 2:3]), dim=-1) # type: ignore + return unprojection_transform.transform_points(xy_inv_depth) + + def get_principal_point(self, **kwargs) -> torch.Tensor: + """ + Return the camera's principal point + + Args: + **kwargs: parameters for the camera extrinsics can be passed in + as keyword arguments to override the default values + set in __init__. + """ + proj_mat = self.get_projection_transform(**kwargs).get_matrix() + return proj_mat[:, 2, :2] + + def get_ndc_camera_transform(self, **kwargs) -> Transform3d: + """ + Returns the transform from camera projection space (screen or NDC) to NDC space. + If the camera is defined already in NDC space, the transform is identity. + For cameras defined in screen space, we adjust the principal point computation + which is defined in the image space (commonly) and scale the points to NDC space. + + This transform leaves the depth unchanged. + + Important: This transforms assumes PyTorch3D conventions for the input points, + i.e. +X left, +Y up. + """ + if self.in_ndc(): + ndc_transform = Transform3d(device=self.device, dtype=torch.float32) + else: + # when cameras are defined in screen/image space, the principal point is + # provided in the (+X right, +Y down), aka image, coordinate system. + # Since input points are defined in the PyTorch3D system (+X left, +Y up), + # we need to adjust for the principal point transform. + pr_point_fix = torch.zeros((self._N, 4, 4), device=self.device, dtype=torch.float32) + pr_point_fix[:, 0, 0] = 1.0 + pr_point_fix[:, 1, 1] = 1.0 + pr_point_fix[:, 2, 2] = 1.0 + pr_point_fix[:, 3, 3] = 1.0 + pr_point_fix[:, :2, 3] = -2.0 * self.get_principal_point(**kwargs) + pr_point_fix_transform = Transform3d(matrix=pr_point_fix.transpose(1, 2).contiguous(), device=self.device) + image_size = kwargs.get("image_size", self.get_image_size()) + screen_to_ndc_transform = get_screen_to_ndc_transform(self, with_xyflip=False, image_size=image_size) + ndc_transform = pr_point_fix_transform.compose(screen_to_ndc_transform) + + return ndc_transform + + def is_perspective(self): + return True + + def in_ndc(self): + return self._in_ndc + + +def SfMOrthographicCameras( + focal_length: _FocalLengthType = 1.0, + principal_point=((0.0, 0.0),), + R: torch.Tensor = _R, + T: torch.Tensor = _T, + device: Device = "cpu", +) -> "OrthographicCameras": + """ + SfMOrthographicCameras has been DEPRECATED. Use OrthographicCameras instead. + Preserving SfMOrthographicCameras for backward compatibility. + """ + + warnings.warn( + """SfMOrthographicCameras is deprecated, + Use OrthographicCameras instead. + SfMOrthographicCameras will be removed in future releases.""", + PendingDeprecationWarning, + ) + + return OrthographicCameras(focal_length=focal_length, principal_point=principal_point, R=R, T=T, device=device) + + +class OrthographicCameras(CamerasBase): + """ + A class which stores a batch of parameters to generate a batch of + transformation matrices using the multi-view geometry convention for + orthographic camera. + + Parameters for this camera are specified in NDC if `in_ndc` is set to True. + If parameters are specified in screen space, `in_ndc` must be set to False. + """ + + # For __getitem__ + _FIELDS = ("K", "R", "T", "focal_length", "principal_point", "_in_ndc", "image_size") + + _SHARED_FIELDS = ("_in_ndc",) + + def __init__( + self, + focal_length: _FocalLengthType = 1.0, + principal_point=((0.0, 0.0),), + R: torch.Tensor = _R, + T: torch.Tensor = _T, + K: Optional[torch.Tensor] = None, + device: Device = "cpu", + in_ndc: bool = True, + image_size: Optional[Union[List, Tuple, torch.Tensor]] = None, + ) -> None: + """ + + Args: + focal_length: Focal length of the camera in world units. + A tensor of shape (N, 1) or (N, 2) for + square and non-square pixels respectively. + principal_point: xy coordinates of the center of + the principal point of the camera in pixels. + A tensor of shape (N, 2). + in_ndc: True if camera parameters are specified in NDC. + If False, then camera parameters are in screen space. + R: Rotation matrix of shape (N, 3, 3) + T: Translation matrix of shape (N, 3) + K: (optional) A calibration matrix of shape (N, 4, 4) + If provided, don't need focal_length, principal_point, image_size + image_size: (height, width) of image size. + A tensor of shape (N, 2) or list/tuple. Required for screen cameras. + device: torch.device or string + """ + # The initializer formats all inputs to torch tensors and broadcasts + # all the inputs to have the same batch dimension where necessary. + kwargs = {"image_size": image_size} if image_size is not None else {} + super().__init__( + device=device, + focal_length=focal_length, + principal_point=principal_point, + R=R, + T=T, + K=K, + _in_ndc=in_ndc, + **kwargs, # pyre-ignore + ) + if image_size is not None: + if (self.image_size < 1).any(): # pyre-ignore + raise ValueError("Image_size provided has invalid values") + else: + self.image_size = None + + # When focal length is provided as one value, expand to + # create (N, 2) shape tensor + if self.focal_length.ndim == 1: # (N,) + self.focal_length = self.focal_length[:, None] # (N, 1) + self.focal_length = self.focal_length.expand(-1, 2) # (N, 2) + + def get_projection_transform(self, **kwargs) -> Transform3d: + """ + Calculate the projection matrix using + the multi-view geometry convention. + + Args: + **kwargs: parameters for the projection can be passed in as keyword + arguments to override the default values set in __init__. + + Returns: + A `Transform3d` object with a batch of `N` projection transforms. + + .. code-block:: python + + fx = focal_length[:,0] + fy = focal_length[:,1] + px = principal_point[:,0] + py = principal_point[:,1] + + K = [ + [fx, 0, 0, px], + [0, fy, 0, py], + [0, 0, 1, 0], + [0, 0, 0, 1], + ] + """ + K = kwargs.get("K", self.K) + if K is not None: + if K.shape != (self._N, 4, 4): + msg = "Expected K to have shape of (%r, 4, 4)" + raise ValueError(msg % (self._N)) + else: + K = _get_sfm_calibration_matrix( + self._N, + self.device, + kwargs.get("focal_length", self.focal_length), + kwargs.get("principal_point", self.principal_point), + orthographic=True, + ) + + transform = Transform3d(matrix=K.transpose(1, 2).contiguous(), device=self.device) + return transform + + def unproject_points( + self, xy_depth: torch.Tensor, world_coordinates: bool = True, from_ndc: bool = False, **kwargs + ) -> torch.Tensor: + """ + Args: + from_ndc: If `False` (default), assumes xy part of input is in + NDC space if self.in_ndc(), otherwise in screen space. If + `True`, assumes xy is in NDC space even if the camera + is defined in screen space. + """ + if world_coordinates: + to_camera_transform = self.get_full_projection_transform(**kwargs) + else: + to_camera_transform = self.get_projection_transform(**kwargs) + if from_ndc: + to_camera_transform = to_camera_transform.compose(self.get_ndc_camera_transform()) + + unprojection_transform = to_camera_transform.inverse() + return unprojection_transform.transform_points(xy_depth) + + def get_principal_point(self, **kwargs) -> torch.Tensor: + """ + Return the camera's principal point + + Args: + **kwargs: parameters for the camera extrinsics can be passed in + as keyword arguments to override the default values + set in __init__. + """ + proj_mat = self.get_projection_transform(**kwargs).get_matrix() + return proj_mat[:, 3, :2] + + def get_ndc_camera_transform(self, **kwargs) -> Transform3d: + """ + Returns the transform from camera projection space (screen or NDC) to NDC space. + If the camera is defined already in NDC space, the transform is identity. + For cameras defined in screen space, we adjust the principal point computation + which is defined in the image space (commonly) and scale the points to NDC space. + + Important: This transforms assumes PyTorch3D conventions for the input points, + i.e. +X left, +Y up. + """ + if self.in_ndc(): + ndc_transform = Transform3d(device=self.device, dtype=torch.float32) + else: + # when cameras are defined in screen/image space, the principal point is + # provided in the (+X right, +Y down), aka image, coordinate system. + # Since input points are defined in the PyTorch3D system (+X left, +Y up), + # we need to adjust for the principal point transform. + pr_point_fix = torch.zeros((self._N, 4, 4), device=self.device, dtype=torch.float32) + pr_point_fix[:, 0, 0] = 1.0 + pr_point_fix[:, 1, 1] = 1.0 + pr_point_fix[:, 2, 2] = 1.0 + pr_point_fix[:, 3, 3] = 1.0 + pr_point_fix[:, :2, 3] = -2.0 * self.get_principal_point(**kwargs) + pr_point_fix_transform = Transform3d(matrix=pr_point_fix.transpose(1, 2).contiguous(), device=self.device) + image_size = kwargs.get("image_size", self.get_image_size()) + screen_to_ndc_transform = get_screen_to_ndc_transform(self, with_xyflip=False, image_size=image_size) + ndc_transform = pr_point_fix_transform.compose(screen_to_ndc_transform) + + return ndc_transform + + def is_perspective(self): + return False + + def in_ndc(self): + return self._in_ndc + + +################################################ +# Helper functions for cameras # +################################################ + + +def _get_sfm_calibration_matrix( + N: int, device: Device, focal_length, principal_point, orthographic: bool = False +) -> torch.Tensor: + """ + Returns a calibration matrix of a perspective/orthographic camera. + + Args: + N: Number of cameras. + focal_length: Focal length of the camera. + principal_point: xy coordinates of the center of + the principal point of the camera in pixels. + orthographic: Boolean specifying if the camera is orthographic or not + + The calibration matrix `K` is set up as follows: + + .. code-block:: python + + fx = focal_length[:,0] + fy = focal_length[:,1] + px = principal_point[:,0] + py = principal_point[:,1] + + for orthographic==True: + K = [ + [fx, 0, 0, px], + [0, fy, 0, py], + [0, 0, 1, 0], + [0, 0, 0, 1], + ] + else: + K = [ + [fx, 0, px, 0], + [0, fy, py, 0], + [0, 0, 0, 1], + [0, 0, 1, 0], + ] + + Returns: + A calibration matrix `K` of the SfM-conventioned camera + of shape (N, 4, 4). + """ + + if not torch.is_tensor(focal_length): + focal_length = torch.tensor(focal_length, device=device) + + if focal_length.ndim in (0, 1) or focal_length.shape[1] == 1: + fx = fy = focal_length + else: + fx, fy = focal_length.unbind(1) + + if not torch.is_tensor(principal_point): + principal_point = torch.tensor(principal_point, device=device) + + px, py = principal_point.unbind(1) + + K = fx.new_zeros(N, 4, 4) + K[:, 0, 0] = fx + K[:, 1, 1] = fy + if orthographic: + K[:, 0, 3] = px + K[:, 1, 3] = py + K[:, 2, 2] = 1.0 + K[:, 3, 3] = 1.0 + else: + K[:, 0, 2] = px + K[:, 1, 2] = py + K[:, 3, 2] = 1.0 + K[:, 2, 3] = 1.0 + + return K + + +################################################ +# Helper functions for world to view transforms +################################################ + + +def get_world_to_view_transform(R: torch.Tensor = _R, T: torch.Tensor = _T) -> Transform3d: + """ + This function returns a Transform3d representing the transformation + matrix to go from world space to view space by applying a rotation and + a translation. + + PyTorch3D uses the same convention as Hartley & Zisserman. + I.e., for camera extrinsic parameters R (rotation) and T (translation), + we map a 3D point `X_world` in world coordinates to + a point `X_cam` in camera coordinates with: + `X_cam = X_world R + T` + + Args: + R: (N, 3, 3) matrix representing the rotation. + T: (N, 3) matrix representing the translation. + + Returns: + a Transform3d object which represents the composed RT transformation. + + """ + # TODO: also support the case where RT is specified as one matrix + # of shape (N, 4, 4). + + if T.shape[0] != R.shape[0]: + msg = "Expected R, T to have the same batch dimension; got %r, %r" + raise ValueError(msg % (R.shape[0], T.shape[0])) + if T.dim() != 2 or T.shape[1:] != (3,): + msg = "Expected T to have shape (N, 3); got %r" + raise ValueError(msg % repr(T.shape)) + if R.dim() != 3 or R.shape[1:] != (3, 3): + msg = "Expected R to have shape (N, 3, 3); got %r" + raise ValueError(msg % repr(R.shape)) + + # Create a Transform3d object + T_ = Translate(T, device=T.device) + R_ = Rotate(R, device=R.device) + return R_.compose(T_) + + +def camera_position_from_spherical_angles( + distance: float, elevation: float, azimuth: float, degrees: bool = True, device: Device = "cpu" +) -> torch.Tensor: + """ + Calculate the location of the camera based on the distance away from + the target point, the elevation and azimuth angles. + + Args: + distance: distance of the camera from the object. + elevation, azimuth: angles. + The inputs distance, elevation and azimuth can be one of the following + - Python scalar + - Torch scalar + - Torch tensor of shape (N) or (1) + degrees: bool, whether the angles are specified in degrees or radians. + device: str or torch.device, device for new tensors to be placed on. + + The vectors are broadcast against each other so they all have shape (N, 1). + + Returns: + camera_position: (N, 3) xyz location of the camera. + """ + broadcasted_args = convert_to_tensors_and_broadcast(distance, elevation, azimuth, device=device) + dist, elev, azim = broadcasted_args + if degrees: + elev = math.pi / 180.0 * elev + azim = math.pi / 180.0 * azim + x = dist * torch.cos(elev) * torch.sin(azim) + y = dist * torch.sin(elev) + z = dist * torch.cos(elev) * torch.cos(azim) + camera_position = torch.stack([x, y, z], dim=1) + if camera_position.dim() == 0: + camera_position = camera_position.view(1, -1) # add batch dim. + return camera_position.view(-1, 3) + + +def look_at_rotation(camera_position, at=((0, 0, 0),), up=((0, 1, 0),), device: Device = "cpu") -> torch.Tensor: + """ + This function takes a vector 'camera_position' which specifies the location + of the camera in world coordinates and two vectors `at` and `up` which + indicate the position of the object and the up directions of the world + coordinate system respectively. The object is assumed to be centered at + the origin. + + The output is a rotation matrix representing the transformation + from world coordinates -> view coordinates. + + Args: + camera_position: position of the camera in world coordinates + at: position of the object in world coordinates + up: vector specifying the up direction in the world coordinate frame. + + The inputs camera_position, at and up can each be a + - 3 element tuple/list + - torch tensor of shape (1, 3) + - torch tensor of shape (N, 3) + + The vectors are broadcast against each other so they all have shape (N, 3). + + Returns: + R: (N, 3, 3) batched rotation matrices + """ + # Format input and broadcast + broadcasted_args = convert_to_tensors_and_broadcast(camera_position, at, up, device=device) + camera_position, at, up = broadcasted_args + for t, n in zip([camera_position, at, up], ["camera_position", "at", "up"]): + if t.shape[-1] != 3: + msg = "Expected arg %s to have shape (N, 3); got %r" + raise ValueError(msg % (n, t.shape)) + z_axis = F.normalize(at - camera_position, eps=1e-5) + x_axis = F.normalize(torch.cross(up, z_axis, dim=1), eps=1e-5) + y_axis = F.normalize(torch.cross(z_axis, x_axis, dim=1), eps=1e-5) + is_close = torch.isclose(x_axis, torch.tensor(0.0), atol=5e-3).all(dim=1, keepdim=True) + if is_close.any(): + replacement = F.normalize(torch.cross(y_axis, z_axis, dim=1), eps=1e-5) + x_axis = torch.where(is_close, replacement, x_axis) + R = torch.cat((x_axis[:, None, :], y_axis[:, None, :], z_axis[:, None, :]), dim=1) + return R.transpose(1, 2) + + +def look_at_view_transform( + dist: _BatchFloatType = 1.0, + elev: _BatchFloatType = 0.0, + azim: _BatchFloatType = 0.0, + degrees: bool = True, + eye: Optional[Union[Sequence, torch.Tensor]] = None, + at=((0, 0, 0),), # (1, 3) + up=((0, 1, 0),), # (1, 3) + device: Device = "cpu", +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This function returns a rotation and translation matrix + to apply the 'Look At' transformation from world -> view coordinates [0]. + + Args: + dist: distance of the camera from the object + elev: angle in degrees or radians. This is the angle between the + vector from the object to the camera, and the horizontal plane y = 0 (xz-plane). + azim: angle in degrees or radians. The vector from the object to + the camera is projected onto a horizontal plane y = 0. + azim is the angle between the projected vector and a + reference vector at (0, 0, 1) on the reference plane (the horizontal plane). + dist, elev and azim can be of shape (1), (N). + degrees: boolean flag to indicate if the elevation and azimuth + angles are specified in degrees or radians. + eye: the position of the camera(s) in world coordinates. If eye is not + None, it will override the camera position derived from dist, elev, azim. + up: the direction of the x axis in the world coordinate system. + at: the position of the object(s) in world coordinates. + eye, up and at can be of shape (1, 3) or (N, 3). + + Returns: + 2-element tuple containing + + - **R**: the rotation to apply to the points to align with the camera. + - **T**: the translation to apply to the points to align with the camera. + + References: + [0] https://www.scratchapixel.com + """ + + if eye is not None: + broadcasted_args = convert_to_tensors_and_broadcast(eye, at, up, device=device) + eye, at, up = broadcasted_args + C = eye + else: + broadcasted_args = convert_to_tensors_and_broadcast(dist, elev, azim, at, up, device=device) + dist, elev, azim, at, up = broadcasted_args + C = camera_position_from_spherical_angles(dist, elev, azim, degrees=degrees, device=device) + at + + R = look_at_rotation(C, at, up, device=device) + T = -torch.bmm(R.transpose(1, 2), C[:, :, None])[:, :, 0] + return R, T + + +def get_ndc_to_screen_transform( + cameras, with_xyflip: bool = False, image_size: Optional[Union[List, Tuple, torch.Tensor]] = None +) -> Transform3d: + """ + PyTorch3D NDC to screen conversion. + Conversion from PyTorch3D's NDC space (+X left, +Y up) to screen/image space + (+X right, +Y down, origin top left). + + Args: + cameras + with_xyflip: flips x- and y-axis if set to True. + Optional kwargs: + image_size: ((height, width),) specifying the height, width + of the image. If not provided, it reads it from cameras. + + We represent the NDC to screen conversion as a Transform3d + with projection matrix + + K = [ + [s, 0, 0, cx], + [0, s, 0, cy], + [0, 0, 1, 0], + [0, 0, 0, 1], + ] + + """ + # We require the image size, which is necessary for the transform + if image_size is None: + msg = "For NDC to screen conversion, image_size=(height, width) needs to be specified." + raise ValueError(msg) + + K = torch.zeros((cameras._N, 4, 4), device=cameras.device, dtype=torch.float32) + if not torch.is_tensor(image_size): + image_size = torch.tensor(image_size, device=cameras.device) + # pyre-fixme[16]: Item `List` of `Union[List[typing.Any], Tensor, Tuple[Any, + # ...]]` has no attribute `view`. + image_size = image_size.view(-1, 2) # of shape (1 or B)x2 + height, width = image_size.unbind(1) + + # For non square images, we scale the points such that smallest side + # has range [-1, 1] and the largest side has range [-u, u], with u > 1. + # This convention is consistent with the PyTorch3D renderer + scale = (image_size.min(dim=1).values - 0.0) / 2.0 + + K[:, 0, 0] = scale + K[:, 1, 1] = scale + K[:, 0, 3] = -1.0 * (width - 0.0) / 2.0 + K[:, 1, 3] = -1.0 * (height - 0.0) / 2.0 + K[:, 2, 2] = 1.0 + K[:, 3, 3] = 1.0 + + # Transpose the projection matrix as PyTorch3D transforms use row vectors. + transform = Transform3d(matrix=K.transpose(1, 2).contiguous(), device=cameras.device) + + if with_xyflip: + # flip x, y axis + xyflip = torch.eye(4, device=cameras.device, dtype=torch.float32) + xyflip[0, 0] = -1.0 + xyflip[1, 1] = -1.0 + xyflip = xyflip.view(1, 4, 4).expand(cameras._N, -1, -1) + xyflip_transform = Transform3d(matrix=xyflip.transpose(1, 2).contiguous(), device=cameras.device) + transform = transform.compose(xyflip_transform) + return transform + + +def get_screen_to_ndc_transform( + cameras, with_xyflip: bool = False, image_size: Optional[Union[List, Tuple, torch.Tensor]] = None +) -> Transform3d: + """ + Screen to PyTorch3D NDC conversion. + Conversion from screen/image space (+X right, +Y down, origin top left) + to PyTorch3D's NDC space (+X left, +Y up). + + Args: + cameras + with_xyflip: flips x- and y-axis if set to True. + Optional kwargs: + image_size: ((height, width),) specifying the height, width + of the image. If not provided, it reads it from cameras. + + We represent the screen to NDC conversion as a Transform3d + with projection matrix + + K = [ + [1/s, 0, 0, cx/s], + [ 0, 1/s, 0, cy/s], + [ 0, 0, 1, 0], + [ 0, 0, 0, 1], + ] + + """ + transform = get_ndc_to_screen_transform(cameras, with_xyflip=with_xyflip, image_size=image_size).inverse() + return transform + + +def try_get_projection_transform(cameras: CamerasBase, cameras_kwargs: Dict[str, Any]) -> Optional[Transform3d]: + """ + Try block to get projection transform from cameras and cameras_kwargs. + + Args: + cameras: cameras instance, can be linear cameras or nonliear cameras + cameras_kwargs: camera parameters to be passed to cameras + + Returns: + If the camera implemented projection_transform, return the + projection transform; Otherwise, return None + """ + + transform = None + try: + transform = cameras.get_projection_transform(**cameras_kwargs) + except NotImplementedError: + pass + return transform diff --git a/vggsfm/minipytorch3d/device_utils.py b/vggsfm/minipytorch3d/device_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b5aa08d9ea32465ccda996ceeaafb3dec5fbea1c --- /dev/null +++ b/vggsfm/minipytorch3d/device_utils.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Optional, Union + +import torch + + +Device = Union[str, torch.device] + + +def make_device(device: Device) -> torch.device: + """ + Makes an actual torch.device object from the device specified as + either a string or torch.device object. If the device is `cuda` without + a specific index, the index of the current device is assigned. + + Args: + device: Device (as str or torch.device) + + Returns: + A matching torch.device object + """ + device = torch.device(device) if isinstance(device, str) else device + if device.type == "cuda" and device.index is None: + # If cuda but with no index, then the current cuda device is indicated. + # In that case, we fix to that device + device = torch.device(f"cuda:{torch.cuda.current_device()}") + return device + + +def get_device(x, device: Optional[Device] = None) -> torch.device: + """ + Gets the device of the specified variable x if it is a tensor, or + falls back to a default CPU device otherwise. Allows overriding by + providing an explicit device. + + Args: + x: a torch.Tensor to get the device from or another type + device: Device (as str or torch.device) to fall back to + + Returns: + A matching torch.device object + """ + + # User overrides device + if device is not None: + return make_device(device) + + # Set device based on input tensor + if torch.is_tensor(x): + return x.device + + # Default device is cpu + return torch.device("cpu") diff --git a/vggsfm/minipytorch3d/harmonic_embedding.py b/vggsfm/minipytorch3d/harmonic_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..b02a9eeb05aafbac89ab993f0d8a59fa3ec43c9d --- /dev/null +++ b/vggsfm/minipytorch3d/harmonic_embedding.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Optional + +import torch + + +class HarmonicEmbedding(torch.nn.Module): + def __init__( + self, n_harmonic_functions: int = 6, omega_0: float = 1.0, logspace: bool = True, append_input: bool = True + ) -> None: + """ + The harmonic embedding layer supports the classical + Nerf positional encoding described in + `NeRF `_ + and the integrated position encoding in + `MIP-NeRF `_. + + During the inference you can provide the extra argument `diag_cov`. + + If `diag_cov is None`, it converts + rays parametrized with a `ray_bundle` to 3D points by + extending each ray according to the corresponding length. + Then it converts each feature + (i.e. vector along the last dimension) in `x` + into a series of harmonic features `embedding`, + where for each i in range(dim) the following are present + in embedding[...]:: + + [ + sin(f_1*x[..., i]), + sin(f_2*x[..., i]), + ... + sin(f_N * x[..., i]), + cos(f_1*x[..., i]), + cos(f_2*x[..., i]), + ... + cos(f_N * x[..., i]), + x[..., i], # only present if append_input is True. + ] + + where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar + denoting the i-th frequency of the harmonic embedding. + + + If `diag_cov is not None`, it approximates + conical frustums following a ray bundle as gaussians, + defined by x, the means of the gaussians and diag_cov, + the diagonal covariances. + Then it converts each gaussian + into a series of harmonic features `embedding`, + where for each i in range(dim) the following are present + in embedding[...]:: + + [ + sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]), + sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]), + ... + sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]), + cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]), + cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),, + ... + cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]), + x[..., i], # only present if append_input is True. + ] + + where N equals `n_harmonic_functions-1`, and f_i is a scalar + denoting the i-th frequency of the harmonic embedding. + + If `logspace==True`, the frequencies `[f_1, ..., f_N]` are + powers of 2: + `f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)` + + If `logspace==False`, frequencies are linearly spaced between + `1.0` and `2**(n_harmonic_functions-1)`: + `f_1, ..., f_N = torch.linspace( + 1.0, 2**(n_harmonic_functions-1), n_harmonic_functions + )` + + Note that `x` is also premultiplied by the base frequency `omega_0` + before evaluating the harmonic functions. + + Args: + n_harmonic_functions: int, number of harmonic + features + omega_0: float, base frequency + logspace: bool, Whether to space the frequencies in + logspace or linear space + append_input: bool, whether to concat the original + input to the harmonic embedding. If true the + output is of the form (embed.sin(), embed.cos(), x) + """ + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange(n_harmonic_functions, dtype=torch.float32) + else: + frequencies = torch.linspace( + 1.0, 2.0 ** (n_harmonic_functions - 1), n_harmonic_functions, dtype=torch.float32 + ) + + self.register_buffer("_frequencies", frequencies * omega_0, persistent=False) + self.register_buffer("_zero_half_pi", torch.tensor([0.0, 0.5 * torch.pi]), persistent=False) + self.append_input = append_input + + def forward(self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: + """ + Args: + x: tensor of shape [..., dim] + diag_cov: An optional tensor of shape `(..., dim)` + representing the diagonal covariance matrices of our Gaussians, joined with x + as means of the Gaussians. + + Returns: + embedding: a harmonic embedding of `x` of shape + [..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray] + """ + # [..., dim, n_harmonic_functions] + embed = x[..., None] * self._frequencies + # [..., 1, dim, n_harmonic_functions] + [2, 1, 1] => [..., 2, dim, n_harmonic_functions] + embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None] + # Use the trig identity cos(x) = sin(x + pi/2) + # and do one vectorized call to sin([x, x+pi/2]) instead of (sin(x), cos(x)). + embed = embed.sin() + if diag_cov is not None: + x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2) + exp_var = torch.exp(-0.5 * x_var) + # [..., 2, dim, n_harmonic_functions] + embed = embed * exp_var[..., None, :, :] + + embed = embed.reshape(*x.shape[:-1], -1) + + if self.append_input: + return torch.cat([embed, x], dim=-1) + return embed + + @staticmethod + def get_output_dim_static(input_dims: int, n_harmonic_functions: int, append_input: bool) -> int: + """ + Utility to help predict the shape of the output of `forward`. + + Args: + input_dims: length of the last dimension of the input tensor + n_harmonic_functions: number of embedding frequencies + append_input: whether or not to concat the original + input to the harmonic embedding + Returns: + int: the length of the last dimension of the output tensor + """ + return input_dims * (2 * n_harmonic_functions + int(append_input)) + + def get_output_dim(self, input_dims: int = 3) -> int: + """ + Same as above. The default for input_dims is 3 for 3D applications + which use harmonic embedding for positional encoding, + so the input might be xyz. + """ + return self.get_output_dim_static(input_dims, len(self._frequencies), self.append_input) diff --git a/vggsfm/minipytorch3d/renderer_utils.py b/vggsfm/minipytorch3d/renderer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6dde2ae19920873558ab83e09e2a4804e314eb19 --- /dev/null +++ b/vggsfm/minipytorch3d/renderer_utils.py @@ -0,0 +1,429 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + + +import copy +import inspect +import warnings +from typing import Any, List, Optional, Tuple, TypeVar, Union + +import numpy as np +import torch +import torch.nn as nn + +from .device_utils import Device, make_device + + +class TensorAccessor(nn.Module): + """ + A helper class to be used with the __getitem__ method. This can be used for + getting/setting the values for an attribute of a class at one particular + index. This is useful when the attributes of a class are batched tensors + and one element in the batch needs to be modified. + """ + + def __init__(self, class_object, index: Union[int, slice]) -> None: + """ + Args: + class_object: this should be an instance of a class which has + attributes which are tensors representing a batch of + values. + index: int/slice, an index indicating the position in the batch. + In __setattr__ and __getattr__ only the value of class + attributes at this index will be accessed. + """ + self.__dict__["class_object"] = class_object + self.__dict__["index"] = index + + def __setattr__(self, name: str, value: Any): + """ + Update the attribute given by `name` to the value given by `value` + at the index specified by `self.index`. + + Args: + name: str, name of the attribute. + value: value to set the attribute to. + """ + v = getattr(self.class_object, name) + if not torch.is_tensor(v): + msg = "Can only set values on attributes which are tensors; got %r" + raise AttributeError(msg % type(v)) + + # Convert the attribute to a tensor if it is not a tensor. + if not torch.is_tensor(value): + value = torch.tensor(value, device=v.device, dtype=v.dtype, requires_grad=v.requires_grad) + + # Check the shapes match the existing shape and the shape of the index. + if v.dim() > 1 and value.dim() > 1 and value.shape[1:] != v.shape[1:]: + msg = "Expected value to have shape %r; got %r" + raise ValueError(msg % (v.shape, value.shape)) + if v.dim() == 0 and isinstance(self.index, slice) and len(value) != len(self.index): + msg = "Expected value to have len %r; got %r" + raise ValueError(msg % (len(self.index), len(value))) + self.class_object.__dict__[name][self.index] = value + + def __getattr__(self, name: str): + """ + Return the value of the attribute given by "name" on self.class_object + at the index specified in self.index. + + Args: + name: string of the attribute name + """ + if hasattr(self.class_object, name): + return self.class_object.__dict__[name][self.index] + else: + msg = "Attribute %s not found on %r" + return AttributeError(msg % (name, self.class_object.__name__)) + + +BROADCAST_TYPES = (float, int, list, tuple, torch.Tensor, np.ndarray) + + +class TensorProperties(nn.Module): + """ + A mix-in class for storing tensors as properties with helper methods. + """ + + def __init__(self, dtype: torch.dtype = torch.float32, device: Device = "cpu", **kwargs) -> None: + """ + Args: + dtype: data type to set for the inputs + device: Device (as str or torch.device) + kwargs: any number of keyword arguments. Any arguments which are + of type (float/int/list/tuple/tensor/array) are broadcasted and + other keyword arguments are set as attributes. + """ + super().__init__() + self.device = make_device(device) + self._N = 0 + if kwargs is not None: + # broadcast all inputs which are float/int/list/tuple/tensor/array + # set as attributes anything else e.g. strings, bools + args_to_broadcast = {} + for k, v in kwargs.items(): + if v is None or isinstance(v, (str, bool)): + setattr(self, k, v) + elif isinstance(v, BROADCAST_TYPES): + args_to_broadcast[k] = v + else: + msg = "Arg %s with type %r is not broadcastable" + warnings.warn(msg % (k, type(v))) + + names = args_to_broadcast.keys() + # convert from type dict.values to tuple + values = tuple(v for v in args_to_broadcast.values()) + + if len(values) > 0: + broadcasted_values = convert_to_tensors_and_broadcast(*values, device=device) + + # Set broadcasted values as attributes on self. + for i, n in enumerate(names): + setattr(self, n, broadcasted_values[i]) + if self._N == 0: + self._N = broadcasted_values[i].shape[0] + + def __len__(self) -> int: + return self._N + + def isempty(self) -> bool: + return self._N == 0 + + def __getitem__(self, index: Union[int, slice]) -> TensorAccessor: + """ + + Args: + index: an int or slice used to index all the fields. + + Returns: + if `index` is an index int/slice return a TensorAccessor class + with getattribute/setattribute methods which return/update the value + at the index in the original class. + """ + if isinstance(index, (int, slice)): + return TensorAccessor(class_object=self, index=index) + + msg = "Expected index of type int or slice; got %r" + raise ValueError(msg % type(index)) + + # pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently. + def to(self, device: Device = "cpu") -> "TensorProperties": + """ + In place operation to move class properties which are tensors to a + specified device. If self has a property "device", update this as well. + """ + device_ = make_device(device) + for k in dir(self): + v = getattr(self, k) + if k == "device": + setattr(self, k, device_) + if torch.is_tensor(v) and v.device != device_: + setattr(self, k, v.to(device_)) + return self + + def cpu(self) -> "TensorProperties": + return self.to("cpu") + + # pyre-fixme[14]: `cuda` overrides method defined in `Module` inconsistently. + def cuda(self, device: Optional[int] = None) -> "TensorProperties": + return self.to(f"cuda:{device}" if device is not None else "cuda") + + def clone(self, other) -> "TensorProperties": + """ + Update the tensor properties of other with the cloned properties of self. + """ + for k in dir(self): + v = getattr(self, k) + if inspect.ismethod(v) or k.startswith("__") or type(v) is TypeVar: + continue + if torch.is_tensor(v): + v_clone = v.clone() + else: + v_clone = copy.deepcopy(v) + setattr(other, k, v_clone) + return other + + def gather_props(self, batch_idx) -> "TensorProperties": + """ + This is an in place operation to reformat all tensor class attributes + based on a set of given indices using torch.gather. This is useful when + attributes which are batched tensors e.g. shape (N, 3) need to be + multiplied with another tensor which has a different first dimension + e.g. packed vertices of shape (V, 3). + + Example + + .. code-block:: python + + self.specular_color = (N, 3) tensor of specular colors for each mesh + + A lighting calculation may use + + .. code-block:: python + + verts_packed = meshes.verts_packed() # (V, 3) + + To multiply these two tensors the batch dimension needs to be the same. + To achieve this we can do + + .. code-block:: python + + batch_idx = meshes.verts_packed_to_mesh_idx() # (V) + + This gives index of the mesh for each vertex in verts_packed. + + .. code-block:: python + + self.gather_props(batch_idx) + self.specular_color = (V, 3) tensor with the specular color for + each packed vertex. + + torch.gather requires the index tensor to have the same shape as the + input tensor so this method takes care of the reshaping of the index + tensor to use with class attributes with arbitrary dimensions. + + Args: + batch_idx: shape (B, ...) where `...` represents an arbitrary + number of dimensions + + Returns: + self with all properties reshaped. e.g. a property with shape (N, 3) + is transformed to shape (B, 3). + """ + # Iterate through the attributes of the class which are tensors. + for k in dir(self): + v = getattr(self, k) + if torch.is_tensor(v): + if v.shape[0] > 1: + # There are different values for each batch element + # so gather these using the batch_idx. + # First clone the input batch_idx tensor before + # modifying it. + _batch_idx = batch_idx.clone() + idx_dims = _batch_idx.shape + tensor_dims = v.shape + if len(idx_dims) > len(tensor_dims): + msg = "batch_idx cannot have more dimensions than %s. " + msg += "got shape %r and %s has shape %r" + raise ValueError(msg % (k, idx_dims, k, tensor_dims)) + if idx_dims != tensor_dims: + # To use torch.gather the index tensor (_batch_idx) has + # to have the same shape as the input tensor. + new_dims = len(tensor_dims) - len(idx_dims) + new_shape = idx_dims + (1,) * new_dims + expand_dims = (-1,) + tensor_dims[1:] + _batch_idx = _batch_idx.view(*new_shape) + _batch_idx = _batch_idx.expand(*expand_dims) + + v = v.gather(0, _batch_idx) + setattr(self, k, v) + return self + + +def format_tensor(input, dtype: torch.dtype = torch.float32, device: Device = "cpu") -> torch.Tensor: + """ + Helper function for converting a scalar value to a tensor. + + Args: + input: Python scalar, Python list/tuple, torch scalar, 1D torch tensor + dtype: data type for the input + device: Device (as str or torch.device) on which the tensor should be placed. + + Returns: + input_vec: torch tensor with optional added batch dimension. + """ + device_ = make_device(device) + if not torch.is_tensor(input): + input = torch.tensor(input, dtype=dtype, device=device_) + + if input.dim() == 0: + input = input.view(1) + + if input.device == device_: + return input + + input = input.to(device=device) + return input + + +def convert_to_tensors_and_broadcast(*args, dtype: torch.dtype = torch.float32, device: Device = "cpu"): + """ + Helper function to handle parsing an arbitrary number of inputs (*args) + which all need to have the same batch dimension. + The output is a list of tensors. + + Args: + *args: an arbitrary number of inputs + Each of the values in `args` can be one of the following + - Python scalar + - Torch scalar + - Torch tensor of shape (N, K_i) or (1, K_i) where K_i are + an arbitrary number of dimensions which can vary for each + value in args. In this case each input is broadcast to a + tensor of shape (N, K_i) + dtype: data type to use when creating new tensors. + device: torch device on which the tensors should be placed. + + Output: + args: A list of tensors of shape (N, K_i) + """ + # Convert all inputs to tensors with a batch dimension + args_1d = [format_tensor(c, dtype, device) for c in args] + + # Find broadcast size + sizes = [c.shape[0] for c in args_1d] + N = max(sizes) + + args_Nd = [] + for c in args_1d: + if c.shape[0] != 1 and c.shape[0] != N: + msg = "Got non-broadcastable sizes %r" % sizes + raise ValueError(msg) + + # Expand broadcast dim and keep non broadcast dims the same size + expand_sizes = (N,) + (-1,) * len(c.shape[1:]) + args_Nd.append(c.expand(*expand_sizes)) + + return args_Nd + + +def ndc_grid_sample( + input: torch.Tensor, grid_ndc: torch.Tensor, *, align_corners: bool = False, **grid_sample_kwargs +) -> torch.Tensor: + """ + Samples a tensor `input` of shape `(B, dim, H, W)` at 2D locations + specified by a tensor `grid_ndc` of shape `(B, ..., 2)` using + the `torch.nn.functional.grid_sample` function. + `grid_ndc` is specified in PyTorch3D NDC coordinate frame. + + Args: + input: The tensor of shape `(B, dim, H, W)` to be sampled. + grid_ndc: A tensor of shape `(B, ..., 2)` denoting the set of + 2D locations at which `input` is sampled. + See [1] for a detailed description of the NDC coordinates. + align_corners: Forwarded to the `torch.nn.functional.grid_sample` + call. See its docstring. + grid_sample_kwargs: Additional arguments forwarded to the + `torch.nn.functional.grid_sample` call. See the corresponding + docstring for a listing of the corresponding arguments. + + Returns: + sampled_input: A tensor of shape `(B, dim, ...)` containing the samples + of `input` at 2D locations `grid_ndc`. + + References: + [1] https://pytorch3d.org/docs/cameras + """ + + batch, *spatial_size, pt_dim = grid_ndc.shape + if batch != input.shape[0]: + raise ValueError("'input' and 'grid_ndc' have to have the same batch size.") + if input.ndim != 4: + raise ValueError("'input' has to be a 4-dimensional Tensor.") + if pt_dim != 2: + raise ValueError("The last dimension of 'grid_ndc' has to be == 2.") + + grid_ndc_flat = grid_ndc.reshape(batch, -1, 1, 2) + + # pyre-fixme[6]: For 2nd param expected `Tuple[int, int]` but got `Size`. + grid_flat = ndc_to_grid_sample_coords(grid_ndc_flat, input.shape[2:]) + + sampled_input_flat = torch.nn.functional.grid_sample( + input, grid_flat, align_corners=align_corners, **grid_sample_kwargs + ) + + sampled_input = sampled_input_flat.reshape([batch, input.shape[1], *spatial_size]) + + return sampled_input + + +def ndc_to_grid_sample_coords(xy_ndc: torch.Tensor, image_size_hw: Tuple[int, int]) -> torch.Tensor: + """ + Convert from the PyTorch3D's NDC coordinates to + `torch.nn.functional.grid_sampler`'s coordinates. + + Args: + xy_ndc: Tensor of shape `(..., 2)` containing 2D points in the + PyTorch3D's NDC coordinates. + image_size_hw: A tuple `(image_height, image_width)` denoting the + height and width of the image tensor to sample. + Returns: + xy_grid_sample: Tensor of shape `(..., 2)` containing 2D points in the + `torch.nn.functional.grid_sample` coordinates. + """ + if len(image_size_hw) != 2 or any(s <= 0 for s in image_size_hw): + raise ValueError("'image_size_hw' has to be a 2-tuple of positive integers") + aspect = min(image_size_hw) / max(image_size_hw) + xy_grid_sample = -xy_ndc # first negate the coords + if image_size_hw[0] >= image_size_hw[1]: + xy_grid_sample[..., 1] *= aspect + else: + xy_grid_sample[..., 0] *= aspect + return xy_grid_sample + + +def parse_image_size(image_size: Union[List[int], Tuple[int, int], int]) -> Tuple[int, int]: + """ + Args: + image_size: A single int (for square images) or a tuple/list of two ints. + + Returns: + A tuple of two ints. + + Throws: + ValueError if got more than two ints, any negative numbers or non-ints. + """ + if not isinstance(image_size, (tuple, list)): + return (image_size, image_size) + if len(image_size) != 2: + raise ValueError("Image size can only be a tuple/list of (H, W)") + if not all(i > 0 for i in image_size): + raise ValueError("Image sizes must be greater than 0; got %d, %d" % image_size) + if not all(isinstance(i, int) for i in image_size): + raise ValueError("Image sizes must be integers; got %f, %f" % image_size) + return tuple(image_size) diff --git a/vggsfm/minipytorch3d/rotation_conversions.py b/vggsfm/minipytorch3d/rotation_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..731116a4a96d270b9470e4bcfe5642418e2c4463 --- /dev/null +++ b/vggsfm/minipytorch3d/rotation_conversions.py @@ -0,0 +1,560 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Optional + +import torch +import torch.nn.functional as F + +from .device_utils import Device, get_device, make_device + + +""" +The transformation matrices returned from the functions in this file assume +the points on which the transformation will be applied are column vectors. +i.e. the R matrix is structured as + + R = [ + [Rxx, Rxy, Rxz], + [Ryx, Ryy, Ryz], + [Rzx, Rzy, Rzz], + ] # (3, 3) + +This matrix can be applied to column vectors by post multiplication +by the points e.g. + + points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point + transformed_points = R * points + +To apply the same matrix to points which are row vectors, the R matrix +can be transposed and pre multiplied by the points: + +e.g. + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + transformed_points = points * R.transpose(1, 0) +""" + + +def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Return a tensor where each element has the absolute value taken from the, + corresponding element of a, with sign taken from the corresponding + element of b. This is like the standard copysign floating-point operation, + but is not careful about negative 0 and NaN. + + Args: + a: source tensor. + b: tensor whose signs will be used, of the same shape as a. + + Returns: + Tensor of the same shape as a with the signs of b. + """ + signs_differ = (a < 0) != (b < 0) + return torch.where(signs_differ, -a, a) + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) + + q_abs = _sqrt_positive_part( + torch.stack( + [1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1 + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,)) + return standardize_quaternion(out) + + +def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [_axis_angle_rotation(c, e) for c, e in zip(convention, torch.unbind(euler_angles, -1))] + # return functools.reduce(torch.matmul, matrices) + return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) + + +def _angle_from_tan(axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool) -> torch.Tensor: + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str) -> int: + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + raise ValueError("letter must be either X, Y or Z.") + + +def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin(matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan(convention[0], convention[1], matrix[..., i2], False, tait_bryan), + central_angle, + _angle_from_tan(convention[2], convention[1], matrix[..., i0, :], True, tait_bryan), + ) + return torch.stack(o, -1) + + +def random_quaternions(n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None) -> torch.Tensor: + """ + Generate random quaternions representing rotations, + i.e. versors with nonnegative real part. + + Args: + n: Number of quaternions in a batch to return. + dtype: Type to return. + device: Desired device of returned tensor. Default: + uses the current device for the default tensor type. + + Returns: + Quaternions as tensor of shape (N, 4). + """ + if isinstance(device, str): + device = torch.device(device) + o = torch.randn((n, 4), dtype=dtype, device=device) + s = (o * o).sum(1) + o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] + return o + + +def random_rotations(n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None) -> torch.Tensor: + """ + Generate random rotations as 3x3 rotation matrices. + + Args: + n: Number of rotation matrices in a batch to return. + dtype: Type to return. + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type. + + Returns: + Rotation matrices as tensor of shape (n, 3, 3). + """ + quaternions = random_quaternions(n, dtype=dtype, device=device) + return quaternion_to_matrix(quaternions) + + +def random_rotation(dtype: Optional[torch.dtype] = None, device: Optional[Device] = None) -> torch.Tensor: + """ + Generate a single random 3x3 rotation matrix. + + Args: + dtype: Type to return + device: Device of returned tensor. Default: if None, + uses the current device for the default tensor type + + Returns: + Rotation matrix as tensor of shape (3, 3). + """ + return random_rotations(1, dtype, device)[0] + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Multiply two quaternions. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions shape (..., 4). + """ + aw, ax, ay, az = torch.unbind(a, -1) + bw, bx, by, bz = torch.unbind(b, -1) + ow = aw * bw - ax * bx - ay * by - az * bz + ox = aw * bx + ax * bw + ay * bz - az * by + oy = aw * by - ax * bz + ay * bw + az * bx + oz = aw * bz + ax * by - ay * bx + az * bw + return torch.stack((ow, ox, oy, oz), -1) + + +def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Multiply two quaternions representing rotations, returning the quaternion + representing their composition, i.e. the versor with nonnegative real part. + Usual torch rules for broadcasting apply. + + Args: + a: Quaternions as tensor of shape (..., 4), real part first. + b: Quaternions as tensor of shape (..., 4), real part first. + + Returns: + The product of a and b, a tensor of quaternions of shape (..., 4). + """ + ab = quaternion_raw_multiply(a, b) + return standardize_quaternion(ab) + + +def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor: + """ + Given a quaternion representing rotation, get the quaternion representing + its inverse. + + Args: + quaternion: Quaternions as tensor of shape (..., 4), with real part + first, which must be versors (unit quaternions). + + Returns: + The inverse, a tensor of quaternions of shape (..., 4). + """ + + scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device) + return quaternion * scaling + + +def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Tensor: + """ + Apply the rotation given by a quaternion to a 3D point. + Usual torch rules for broadcasting apply. + + Args: + quaternion: Tensor of quaternions, real part first, of shape (..., 4). + point: Tensor of 3D points of shape (..., 3). + + Returns: + Tensor of rotated points of shape (..., 3). + """ + if point.size(-1) != 3: + raise ValueError(f"Points are not in 3D, {point.shape}.") + real_parts = point.new_zeros(point.shape[:-1] + (1,)) + point_as_quaternion = torch.cat((real_parts, point), -1) + out = quaternion_raw_multiply( + quaternion_raw_multiply(quaternion, point_as_quaternion), quaternion_invert(quaternion) + ) + return out[..., 1:] + + +def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = angles * 0.5 + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = torch.sin(half_angles[~small_angles]) / angles[~small_angles] + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + quaternions = torch.cat([torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1) + return quaternions + + +def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = torch.sin(half_angles[~small_angles]) / angles[~small_angles] + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalization per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + batch_dim = matrix.size()[:-2] + return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) diff --git a/vggsfm/minipytorch3d/transform3d.py b/vggsfm/minipytorch3d/transform3d.py new file mode 100644 index 0000000000000000000000000000000000000000..c5443b89d1f98a9b64edb4ba9d35b4512e001726 --- /dev/null +++ b/vggsfm/minipytorch3d/transform3d.py @@ -0,0 +1,793 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import math +import os +import warnings +from typing import List, Optional, Union + +import torch + +from .device_utils import Device, get_device, make_device + +# from ..common.workaround import _safe_det_3x3 +from .rotation_conversions import _axis_angle_rotation + + +def _safe_det_3x3(t: torch.Tensor): + """ + Fast determinant calculation for a batch of 3x3 matrices. + + Note, result of this function might not be the same as `torch.det()`. + The differences might be in the last significant digit. + + Args: + t: Tensor of shape (N, 3, 3). + + Returns: + Tensor of shape (N) with determinants. + """ + + det = ( + t[..., 0, 0] * (t[..., 1, 1] * t[..., 2, 2] - t[..., 1, 2] * t[..., 2, 1]) + - t[..., 0, 1] * (t[..., 1, 0] * t[..., 2, 2] - t[..., 2, 0] * t[..., 1, 2]) + + t[..., 0, 2] * (t[..., 1, 0] * t[..., 2, 1] - t[..., 2, 0] * t[..., 1, 1]) + ) + + return det + + +class Transform3d: + """ + A Transform3d object encapsulates a batch of N 3D transformations, and knows + how to transform points and normal vectors. Suppose that t is a Transform3d; + then we can do the following: + + .. code-block:: python + + N = len(t) + points = torch.randn(N, P, 3) + normals = torch.randn(N, P, 3) + points_transformed = t.transform_points(points) # => (N, P, 3) + normals_transformed = t.transform_normals(normals) # => (N, P, 3) + + + BROADCASTING + Transform3d objects supports broadcasting. Suppose that t1 and tN are + Transform3d objects with len(t1) == 1 and len(tN) == N respectively. Then we + can broadcast transforms like this: + + .. code-block:: python + + t1.transform_points(torch.randn(P, 3)) # => (P, 3) + t1.transform_points(torch.randn(1, P, 3)) # => (1, P, 3) + t1.transform_points(torch.randn(M, P, 3)) # => (M, P, 3) + tN.transform_points(torch.randn(P, 3)) # => (N, P, 3) + tN.transform_points(torch.randn(1, P, 3)) # => (N, P, 3) + + + COMBINING TRANSFORMS + Transform3d objects can be combined in two ways: composing and stacking. + Composing is function composition. Given Transform3d objects t1, t2, t3, + the following all compute the same thing: + + .. code-block:: python + + y1 = t3.transform_points(t2.transform_points(t1.transform_points(x))) + y2 = t1.compose(t2).compose(t3).transform_points(x) + y3 = t1.compose(t2, t3).transform_points(x) + + + Composing transforms should broadcast. + + .. code-block:: python + + if len(t1) == 1 and len(t2) == N, then len(t1.compose(t2)) == N. + + We can also stack a sequence of Transform3d objects, which represents + composition along the batch dimension; then the following should compute the + same thing. + + .. code-block:: python + + N, M = len(tN), len(tM) + xN = torch.randn(N, P, 3) + xM = torch.randn(M, P, 3) + y1 = torch.cat([tN.transform_points(xN), tM.transform_points(xM)], dim=0) + y2 = tN.stack(tM).transform_points(torch.cat([xN, xM], dim=0)) + + BUILDING TRANSFORMS + We provide convenience methods for easily building Transform3d objects + as compositions of basic transforms. + + .. code-block:: python + + # Scale by 0.5, then translate by (1, 2, 3) + t1 = Transform3d().scale(0.5).translate(1, 2, 3) + + # Scale each axis by a different amount, then translate, then scale + t2 = Transform3d().scale(1, 3, 3).translate(2, 3, 1).scale(2.0) + + t3 = t1.compose(t2) + tN = t1.stack(t3, t3) + + + BACKPROP THROUGH TRANSFORMS + When building transforms, we can also parameterize them by Torch tensors; + in this case we can backprop through the construction and application of + Transform objects, so they could be learned via gradient descent or + predicted by a neural network. + + .. code-block:: python + + s1_params = torch.randn(N, requires_grad=True) + t_params = torch.randn(N, 3, requires_grad=True) + s2_params = torch.randn(N, 3, requires_grad=True) + + t = Transform3d().scale(s1_params).translate(t_params).scale(s2_params) + x = torch.randn(N, 3) + y = t.transform_points(x) + loss = compute_loss(y) + loss.backward() + + with torch.no_grad(): + s1_params -= lr * s1_params.grad + t_params -= lr * t_params.grad + s2_params -= lr * s2_params.grad + + CONVENTIONS + We adopt a right-hand coordinate system, meaning that rotation about an axis + with a positive angle results in a counter clockwise rotation. + + This class assumes that transformations are applied on inputs which + are row vectors. The internal representation of the Nx4x4 transformation + matrix is of the form: + + .. code-block:: python + + M = [ + [Rxx, Ryx, Rzx, 0], + [Rxy, Ryy, Rzy, 0], + [Rxz, Ryz, Rzz, 0], + [Tx, Ty, Tz, 1], + ] + + To apply the transformation to points, which are row vectors, the latter are + converted to homogeneous (4D) coordinates and right-multiplied by the M matrix: + + .. code-block:: python + + points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point + [transformed_points, 1] ∝ [points, 1] @ M + + """ + + def __init__( + self, dtype: torch.dtype = torch.float32, device: Device = "cpu", matrix: Optional[torch.Tensor] = None + ) -> None: + """ + Args: + dtype: The data type of the transformation matrix. + to be used if `matrix = None`. + device: The device for storing the implemented transformation. + If `matrix != None`, uses the device of input `matrix`. + matrix: A tensor of shape (4, 4) or of shape (minibatch, 4, 4) + representing the 4x4 3D transformation matrix. + If `None`, initializes with identity using + the specified `device` and `dtype`. + """ + + if matrix is None: + self._matrix = torch.eye(4, dtype=dtype, device=device).view(1, 4, 4) + else: + if matrix.ndim not in (2, 3): + raise ValueError('"matrix" has to be a 2- or a 3-dimensional tensor.') + if matrix.shape[-2] != 4 or matrix.shape[-1] != 4: + raise ValueError('"matrix" has to be a tensor of shape (minibatch, 4, 4) or (4, 4).') + # set dtype and device from matrix + dtype = matrix.dtype + device = matrix.device + self._matrix = matrix.view(-1, 4, 4) + + self._transforms = [] # store transforms to compose + self._lu = None + self.device = make_device(device) + self.dtype = dtype + + def __len__(self) -> int: + return self.get_matrix().shape[0] + + def __getitem__(self, index: Union[int, List[int], slice, torch.BoolTensor, torch.LongTensor]) -> "Transform3d": + """ + Args: + index: Specifying the index of the transform to retrieve. + Can be an int, slice, list of ints, boolean, long tensor. + Supports negative indices. + + Returns: + Transform3d object with selected transforms. The tensors are not cloned. + """ + if isinstance(index, int): + index = [index] + return self.__class__(matrix=self.get_matrix()[index]) + + def compose(self, *others: "Transform3d") -> "Transform3d": + """ + Return a new Transform3d representing the composition of self with the + given other transforms, which will be stored as an internal list. + + Args: + *others: Any number of Transform3d objects + + Returns: + A new Transform3d with the stored transforms + """ + out = Transform3d(dtype=self.dtype, device=self.device) + out._matrix = self._matrix.clone() + for other in others: + if not isinstance(other, Transform3d): + msg = "Only possible to compose Transform3d objects; got %s" + raise ValueError(msg % type(other)) + out._transforms = self._transforms + list(others) + return out + + def get_matrix(self) -> torch.Tensor: + """ + Returns a 4×4 matrix corresponding to each transform in the batch. + + If the transform was composed from others, the matrix for the composite + transform will be returned. + For example, if self.transforms contains transforms t1, t2, and t3, and + given a set of points x, the following should be true: + + .. code-block:: python + + y1 = t1.compose(t2, t3).transform(x) + y2 = t3.transform(t2.transform(t1.transform(x))) + y1.get_matrix() == y2.get_matrix() + + Where necessary, those transforms are broadcast against each other. + + Returns: + A (N, 4, 4) batch of transformation matrices representing + the stored transforms. See the class documentation for the conventions. + """ + composed_matrix = self._matrix.clone() + if len(self._transforms) > 0: + for other in self._transforms: + other_matrix = other.get_matrix() + composed_matrix = _broadcast_bmm(composed_matrix, other_matrix) + return composed_matrix + + def _get_matrix_inverse(self) -> torch.Tensor: + """ + Return the inverse of self._matrix. + """ + return torch.inverse(self._matrix) + + def inverse(self, invert_composed: bool = False) -> "Transform3d": + """ + Returns a new Transform3d object that represents an inverse of the + current transformation. + + Args: + invert_composed: + - True: First compose the list of stored transformations + and then apply inverse to the result. This is + potentially slower for classes of transformations + with inverses that can be computed efficiently + (e.g. rotations and translations). + - False: Invert the individual stored transformations + independently without composing them. + + Returns: + A new Transform3d object containing the inverse of the original + transformation. + """ + + tinv = Transform3d(dtype=self.dtype, device=self.device) + + if invert_composed: + # first compose then invert + tinv._matrix = torch.inverse(self.get_matrix()) + else: + # self._get_matrix_inverse() implements efficient inverse + # of self._matrix + i_matrix = self._get_matrix_inverse() + + # 2 cases: + if len(self._transforms) > 0: + # a) Either we have a non-empty list of transforms: + # Here we take self._matrix and append its inverse at the + # end of the reverted _transforms list. After composing + # the transformations with get_matrix(), this correctly + # right-multiplies by the inverse of self._matrix + # at the end of the composition. + tinv._transforms = [t.inverse() for t in reversed(self._transforms)] + last = Transform3d(dtype=self.dtype, device=self.device) + last._matrix = i_matrix + tinv._transforms.append(last) + else: + # b) Or there are no stored transformations + # we just set inverted matrix + tinv._matrix = i_matrix + + return tinv + + def stack(self, *others: "Transform3d") -> "Transform3d": + """ + Return a new batched Transform3d representing the batch elements from + self and all the given other transforms all batched together. + + Args: + *others: Any number of Transform3d objects + + Returns: + A new Transform3d. + """ + transforms = [self] + list(others) + matrix = torch.cat([t.get_matrix() for t in transforms], dim=0) + out = Transform3d(dtype=self.dtype, device=self.device) + out._matrix = matrix + return out + + def transform_points(self, points, eps: Optional[float] = None) -> torch.Tensor: + """ + Use this transform to transform a set of 3D points. Assumes row major + ordering of the input points. + + Args: + points: Tensor of shape (P, 3) or (N, P, 3) + eps: If eps!=None, the argument is used to clamp the + last coordinate before performing the final division. + The clamping corresponds to: + last_coord := (last_coord.sign() + (last_coord==0)) * + torch.clamp(last_coord.abs(), eps), + i.e. the last coordinates that are exactly 0 will + be clamped to +eps. + + Returns: + points_out: points of shape (N, P, 3) or (P, 3) depending + on the dimensions of the transform + """ + points_batch = points.clone() + if points_batch.dim() == 2: + points_batch = points_batch[None] # (P, 3) -> (1, P, 3) + if points_batch.dim() != 3: + msg = "Expected points to have dim = 2 or dim = 3: got shape %r" + raise ValueError(msg % repr(points.shape)) + + N, P, _3 = points_batch.shape + ones = torch.ones(N, P, 1, dtype=points.dtype, device=points.device) + points_batch = torch.cat([points_batch, ones], dim=2) + + composed_matrix = self.get_matrix() + points_out = _broadcast_bmm(points_batch, composed_matrix) + denom = points_out[..., 3:] # denominator + if eps is not None: + denom_sign = denom.sign() + (denom == 0.0).type_as(denom) + denom = denom_sign * torch.clamp(denom.abs(), eps) + points_out = points_out[..., :3] / denom + + # When transform is (1, 4, 4) and points is (P, 3) return + # points_out of shape (P, 3) + if points_out.shape[0] == 1 and points.dim() == 2: + points_out = points_out.reshape(points.shape) + + return points_out + + def transform_normals(self, normals) -> torch.Tensor: + """ + Use this transform to transform a set of normal vectors. + + Args: + normals: Tensor of shape (P, 3) or (N, P, 3) + + Returns: + normals_out: Tensor of shape (P, 3) or (N, P, 3) depending + on the dimensions of the transform + """ + if normals.dim() not in [2, 3]: + msg = "Expected normals to have dim = 2 or dim = 3: got shape %r" + raise ValueError(msg % (normals.shape,)) + composed_matrix = self.get_matrix() + + # TODO: inverse is bad! Solve a linear system instead + mat = composed_matrix[:, :3, :3] + normals_out = _broadcast_bmm(normals, mat.transpose(1, 2).inverse()) + + # This doesn't pass unit tests. TODO investigate further + # if self._lu is None: + # self._lu = self._matrix[:, :3, :3].transpose(1, 2).lu() + # normals_out = normals.lu_solve(*self._lu) + + # When transform is (1, 4, 4) and normals is (P, 3) return + # normals_out of shape (P, 3) + if normals_out.shape[0] == 1 and normals.dim() == 2: + normals_out = normals_out.reshape(normals.shape) + + return normals_out + + def translate(self, *args, **kwargs) -> "Transform3d": + return self.compose(Translate(*args, device=self.device, dtype=self.dtype, **kwargs)) + + def scale(self, *args, **kwargs) -> "Transform3d": + return self.compose(Scale(*args, device=self.device, dtype=self.dtype, **kwargs)) + + def rotate(self, *args, **kwargs) -> "Transform3d": + return self.compose(Rotate(*args, device=self.device, dtype=self.dtype, **kwargs)) + + def rotate_axis_angle(self, *args, **kwargs) -> "Transform3d": + return self.compose(RotateAxisAngle(*args, device=self.device, dtype=self.dtype, **kwargs)) + + def clone(self) -> "Transform3d": + """ + Deep copy of Transforms object. All internal tensors are cloned + individually. + + Returns: + new Transforms object. + """ + other = Transform3d(dtype=self.dtype, device=self.device) + if self._lu is not None: + other._lu = [elem.clone() for elem in self._lu] + other._matrix = self._matrix.clone() + other._transforms = [t.clone() for t in self._transforms] + return other + + def to(self, device: Device, copy: bool = False, dtype: Optional[torch.dtype] = None) -> "Transform3d": + """ + Match functionality of torch.Tensor.to() + If copy = True or the self Tensor is on a different device, the + returned tensor is a copy of self with the desired torch.device. + If copy = False and the self Tensor already has the correct torch.device, + then self is returned. + + Args: + device: Device (as str or torch.device) for the new tensor. + copy: Boolean indicator whether or not to clone self. Default False. + dtype: If not None, casts the internal tensor variables + to a given torch.dtype. + + Returns: + Transform3d object. + """ + device_ = make_device(device) + dtype_ = self.dtype if dtype is None else dtype + skip_to = self.device == device_ and self.dtype == dtype_ + + if not copy and skip_to: + return self + + other = self.clone() + + if skip_to: + return other + + other.device = device_ + other.dtype = dtype_ + other._matrix = other._matrix.to(device=device_, dtype=dtype_) + other._transforms = [t.to(device_, copy=copy, dtype=dtype_) for t in other._transforms] + return other + + def cpu(self) -> "Transform3d": + return self.to("cpu") + + def cuda(self) -> "Transform3d": + return self.to("cuda") + + +class Translate(Transform3d): + def __init__(self, x, y=None, z=None, dtype: torch.dtype = torch.float32, device: Optional[Device] = None) -> None: + """ + Create a new Transform3d representing 3D translations. + + Option I: Translate(xyz, dtype=torch.float32, device='cpu') + xyz should be a tensor of shape (N, 3) + + Option II: Translate(x, y, z, dtype=torch.float32, device='cpu') + Here x, y, and z will be broadcast against each other and + concatenated to form the translation. Each can be: + - A python scalar + - A torch scalar + - A 1D torch tensor + """ + xyz = _handle_input(x, y, z, dtype, device, "Translate") + super().__init__(device=xyz.device, dtype=dtype) + N = xyz.shape[0] + + mat = torch.eye(4, dtype=dtype, device=self.device) + mat = mat.view(1, 4, 4).repeat(N, 1, 1) + mat[:, 3, :3] = xyz + self._matrix = mat + + def _get_matrix_inverse(self) -> torch.Tensor: + """ + Return the inverse of self._matrix. + """ + inv_mask = self._matrix.new_ones([1, 4, 4]) + inv_mask[0, 3, :3] = -1.0 + i_matrix = self._matrix * inv_mask + return i_matrix + + +class Scale(Transform3d): + def __init__(self, x, y=None, z=None, dtype: torch.dtype = torch.float32, device: Optional[Device] = None) -> None: + """ + A Transform3d representing a scaling operation, with different scale + factors along each coordinate axis. + + Option I: Scale(s, dtype=torch.float32, device='cpu') + s can be one of + - Python scalar or torch scalar: Single uniform scale + - 1D torch tensor of shape (N,): A batch of uniform scale + - 2D torch tensor of shape (N, 3): Scale differently along each axis + + Option II: Scale(x, y, z, dtype=torch.float32, device='cpu') + Each of x, y, and z can be one of + - python scalar + - torch scalar + - 1D torch tensor + """ + xyz = _handle_input(x, y, z, dtype, device, "scale", allow_singleton=True) + super().__init__(device=xyz.device, dtype=dtype) + N = xyz.shape[0] + + # TODO: Can we do this all in one go somehow? + mat = torch.eye(4, dtype=dtype, device=self.device) + mat = mat.view(1, 4, 4).repeat(N, 1, 1) + mat[:, 0, 0] = xyz[:, 0] + mat[:, 1, 1] = xyz[:, 1] + mat[:, 2, 2] = xyz[:, 2] + self._matrix = mat + + def _get_matrix_inverse(self) -> torch.Tensor: + """ + Return the inverse of self._matrix. + """ + xyz = torch.stack([self._matrix[:, i, i] for i in range(4)], dim=1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + ixyz = 1.0 / xyz + # pyre-fixme[6]: For 1st param expected `Tensor` but got `float`. + imat = torch.diag_embed(ixyz, dim1=1, dim2=2) + return imat + + +class Rotate(Transform3d): + def __init__( + self, + R: torch.Tensor, + dtype: torch.dtype = torch.float32, + device: Optional[Device] = None, + orthogonal_tol: float = 1e-5, + ) -> None: + """ + Create a new Transform3d representing 3D rotation using a rotation + matrix as the input. + + Args: + R: a tensor of shape (3, 3) or (N, 3, 3) + orthogonal_tol: tolerance for the test of the orthogonality of R + + """ + device_ = get_device(R, device) + super().__init__(device=device_, dtype=dtype) + if R.dim() == 2: + R = R[None] + if R.shape[-2:] != (3, 3): + msg = "R must have shape (3, 3) or (N, 3, 3); got %s" + raise ValueError(msg % repr(R.shape)) + R = R.to(device=device_, dtype=dtype) + if os.environ.get("PYTORCH3D_CHECK_ROTATION_MATRICES", "0") == "1": + # Note: aten::all_close in the check is computationally slow, so we + # only run the check when PYTORCH3D_CHECK_ROTATION_MATRICES is on. + _check_valid_rotation_matrix(R, tol=orthogonal_tol) + N = R.shape[0] + mat = torch.eye(4, dtype=dtype, device=device_) + mat = mat.view(1, 4, 4).repeat(N, 1, 1) + mat[:, :3, :3] = R + self._matrix = mat + + def _get_matrix_inverse(self) -> torch.Tensor: + """ + Return the inverse of self._matrix. + """ + return self._matrix.permute(0, 2, 1).contiguous() + + +class RotateAxisAngle(Rotate): + def __init__( + self, + angle, + axis: str = "X", + degrees: bool = True, + dtype: torch.dtype = torch.float32, + device: Optional[Device] = None, + ) -> None: + """ + Create a new Transform3d representing 3D rotation about an axis + by an angle. + + Assuming a right-hand coordinate system, positive rotation angles result + in a counter clockwise rotation. + + Args: + angle: + - A torch tensor of shape (N,) + - A python scalar + - A torch scalar + axis: + string: one of ["X", "Y", "Z"] indicating the axis about which + to rotate. + NOTE: All batch elements are rotated about the same axis. + """ + axis = axis.upper() + if axis not in ["X", "Y", "Z"]: + msg = "Expected axis to be one of ['X', 'Y', 'Z']; got %s" + raise ValueError(msg % axis) + angle = _handle_angle_input(angle, dtype, device, "RotateAxisAngle") + angle = (angle / 180.0 * math.pi) if degrees else angle + # We assume the points on which this transformation will be applied + # are row vectors. The rotation matrix returned from _axis_angle_rotation + # is for transforming column vectors. Therefore we transpose this matrix. + # R will always be of shape (N, 3, 3) + R = _axis_angle_rotation(axis, angle).transpose(1, 2) + super().__init__(device=angle.device, R=R, dtype=dtype) + + +def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + """ + Helper function for _handle_input. + + Args: + c: Python scalar, torch scalar, or 1D torch tensor + + Returns: + c_vec: 1D torch tensor + """ + if not torch.is_tensor(c): + c = torch.tensor(c, dtype=dtype, device=device) + if c.dim() == 0: + c = c.view(1) + if c.device != device or c.dtype != dtype: + c = c.to(device=device, dtype=dtype) + return c + + +def _handle_input( + x, y, z, dtype: torch.dtype, device: Optional[Device], name: str, allow_singleton: bool = False +) -> torch.Tensor: + """ + Helper function to handle parsing logic for building transforms. The output + is always a tensor of shape (N, 3), but there are several types of allowed + input. + + Case I: Single Matrix + In this case x is a tensor of shape (N, 3), and y and z are None. Here just + return x. + + Case II: Vectors and Scalars + In this case each of x, y, and z can be one of the following + - Python scalar + - Torch scalar + - Torch tensor of shape (N, 1) or (1, 1) + In this case x, y and z are broadcast to tensors of shape (N, 1) + and concatenated to a tensor of shape (N, 3) + + Case III: Singleton (only if allow_singleton=True) + In this case y and z are None, and x can be one of the following: + - Python scalar + - Torch scalar + - Torch tensor of shape (N, 1) or (1, 1) + Here x will be duplicated 3 times, and we return a tensor of shape (N, 3) + + Returns: + xyz: Tensor of shape (N, 3) + """ + device_ = get_device(x, device) + # If x is actually a tensor of shape (N, 3) then just return it + if torch.is_tensor(x) and x.dim() == 2: + if x.shape[1] != 3: + msg = "Expected tensor of shape (N, 3); got %r (in %s)" + raise ValueError(msg % (x.shape, name)) + if y is not None or z is not None: + msg = "Expected y and z to be None (in %s)" % name + raise ValueError(msg) + return x.to(device=device_, dtype=dtype) + + if allow_singleton and y is None and z is None: + y = x + z = x + + # Convert all to 1D tensors + xyz = [_handle_coord(c, dtype, device_) for c in [x, y, z]] + + # Broadcast and concatenate + sizes = [c.shape[0] for c in xyz] + N = max(sizes) + for c in xyz: + if c.shape[0] != 1 and c.shape[0] != N: + msg = "Got non-broadcastable sizes %r (in %s)" % (sizes, name) + raise ValueError(msg) + xyz = [c.expand(N) for c in xyz] + xyz = torch.stack(xyz, dim=1) + return xyz + + +def _handle_angle_input(x, dtype: torch.dtype, device: Optional[Device], name: str) -> torch.Tensor: + """ + Helper function for building a rotation function using angles. + The output is always of shape (N,). + + The input can be one of: + - Torch tensor of shape (N,) + - Python scalar + - Torch scalar + """ + device_ = get_device(x, device) + if torch.is_tensor(x) and x.dim() > 1: + msg = "Expected tensor of shape (N,); got %r (in %s)" + raise ValueError(msg % (x.shape, name)) + else: + return _handle_coord(x, dtype, device_) + + +def _broadcast_bmm(a, b) -> torch.Tensor: + """ + Batch multiply two matrices and broadcast if necessary. + + Args: + a: torch tensor of shape (P, K) or (M, P, K) + b: torch tensor of shape (N, K, K) + + Returns: + a and b broadcast multiplied. The output batch dimension is max(N, M). + + To broadcast transforms across a batch dimension if M != N then + expect that either M = 1 or N = 1. The tensor with batch dimension 1 is + expanded to have shape N or M. + """ + if a.dim() == 2: + a = a[None] + if len(a) != len(b): + if not ((len(a) == 1) or (len(b) == 1)): + msg = "Expected batch dim for bmm to be equal or 1; got %r, %r" + raise ValueError(msg % (a.shape, b.shape)) + if len(a) == 1: + a = a.expand(len(b), -1, -1) + if len(b) == 1: + b = b.expand(len(a), -1, -1) + return a.bmm(b) + + +@torch.no_grad() +def _check_valid_rotation_matrix(R, tol: float = 1e-7) -> None: + """ + Determine if R is a valid rotation matrix by checking it satisfies the + following conditions: + + ``RR^T = I and det(R) = 1`` + + Args: + R: an (N, 3, 3) matrix + + Returns: + None + + Emits a warning if R is an invalid rotation matrix. + """ + N = R.shape[0] + eye = torch.eye(3, dtype=R.dtype, device=R.device) + eye = eye.view(1, 3, 3).expand(N, -1, -1) + orthogonal = torch.allclose(R.bmm(R.transpose(1, 2)), eye, atol=tol) + det_R = _safe_det_3x3(R) + no_distortion = torch.allclose(det_R, torch.ones_like(det_R)) + if not (orthogonal and no_distortion): + msg = "R is not a valid rotation matrix" + warnings.warn(msg) + return diff --git a/vggsfm/vggsfm/datasets/camera_transform.py b/vggsfm/vggsfm/datasets/camera_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..83cba22b0d8aab5d9b9ab98ee805e1233ba26690 --- /dev/null +++ b/vggsfm/vggsfm/datasets/camera_transform.py @@ -0,0 +1,236 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Adapted from https://github.com/amyxlase/relpose-plus-plus + +import torch +import numpy as np +import math + + + +from minipytorch3d.cameras import PerspectiveCameras +from minipytorch3d.transform3d import Rotate, Translate +from minipytorch3d.rotation_conversions import matrix_to_quaternion, quaternion_to_matrix + +# from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras +# from pytorch3d.transforms import Rotate, Translate +# from pytorch3d.transforms.rotation_conversions import matrix_to_quaternion, quaternion_to_matrix + + +def bbox_xyxy_to_xywh(xyxy): + wh = xyxy[2:] - xyxy[:2] + xywh = np.concatenate([xyxy[:2], wh]) + return xywh + + +def adjust_camera_to_bbox_crop_(fl, pp, image_size_wh: torch.Tensor, clamp_bbox_xywh: torch.Tensor): + focal_length_px, principal_point_px = _convert_ndc_to_pixels(fl, pp, image_size_wh) + + principal_point_px_cropped = principal_point_px - clamp_bbox_xywh[:2] + + focal_length, principal_point_cropped = _convert_pixels_to_ndc( + focal_length_px, principal_point_px_cropped, clamp_bbox_xywh[2:] + ) + + return focal_length, principal_point_cropped + + +def adjust_camera_to_image_scale_(fl, pp, original_size_wh: torch.Tensor, new_size_wh: torch.LongTensor): + focal_length_px, principal_point_px = _convert_ndc_to_pixels(fl, pp, original_size_wh) + + # now scale and convert from pixels to NDC + image_size_wh_output = new_size_wh.float() + scale = (image_size_wh_output / original_size_wh).min(dim=-1, keepdim=True).values + focal_length_px_scaled = focal_length_px * scale + principal_point_px_scaled = principal_point_px * scale + + focal_length_scaled, principal_point_scaled = _convert_pixels_to_ndc( + focal_length_px_scaled, principal_point_px_scaled, image_size_wh_output + ) + return focal_length_scaled, principal_point_scaled + + +def _convert_ndc_to_pixels(focal_length: torch.Tensor, principal_point: torch.Tensor, image_size_wh: torch.Tensor): + half_image_size = image_size_wh / 2 + rescale = half_image_size.min() + principal_point_px = half_image_size - principal_point * rescale + focal_length_px = focal_length * rescale + return focal_length_px, principal_point_px + + +def _convert_pixels_to_ndc( + focal_length_px: torch.Tensor, principal_point_px: torch.Tensor, image_size_wh: torch.Tensor +): + half_image_size = image_size_wh / 2 + rescale = half_image_size.min() + principal_point = (half_image_size - principal_point_px) / rescale + focal_length = focal_length_px / rescale + return focal_length, principal_point + + +def normalize_cameras( + cameras, compute_optical=True, first_camera=True, normalize_trans=True, scale=1.0, points=None, max_norm=False +): + """ + Normalizes cameras such that + (1) the optical axes point to the origin and the average distance to the origin is 1 + (2) the first camera is the origin + (3) the translation vector is normalized + + TODO: some transforms overlap with others. no need to do so many transforms + Args: + cameras (List[camera]). + """ + # Let distance from first camera to origin be unit + new_cameras = cameras.clone() + + if compute_optical: + new_cameras, points = compute_optical_transform(new_cameras, points=points) + + if first_camera: + new_cameras, points = first_camera_transform(new_cameras, points=points) + + if normalize_trans: + new_cameras, points = normalize_translation(new_cameras, points=points, max_norm=max_norm) + + return new_cameras, points + + +def compute_optical_transform(new_cameras, points=None): + """ + adapted from https://github.com/amyxlase/relpose-plus-plus + """ + + new_transform = new_cameras.get_world_to_view_transform() + p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection(new_cameras) + t = Translate(p_intersect) + scale = dist.squeeze()[0] + + if points is not None: + points = t.inverse().transform_points(points) + points = points / scale + + # Degenerate case + if scale == 0: + scale = torch.norm(new_cameras.T, dim=(0, 1)) + scale = torch.sqrt(scale) + new_cameras.T = new_cameras.T / scale + else: + new_matrix = t.compose(new_transform).get_matrix() + new_cameras.R = new_matrix[:, :3, :3] + new_cameras.T = new_matrix[:, 3, :3] / scale + + return new_cameras, points + + +def compute_optical_axis_intersection(cameras): + centers = cameras.get_camera_center() + principal_points = cameras.principal_point + + one_vec = torch.ones((len(cameras), 1)) + optical_axis = torch.cat((principal_points, one_vec), -1) + + pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True) + + pp2 = pp[torch.arange(pp.shape[0]), torch.arange(pp.shape[0])] + + directions = pp2 - centers + centers = centers.unsqueeze(0).unsqueeze(0) + directions = directions.unsqueeze(0).unsqueeze(0) + + p_intersect, p_line_intersect, _, r = intersect_skew_line_groups(p=centers, r=directions, mask=None) + + p_intersect = p_intersect.squeeze().unsqueeze(0) + dist = (p_intersect - centers).norm(dim=-1) + + return p_intersect, dist, p_line_intersect, pp2, r + + +def intersect_skew_line_groups(p, r, mask): + # p, r both of shape (B, N, n_intersected_lines, 3) + # mask of shape (B, N, n_intersected_lines) + p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask) + _, p_line_intersect = _point_line_distance(p, r, p_intersect[..., None, :].expand_as(p)) + intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum(dim=-1) + return p_intersect, p_line_intersect, intersect_dist_squared, r + + +def intersect_skew_lines_high_dim(p, r, mask=None): + # Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions + dim = p.shape[-1] + # make sure the heading vectors are l2-normed + if mask is None: + mask = torch.ones_like(p[..., 0]) + r = torch.nn.functional.normalize(r, dim=-1) + + eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None] + I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None] + sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3) + p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0] + + if torch.any(torch.isnan(p_intersect)): + print(p_intersect) + raise ValueError(f"p_intersect is NaN") + + return p_intersect, r + + +def _point_line_distance(p1, r1, p2): + df = p2 - p1 + proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1) + line_pt_nearest = p2 - proj_vector + d = (proj_vector).norm(dim=-1) + return d, line_pt_nearest + + +def first_camera_transform(cameras, rotation_only=False, points=None): + """ + Transform so that the first camera is the origin + """ + + new_cameras = cameras.clone() + new_transform = new_cameras.get_world_to_view_transform() + + tR = Rotate(new_cameras.R[0].unsqueeze(0)) + if rotation_only: + t = tR.inverse() + else: + tT = Translate(new_cameras.T[0].unsqueeze(0)) + t = tR.compose(tT).inverse() + + if points is not None: + points = t.inverse().transform_points(points) + + new_matrix = t.compose(new_transform).get_matrix() + + new_cameras.R = new_matrix[:, :3, :3] + new_cameras.T = new_matrix[:, 3, :3] + + return new_cameras, points + + +def normalize_translation(new_cameras, points=None, max_norm=False): + t_gt = new_cameras.T.clone() + t_gt = t_gt[1:, :] + + if max_norm: + t_gt_scale = torch.norm(t_gt, dim=(-1)) + t_gt_scale = t_gt_scale.max() + t_gt_scale = t_gt_scale.clamp(min=0.01, max=100) + else: + t_gt_scale = torch.norm(t_gt, dim=(0, 1)) + t_gt_scale = t_gt_scale / math.sqrt(len(t_gt)) + t_gt_scale = t_gt_scale / 2 + t_gt_scale = t_gt_scale.clamp(min=0.01, max=100) + + new_cameras.T = new_cameras.T / t_gt_scale + + if points is not None: + points = points / t_gt_scale + + return new_cameras, points diff --git a/vggsfm/vggsfm/datasets/demo_loader.py b/vggsfm/vggsfm/datasets/demo_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..1a31c169a30613f22f44f7d37520ad6a394bca2c --- /dev/null +++ b/vggsfm/vggsfm/datasets/demo_loader.py @@ -0,0 +1,313 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import json +import random + +import glob +import torch +import numpy as np + +from PIL import Image, ImageFile +from torchvision import transforms +from torch.utils.data import Dataset + +from minipytorch3d.cameras import PerspectiveCameras + +import pycolmap + +from ..utils.tensor_to_pycolmap import pycolmap_to_batch_matrix + + +from .camera_transform import ( + normalize_cameras, + adjust_camera_to_bbox_crop_, + adjust_camera_to_image_scale_, + bbox_xyxy_to_xywh, +) + + +Image.MAX_IMAGE_PIXELS = None +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +class DemoLoader(Dataset): + def __init__( + self, + SCENE_DIR, + transform=None, + img_size=1024, + eval_time=True, + normalize_cameras=True, + sort_by_filename=True, + load_gt=False, + cfg=None, + ): + self.cfg = cfg + + self.sequences = {} + + if SCENE_DIR == None: + raise NotImplementedError + + print(f"SCENE_DIR is {SCENE_DIR}") + + self.SCENE_DIR = SCENE_DIR + self.crop_longest = True + self.load_gt = load_gt + self.sort_by_filename = sort_by_filename + + bag_name = os.path.basename(SCENE_DIR) + img_filenames = glob.glob(os.path.join(SCENE_DIR, "images/*")) + + if self.sort_by_filename: + img_filenames = sorted(img_filenames) + + filtered_data = [] + + if self.load_gt: + """ + We assume the ground truth cameras exist in the format of colmap + """ + reconstruction = pycolmap.Reconstruction(os.path.join(SCENE_DIR, "sparse", "0")) + + calib_dict = {} + for image_id, image in reconstruction.images.items(): + extrinsic = reconstruction.images[image_id].cam_from_world.matrix + camera_id = image.camera_id + intrinsic = reconstruction.cameras[camera_id].calibration_matrix() + + R = torch.from_numpy(extrinsic[:, :3]) + T = torch.from_numpy(extrinsic[:, 3]) + fl = torch.from_numpy(intrinsic[[0, 1], [0, 1]]) + pp = torch.from_numpy(intrinsic[[0, 1], [2, 2]]) + + calib_dict[image.name] = {"R": R, "T": T, "focal_length": fl, "principal_point": pp} + + for img_name in img_filenames: + frame_dict = {} + frame_dict["filepath"] = img_name + + if self.load_gt: + anno_dict = calib_dict[os.path.basename(img_name)] + frame_dict.update(anno_dict) + + filtered_data.append(frame_dict) + self.sequences[bag_name] = filtered_data + + self.sequence_list = sorted(self.sequences.keys()) + + if transform is None: + self.transform = transforms.Compose([transforms.ToTensor(), transforms.Resize(img_size, antialias=True)]) + else: + self.transform = transform + + self.jitter_scale = [1, 1] + self.jitter_trans = [0, 0] + + self.img_size = img_size + self.eval_time = eval_time + + self.normalize_cameras = normalize_cameras + + print(f"Data size of Sequence: {len(self)}") + + def __len__(self): + return len(self.sequence_list) + + def _crop_image(self, image, bbox, white_bg=False): + if white_bg: + # Only support PIL Images + image_crop = Image.new("RGB", (bbox[2] - bbox[0], bbox[3] - bbox[1]), (255, 255, 255)) + image_crop.paste(image, (-bbox[0], -bbox[1])) + else: + image_crop = transforms.functional.crop( + image, top=bbox[1], left=bbox[0], height=bbox[3] - bbox[1], width=bbox[2] - bbox[0] + ) + return image_crop + + def __getitem__(self, idx_N): + if self.eval_time: + return self.get_data(index=idx_N, ids=None) + else: + raise NotImplementedError("Do not train on Sequence.") + + def get_data(self, index=None, sequence_name=None, ids=None, return_path=False): + if sequence_name is None: + sequence_name = self.sequence_list[index] + + metadata = self.sequences[sequence_name] + + if ids is None: + ids = np.arange(len(metadata)) + + annos = [metadata[i] for i in ids] + + if self.sort_by_filename: + annos = sorted(annos, key=lambda x: x["filepath"]) + + images = [] + image_paths = [] + + if self.load_gt: + rotations = [] + translations = [] + focal_lengths = [] + principal_points = [] + + for anno in annos: + image_path = anno["filepath"] + + image = Image.open(image_path).convert("RGB") + + images.append(image) + image_paths.append(image_path) + + if self.load_gt: + rotations.append(anno["R"]) + translations.append(anno["T"]) + + # focal length and principal point + # from OPENCV to PT3D + original_size_wh = np.array(image.size) + scale = min(original_size_wh) / 2 + c0 = original_size_wh / 2.0 + focal_pytorch3d = anno["focal_length"] / scale + + # mirrored principal point + p0_pytorch3d = -(anno["principal_point"] - c0) / scale + focal_lengths.append(focal_pytorch3d) + principal_points.append(p0_pytorch3d) + + batch = {"seq_name": sequence_name, "frame_num": len(metadata)} + + crop_parameters = [] + images_transformed = [] + + if self.load_gt: + new_fls = [] + new_pps = [] + + for i, (anno, image) in enumerate(zip(annos, images)): + w, h = image.width, image.height + + if self.crop_longest: + crop_dim = max(h, w) + top = (h - crop_dim) // 2 + left = (w - crop_dim) // 2 + bbox = np.array([left, top, left + crop_dim, top + crop_dim]) + else: + bbox = np.array(anno["bbox"]) + + crop_paras = calculate_crop_parameters(image, bbox, crop_dim, self.img_size) + crop_parameters.append(crop_paras) + + # Crop image by bbox + image = self._crop_image(image, bbox) + + images_transformed.append(self.transform(image)) + + if self.load_gt: + bbox_xywh = torch.FloatTensor(bbox_xyxy_to_xywh(bbox)) + + # Cropping images + focal_length_cropped, principal_point_cropped = adjust_camera_to_bbox_crop_( + focal_lengths[i], principal_points[i], torch.FloatTensor(image.size), bbox_xywh + ) + + # Resizing images + new_focal_length, new_principal_point = adjust_camera_to_image_scale_( + focal_length_cropped, + principal_point_cropped, + torch.FloatTensor(image.size), + torch.FloatTensor([self.img_size, self.img_size]), + ) + + new_fls.append(new_focal_length) + new_pps.append(new_principal_point) + + images = images_transformed + + if self.load_gt: + new_fls = torch.stack(new_fls) + new_pps = torch.stack(new_pps) + + batchR = torch.cat([data["R"][None] for data in annos]) + batchT = torch.cat([data["T"][None] for data in annos]) + + batch["rawR"] = batchR.clone() + batch["rawT"] = batchT.clone() + + # From OPENCV/COLMAP to PT3D + batchR = batchR.clone().permute(0, 2, 1) + batchT = batchT.clone() + batchR[:, :, :2] *= -1 + batchT[:, :2] *= -1 + + cameras = PerspectiveCameras( + focal_length=new_fls.float(), principal_point=new_pps.float(), R=batchR.float(), T=batchT.float() + ) + + if self.normalize_cameras: + # Move the cameras so that they stay in a better coordinate + # This will not affect the evaluation result + normalized_cameras, points = normalize_cameras(cameras, points=None) + + if normalized_cameras == -1: + print("Error in normalizing cameras: camera scale was 0") + raise RuntimeError + + batch["R"] = normalized_cameras.R + batch["T"] = normalized_cameras.T + + batch["fl"] = normalized_cameras.focal_length + batch["pp"] = normalized_cameras.principal_point + + if torch.any(torch.isnan(batch["T"])): + print(ids) + # print(category) + print(sequence_name) + raise RuntimeError + else: + batch["R"] = cameras.R + batch["T"] = cameras.T + batch["fl"] = cameras.focal_length + batch["pp"] = cameras.principal_point + + batch["crop_params"] = torch.stack(crop_parameters) + + # Add images + if self.transform is not None: + images = torch.stack(images) + + if not self.eval_time: + raise ValueError("color aug should not happen for Sequence") + + batch["image"] = images.clamp(0, 1) + + if return_path: + return batch, image_paths + + return batch + + +def calculate_crop_parameters(image, bbox, crop_dim, img_size): + crop_center = (bbox[:2] + bbox[2:]) / 2 + # convert crop center to correspond to a "square" image + width, height = image.size + length = max(width, height) + s = length / min(width, height) + crop_center = crop_center + (length - np.array([width, height])) / 2 + # convert to NDC + cc = s - 2 * s * crop_center / length + crop_width = 2 * s * (bbox[2] - bbox[0]) / length + bbox_after = bbox / crop_dim * img_size + crop_parameters = torch.tensor( + [width, height, crop_width, s, bbox_after[0], bbox_after[1], bbox_after[2], bbox_after[3]] + ).float() + return crop_parameters diff --git a/vggsfm/vggsfm/datasets/imc.py b/vggsfm/vggsfm/datasets/imc.py new file mode 100644 index 0000000000000000000000000000000000000000..c52264305cb53085c1502106a6101909cbb0a681 --- /dev/null +++ b/vggsfm/vggsfm/datasets/imc.py @@ -0,0 +1,327 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import json +import random + +import glob +import torch +import numpy as np + +from PIL import Image, ImageFile +from torchvision import transforms +from torch.utils.data import Dataset + +from minipytorch3d.cameras import PerspectiveCameras + +from .camera_transform import ( + normalize_cameras, + adjust_camera_to_bbox_crop_, + adjust_camera_to_image_scale_, + bbox_xyxy_to_xywh, +) + +from .imc_helper import parse_file_to_list, load_calib + + +Image.MAX_IMAGE_PIXELS = None +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +class IMCDataset(Dataset): + def __init__( + self, + IMC_DIR, + split="train", + transform=None, + img_size=1024, + eval_time=True, + normalize_cameras=True, + sort_by_filename=True, + cfg=None, + ): + self.cfg = cfg + + self.sequences = {} + + if IMC_DIR == None: + raise NotImplementedError + + print(f"IMC_DIR is {IMC_DIR}") + + if split == "train": + raise ValueError("We don't want to train on IMC") + elif split == "test": + bag_names = glob.glob(os.path.join(IMC_DIR, "*/set_100/sub_set/*.txt")) + + if cfg.imc_scene_eight: + # In some settings, the scene london_bridge is removed from IMC + bag_names = [name for name in bag_names if "london_bridge" not in name] + + for bag_name in bag_names: + parts = bag_name.split("/") # Split the string into parts by '/' + location = parts[-4] # The location part is at index 5 + bag_info = parts[-1].split(".")[0] # The bag info part is the last part, and remove '.txt' + new_bag_name = f"{bag_info}_{location}" # Format the new bag name + + img_filenames = parse_file_to_list(bag_name, "/".join(parts[:-2])) + filtered_data = [] + + for img_name in img_filenames: + calib_file = img_name.replace("images", "calibration").replace("jpg", "h5") + calib_file = "/".join( + calib_file.rsplit("/", 1)[:-1] + ["calibration_" + calib_file.rsplit("/", 1)[-1]] + ) + calib_dict = load_calib([calib_file]) + + calib = calib_dict[os.path.basename(img_name).split(".")[0]] + intri = torch.from_numpy(np.copy(calib["K"])) + + R = torch.from_numpy(np.copy(calib["R"])) + + tvec = torch.from_numpy(np.copy(calib["T"]).reshape((3,))) + + fl = torch.from_numpy(np.stack([intri[0, 0], intri[1, 1]], axis=0)) + pp = torch.from_numpy(np.stack([intri[0, 2], intri[1, 2]], axis=0)) + + filtered_data.append( + { + "filepath": img_name, + "R": R, + "T": tvec, + "focal_length": fl, + "principal_point": pp, + "calib": calib, + } + ) + self.sequences[new_bag_name] = filtered_data + else: + raise ValueError("please specify correct set") + + self.IMC_DIR = IMC_DIR + self.crop_longest = True + + self.sequence_list = sorted(self.sequences.keys()) + + self.split = split + self.sort_by_filename = sort_by_filename + + if transform is None: + self.transform = transforms.Compose([transforms.ToTensor(), transforms.Resize(img_size, antialias=True)]) + else: + self.transform = transform + + random_aug = False # do not use random_aug for IMC + + if random_aug and not eval_time: + self.jitter_scale = cfg.jitter_scale + self.jitter_trans = cfg.jitter_trans + else: + self.jitter_scale = [1, 1] + self.jitter_trans = [0, 0] + + self.img_size = img_size + self.eval_time = eval_time + + self.normalize_cameras = normalize_cameras + + print(f"Data size of IMC: {len(self)}") + + def __len__(self): + return len(self.sequence_list) + + def _crop_image(self, image, bbox, white_bg=False): + if white_bg: + # Only support PIL Images + image_crop = Image.new("RGB", (bbox[2] - bbox[0], bbox[3] - bbox[1]), (255, 255, 255)) + image_crop.paste(image, (-bbox[0], -bbox[1])) + else: + image_crop = transforms.functional.crop( + image, top=bbox[1], left=bbox[0], height=bbox[3] - bbox[1], width=bbox[2] - bbox[0] + ) + return image_crop + + def __getitem__(self, idx_N): + if self.eval_time: + return self.get_data(index=idx_N, ids=None) + else: + raise NotImplementedError("Do not train on IMC.") + + def get_data(self, index=None, sequence_name=None, ids=None, return_path=False): + if sequence_name is None: + sequence_name = self.sequence_list[index] + + metadata = self.sequences[sequence_name] + + if ids is None: + ids = np.arange(len(metadata)) + + annos = [metadata[i] for i in ids] + + if self.sort_by_filename: + annos = sorted(annos, key=lambda x: x["filepath"]) + + images = [] + image_paths = [] + rotations = [] + translations = [] + focal_lengths = [] + principal_points = [] + + for anno in annos: + filepath = anno["filepath"] + + image_path = os.path.join(self.IMC_DIR, filepath) + image = Image.open(image_path).convert("RGB") + + images.append(image) + image_paths.append(image_path) + rotations.append(anno["R"]) + translations.append(anno["T"]) + + # focal length and principal point + # from OPENCV to PT3D + original_size_wh = np.array(image.size) + scale = min(original_size_wh) / 2 + c0 = original_size_wh / 2.0 + focal_pytorch3d = anno["focal_length"] / scale + + # mirrored principal point + p0_pytorch3d = -(anno["principal_point"] - c0) / scale + focal_lengths.append(focal_pytorch3d) + principal_points.append(p0_pytorch3d) + + batch = {"seq_name": sequence_name, "frame_num": len(metadata)} + + crop_parameters = [] + images_transformed = [] + new_fls = [] + new_pps = [] + + for i, (anno, image) in enumerate(zip(annos, images)): + w, h = image.width, image.height + + if self.crop_longest: + crop_dim = max(h, w) + top = (h - crop_dim) // 2 + left = (w - crop_dim) // 2 + bbox = np.array([left, top, left + crop_dim, top + crop_dim]) + else: + bbox = np.array(anno["bbox"]) + + if self.eval_time: + bbox_jitter = bbox + else: + # No you should not go here for IMC + # because we never use IMC for training + bbox_jitter = self._jitter_bbox(bbox) + + bbox_xywh = torch.FloatTensor(bbox_xyxy_to_xywh(bbox_jitter)) + + # Cropping images + focal_length_cropped, principal_point_cropped = adjust_camera_to_bbox_crop_( + focal_lengths[i], principal_points[i], torch.FloatTensor(image.size), bbox_xywh + ) + + crop_paras = calculate_crop_parameters(image, bbox_jitter, crop_dim, self.img_size) + crop_parameters.append(crop_paras) + + # Crop image by bbox_jitter + image = self._crop_image(image, bbox_jitter) + + # Resizing images + new_focal_length, new_principal_point = adjust_camera_to_image_scale_( + focal_length_cropped, + principal_point_cropped, + torch.FloatTensor(image.size), + torch.FloatTensor([self.img_size, self.img_size]), + ) + + images_transformed.append(self.transform(image)) + new_fls.append(new_focal_length) + new_pps.append(new_principal_point) + + images = images_transformed + + new_fls = torch.stack(new_fls) + new_pps = torch.stack(new_pps) + + batchR = torch.cat([data["R"][None] for data in annos]) + batchT = torch.cat([data["T"][None] for data in annos]) + + batch["rawR"] = batchR.clone() + batch["rawT"] = batchT.clone() + + # From OPENCV/COLMAP to PT3D + batchR = batchR.clone().permute(0, 2, 1) + batchT = batchT.clone() + batchR[:, :, :2] *= -1 + batchT[:, :2] *= -1 + + cameras = PerspectiveCameras( + focal_length=new_fls.float(), principal_point=new_pps.float(), R=batchR.float(), T=batchT.float() + ) + + if self.normalize_cameras: + # Move the cameras so that they stay in a better coordinate + # This will not affect the evaluation result + normalized_cameras, points = normalize_cameras(cameras, points=None) + + if normalized_cameras == -1: + print("Error in normalizing cameras: camera scale was 0") + raise RuntimeError + + batch["R"] = normalized_cameras.R + batch["T"] = normalized_cameras.T + + batch["fl"] = normalized_cameras.focal_length + batch["pp"] = normalized_cameras.principal_point + + if torch.any(torch.isnan(batch["T"])): + print(ids) + # print(category) + print(sequence_name) + raise RuntimeError + else: + batch["R"] = cameras.R + batch["T"] = cameras.T + batch["fl"] = cameras.focal_length + batch["pp"] = cameras.principal_point + + batch["crop_params"] = torch.stack(crop_parameters) + + # Add images + if self.transform is not None: + images = torch.stack(images) + + if not self.eval_time: + raise ValueError("color aug should not happen for IMC") + + batch["image"] = images.clamp(0, 1) + + if return_path: + return batch, image_paths + + return batch + + +def calculate_crop_parameters(image, bbox_jitter, crop_dim, img_size): + crop_center = (bbox_jitter[:2] + bbox_jitter[2:]) / 2 + # convert crop center to correspond to a "square" image + width, height = image.size + length = max(width, height) + s = length / min(width, height) + crop_center = crop_center + (length - np.array([width, height])) / 2 + # convert to NDC + cc = s - 2 * s * crop_center / length + crop_width = 2 * s * (bbox_jitter[2] - bbox_jitter[0]) / length + bbox_after = bbox_jitter / crop_dim * img_size + crop_parameters = torch.tensor( + [-cc[0], -cc[1], crop_width, s, bbox_after[0], bbox_after[1], bbox_after[2], bbox_after[3]] + ).float() + return crop_parameters diff --git a/vggsfm/vggsfm/datasets/imc_helper.py b/vggsfm/vggsfm/datasets/imc_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..b71564e29f7ade7f72038a7075b4014801899861 --- /dev/null +++ b/vggsfm/vggsfm/datasets/imc_helper.py @@ -0,0 +1,1145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# Adapted from https://github.com/ubc-vision/image-matching-benchmark +# and +# https://github.com/colmap/colmap + +import os +import numpy as np +import json + +from copy import deepcopy +from datetime import datetime + + +import cv2 +import h5py + + +import collections +import struct +import argparse + + +CameraModel = collections.namedtuple("CameraModel", ["model_id", "model_name", "num_params"]) +Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"]) +BaseImage = collections.namedtuple("Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) +Point3D = collections.namedtuple("Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) + + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + + +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), +} +CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS]) +CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS]) + + +def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + + +def write_next_bytes(fid, data, format_char_sequence, endian_character="<"): + """pack and write to a binary file. + :param fid: + :param data: data to send, if multiple elements are sent at the same time, + they should be encapsuled either in a list or a tuple + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + should be the same length as the data list or tuple + :param endian_character: Any of {@, =, <, >, !} + """ + if isinstance(data, (list, tuple)): + bytes = struct.pack(endian_character + format_char_sequence, *data) + else: + bytes = struct.pack(endian_character + format_char_sequence, data) + fid.write(bytes) + + +def read_cameras_text(path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + cameras = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera(id=camera_id, model=model, width=width, height=height, params=params) + return cameras + + +def read_cameras_binary(path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + cameras = {} + with open(path_to_model_file, "rb") as fid: + num_cameras = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_cameras): + camera_properties = read_next_bytes(fid, num_bytes=24, format_char_sequence="iiQQ") + camera_id = camera_properties[0] + model_id = camera_properties[1] + model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name + width = camera_properties[2] + height = camera_properties[3] + num_params = CAMERA_MODEL_IDS[model_id].num_params + params = read_next_bytes(fid, num_bytes=8 * num_params, format_char_sequence="d" * num_params) + cameras[camera_id] = Camera( + id=camera_id, model=model_name, width=width, height=height, params=np.array(params) + ) + assert len(cameras) == num_cameras + return cameras + + +def write_cameras_text(cameras, path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + HEADER = ( + "# Camera list with one line of data per camera:\n" + + "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" + + "# Number of cameras: {}\n".format(len(cameras)) + ) + with open(path, "w") as fid: + fid.write(HEADER) + for _, cam in cameras.items(): + to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params] + line = " ".join([str(elem) for elem in to_write]) + fid.write(line + "\n") + + +def write_cameras_binary(cameras, path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(cameras), "Q") + for _, cam in cameras.items(): + model_id = CAMERA_MODEL_NAMES[cam.model].model_id + camera_properties = [cam.id, model_id, cam.width, cam.height] + write_next_bytes(fid, camera_properties, "iiQQ") + for p in cam.params: + write_next_bytes(fid, float(p), "d") + return cameras + + +def read_images_text(path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + images = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack([tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))]) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + images[image_id] = Image( + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) + return images + + +def read_images_binary(path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_reg_images): + binary_image_properties = read_next_bytes(fid, num_bytes=64, format_char_sequence="idddddddi") + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + image_name = "" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + image_name += current_char.decode("utf-8") + current_char = read_next_bytes(fid, 1, "c")[0] + num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[0] + x_y_id_s = read_next_bytes(fid, num_bytes=24 * num_points2D, format_char_sequence="ddq" * num_points2D) + xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))]) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + images[image_id] = Image( + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) + return images + + +def write_images_text(images, path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + if len(images) == 0: + mean_observations = 0 + else: + mean_observations = sum((len(img.point3D_ids) for _, img in images.items())) / len(images) + HEADER = ( + "# Image list with two lines of data per image:\n" + + "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" + + "# POINTS2D[] as (X, Y, POINT3D_ID)\n" + + "# Number of images: {}, mean observations per image: {}\n".format(len(images), mean_observations) + ) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, img in images.items(): + image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name] + first_line = " ".join(map(str, image_header)) + fid.write(first_line + "\n") + + points_strings = [] + for xy, point3D_id in zip(img.xys, img.point3D_ids): + points_strings.append(" ".join(map(str, [*xy, point3D_id]))) + fid.write(" ".join(points_strings) + "\n") + + +def write_images_binary(images, path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(images), "Q") + for _, img in images.items(): + write_next_bytes(fid, img.id, "i") + write_next_bytes(fid, img.qvec.tolist(), "dddd") + write_next_bytes(fid, img.tvec.tolist(), "ddd") + write_next_bytes(fid, img.camera_id, "i") + for char in img.name: + write_next_bytes(fid, char.encode("utf-8"), "c") + write_next_bytes(fid, b"\x00", "c") + write_next_bytes(fid, len(img.point3D_ids), "Q") + for xy, p3d_id in zip(img.xys, img.point3D_ids): + write_next_bytes(fid, [*xy, p3d_id], "ddq") + + +def read_points3D_text(path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + points3D = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + point3D_id = int(elems[0]) + xyz = np.array(tuple(map(float, elems[1:4]))) + rgb = np.array(tuple(map(int, elems[4:7]))) + error = float(elems[7]) + image_ids = np.array(tuple(map(int, elems[8::2]))) + point2D_idxs = np.array(tuple(map(int, elems[9::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, xyz=xyz, rgb=rgb, error=error, image_ids=image_ids, point2D_idxs=point2D_idxs + ) + return points3D + + +def read_points3D_binary(path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + points3D = {} + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_points): + binary_point_line_properties = read_next_bytes(fid, num_bytes=43, format_char_sequence="QdddBBBd") + point3D_id = binary_point_line_properties[0] + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[0] + track_elems = read_next_bytes(fid, num_bytes=8 * track_length, format_char_sequence="ii" * track_length) + image_ids = np.array(tuple(map(int, track_elems[0::2]))) + point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, xyz=xyz, rgb=rgb, error=error, image_ids=image_ids, point2D_idxs=point2D_idxs + ) + return points3D + + +def write_points3D_text(points3D, path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + if len(points3D) == 0: + mean_track_length = 0 + else: + mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items())) / len(points3D) + HEADER = ( + "# 3D point list with one line of data per point:\n" + + "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" + + "# Number of points: {}, mean track length: {}\n".format(len(points3D), mean_track_length) + ) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, pt in points3D.items(): + point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error] + fid.write(" ".join(map(str, point_header)) + " ") + track_strings = [] + for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs): + track_strings.append(" ".join(map(str, [image_id, point2D]))) + fid.write(" ".join(track_strings) + "\n") + + +def write_points3D_binary(points3D, path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(points3D), "Q") + for _, pt in points3D.items(): + write_next_bytes(fid, pt.id, "Q") + write_next_bytes(fid, pt.xyz.tolist(), "ddd") + write_next_bytes(fid, pt.rgb.tolist(), "BBB") + write_next_bytes(fid, pt.error, "d") + track_length = pt.image_ids.shape[0] + write_next_bytes(fid, track_length, "Q") + for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): + write_next_bytes(fid, [image_id, point2D_id], "ii") + + +def detect_model_format(path, ext): + if ( + os.path.isfile(os.path.join(path, "cameras" + ext)) + and os.path.isfile(os.path.join(path, "images" + ext)) + and os.path.isfile(os.path.join(path, "points3D" + ext)) + ): + print("Detected model format: '" + ext + "'") + return True + + return False + + +def read_model(path, ext=""): + # try to detect the extension automatically + if ext == "": + if detect_model_format(path, ".bin"): + ext = ".bin" + elif detect_model_format(path, ".txt"): + ext = ".txt" + else: + print("Provide model format: '.bin' or '.txt'") + return + + if ext == ".txt": + cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) + images = read_images_text(os.path.join(path, "images" + ext)) + points3D = read_points3D_text(os.path.join(path, "points3D") + ext) + else: + cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) + images = read_images_binary(os.path.join(path, "images" + ext)) + points3D = read_points3D_binary(os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def write_model(cameras, images, points3D, path, ext=".bin"): + if ext == ".txt": + write_cameras_text(cameras, os.path.join(path, "cameras" + ext)) + write_images_text(images, os.path.join(path, "images" + ext)) + write_points3D_text(points3D, os.path.join(path, "points3D") + ext) + else: + write_cameras_binary(cameras, os.path.join(path, "cameras" + ext)) + write_images_binary(images, os.path.join(path, "images" + ext)) + write_points3D_binary(points3D, os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def qvec2rotmat(qvec): + return np.array( + [ + [ + 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], + ], + [ + 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], + ], + [ + 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, + ], + ] + ) + + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = ( + np.array( + [ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz], + ] + ) + / 3.0 + ) + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + + +################################################################# + + +def build_composite_image(image_path1, image_path2, axis=1, margin=0, background=1): + """ + Load two images and returns a composite image. + + Parameters + ---------- + image_path1: Fullpath to image 1. + image_path2: Fullpath to image 2. + in: Space between images + ite) + + Returns + ------- + (Composite image, + (vertical_offset1, vertical_offset2), + (horizontal_offset1, horizontal_offset2)) + """ + + if background != 0 and background != 1: + background = 1 + if axis != 0 and axis != 1: + raise RuntimeError("Axis must be 0 (vertical) or 1 (horizontal") + + im1 = cv2.imread(image_path1) + if im1.ndim == 3: + im1 = cv2.cvtColor(im1, cv2.COLOR_BGR2RGB) + elif im1.ndim == 2: + im1 = cv2.cvtColor(im1, cv2.COLOR_GRAY2RGB) + else: + raise RuntimeError("invalid image format") + + im2 = cv2.imread(image_path2) + if im2.ndim == 3: + im2 = cv2.cvtColor(im2, cv2.COLOR_BGR2RGB) + elif im2.ndim == 2: + im2 = cv2.cvtColor(im2, cv2.COLOR_GRAY2RGB) + else: + raise RuntimeError("invalid image format") + + h1, w1, _ = im1.shape + h2, w2, _ = im2.shape + + if axis == 1: + composite = np.zeros((max(h1, h2), w1 + w2 + margin, 3), dtype=np.uint8) + 255 * background + if h1 > h2: + voff1, voff2 = 0, (h1 - h2) // 2 + else: + voff1, voff2 = (h2 - h1) // 2, 0 + hoff1, hoff2 = 0, w1 + margin + else: + composite = np.zeros((h1 + h2 + margin, max(w1, w2), 3), dtype=np.uint8) + 255 * background + if w1 > w2: + hoff1, hoff2 = 0, (w1 - w2) // 2 + else: + hoff1, hoff2 = (w2 - w1) // 2, 0 + voff1, voff2 = 0, h1 + margin + composite[voff1 : voff1 + h1, hoff1 : hoff1 + w1, :] = im1 + composite[voff2 : voff2 + h2, hoff2 : hoff2 + w2, :] = im2 + + return (composite, (voff1, voff2), (hoff1, hoff2)) + + +def save_h5(dict_to_save, filename): + """Saves dictionary to HDF5 file""" + + with h5py.File(filename, "w") as f: + for key in dict_to_save: + f.create_dataset(key, data=dict_to_save[key]) + + +def load_h5(filename): + """Loads dictionary from hdf5 file""" + + dict_to_load = {} + try: + with h5py.File(filename, "r") as f: + keys = [key for key in f.keys()] + for key in keys: + dict_to_load[key] = f[key][()] + except Exception as e: + print("Following error occured when loading h5 file {}: {}".format(filename, e)) + return dict_to_load + + +########################################## + + +def load_image(image_path, use_color_image=False, input_width=512, crop_center=True, force_rgb=False): + """ + Loads image and do preprocessing. + + Parameters + ---------- + image_path: Fullpath to the image. + use_color_image: Flag to read color/gray image + input_width: Width of the image for scaling + crop_center: Flag to crop while scaling + force_rgb: Flag to convert color image from BGR to RGB + + Returns + ------- + Tuple of (Color/Gray image, scale_factor) + """ + + # Assuming all images in the directory are color images + image = cv2.imread(image_path) + if not use_color_image: + image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + elif force_rgb: + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # Crop center and resize image into something reasonable + scale_factor = 1.0 + if crop_center: + rows, cols = image.shape[:2] + if rows > cols: + cut = (rows - cols) // 2 + img_cropped = image[cut : cut + cols, :] + else: + cut = (cols - rows) // 2 + img_cropped = image[:, cut : cut + rows] + scale_factor = float(input_width) / float(img_cropped.shape[0]) + image = cv2.resize(img_cropped, (input_width, input_width)) + + return (image, scale_factor) + + +def load_depth(depth_path): + return load_h5(depth_path)["depth"] + + +def load_vis(vis_fullpath_list, subset_index=None): + """ + Given fullpath_list load all visibility ranges + """ + vis = [] + if subset_index is None: + for vis_file in vis_fullpath_list: + # Load visibility + vis.append(np.loadtxt(vis_file).flatten().astype("float32")) + else: + for idx in subset_index: + tmp_vis = np.loadtxt(vis_fullpath_list[idx]).flatten().astype("float32") + tmp_vis = tmp_vis[subset_index] + vis.append(tmp_vis) + return vis + + +def load_calib(calib_fullpath_list, subset_index=None): + """Load all calibration files and create a dictionary.""" + + calib = {} + if subset_index is None: + for _calib_file in calib_fullpath_list: + img_name = os.path.splitext(os.path.basename(_calib_file))[0].replace("calibration_", "") + # _calib_file.split( + # '/')[-1].replace('calibration_', '')[:-3] + # # Don't know why, but rstrip .h5 also strips + # # more then necssary sometimes! + # # + # # img_name = _calib_file.split( + # # '/')[-1].replace('calibration_', '').rstrip('.h5') + calib[img_name] = load_h5(_calib_file) + else: + for idx in subset_index: + _calib_file = calib_fullpath_list[idx] + img_name = os.path.splitext(os.path.basename(_calib_file))[0].replace("calibration_", "") + calib[img_name] = load_h5(_calib_file) + return calib + + +def load_h5_valid_image(path, deprecated_images): + return remove_keys(load_h5(path), deprecated_images) + + +def remove_keys(d, key_list): + for key in key_list: + del_key_list = [tmp_key for tmp_key in d.keys() if key in tmp_key] + for del_key in del_key_list: + del d[del_key] + return d + + +##################################################################################### +def get_uuid(cfg): + return cfg.method_dict["config_common"]["json_label"].split("-")[0] + + +def get_eval_path(mode, cfg): + if mode == "feature": + return get_feature_path(cfg) + elif mode == "match": + return get_match_path(cfg) + elif mode == "filter": + return get_filter_path(cfg) + elif mode == "model": + return get_geom_path(cfg) + elif mode == "stereo": + return get_stereo_path(cfg) + elif mode == "multiview": + return get_multiview_path(cfg) + else: + raise ValueError("Unknown job type") + + +def get_eval_file(mode, cfg, job_id=None): + if job_id: + return os.path.join(get_eval_path(mode, cfg), "{}.{}".format(job_id, mode)) + else: + try: + file_list = os.listdir(get_eval_path(mode, cfg)) + valid_file = [file for file in file_list if file.split(".")[-1] == mode] + if len(valid_file) == 0: + return None + elif len(valid_file) == 1: + return os.path.join(get_eval_path(mode, cfg), valid_file[0]) + else: + print("Should never be here") + import IPython + + IPython.embed() + return None + except FileNotFoundError: + os.makedirs(get_eval_path(mode, cfg)) + return None + + +def get_data_path(cfg): + """Returns where the per-dataset results folder is stored. + + TODO: This probably should be done in a neater way. + """ + + # Get data directory for 'set_100' + return os.path.join(cfg.path_data, cfg.dataset, cfg.scene, "set_{}".format(cfg.num_max_set)) + + +def get_base_path(cfg): + """Returns where the per-dataset results folder is stored.""" + + if cfg.is_challenge: + cur_date = "{:%Y%m%d}".format(datetime.now()) + return os.path.join(cfg.path_results, "challenge", get_uuid(cfg), cfg.dataset, cfg.scene) + else: + return os.path.join(cfg.path_results, cfg.dataset, cfg.scene) + + +def get_feature_path(cfg): + """Returns where the keypoints and descriptor results folder is stored. + + Method names converted to lower-case.""" + + common = cfg.method_dict["config_common"] + return os.path.join( + get_base_path(cfg), + "{}_{}_{}".format(common["keypoint"].lower(), common["num_keypoints"], common["descriptor"].lower()), + ) + + +def get_kp_file(cfg): + """Returns the path to the keypoint file.""" + + return os.path.join(get_feature_path(cfg), "keypoints.h5") + + +def get_scale_file(cfg): + """Returns the path to the scale file.""" + + return os.path.join(get_feature_path(cfg), "scales.h5") + + +def get_score_file(cfg): + """Returns the path to the score file.""" + + return os.path.join(get_feature_path(cfg), "scores.h5") + + +def get_angle_file(cfg): + """Returns the path to the angle file.""" + + return os.path.join(get_feature_path(cfg), "angles.h5") + + +def get_affine_file(cfg): + """Returns the path to the angle file.""" + + return os.path.join(get_feature_path(cfg), "affine.h5") + + +def get_desc_file(cfg): + """Returns the path to the descriptor file.""" + + return os.path.join(get_feature_path(cfg), "descriptors.h5") + + +def get_match_name(cfg): + """Return folder name for the matching model. + + Converted to lower-case to avoid conflicts.""" + cur_key = "config_{}_{}".format(cfg.dataset, cfg.task) + + # simply return 'custom_matcher' if it is provided + if cfg.method_dict[cur_key]["use_custom_matches"]: + return cfg.method_dict[cur_key]["custom_matches_name"] + + # consturct matcher name + matcher = cfg.method_dict[cur_key]["matcher"] + + # Make a custom string for the matching folder + label = [] + + # Base name + label += [matcher["method"]] + + # flann/bf + if matcher["flann"]: + label += ["flann"] + else: + label += ["bf"] + + # number of neighbours + label += ["numnn-{}".format(matcher["num_nn"])] + + # distance + label += ["dist-{}".format(matcher["distance"])] + + # 2-way matching + if not matcher["symmetric"]["enabled"]: + label += ["nosym"] + else: + label += ["sym-{}".format(matcher["symmetric"]["reduce"])] + + # filtering + if matcher["filtering"]["type"] == "none": + label += ["nofilter"] + elif matcher["filtering"]["type"].lower() in ["snn_ratio_pairwise", "snn_ratio_vs_last"]: + # Threshold == 1 means no ratio test + # It just makes writing the config files easier + if matcher["filtering"]["threshold"] == 1: + label += ["nofilter"] + else: + label += ["filter-{}-{}".format(matcher["filtering"]["type"], matcher["filtering"]["threshold"])] + elif matcher["filtering"]["type"].lower() == "fginn_ratio_pairwise": + label += [ + "filter-fginn-pairwise-{}-{}".format( + matcher["filtering"]["threshold"], matcher["filtering"]["fginn_radius"] + ) + ] + else: + raise ValueError("Unknown filtering type") + + # distance filtering + if "descriptor_distance_filter" in matcher: + if "threshold" in matcher["descriptor_distance_filter"]: + max_dist = matcher["descriptor_distance_filter"]["threshold"] + label += ["maxdist-{:.03f}".format(max_dist)] + + return "_".join(label).lower() + + +def get_filter_path(cfg): + """Returns folder location for the outlier filter results.""" + + cur_key = "config_{}_{}".format(cfg.dataset, cfg.task) + + # Bypass this when using custom matches + if cfg.method_dict[cur_key]["use_custom_matches"]: + return os.path.join(get_match_path(cfg), "no_filter") + + # Otherwise, depends on the filter method + outlier_filter = cfg.method_dict[cur_key]["outlier_filter"] + if outlier_filter["method"] in ["cne-bp-nd"]: + return os.path.join(get_match_path(cfg), outlier_filter["method"]) + elif outlier_filter["method"] == "none": + return os.path.join(get_match_path(cfg), "no_filter") + else: + raise ValueError("Unknown outlier_filter type") + + +def get_match_path(cfg): + """Returns where the match results folder is stored.""" + return os.path.join(get_feature_path(cfg), get_match_name(cfg)) + + +def get_match_file(cfg): + """Returns the path to the match file.""" + + return os.path.join(get_match_path(cfg), "matches.h5") + + +def get_filter_match_file(cfg): + """Returns the path to the match file after pre-filtering.""" + + return os.path.join(get_filter_path(cfg), "matches.h5") + + +def get_match_cost_file(cfg): + """Returns the path to the match file.""" + + return os.path.join(get_match_path(cfg), "matching_cost.h5") + + +def get_geom_name(cfg): + """Return folder name for the geometry model. + + Converted to lower-case to avoid conflicts.""" + + geom = cfg.method_dict["config_{}_{}".format(cfg.dataset, cfg.task)]["geom"] + method = geom["method"].lower() + + # This entry is a temporary fix + if method in ["cv2-ransac-f", "cv2-usacdef-f", "cv2-usacmagsac-f", "cv2-usacfast-f", "cv2-usacaccurate-f"]: + label = "_".join( + [method, "th", str(geom["threshold"]), "conf", str(geom["confidence"]), "maxiter", str(geom["max_iter"])] + ) + elif method in ["cv2-ransac-e"]: + label = "_".join([method, "th", str(geom["threshold"]), "conf", str(geom["confidence"])]) + elif method in ["cmp-degensac-f", "cmp-degensac-f-laf", "cmp-gc-ransac-e"]: + label = "_".join( + [ + method, + "th", + str(geom["threshold"]), + "conf", + str(geom["confidence"]), + "max_iter", + str(geom["max_iter"]), + "error", + str(geom["error_type"]), + "degencheck", + str(geom["degeneracy_check"]), + ] + ) + elif method in ["cmp-gc-ransac-f", "skimage-ransac-f", "cmp-magsac-f"]: + label = "_".join( + [method, "th", str(geom["threshold"]), "conf", str(geom["confidence"]), "max_iter", str(geom["max_iter"])] + ) + elif method in ["cv2-lmeds-e", "cv2-lmeds-f"]: + label = "_".join([method, "conf", str(geom["confidence"])]) + elif method in ["intel-dfe-f"]: + label = "_".join([method, "th", str(geom["threshold"]), "postprocess", str(geom["postprocess"])]) + elif method in ["cv2-7pt", "cv2-8pt"]: + label = method + else: + raise ValueError("Unknown method for E/F estimation") + + return label.lower() + + +def get_geom_path(cfg): + """Returns where the match results folder is stored.""" + + geom_name = get_geom_name(cfg) + return os.path.join(get_filter_path(cfg), "stereo-fold-{}".format(cfg.run), geom_name) + + +def get_geom_file(cfg): + """Returns the path to the match file.""" + + return os.path.join(get_geom_path(cfg), "essential.h5") + + +def get_geom_inl_file(cfg): + """Returns the path to the match file.""" + return os.path.join(get_geom_path(cfg), "essential_inliers.h5") + + +def get_geom_cost_file(cfg): + """Returns the path to the geom cost file.""" + return os.path.join(get_geom_path(cfg), "geom_cost.h5") + + +def get_cne_temp_path(cfg): + return os.path.join(get_filter_path(cfg), "temp_cne") + + +def get_filter_match_file_for_computing_model(cfg): + filter_match_file = os.path.join(get_filter_path(cfg), "matches_imported_stereo_{}.h5".format(cfg.run)) + if os.path.isfile(filter_match_file): + return filter_match_file + else: + return get_filter_match_file(cfg) + + +def get_filter_match_file(cfg): + return os.path.join(get_filter_path(cfg), "matches_inlier.h5") + + +def get_filter_cost_file(cfg): + return os.path.join(get_filter_path(cfg), "matches_inlier_cost.h5") + + +def get_cne_data_dump_path(cfg): + return os.path.join(get_cne_temp_path(cfg), "data_dump") + + +def get_stereo_path(cfg): + """Returns the path to where the stereo results are stored.""" + + return os.path.join(get_geom_path(cfg), "set_{}".format(cfg.num_max_set)) + + +def get_stereo_pose_file(cfg, th=None): + """Returns the path to where the stereo pose file.""" + + label = "" if th is None else "-th-{:s}".format(th) + return os.path.join(get_stereo_path(cfg), "stereo_pose_errors{}.h5".format(label)) + + +def get_repeatability_score_file(cfg, th=None): + """Returns the path to the repeatability file.""" + + label = "" if th is None else "-th-{:s}".format(th) + return os.path.join(get_stereo_path(cfg), "repeatability_score_file{}.h5".format(label)) + + +def get_stereo_epipolar_pre_match_file(cfg, th=None): + """Returns the path to the match file.""" + + label = "" if th is None else "-th-{:s}".format(th) + return os.path.join(get_stereo_path(cfg), "stereo_epipolar_pre_match_errors{}.h5".format(label)) + + +def get_stereo_epipolar_refined_match_file(cfg, th=None): + """Returns the path to the filtered match file.""" + + label = "" if th is None else "-th-{:s}".format(th) + return os.path.join(get_stereo_path(cfg), "stereo_epipolar_refined_match_errors{}.h5".format(label)) + + +def get_stereo_epipolar_final_match_file(cfg, th=None): + """Returns the path to the match file after RANSAC.""" + + label = "" if th is None else "-th-{:s}".format(th) + return os.path.join(get_stereo_path(cfg), "stereo_epipolar_final_match_errors{}.h5".format(label)) + + +def get_stereo_depth_projection_pre_match_file(cfg, th=None): + """Returns the path to the errors file for input matches.""" + + label = "" if th is None else "-th-{:s}".format(th) + return os.path.join(get_stereo_path(cfg), "stereo_projection_errors_pre_match{}.h5".format(label)) + + +def get_stereo_depth_projection_refined_match_file(cfg, th=None): + """Returns the path to the errors file for filtered matches.""" + + label = "" if th is None else "-th-{:s}".format(th) + return os.path.join(get_stereo_path(cfg), "stereo_projection_errors_refined_match{}.h5".format(label)) + + +def get_stereo_depth_projection_final_match_file(cfg, th=None): + """Returns the path to the errors file for final matches.""" + + label = "" if th is None else "-th-{:s}".format(th) + return os.path.join(get_stereo_path(cfg), "stereo_projection_errors_final_match{}.h5".format(label)) + + +def get_colmap_path(cfg): + """Returns the path to colmap results for individual bag.""" + + return os.path.join(get_multiview_path(cfg), "bag_size_{}".format(cfg.bag_size), "bag_id_{:05d}".format(cfg.bag_id)) + + +def get_multiview_path(cfg): + """Returns the path to multiview folder.""" + + return os.path.join(get_filter_path(cfg), "multiview-fold-{}".format(cfg.run)) + + +def get_colmap_mark_file(cfg): + """Returns the path to colmap flag.""" + + return os.path.join(get_colmap_path(cfg), "colmap_has_run") + + +def get_colmap_pose_file(cfg): + """Returns the path to colmap pose files.""" + + return os.path.join(get_colmap_path(cfg), "colmap_pose_errors.h5") + + +def get_colmap_output_path(cfg): + """Returns the path to colmap outputs.""" + + return os.path.join(get_colmap_path(cfg), "colmap") + + +def get_colmap_temp_path(cfg): + """Returns the path to colmap working path.""" + + # TODO: Do we want to use slurm temp directory? + return os.path.join(get_colmap_path(cfg), "temp_colmap") + + +def parse_file_to_list(file_name, data_dir): + """ + Parses filenames from the given text file using the `data_dir` + + :param file_name: File with list of file names + :param data_dir: Full path location appended to the filename + + :return: List of full paths to the file names + """ + + fullpath_list = [] + with open(file_name, "r") as f: + while True: + # Read a single line + line = f.readline() + # Check the `line` type + if not isinstance(line, str): + line = line.decode("utf-8") + if not line: + break + # Strip `\n` at the end and append to the `fullpath_list` + fullpath_list.append(os.path.join(data_dir, line.rstrip("\n"))) + return fullpath_list + + +def get_fullpath_list(data_dir, key): + """ + Returns the full-path lists to image info in `data_dir` + + :param data_dir: Path to the location of dataset + :param key: Which item to retrieve from + + :return: Tuple containing fullpath lists for the key item + """ + # Read the list of images, homography and geometry + list_file = os.path.join(data_dir, "{}.txt".format(key)) + + # Parse files to fullpath lists + fullpath_list = parse_file_to_list(list_file, data_dir) + + return fullpath_list + + +def get_item_name_list(fullpath_list): + """Returns each item name in the full path list, excluding the extension""" + + return [os.path.splitext(os.path.basename(_s))[0] for _s in fullpath_list] + + +def get_stereo_viz_folder(cfg): + """Returns the path to the stereo visualizations folder.""" + + base = os.path.join(cfg.method_dict["config_common"]["json_label"].lower(), cfg.dataset, cfg.scene, "stereo") + + return os.path.join(cfg.path_visualization, "png", base), os.path.join(cfg.path_visualization, "jpg", base) + + +def get_colmap_viz_folder(cfg): + """Returns the path to the multiview visualizations folder.""" + + base = os.path.join(cfg.method_dict["config_common"]["json_label"].lower(), cfg.dataset, cfg.scene, "multiview") + + return os.path.join(cfg.path_visualization, "png", base), os.path.join(cfg.path_visualization, "jpg", base) + + +def get_pairs_per_threshold(data_dir): + pairs = {} + for th in np.arange(0, 1, 0.1): + pairs["{:0.1f}".format(th)] = np.load("{}/new-vis-pairs/keys-th-{:0.1f}.npy".format(data_dir, th)) + return pairs diff --git a/vggsfm/vggsfm/models/__init__.py b/vggsfm/vggsfm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..72dc11714f312b149b03ed0d598010d16c457559 --- /dev/null +++ b/vggsfm/vggsfm/models/__init__.py @@ -0,0 +1,9 @@ +from .vggsfm import VGGSfM + + +from .track_modules.blocks import BasicEncoder, ShallowEncoder +from .track_modules.base_track_predictor import BaseTrackerPredictor + +from .track_predictor import TrackerPredictor +from .camera_predictor import CameraPredictor +from .triangulator import Triangulator diff --git a/vggsfm/vggsfm/models/camera_predictor.py b/vggsfm/vggsfm/models/camera_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..8b9bf76d1f4f2072ec37692f292769f09c646398 --- /dev/null +++ b/vggsfm/vggsfm/models/camera_predictor.py @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from collections import defaultdict +from dataclasses import field, dataclass + +import math +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from hydra.utils import instantiate +from einops import rearrange, repeat +from typing import Any, Dict, List, Optional, Tuple, Union, Callable + + +from .modules import AttnBlock, CrossAttnBlock, Mlp, ResidualBlock + +from .utils import get_2d_sincos_pos_embed, PoseEmbedding, pose_encoding_to_camera, camera_to_pose_encoding + + +logger = logging.getLogger(__name__) + + +_RESNET_MEAN = [0.485, 0.456, 0.406] +_RESNET_STD = [0.229, 0.224, 0.225] + + +class CameraPredictor(nn.Module): + def __init__( + self, + hidden_size=768, + num_heads=8, + mlp_ratio=4, + z_dim: int = 768, + down_size=336, + att_depth=8, + trunk_depth=4, + backbone="dinov2b", + pose_encoding_type="absT_quaR_OneFL", + cfg=None, + ): + super().__init__() + self.cfg = cfg + + self.att_depth = att_depth + self.down_size = down_size + self.pose_encoding_type = pose_encoding_type + + if self.pose_encoding_type == "absT_quaR_OneFL": + self.target_dim = 8 + if self.pose_encoding_type == "absT_quaR_logFL": + self.target_dim = 9 + + self.backbone = self.get_backbone(backbone) + + for param in self.backbone.parameters(): + param.requires_grad = False + + self.input_transform = Mlp(in_features=z_dim, out_features=hidden_size, drop=0) + self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + # sine and cosine embed for camera parameters + self.embed_pose = PoseEmbedding( + target_dim=self.target_dim, n_harmonic_functions=(hidden_size // self.target_dim) // 2, append_input=False + ) + + self.pose_token = nn.Parameter(torch.zeros(1, 1, 1, hidden_size)) # register + + self.pose_branch = Mlp( + in_features=hidden_size, hidden_features=hidden_size * 2, out_features=hidden_size + self.target_dim, drop=0 + ) + + self.ffeat_updater = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.GELU()) + + self.self_att = nn.ModuleList( + [ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) + for _ in range(self.att_depth) + ] + ) + + self.cross_att = nn.ModuleList( + [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(self.att_depth)] + ) + + self.trunk = nn.Sequential( + *[ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) + for _ in range(trunk_depth) + ] + ) + + self.gamma = 0.8 + + nn.init.normal_(self.pose_token, std=1e-6) + + for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)): + self.register_buffer(name, torch.FloatTensor(value).view(1, 3, 1, 1), persistent=False) + + def forward(self, reshaped_image, preliminary_cameras=None, iters=4, batch_size=None, rgb_feat_init=None): + """ + reshaped_image: Bx3xHxW. The values of reshaped_image are within [0, 1] + preliminary_cameras: cameras in opencv coordinate. + """ + + if rgb_feat_init is None: + # Get the 2D image features + rgb_feat, B, S, C = self.get_2D_image_features(reshaped_image, batch_size) + else: + rgb_feat = rgb_feat_init + B, S, C = rgb_feat.shape + + if preliminary_cameras is not None: + # Init the pred_pose_enc by preliminary_cameras + pred_pose_enc = ( + camera_to_pose_encoding(preliminary_cameras, pose_encoding_type=self.pose_encoding_type) + .reshape(B, S, -1) + .to(rgb_feat.dtype) + ) + else: + # Or you can use random init for the poses + pred_pose_enc = torch.zeros(B, S, self.target_dim).to(rgb_feat.device) + + rgb_feat_init = rgb_feat.clone() + + for iter_num in range(iters): + pred_pose_enc = pred_pose_enc.detach() + + # Embed the camera parameters and add to rgb_feat + pose_embed = self.embed_pose(pred_pose_enc) + rgb_feat = rgb_feat + pose_embed + + # Run trunk transformers on rgb_feat + rgb_feat = self.trunk(rgb_feat) + + # Predict the delta feat and pose encoding at each iteration + delta = self.pose_branch(rgb_feat) + delta_pred_pose_enc = delta[..., : self.target_dim] + delta_feat = delta[..., self.target_dim :] + + rgb_feat = self.ffeat_updater(self.norm(delta_feat)) + rgb_feat + + pred_pose_enc = pred_pose_enc + delta_pred_pose_enc + + # Residual connection + rgb_feat = (rgb_feat + rgb_feat_init) / 2 + + # Pose encoding to Cameras + pred_cameras = pose_encoding_to_camera( + pred_pose_enc, pose_encoding_type=self.pose_encoding_type, to_OpenCV=True + ) + pose_predictions = { + "pred_pose_enc": pred_pose_enc, + "pred_cameras": pred_cameras, + "rgb_feat_init": rgb_feat_init, + } + + return pose_predictions + + def get_backbone(self, backbone): + """ + Load the backbone model. + """ + if backbone == "dinov2s": + return torch.hub.load("facebookresearch/dinov2", "dinov2_vits14_reg") + elif backbone == "dinov2b": + return torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg") + else: + raise NotImplementedError(f"Backbone '{backbone}' not implemented") + + def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor: + return (img - self._resnet_mean) / self._resnet_std + + def get_2D_image_features(self, reshaped_image, batch_size): + # Get the 2D image features + if reshaped_image.shape[-1] != self.down_size: + reshaped_image = F.interpolate( + reshaped_image, (self.down_size, self.down_size), mode="bilinear", align_corners=True + ) + + with torch.no_grad(): + reshaped_image = self._resnet_normalize_image(reshaped_image) + rgb_feat = self.backbone(reshaped_image, is_training=True) + # B x P x C + rgb_feat = rgb_feat["x_norm_patchtokens"] + + rgb_feat = self.input_transform(rgb_feat) + rgb_feat = self.norm(rgb_feat) + + rgb_feat = rearrange(rgb_feat, "(b s) p c -> b s p c", b=batch_size) + + B, S, P, C = rgb_feat.shape + patch_num = int(math.sqrt(P)) + + # add embedding of 2D spaces + pos_embed = get_2d_sincos_pos_embed(C, grid_size=(patch_num, patch_num)).permute(0, 2, 3, 1)[None] + pos_embed = pos_embed.reshape(1, 1, patch_num * patch_num, C).to(rgb_feat.device) + + rgb_feat = rgb_feat + pos_embed + + # register for pose + pose_token = self.pose_token.expand(B, S, -1, -1) + rgb_feat = torch.cat([pose_token, rgb_feat], dim=-2) + + B, S, P, C = rgb_feat.shape + + for idx in range(self.att_depth): + # self attention + rgb_feat = rearrange(rgb_feat, "b s p c -> (b s) p c", b=B, s=S) + rgb_feat = self.self_att[idx](rgb_feat) + rgb_feat = rearrange(rgb_feat, "(b s) p c -> b s p c", b=B, s=S) + + feat_0 = rgb_feat[:, 0] + feat_others = rgb_feat[:, 1:] + + # cross attention + feat_others = rearrange(feat_others, "b m p c -> b (m p) c", m=S - 1, p=P) + feat_others = self.cross_att[idx](feat_others, feat_0) + + feat_others = rearrange(feat_others, "b (m p) c -> b m p c", m=S - 1, p=P) + rgb_feat = torch.cat([rgb_feat[:, 0:1], feat_others], dim=1) + + rgb_feat = rgb_feat[:, :, 0] + + return rgb_feat, B, S, C diff --git a/vggsfm/vggsfm/models/modules.py b/vggsfm/vggsfm/models/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..31a0fd82935a6903b9858d09a568a0e481d19e12 --- /dev/null +++ b/vggsfm/vggsfm/models/modules.py @@ -0,0 +1,202 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from typing import Callable +import collections +from torch import Tensor +from itertools import repeat + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +to_2tuple = _ntuple(2) + + +class ResidualBlock(nn.Module): + """ + ResidualBlock: construct a block of two conv layers with residual connections + """ + + def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros" + ) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros") + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + else: + raise NotImplementedError + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class AttnBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, + mlp_ratio=4.0, + **block_kwargs + ): + """ + Self attention block + """ + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, mask=None): + # Prepare the mask for PyTorch's attention (it expects a different format) + # attn_mask = mask if mask is not None else None + # Normalize before attention + x = self.norm1(x) + + # PyTorch's MultiheadAttention returns attn_output, attn_output_weights + # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) + + attn_output, _ = self.attn(x, x, x) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x + + +class CrossAttnBlock(nn.Module): + def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs): + """ + Cross attention block + """ + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm_context = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.cross_attn = nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs + ) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, context, mask=None): + # Normalize inputs + x = self.norm1(x) + context = self.norm_context(context) + + # Apply cross attention + # Note: nn.MultiheadAttention returns attn_output, attn_output_weights + attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x diff --git a/vggsfm/vggsfm/models/track_modules/__init__.py b/vggsfm/vggsfm/models/track_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vggsfm/vggsfm/models/track_modules/base_track_predictor.py b/vggsfm/vggsfm/models/track_modules/base_track_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..650f129886e21259e0ba46f9e2e3f7faefbcd118 --- /dev/null +++ b/vggsfm/vggsfm/models/track_modules/base_track_predictor.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from einops import rearrange, repeat + +from .blocks import EfficientUpdateFormer, CorrBlock, EfficientCorrBlock +from ..utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed + + +class BaseTrackerPredictor(nn.Module): + def __init__( + self, + stride=4, + corr_levels=5, + corr_radius=4, + latent_dim=128, + hidden_size=384, + use_spaceatt=True, + depth=6, + fine=False, + cfg=None, + ): + super(BaseTrackerPredictor, self).__init__() + """ + The base template to create a track predictor + + Modified from https://github.com/facebookresearch/co-tracker/ + """ + + self.cfg = cfg + + self.stride = stride + self.latent_dim = latent_dim + self.corr_levels = corr_levels + self.corr_radius = corr_radius + self.hidden_size = hidden_size + self.fine = fine + + self.flows_emb_dim = latent_dim // 2 + self.transformer_dim = self.corr_levels * (self.corr_radius * 2 + 1) ** 2 + self.latent_dim * 2 + + self.efficient_corr = cfg.MODEL.TRACK.efficient_corr + + if self.fine: + # TODO this is the old dummy code, will remove this when we train next model + self.transformer_dim += 4 if self.transformer_dim % 2 == 0 else 5 + else: + self.transformer_dim += (4 - self.transformer_dim % 4) % 4 + + space_depth = depth if use_spaceatt else 0 + time_depth = depth + + self.updateformer = EfficientUpdateFormer( + space_depth=space_depth, + time_depth=time_depth, + input_dim=self.transformer_dim, + hidden_size=self.hidden_size, + output_dim=self.latent_dim + 2, + mlp_ratio=4.0, + add_space_attn=use_spaceatt, + ) + + self.norm = nn.GroupNorm(1, self.latent_dim) + + # A linear layer to update track feats at each iteration + self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()) + + if not self.fine: + self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + def forward(self, query_points, fmaps=None, iters=4, return_feat=False, down_ratio=1): + """ + query_points: B x N x 2, the number of batches, tracks, and xy + fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. + note HH and WW is the size of feature maps instead of original images + """ + B, N, D = query_points.shape + B, S, C, HH, WW = fmaps.shape + + assert D == 2 + + # Scale the input query_points because we may downsample the images + # by down_ratio or self.stride + # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map + # its query_points should be query_points/4 + if down_ratio > 1: + query_points = query_points / float(down_ratio) + query_points = query_points / float(self.stride) + + # Init with coords as the query points + # It means the search will start from the position of query points at the reference frames + coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1) + + # Sample/extract the features of the query points in the query frame + query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0]) + + # init track feats by query feats + track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C + # back up the init coords + coords_backup = coords.clone() + + # Construct the correlation block + if self.efficient_corr: + fcorr_fn = EfficientCorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius) + else: + fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius) + + coord_preds = [] + + # Iterative Refinement + for itr in range(iters): + # Detach the gradients from the last iteration + # (in my experience, not very important for performance) + coords = coords.detach() + + # Compute the correlation (check the implementation of CorrBlock) + if self.efficient_corr: + fcorrs = fcorr_fn.sample(coords, track_feats) + else: + fcorr_fn.corr(track_feats) + fcorrs = fcorr_fn.sample(coords) # B, S, N, corrdim + + corrdim = fcorrs.shape[3] + + fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corrdim) + + # Movement of current coords relative to query points + flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) + + flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False) + + # (In my trials, it is also okay to just add the flows_emb instead of concat) + flows_emb = torch.cat([flows_emb, flows], dim=-1) + + track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) + + # Concatenate them as the input for the transformers + transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2) + + if transformer_input.shape[2] < self.transformer_dim: + # pad the features to match the dimension + pad_dim = self.transformer_dim - transformer_input.shape[2] + pad = torch.zeros_like(flows_emb[..., 0:pad_dim]) + transformer_input = torch.cat([transformer_input, pad], dim=2) + + # 2D positional embed + # TODO: this can be much simplified + pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device) + sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0]) + sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1) + + x = transformer_input + sampled_pos_emb + + # B, N, S, C + x = rearrange(x, "(b n) s d -> b n s d", b=B) + + # Compute the delta coordinates and delta track features + delta = self.updateformer(x) + # BN, S, C + delta = rearrange(delta, " b n s d -> (b n) s d", b=B) + delta_coords_ = delta[:, :, :2] + delta_feats_ = delta[:, :, 2:] + + track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) + delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) + + # Update the track features + track_feats_ = self.ffeat_updater(self.norm(delta_feats_)) + track_feats_ + track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC + + # B x S x N x 2 + coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) + + # Force coord0 as query + # because we assume the query points should not be changed + coords[:, 0] = coords_backup[:, 0] + + # The predicted tracks are in the original image scale + if down_ratio > 1: + coord_preds.append(coords * self.stride * down_ratio) + else: + coord_preds.append(coords * self.stride) + + # B, S, N + if not self.fine: + vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) + vis_e = torch.sigmoid(vis_e) + else: + vis_e = None + + if return_feat: + return coord_preds, vis_e, track_feats, query_track_feat + else: + return coord_preds, vis_e diff --git a/vggsfm/vggsfm/models/track_modules/blocks.py b/vggsfm/vggsfm/models/track_modules/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..01db94f8d8cf19d2312c6d653e6061c4f24cc398 --- /dev/null +++ b/vggsfm/vggsfm/models/track_modules/blocks.py @@ -0,0 +1,399 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Modified from https://github.com/facebookresearch/co-tracker/ + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from typing import Callable +import collections +from torch import Tensor +from itertools import repeat + +from ..utils import bilinear_sampler + +from ..modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock + + +class BasicEncoder(nn.Module): + def __init__(self, input_dim=3, output_dim=128, stride=4, cfg=None): + super(BasicEncoder, self).__init__() + + self.stride = stride + self.norm_fn = "instance" + self.in_planes = output_dim // 2 + + self.norm1 = nn.InstanceNorm2d(self.in_planes) + self.norm2 = nn.InstanceNorm2d(output_dim * 2) + + self.conv1 = nn.Conv2d(input_dim, self.in_planes, kernel_size=7, stride=2, padding=3, padding_mode="zeros") + self.relu1 = nn.ReLU(inplace=True) + self.layer1 = self._make_layer(output_dim // 2, stride=1) + self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2) + self.layer3 = self._make_layer(output_dim, stride=2) + self.layer4 = self._make_layer(output_dim, stride=2) + + self.conv2 = nn.Conv2d( + output_dim * 3 + output_dim // 4, output_dim * 2, kernel_size=3, padding=1, padding_mode="zeros" + ) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.InstanceNorm2d)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + _, _, H, W = x.shape + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + a = self.layer1(x) + b = self.layer2(a) + c = self.layer3(b) + d = self.layer4(c) + + a = _bilinear_intepolate(a, self.stride, H, W) + b = _bilinear_intepolate(b, self.stride, H, W) + c = _bilinear_intepolate(c, self.stride, H, W) + d = _bilinear_intepolate(d, self.stride, H, W) + + x = self.conv2(torch.cat([a, b, c, d], dim=1)) + x = self.norm2(x) + x = self.relu2(x) + x = self.conv3(x) + return x + + +class ShallowEncoder(nn.Module): + def __init__(self, input_dim=3, output_dim=32, stride=1, norm_fn="instance", cfg=None): + super(ShallowEncoder, self).__init__() + self.stride = stride + self.norm_fn = norm_fn + self.in_planes = output_dim + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes) + self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2) + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(self.in_planes) + self.norm2 = nn.BatchNorm2d(output_dim * 2) + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(self.in_planes) + self.norm2 = nn.InstanceNorm2d(output_dim * 2) + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(input_dim, self.in_planes, kernel_size=3, stride=2, padding=1, padding_mode="zeros") + self.relu1 = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(output_dim, stride=2) + + self.layer2 = self._make_layer(output_dim, stride=2) + self.conv2 = nn.Conv2d(output_dim, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + self.in_planes = dim + + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + return layer1 + + def forward(self, x): + _, _, H, W = x.shape + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + tmp = self.layer1(x) + x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True) + tmp = self.layer2(tmp) + x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True) + tmp = None + x = self.conv2(x) + x + + x = F.interpolate(x, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True) + + return x + + +def _bilinear_intepolate(x, stride, H, W): + return F.interpolate(x, (H // stride, W // stride), mode="bilinear", align_corners=True) + + +class EfficientUpdateFormer(nn.Module): + """ + Transformer model that updates track estimates. + """ + + def __init__( + self, + space_depth=6, + time_depth=6, + input_dim=320, + hidden_size=384, + num_heads=8, + output_dim=130, + mlp_ratio=4.0, + add_space_attn=True, + num_virtual_tracks=64, + ): + super().__init__() + + self.out_channels = 2 + self.num_heads = num_heads + self.hidden_size = hidden_size + self.add_space_attn = add_space_attn + self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) + self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) + self.num_virtual_tracks = num_virtual_tracks + + if self.add_space_attn: + self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size)) + else: + self.virual_tracks = None + + self.time_blocks = nn.ModuleList( + [ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) + for _ in range(time_depth) + ] + ) + + if add_space_attn: + self.space_virtual_blocks = nn.ModuleList( + [ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) + for _ in range(space_depth) + ] + ) + self.space_point2virtual_blocks = nn.ModuleList( + [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] + ) + self.space_virtual2point_blocks = nn.ModuleList( + [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] + ) + assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + def forward(self, input_tensor, mask=None): + tokens = self.input_transform(input_tensor) + + init_tokens = tokens + + B, _, T, _ = tokens.shape + + if self.add_space_attn: + virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) + tokens = torch.cat([tokens, virtual_tokens], dim=1) + + _, N, _, _ = tokens.shape + + j = 0 + for i in range(len(self.time_blocks)): + time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C + time_tokens = self.time_blocks[i](time_tokens) + + tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C + if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0): + space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C + point_tokens = space_tokens[:, : N - self.num_virtual_tracks] + virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] + + virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask) + virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) + point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask) + space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) + tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C + j += 1 + + if self.add_space_attn: + tokens = tokens[:, : N - self.num_virtual_tracks] + + tokens = tokens + init_tokens + + flow = self.flow_head(tokens) + return flow + + +class CorrBlock: + def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"): + B, S, C, H, W = fmaps.shape + self.S, self.C, self.H, self.W = S, C, H, W + self.padding_mode = padding_mode + self.num_levels = num_levels + self.radius = radius + self.fmaps_pyramid = [] + self.multiple_track_feats = multiple_track_feats + + self.fmaps_pyramid.append(fmaps) + for i in range(self.num_levels - 1): + fmaps_ = fmaps.reshape(B * S, C, H, W) + fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2) + _, _, H, W = fmaps_.shape + fmaps = fmaps_.reshape(B, S, C, H, W) + self.fmaps_pyramid.append(fmaps) + + def sample(self, coords): + r = self.radius + B, S, N, D = coords.shape + assert D == 2 + + H, W = self.H, self.W + out_pyramid = [] + for i in range(self.num_levels): + corrs = self.corrs_pyramid[i] # B, S, N, H, W + *_, H, W = corrs.shape + + dx = torch.linspace(-r, r, 2 * r + 1) + dy = torch.linspace(-r, r, 2 * r + 1) + delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device) + + centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode) + corrs = corrs.view(B, S, N, -1) + + out_pyramid.append(corrs) + + out = torch.cat(out_pyramid, dim=-1).contiguous() # B, S, N, LRR*2 + return out + + def corr(self, targets): + B, S, N, C = targets.shape + if self.multiple_track_feats: + targets_split = targets.split(C // self.num_levels, dim=-1) + B, S, N, C = targets_split[0].shape + + assert C == self.C + assert S == self.S + + fmap1 = targets + + self.corrs_pyramid = [] + for i, fmaps in enumerate(self.fmaps_pyramid): + *_, H, W = fmaps.shape + fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W) + if self.multiple_track_feats: + fmap1 = targets_split[i] + corrs = torch.matmul(fmap1, fmap2s) + corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W + corrs = corrs / torch.sqrt(torch.tensor(C).float()) + self.corrs_pyramid.append(corrs) + + +class EfficientCorrBlock: + def __init__(self, fmaps, num_levels=4, radius=4): + B, S, C, H, W = fmaps.shape + self.num_levels = num_levels + self.radius = radius + self.fmaps_pyramid = [] + self.fmaps_pyramid.append(fmaps) + for i in range(self.num_levels - 1): + fmaps_ = fmaps.reshape(B * S, C, H, W) + fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2) + _, _, H, W = fmaps_.shape + fmaps = fmaps_.reshape(B, S, C, H, W) + self.fmaps_pyramid.append(fmaps) + + def sample(self, coords, target): + r = self.radius + device = coords.device + B, S, N, D = coords.shape + assert D == 2 + target = target.permute(0, 1, 3, 2).unsqueeze(-1) + + out_pyramid = [] + + for i in range(self.num_levels): + pyramid = self.fmaps_pyramid[i] + C, H, W = pyramid.shape[2:] + centroid_lvl = ( + torch.cat([torch.zeros_like(coords[..., :1], device=device), coords], dim=-1).reshape(B * S, N, 1, 1, 3) + / 2**i + ) + + dx = torch.linspace(-r, r, 2 * r + 1, device=device) + dy = torch.linspace(-r, r, 2 * r + 1, device=device) + xgrid, ygrid = torch.meshgrid(dy, dx, indexing="ij") + zgrid = torch.zeros_like(xgrid, device=device) + delta = torch.stack([zgrid, xgrid, ygrid], axis=-1) + delta_lvl = delta.view(1, 1, 2 * r + 1, 2 * r + 1, 3) + coords_lvl = centroid_lvl + delta_lvl + pyramid_sample = bilinear_sampler(pyramid.reshape(B * S, C, 1, H, W), coords_lvl) + + corr = torch.sum(target * pyramid_sample.reshape(B, S, C, N, -1), dim=2) + corr = corr / torch.sqrt(torch.tensor(C).float()) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2 + return out + + +# # Inside the tracker forward funciton: +# if self.efficient_corr: +# corr_block = EfficientCorrBlock( +# fmaps, +# num_levels=4, +# radius=3, +# padding_mode="border", +# ) +# else: +# corr_block = CorrBlock( +# fmaps, +# num_levels=4, +# radius=3, +# padding_mode="border", +# ) +# if self.efficient_corr: +# fcorrs = corr_block.sample(coords, track_feat) +# else: +# corr_block.corr(track_feat) +# fcorrs = corr_block.sample(coords) diff --git a/vggsfm/vggsfm/models/track_modules/refine_track.py b/vggsfm/vggsfm/models/track_modules/refine_track.py new file mode 100644 index 0000000000000000000000000000000000000000..3936486b1effb316fbfe3bad914f2e9ce9f6428a --- /dev/null +++ b/vggsfm/vggsfm/models/track_modules/refine_track.py @@ -0,0 +1,273 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from torch import nn, einsum +from einops import rearrange, repeat +from einops.layers.torch import Rearrange, Reduce + +from PIL import Image +import os +from typing import Union, Tuple +from kornia.utils.grid import create_meshgrid +from kornia.geometry.subpix import dsnt + + +def refine_track( + images, fine_fnet, fine_tracker, coarse_pred, compute_score=False, pradius=15, sradius=2, fine_iters=6, cfg=None +): + """ + Refines the tracking of images using a fine track predictor and a fine feature network. + Check https://arxiv.org/abs/2312.04563 for more details. + + Args: + images (torch.Tensor): The images to be tracked. + fine_fnet (nn.Module): The fine feature network. + fine_tracker (nn.Module): The fine track predictor. + coarse_pred (torch.Tensor): The coarse predictions of tracks. + compute_score (bool, optional): Whether to compute the score. Defaults to False. + pradius (int, optional): The radius of a patch. Defaults to 15. + sradius (int, optional): The search radius. Defaults to 2. + cfg (dict, optional): The configuration dictionary. Defaults to None. + + Returns: + torch.Tensor: The refined tracks. + torch.Tensor, optional: The score. + """ + + # coarse_pred shape: BxSxNx2, + # where B is the batch, S is the video/images length, and N is the number of tracks + # now we are going to extract patches with the center at coarse_pred + # Please note that the last dimension indicates x and y, and hence has a dim number of 2 + B, S, N, _ = coarse_pred.shape + _, _, _, H, W = images.shape + + # Given the raidus of a patch, compute the patch size + psize = pradius * 2 + 1 + + # Note that we assume the first frame is the query frame + # so the 2D locations of the first frame are the query points + query_points = coarse_pred[:, 0] + + # Given 2D positions, we can use grid_sample to extract patches + # but it takes too much memory. + # Instead, we use the floored track xy to sample patches. + + # For example, if the query point xy is (128.16, 252.78), + # and the patch size is (31, 31), + # our goal is to extract the content of a rectangle + # with left top: (113.16, 237.78) + # and right bottom: (143.16, 267.78). + # However, we record the floored left top: (113, 237) + # and the offset (0.16, 0.78) + # Then what we need is just unfolding the images like in CNN, + # picking the content at [(113, 237), (143, 267)]. + # Such operations are highly optimized at pytorch + # (well if you really want to use interpolation, check the function extract_glimpse() below) + + with torch.no_grad(): + content_to_extract = images.reshape(B * S, 3, H, W) + C_in = content_to_extract.shape[1] + + # Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html + # for the detailed explanation of unfold() + # Here it runs sliding windows (psize x psize) to build patches + # The shape changes from + # (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize + # where Psize is the size of patch + content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1) + + # Floor the coarse predictions to get integers and save the fractional/decimal + track_int = coarse_pred.floor().int() + track_frac = coarse_pred - track_int + + # Note the points represent the center of patches + # now we get the location of the top left corner of patches + # because the ouput of pytorch unfold are indexed by top left corner + topleft = track_int - pradius + topleft_BSN = topleft.clone() + + # clamp the values so that we will not go out of indexes + # NOTE: (VERY IMPORTANT: This operation ASSUMES H=W). + # You need to seperately clamp x and y if H!=W + topleft = topleft.clamp(0, H - psize) + + # Reshape from BxSxNx2 -> (B*S)xNx2 + topleft = topleft.reshape(B * S, N, 2) + + # Prepare batches for indexing, shape: (B*S)xN + batch_indices = torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device) + + # Extract image patches based on top left corners + # extracted_patches: (B*S) x N x C_in x Psize x Psize + extracted_patches = content_to_extract[batch_indices, :, topleft[..., 1], topleft[..., 0]] + + # Feed patches to fine fent for features + patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize)) + + C_out = patch_feat.shape[1] + + # Refine the coarse tracks by fine_tracker + + # reshape back to B x S x N x C_out x Psize x Psize + patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize) + patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q") + + # Prepare for the query points for fine tracker + # They are relative to the patch left top corner, + # instead of the image top left corner now + # patch_query_points: N x 1 x 2 + # only 1 here because for each patch we only have 1 query point + patch_query_points = track_frac[:, 0] + pradius + patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1) + + # Feed the PATCH query points and tracks into fine tracker + fine_pred_track_lists, _, _, query_point_feat = fine_tracker( + query_points=patch_query_points, fmaps=patch_feat, iters=fine_iters, return_feat=True + ) + + # relative the patch top left + fine_pred_track = fine_pred_track_lists[-1].clone() + + # From (relative to the patch top left) to (relative to the image top left) + for idx in range(len(fine_pred_track_lists)): + fine_level = rearrange(fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N) + fine_level = fine_level.squeeze(-2) + fine_level = fine_level + topleft_BSN + fine_pred_track_lists[idx] = fine_level + + # relative to the image top left + refined_tracks = fine_pred_track_lists[-1].clone() + refined_tracks[:, 0] = query_points + + score = None + + if compute_score: + score = compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out) + + return refined_tracks, score + + +def compute_score_fn(query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out): + """ + Compute the scores, i.e., the standard deviation of the 2D similarity heatmaps, + given the query point features and reference frame feature maps + """ + + # query_point_feat initial shape: B x N x C_out, + # query_point_feat indicates the feat at the coorponsing query points + # Therefore we don't have S dimension here + query_point_feat = query_point_feat.reshape(B, N, C_out) + # reshape and expand to B x (S-1) x N x C_out + query_point_feat = query_point_feat.unsqueeze(1).expand(-1, S - 1, -1, -1) + # and reshape to (B*(S-1)*N) x C_out + query_point_feat = query_point_feat.reshape(B * (S - 1) * N, C_out) + + # Radius and size for computing the score + ssize = sradius * 2 + 1 + + # Reshape, you know it, so many reshaping operations + patch_feat = rearrange(patch_feat, "(b n) s c p q -> b s n c p q", b=B, n=N) + + # Again, we unfold the patches to smaller patches + # so that we can then focus on smaller patches + # patch_feat_unfold shape: + # B x S x N x C_out x (psize - 2*sradius) x (psize - 2*sradius) x ssize x ssize + # well a bit scary, but actually not + patch_feat_unfold = patch_feat.unfold(4, ssize, 1).unfold(5, ssize, 1) + + # Do the same stuffs above, i.e., the same as extracting patches + fine_prediction_floor = fine_pred_track.floor().int() + fine_level_floor_topleft = fine_prediction_floor - sradius + + # Clamp to ensure the smaller patch is valid + fine_level_floor_topleft = fine_level_floor_topleft.clamp(0, psize - ssize) + fine_level_floor_topleft = fine_level_floor_topleft.squeeze(2) + + # Prepare the batch indices and xy locations + + batch_indices_score = torch.arange(B)[:, None, None].expand(-1, S, N) # BxSxN + batch_indices_score = batch_indices_score.reshape(-1).to(patch_feat_unfold.device) # B*S*N + y_indices = fine_level_floor_topleft[..., 0].flatten() # Flatten H indices + x_indices = fine_level_floor_topleft[..., 1].flatten() # Flatten W indices + + reference_frame_feat = patch_feat_unfold.reshape( + B * S * N, C_out, psize - sradius * 2, psize - sradius * 2, ssize, ssize + ) + + # Note again, according to pytorch convention + # x_indices cooresponds to [..., 1] and y_indices cooresponds to [..., 0] + reference_frame_feat = reference_frame_feat[batch_indices_score, :, x_indices, y_indices] + reference_frame_feat = reference_frame_feat.reshape(B, S, N, C_out, ssize, ssize) + # pick the frames other than the first one, so we have S-1 frames here + reference_frame_feat = reference_frame_feat[:, 1:].reshape(B * (S - 1) * N, C_out, ssize * ssize) + + # Compute similarity + sim_matrix = torch.einsum("mc,mcr->mr", query_point_feat, reference_frame_feat) + softmax_temp = 1.0 / C_out**0.5 + heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1) + # 2D heatmaps + heatmap = heatmap.reshape(B * (S - 1) * N, ssize, ssize) # * x ssize x ssize + + coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] + grid_normalized = create_meshgrid(ssize, ssize, normalized_coordinates=True, device=heatmap.device).reshape( + 1, -1, 2 + ) + + var = torch.sum(grid_normalized**2 * heatmap.view(-1, ssize * ssize, 1), dim=1) - coords_normalized**2 + std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # clamp needed for numerical stability + + score = std.reshape(B, S - 1, N) + # set score as 1 for the query frame + score = torch.cat([torch.ones_like(score[:, 0:1]), score], dim=1) + + return score + + +def extract_glimpse( + tensor: torch.Tensor, size: Tuple[int, int], offsets, mode="bilinear", padding_mode="zeros", debug=False, orib=None +): + B, C, W, H = tensor.shape + + h, w = size + xs = torch.arange(0, w, dtype=tensor.dtype, device=tensor.device) - (w - 1) / 2.0 + ys = torch.arange(0, h, dtype=tensor.dtype, device=tensor.device) - (h - 1) / 2.0 + + vy, vx = torch.meshgrid(ys, xs) + grid = torch.stack([vx, vy], dim=-1) # h, w, 2 + grid = grid[None] + + B, N, _ = offsets.shape + + offsets = offsets.reshape((B * N), 1, 1, 2) + offsets_grid = offsets + grid + + # normalised grid to [-1, 1] + offsets_grid = (offsets_grid - offsets_grid.new_tensor([W / 2, H / 2])) / offsets_grid.new_tensor([W / 2, H / 2]) + + # BxCxHxW -> Bx1xCxHxW + tensor = tensor[:, None] + + # Bx1xCxHxW -> BxNxCxHxW + tensor = tensor.expand(-1, N, -1, -1, -1) + + # BxNxCxHxW -> (B*N)xCxHxW + tensor = tensor.reshape((B * N), C, W, H) + + sampled = torch.nn.functional.grid_sample( + tensor, offsets_grid, mode=mode, align_corners=False, padding_mode=padding_mode + ) + + # NOTE: I am not sure it should be h, w or w, h here + # but okay for sqaures + sampled = sampled.reshape(B, N, C, h, w) + + return sampled diff --git a/vggsfm/vggsfm/models/track_predictor.py b/vggsfm/vggsfm/models/track_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..0583230baea79f71ec1b532822dcecfcb5002c86 --- /dev/null +++ b/vggsfm/vggsfm/models/track_predictor.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from torch import nn, einsum +from einops import rearrange, repeat +from einops.layers.torch import Rearrange, Reduce + +from hydra.utils import instantiate +from .track_modules.refine_track import refine_track + + +class TrackerPredictor(nn.Module): + def __init__(self, COARSE, FINE, stride=4, corr_levels=5, corr_radius=4, latent_dim=128, cfg=None, **extra_args): + super(TrackerPredictor, self).__init__() + """ + COARSE and FINE are the dicts to construct the modules + + Both coarse_predictor and fine_predictor are constructed as a BaseTrackerPredictor, + check track_modules/base_track_predictor.py + + Both coarse_fnet and fine_fnet are constructed as a 2D CNN network + check track_modules/blocks.py for BasicEncoder and ShallowEncoder + """ + + self.cfg = cfg + + # coarse predictor + self.coarse_down_ratio = COARSE.down_ratio + self.coarse_fnet = instantiate(COARSE.FEATURENET, _recursive_=False, stride=COARSE.stride, cfg=cfg) + self.coarse_predictor = instantiate(COARSE.PREDICTOR, _recursive_=False, stride=COARSE.stride, cfg=cfg) + + # fine predictor, forced to use stride = 1 + self.fine_fnet = instantiate(FINE.FEATURENET, _recursive_=False, stride=1, cfg=cfg) + self.fine_predictor = instantiate(FINE.PREDICTOR, _recursive_=False, stride=1, cfg=cfg) + + def forward(self, images, query_points, fmaps=None, coarse_iters=6): + """ + images: images as rgb, in the range of [0, 1], with a shape of B x S x 3 x H x W + query_points: 2D xy of query points, relative to top left, with a shape of B x N x 2 + + """ + + if fmaps is None: + fmaps = self.process_images_to_fmaps(images) + + # coarse prediction + coarse_pred_track_lists, pred_vis = self.coarse_predictor( + query_points=query_points, fmaps=fmaps, iters=coarse_iters, down_ratio=self.coarse_down_ratio + ) + coarse_pred_track = coarse_pred_track_lists[-1] + + # refine the coarse prediction + fine_pred_track, pred_score = refine_track( + images, self.fine_fnet, self.fine_predictor, coarse_pred_track, compute_score=True, cfg=self.cfg + ) + + return fine_pred_track, coarse_pred_track, pred_vis, pred_score + + def process_images_to_fmaps(self, images): + """ + This function processes images for inference. + + Args: + images (np.array): The images to be processed. + + Returns: + np.array: The processed images. + """ + batch_num, frame_num, image_dim, height, width = images.shape + assert batch_num == 1, "now we only support processing one scene during inference" + reshaped_image = images.reshape(batch_num * frame_num, image_dim, height, width) + if self.coarse_down_ratio > 1: + # whether or not scale down the input images to save memory + fmaps = self.coarse_fnet( + F.interpolate( + reshaped_image, scale_factor=1 / self.coarse_down_ratio, mode="bilinear", align_corners=True + ) + ) + else: + fmaps = self.coarse_fnet(reshaped_image) + fmaps = fmaps.reshape(batch_num, frame_num, -1, fmaps.shape[-2], fmaps.shape[-1]) + + return fmaps diff --git a/vggsfm/vggsfm/models/triangulator.py b/vggsfm/vggsfm/models/triangulator.py new file mode 100644 index 0000000000000000000000000000000000000000..a46148fa365dd02a3f9da990f8294265e974f94c --- /dev/null +++ b/vggsfm/vggsfm/models/triangulator.py @@ -0,0 +1,366 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from hydra.utils import instantiate + +import logging +from collections import defaultdict +from dataclasses import field, dataclass +from typing import Any, Dict, List, Optional, Tuple, Union, Callable +from einops import rearrange, repeat +import copy +import kornia +import pycolmap +from torch.cuda.amp import autocast + +# from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras +from minipytorch3d.cameras import CamerasBase, PerspectiveCameras + + +# ##################### +# from ..two_view_geo.utils import inlier_by_fundamental +from ..utils.triangulation import ( + create_intri_matrix, + triangulate_by_pair, + init_BA, + init_refine_pose, + refine_pose, + triangulate_tracks, + global_BA, + iterative_global_BA, +) +from ..utils.triangulation_helpers import filter_all_points3D + +from .utils import get_EFP + + +class Triangulator(nn.Module): + def __init__(self, cfg=None): + super().__init__() + """ + The module for triangulation and BA adjustment + + NOTE After VGGSfM v1.1, we remove the learnable parameters of Triangulator + """ + self.cfg = cfg + + def forward( + self, + pred_cameras, + pred_tracks, + pred_vis, + images, + preliminary_dict, + pred_score=None, + fmat_thres=0.5, + init_max_reproj_error=0.5, + BA_iters=2, + max_reproj_error=4, + init_tri_angle_thres=16, + min_valid_track_length=3, + image_paths=None, + crop_params=None, + cfg=None, + ): + """ + Conduct triangulation and bundle adjustment given + the input pred_cameras, pred_tracks, pred_vis, and pred_score + + We use the pred_cameras from camera_predictor but it can be any init. + Please note pred_tracks are defined in pixels. + """ + # for safety + torch.cuda.empty_cache() + + device = pred_tracks.device + + with autocast(dtype=torch.float32): + B, S, _, H, W = images.shape + _, _, N, _ = pred_tracks.shape + + assert B == 1 # The released implementation now only supports batch=1 during inference + + image_size = torch.tensor([W, H], dtype=pred_tracks.dtype, device=device) + # extrinsics: B x S x 3 x 4 + # intrinsics: B x S x 3 x 3 + # focal_length, principal_point : B x S x 2 + + extrinsics, intrinsics, _, _ = get_EFP(pred_cameras, image_size, B, S) + + extrinsics = extrinsics.double() + inlier_fmat = preliminary_dict["fmat_inlier_mask"] + + # Remove B dim + # To simplify the code, now we only support B==1 during inference + extrinsics = extrinsics[0] + intrinsics = intrinsics[0] + pred_tracks = pred_tracks[0] + pred_vis = pred_vis[0] + pred_score = pred_score[0] + inlier_fmat = inlier_fmat[0] + + tracks_normalized = normalize_tracks(pred_tracks, intrinsics) + # Visibility inlier + inlier_vis = pred_vis > 0.05 # TODO: avoid hardcoded + inlier_vis = inlier_vis[1:] + + # Intersection of inlier_fmat and inlier_vis + inlier_geo_vis = torch.logical_and(inlier_fmat, inlier_vis) + + # For initialization + # we first triangulate a point cloud for each pair of query-reference images, + # i.e., we have S-1 3D point clouds + # points_3d_pair: S-1 x N x 3 + points_3d_pair, cheirality_mask_pair, triangle_value_pair = triangulate_by_pair( + extrinsics[None], tracks_normalized[None] + ) + + # Check which point cloud can provide sufficient inliers + # that pass the triangulation angle and cheirality check + # Pick the highest inlier_geo_vis one as the initial point cloud + trial_count = 0 + while trial_count < 5: + # If no success, relax the constraint + # try at most 5 times + triangle_mask = triangle_value_pair >= init_tri_angle_thres + inlier_total = torch.logical_and(inlier_geo_vis, cheirality_mask_pair) + inlier_total = torch.logical_and(inlier_total, triangle_mask) + inlier_num_per_frame = inlier_total.sum(dim=-1) + + max_num_inlier = inlier_num_per_frame.max() + max_num_inlier_ratio = max_num_inlier / N + + # We accept a pair only when the numer of inliers and the ratio + # is higher than a thres + if (max_num_inlier >= 100) and (max_num_inlier_ratio >= 0.25): + break + + if init_tri_angle_thres < 2: + break + + init_tri_angle_thres = init_tri_angle_thres // 2 + trial_count += 1 + + # Conduct BA on the init point cloud and init pair + points3D_init, extrinsics, intrinsics, track_init_mask, reconstruction, init_idx = init_BA( + extrinsics, + intrinsics, + pred_tracks, + points_3d_pair, + inlier_total, + image_size, + init_max_reproj_error=init_max_reproj_error, + ) + + # Given we have a well-conditioned point cloud, + # we can optimize all the cameras by absolute pose refinement as in + # https://github.com/colmap/colmap/blob/4ced4a5bc72fca93a2ffaea2c7e193bc62537416/src/colmap/estimators/pose.cc#L207 + # Basically it is a bundle adjustment without optmizing 3D points + # It is fine even this step fails + + extrinsics, intrinsics, valid_intri_mask = init_refine_pose( + extrinsics, + intrinsics, + inlier_geo_vis, + points3D_init, + pred_tracks, + track_init_mask, + image_size, + init_idx, + ) + + points3D, extrinsics, intrinsics, valid_tracks, reconstruction = self.triangulate_tracks_and_BA( + pred_tracks, intrinsics, extrinsics, pred_vis, pred_score, image_size, device, min_valid_track_length,max_reproj_error + ) + + if cfg.robust_refine > 0: + for refine_idx in range(cfg.robust_refine): + # Helpful for some turnable videos + inlier_vis_all = pred_vis > 0.05 + + force_estimate = refine_idx == (cfg.robust_refine - 1) + + extrinsics, intrinsics, valid_intri_mask = refine_pose( + extrinsics, + intrinsics, + inlier_vis_all, + points3D, + pred_tracks, + valid_tracks, + image_size, + force_estimate=force_estimate, + ) + + points3D, extrinsics, intrinsics, valid_tracks, reconstruction = self.triangulate_tracks_and_BA( + pred_tracks, + intrinsics, + extrinsics, + pred_vis, + pred_score, + image_size, + device, + min_valid_track_length, + max_reproj_error, + ) + + # try: + ba_options = pycolmap.BundleAdjustmentOptions() + ba_options.print_summary = False + + print(f"Running iterative BA by {BA_iters} times") + for BA_iter in range(BA_iters): + if BA_iter == (BA_iters - 1): + ba_options.print_summary = True + lastBA = True + else: + lastBA = False + + try: + ( + points3D, + extrinsics, + intrinsics, + valid_tracks, + BA_inlier_masks, + reconstruction, + ) = iterative_global_BA( + pred_tracks, + intrinsics, + extrinsics, + pred_vis, + pred_score, + valid_tracks, + points3D, + image_size, + lastBA=lastBA, + min_valid_track_length=min_valid_track_length, + max_reproj_error=max_reproj_error, + ba_options=ba_options, + ) + max_reproj_error = max_reproj_error // 2 + if max_reproj_error <= 1: + max_reproj_error = 1 + except: + print(f"Oh BA fails at iter {BA_iter}! Careful") + + rot_BA = extrinsics[:, :3, :3] + trans_BA = extrinsics[:, :3, 3] + + # find the invalid predictions + scale = image_size.max() + valid_intri_mask = torch.logical_and(intrinsics[:, 0, 0] >= 0.1 * scale, intrinsics[:, 0, 0] <= 30 * scale) + valid_trans_mask = (trans_BA.abs() <= 30).all(-1) + valid_frame_mask = torch.logical_and(valid_intri_mask, valid_trans_mask) + + for pyimageid in reconstruction.images: + # scale from resized image size to the real size + # rename the images to the original names + pyimage = reconstruction.images[pyimageid] + pycamera = reconstruction.cameras[pyimage.camera_id] + + pyimage.name = image_paths[pyimageid] + + pred_params = copy.deepcopy(pycamera.params) + real_image_size = crop_params[0, pyimageid][:2] + real_focal = real_image_size.max() / cfg.img_size * pred_params[0] + + real_pp = real_image_size.cpu().numpy() // 2 + + pred_params[0] = real_focal + pred_params[1:3] = real_pp + pycamera.params = pred_params + pycamera.width = real_image_size[0] + pycamera.height = real_image_size[1] + + if cfg.extract_color: + from vggsfm.models.utils import sample_features4d + + pred_track_rgb = sample_features4d(images.squeeze(0), pred_tracks) + valid_track_rgb = pred_track_rgb[:, valid_tracks] + + sum_rgb = (BA_inlier_masks.float()[..., None] * valid_track_rgb).sum(dim=0) + points3D_rgb = sum_rgb / BA_inlier_masks.sum(dim=0)[:, None] + else: + points3D_rgb = None + + # From OpenCV/COLMAP to PyTorch3D + rot_BA = rot_BA.clone().permute(0, 2, 1) + trans_BA = trans_BA.clone() + trans_BA[:, :2] *= -1 + rot_BA[:, :, :2] *= -1 + BA_cameras_PT3D = PerspectiveCameras(R=rot_BA, T=trans_BA, device=device) + + if cfg.filter_invalid_frame: + BA_cameras_PT3D = BA_cameras_PT3D[valid_frame_mask] + extrinsics = extrinsics[valid_frame_mask] + intrinsics = intrinsics[valid_frame_mask] + invalid_ids = torch.nonzero(~valid_frame_mask).squeeze(1) + invalid_ids = invalid_ids.cpu().numpy().tolist() + for invalid_id in invalid_ids: + reconstruction.deregister_image(invalid_id) + + return BA_cameras_PT3D, extrinsics, intrinsics, points3D, points3D_rgb, reconstruction, valid_frame_mask + + def triangulate_tracks_and_BA( + self, pred_tracks, intrinsics, extrinsics, pred_vis, pred_score, image_size, device, min_valid_track_length, max_reproj_error=4 + ): + """ """ + # Normalize the tracks + tracks_normalized_refined = normalize_tracks(pred_tracks, intrinsics) + + # Conduct triangulation to all the frames + # We adopt LORANSAC here again + + best_triangulated_points, best_inlier_num, best_inlier_mask = triangulate_tracks( + extrinsics, tracks_normalized_refined, track_vis=pred_vis, track_score=pred_score, max_ransac_iters=128 + ) + # Determine valid tracks based on inlier numbers + valid_tracks = best_inlier_num >= min_valid_track_length + # Perform global bundle adjustment + points3D, extrinsics, intrinsics, reconstruction = global_BA( + best_triangulated_points, + valid_tracks, + pred_tracks, + best_inlier_mask, + extrinsics, + intrinsics, + image_size, + device, + ) + + valid_poins3D_mask = filter_all_points3D( + points3D, pred_tracks[:, valid_tracks], extrinsics, intrinsics, check_triangle=False, max_reproj_error=max_reproj_error + ) + points3D = points3D[valid_poins3D_mask] + + valid_tracks_tmp = valid_tracks.clone() + valid_tracks_tmp[valid_tracks] = valid_poins3D_mask + valid_tracks = valid_tracks_tmp.clone() + + return points3D, extrinsics, intrinsics, valid_tracks, reconstruction + + +def normalize_tracks(pred_tracks, intrinsics): + """ + Normalize predicted tracks based on camera intrinsics. + Args: + intrinsics (torch.Tensor): The camera intrinsics tensor of shape [batch_size, 3, 3]. + pred_tracks (torch.Tensor): The predicted tracks tensor of shape [batch_size, num_tracks, 2]. + Returns: + torch.Tensor: Normalized tracks tensor. + """ + principal_point = intrinsics[:, [0, 1], [2, 2]].unsqueeze(-2) + focal_length = intrinsics[:, [0, 1], [0, 1]].unsqueeze(-2) + tracks_normalized = (pred_tracks - principal_point) / focal_length + return tracks_normalized diff --git a/vggsfm/vggsfm/models/utils.py b/vggsfm/vggsfm/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..021e7521b8d138de73f58c13daf620cba19e8bd0 --- /dev/null +++ b/vggsfm/vggsfm/models/utils.py @@ -0,0 +1,384 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from https://github.com/facebookresearch/PoseDiffusion +# and https://github.com/facebookresearch/co-tracker/tree/main + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Union +from einops import rearrange, repeat + + +from minipytorch3d.harmonic_embedding import HarmonicEmbedding + +from minipytorch3d.cameras import CamerasBase, PerspectiveCameras +from minipytorch3d.rotation_conversions import matrix_to_quaternion, quaternion_to_matrix + + +# from pytorch3d.renderer import HarmonicEmbedding +# from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras +# from pytorch3d.transforms.rotation_conversions import matrix_to_quaternion, quaternion_to_matrix + +from ..utils.metric import closed_form_inverse_OpenCV +from ..utils.triangulation import create_intri_matrix + +EPS = 1e-9 + + +def get_EFP(pred_cameras, image_size, B, S, default_focal=False): + """ + Converting PyTorch3D cameras to extrinsics, intrinsics matrix + + Return extrinsics, intrinsics, focal_length, principal_point + """ + scale = image_size.min() + + focal_length = pred_cameras.focal_length + + principal_point = torch.zeros_like(focal_length) + + focal_length = focal_length * scale / 2 + principal_point = (image_size[None] - principal_point * scale) / 2 + + Rots = pred_cameras.R.clone() + Trans = pred_cameras.T.clone() + + extrinsics = torch.cat([Rots, Trans[..., None]], dim=-1) + + # reshape + extrinsics = extrinsics.reshape(B, S, 3, 4) + focal_length = focal_length.reshape(B, S, 2) + principal_point = principal_point.reshape(B, S, 2) + + # only one dof focal length + if default_focal: + focal_length[:] = scale + else: + focal_length = focal_length.mean(dim=-1, keepdim=True).expand(-1, -1, 2) + focal_length = focal_length.clamp(0.2 * scale, 5 * scale) + + intrinsics = create_intri_matrix(focal_length, principal_point) + return extrinsics, intrinsics, focal_length, principal_point + + +def pose_encoding_to_camera( + pose_encoding, + pose_encoding_type="absT_quaR_logFL", + log_focal_length_bias=1.8, + min_focal_length=0.1, + max_focal_length=30, + return_dict=False, + to_OpenCV=True, +): + """ + Args: + pose_encoding: A tensor of shape `BxNxC`, containing a batch of + `BxN` `C`-dimensional pose encodings. + pose_encoding_type: The type of pose encoding, + """ + pose_encoding_reshaped = pose_encoding.reshape(-1, pose_encoding.shape[-1]) # Reshape to BNxC + + if pose_encoding_type == "absT_quaR_logFL": + # 3 for absT, 4 for quaR, 2 for absFL + abs_T = pose_encoding_reshaped[:, :3] + quaternion_R = pose_encoding_reshaped[:, 3:7] + R = quaternion_to_matrix(quaternion_R) + log_focal_length = pose_encoding_reshaped[:, 7:9] + # log_focal_length_bias was the hyperparameter + # to ensure the mean of logFL close to 0 during training + # Now converted back + focal_length = (log_focal_length + log_focal_length_bias).exp() + # clamp to avoid weird fl values + focal_length = torch.clamp(focal_length, min=min_focal_length, max=max_focal_length) + elif pose_encoding_type == "absT_quaR_OneFL": + # 3 for absT, 4 for quaR, 1 for absFL + # [absolute translation, quaternion rotation, normalized focal length] + abs_T = pose_encoding_reshaped[:, :3] + quaternion_R = pose_encoding_reshaped[:, 3:7] + R = quaternion_to_matrix(quaternion_R) + focal_length = pose_encoding_reshaped[:, 7:8] + focal_length = torch.clamp(focal_length, min=min_focal_length, max=max_focal_length) + else: + raise ValueError(f"Unknown pose encoding {pose_encoding_type}") + + if to_OpenCV: + ### From Pytorch3D coordinate to OpenCV coordinate: + # I hate coordinate conversion + R = R.clone() + abs_T = abs_T.clone() + R[:, :, :2] *= -1 + abs_T[:, :2] *= -1 + R = R.permute(0, 2, 1) + + extrinsics_4x4 = torch.eye(4, 4).to(R.dtype).to(R.device) + extrinsics_4x4 = extrinsics_4x4[None].repeat(len(R), 1, 1) + + extrinsics_4x4[:, :3, :3] = R.clone() + extrinsics_4x4[:, :3, 3] = abs_T.clone() + + rel_transform = closed_form_inverse_OpenCV(extrinsics_4x4[0:1]) + rel_transform = rel_transform.expand(len(extrinsics_4x4), -1, -1) + + # relative to the first camera + # NOTE it is extrinsics_4x4 x rel_transform instead of rel_transform x extrinsics_4x4 + # this is different in opencv / pytorch3d convention + extrinsics_4x4 = torch.bmm(extrinsics_4x4, rel_transform) + + R = extrinsics_4x4[:, :3, :3].clone() + abs_T = extrinsics_4x4[:, :3, 3].clone() + + if return_dict: + return {"focal_length": focal_length, "R": R, "T": abs_T} + + pred_cameras = PerspectiveCameras(focal_length=focal_length, R=R, T=abs_T, device=R.device) + return pred_cameras + + +def camera_to_pose_encoding( + camera, pose_encoding_type="absT_quaR_logFL", log_focal_length_bias=1.8, min_focal_length=0.1, max_focal_length=30 +): + """ + Inverse to pose_encoding_to_camera + """ + if pose_encoding_type == "absT_quaR_logFL": + # Convert rotation matrix to quaternion + quaternion_R = matrix_to_quaternion(camera.R) + + # Calculate log_focal_length + log_focal_length = ( + torch.log(torch.clamp(camera.focal_length, min=min_focal_length, max=max_focal_length)) + - log_focal_length_bias + ) + + # Concatenate to form pose_encoding + pose_encoding = torch.cat([camera.T, quaternion_R, log_focal_length], dim=-1) + + elif pose_encoding_type == "absT_quaR_OneFL": + # [absolute translation, quaternion rotation, normalized focal length] + quaternion_R = matrix_to_quaternion(camera.R) + focal_length = (torch.clamp(camera.focal_length, min=min_focal_length, max=max_focal_length))[..., 0:1] + pose_encoding = torch.cat([camera.T, quaternion_R, focal_length], dim=-1) + else: + raise ValueError(f"Unknown pose encoding {pose_encoding_type}") + + return pose_encoding + + +class PoseEmbedding(nn.Module): + def __init__(self, target_dim, n_harmonic_functions=10, append_input=True): + super().__init__() + + self._emb_pose = HarmonicEmbedding(n_harmonic_functions=n_harmonic_functions, append_input=append_input) + + self.out_dim = self._emb_pose.get_output_dim(target_dim) + + def forward(self, pose_encoding): + e_pose_encoding = self._emb_pose(pose_encoding) + return e_pose_encoding + + +def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor: + """ + This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. + It is a wrapper of get_2d_sincos_pos_embed_from_grid. + Args: + - embed_dim: The embedding dimension. + - grid_size: The grid size. + Returns: + - pos_embed: The generated 2D positional embedding. + """ + if isinstance(grid_size, tuple): + grid_size_h, grid_size_w = grid_size + else: + grid_size_h = grid_size_w = grid_size + grid_h = torch.arange(grid_size_h, dtype=torch.float) + grid_w = torch.arange(grid_size_w, dtype=torch.float) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") + grid = torch.stack(grid, dim=0) + grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if return_grid: + return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid + return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) + + +def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor: + """ + This function generates a 2D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - grid: The grid to generate the embedding from. + + Returns: + - emb: The generated 2D positional embedding. + """ + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=torch.double) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb[None].float() + + +def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: + """ + This function generates a 2D positional embedding from given coordinates using sine and cosine functions. + + Args: + - xy: The coordinates to generate the embedding from. + - C: The size of the embedding. + - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. + + Returns: + - pe: The generated 2D positional embedding. + """ + B, N, D = xy.shape + assert D == 2 + + x = xy[:, :, 0:1] + y = xy[:, :, 1:2] + div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2)) + + pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + + pe_x[:, :, 0::2] = torch.sin(x * div_term) + pe_x[:, :, 1::2] = torch.cos(x * div_term) + + pe_y[:, :, 0::2] = torch.sin(y * div_term) + pe_y[:, :, 1::2] = torch.cos(y * div_term) + + pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) + if cat_coords: + pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) + return pe + + +def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): + r"""Sample a tensor using bilinear interpolation + + `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at + coordinates :attr:`coords` using bilinear interpolation. It is the same + as `torch.nn.functional.grid_sample()` but with a different coordinate + convention. + + The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where + :math:`B` is the batch size, :math:`C` is the number of channels, + :math:`H` is the height of the image, and :math:`W` is the width of the + image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is + interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. + + Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, + in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note + that in this case the order of the components is slightly different + from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. + + If `align_corners` is `True`, the coordinate :math:`x` is assumed to be + in the range :math:`[0,W-1]`, with 0 corresponding to the center of the + left-most image pixel :math:`W-1` to the center of the right-most + pixel. + + If `align_corners` is `False`, the coordinate :math:`x` is assumed to + be in the range :math:`[0,W]`, with 0 corresponding to the left edge of + the left-most pixel :math:`W` to the right edge of the right-most + pixel. + + Similar conventions apply to the :math:`y` for the range + :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range + :math:`[0,T-1]` and :math:`[0,T]`. + + Args: + input (Tensor): batch of input images. + coords (Tensor): batch of coordinates. + align_corners (bool, optional): Coordinate convention. Defaults to `True`. + padding_mode (str, optional): Padding mode. Defaults to `"border"`. + + Returns: + Tensor: sampled points. + """ + + sizes = input.shape[2:] + + assert len(sizes) in [2, 3] + + if len(sizes) == 3: + # t x y -> x y t to match dimensions T H W in grid_sample + coords = coords[..., [1, 2, 0]] + + if align_corners: + coords = coords * torch.tensor([2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device) + else: + coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device) + + coords -= 1 + + return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) + + +def sample_features4d(input, coords): + r"""Sample spatial features + + `sample_features4d(input, coords)` samples the spatial features + :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. + + The field is sampled at coordinates :attr:`coords` using bilinear + interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, + 3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the + same convention as :func:`bilinear_sampler` with `align_corners=True`. + + The output tensor has one feature per point, and has shape :math:`(B, + R, C)`. + + Args: + input (Tensor): spatial features. + coords (Tensor): points. + + Returns: + Tensor: sampled features. + """ + + B, _, _, _ = input.shape + + # B R 2 -> B R 1 2 + coords = coords.unsqueeze(2) + + # B C R 1 + feats = bilinear_sampler(input, coords) + + return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C diff --git a/vggsfm/vggsfm/models/vggsfm.py b/vggsfm/vggsfm/models/vggsfm.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b148d28d69a64ee438b1632dae5798f02418a3 --- /dev/null +++ b/vggsfm/vggsfm/models/vggsfm.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Any, Dict, List, Optional, Tuple, Union + +from hydra.utils import instantiate + + +class VGGSfM(nn.Module): + def __init__(self, TRACK: Dict, CAMERA: Dict, TRIANGULAE: Dict, cfg=None): + """ + Initializes a VGGSfM model + + TRACK, CAMERA, TRIANGULAE are the dicts to construct the model modules + cfg is the whole hydra config + """ + super().__init__() + + self.cfg = cfg + + # models.TrackerPredictor + self.track_predictor = instantiate(TRACK, _recursive_=False, cfg=cfg) + + # models.CameraPredictor + self.camera_predictor = instantiate(CAMERA, _recursive_=False, cfg=cfg) + + # models.Triangulator + self.triangulator = instantiate(TRIANGULAE, _recursive_=False, cfg=cfg) diff --git a/vggsfm/vggsfm/two_view_geo/essential.py b/vggsfm/vggsfm/two_view_geo/essential.py new file mode 100644 index 0000000000000000000000000000000000000000..c7905d3b6f50cc695b2e2b4cdb2687606ecce3bd --- /dev/null +++ b/vggsfm/vggsfm/two_view_geo/essential.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# The code structure learned from https://github.com/kornia/kornia +# Some funtions adapted from https://github.com/kornia/kornia +# The minimal solvers learned from https://github.com/colmap/colmap + + +import torch +from torch.cuda.amp import autocast +from kornia.geometry import solvers +from kornia.core import eye, ones_like, stack, where, zeros +from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SAME_SHAPE, KORNIA_CHECK_SHAPE + + +from typing import Optional, Tuple + + +from .utils import ( + generate_samples, + calculate_residual_indicator, + local_refinement, + _torch_svd_cast, + sampson_epipolar_distance_batched, +) + + +def decompose_essential_matrix(E_mat: torch.Tensor): + r"""Decompose an essential matrix to possible rotations and translation. + + This function decomposes the essential matrix E using svd decomposition + and give the possible solutions: :math:`R1, R2, t`. + + Args: + E_mat: The essential matrix in the form of :math:`(*, 3, 3)`. + + Returns: + A tuple containing the first and second possible rotation matrices and the translation vector. + The shape of the tensors with be same input :math:`[(*, 3, 3), (*, 3, 3), (*, 3, 1)]`. + """ + + with autocast(dtype=torch.double): + if not (len(E_mat.shape) >= 2 and E_mat.shape[-2:]): + raise AssertionError(E_mat.shape) + + # decompose matrix by its singular values + U, _, V = _torch_svd_cast(E_mat) + Vt = V.transpose(-2, -1) + + mask = ones_like(E_mat) + mask[..., -1:] *= -1.0 # fill last column with negative values + + maskt = mask.transpose(-2, -1) + + # avoid singularities + U = where((torch.det(U) < 0.0)[..., None, None], U * mask, U) + Vt = where((torch.det(Vt) < 0.0)[..., None, None], Vt * maskt, Vt) + + W = cross_product_matrix(torch.tensor([[0.0, 0.0, 1.0]]).type_as(E_mat)) + W[..., 2, 2] += 1.0 + + # reconstruct rotations and retrieve translation vector + U_W_Vt = U @ W @ Vt + U_Wt_Vt = U @ W.transpose(-2, -1) @ Vt + + # return values + R1 = U_W_Vt + R2 = U_Wt_Vt + T = U[..., -1:] + + # compbine and returns the four possible solutions + Rs = stack([R1, R1, R2, R2], dim=1) + Ts = stack([T, -T, T, -T], dim=1) + + return Rs, Ts[..., 0] + + +def cross_product_matrix(x: torch.Tensor) -> torch.Tensor: + r"""Return the cross_product_matrix symmetric matrix of a vector. + + Args: + x: The input vector to construct the matrix in the shape :math:`(*, 3)`. + + Returns: + The constructed cross_product_matrix symmetric matrix with shape :math:`(*, 3, 3)`. + """ + if not x.shape[-1] == 3: + raise AssertionError(x.shape) + # get vector compononens + x0 = x[..., 0] + x1 = x[..., 1] + x2 = x[..., 2] + + # construct the matrix, reshape to 3x3 and return + zeros = torch.zeros_like(x0) + cross_product_matrix_flat = stack([zeros, -x2, x1, x2, zeros, -x0, -x1, x0, zeros], dim=-1) + shape_ = x.shape[:-1] + (3, 3) + return cross_product_matrix_flat.view(*shape_) + + +def estimate_essential(points1, points2, focal_length, principal_point, max_ransac_iters=1024, max_error=4, lo_num=50): + """ + Estimate essential matrix by 5 point algorithm with LORANSAC + + points1, points2: Pytorch Tensor, BxNx2 + + best_emat: Bx3x3 + """ + with autocast(dtype=torch.double): + # normalize points for 5p + principal_point = principal_point.unsqueeze(1) + focal_length = focal_length.unsqueeze(1) + + points1 = (points1 - principal_point[..., :2]) / focal_length[..., :2] + points2 = (points2 - principal_point[..., 2:]) / focal_length[..., 2:] + + max_error = max_error / focal_length.mean(dim=-1, keepdim=True) + max_thres = max_error**2 + + B, N, _ = points1.shape + point_per_sample = 5 # 5p algorithm + + # randomly sample 5 point set by max_ransac_iters times + # ransac_idx: torch matirx Nx5 + ransac_idx = generate_samples(N, max_ransac_iters, point_per_sample) + left = points1[:, ransac_idx].view(B * max_ransac_iters, point_per_sample, 2) + right = points2[:, ransac_idx].view(B * max_ransac_iters, point_per_sample, 2) + + # 5p algorithm will provide 10 potential answers + # so the shape of emat_ransac is + # B x (max_ransac_iters*10) x 3 x 3 + #################################################################################### + emat_ransac = run_5point(left, right) + emat_ransac = emat_ransac.reshape(B, max_ransac_iters, 10, 3, 3).reshape(B, max_ransac_iters * 10, 3, 3) + + residuals = sampson_epipolar_distance_batched(points1, points2, emat_ransac, squared=True) + + inlier_mask = residuals <= max_thres + inlier_num = inlier_mask.sum(dim=-1) + + _, sorted_indices = torch.sort(inlier_num, dim=1, descending=True) + + # Local Refinement by + # 5p algorithm with inliers + emat_lo = local_refinement( + run_5point, points1, points2, inlier_mask, sorted_indices, lo_num=lo_num, essential=True + ) + + emat_lo = emat_lo.reshape(B, 10 * lo_num, 3, 3) + + # choose the one with the higher inlier number and smallest (valid) residual + all_emat = torch.cat([emat_ransac, emat_lo], dim=1) + residuals_all = sampson_epipolar_distance_batched(points1, points2, all_emat, squared=True) + + residual_indicator, inlier_num_all, inlier_mask_all = calculate_residual_indicator(residuals_all, max_thres) + + batch_index = torch.arange(B).unsqueeze(-1).expand(-1, lo_num) + best_e_indices = torch.argmax(residual_indicator, dim=1) + + best_emat = all_emat[batch_index[:, 0], best_e_indices] + best_inlier_num = inlier_num_all[batch_index[:, 0], best_e_indices] + best_inlier_mask = inlier_mask_all[batch_index[:, 0], best_e_indices] + + return best_emat, best_inlier_num, best_inlier_mask + + +def run_5point( + points1: torch.Tensor, + points2: torch.Tensor, + masks: Optional[torch.Tensor] = None, + weights: Optional[torch.Tensor] = None, +) -> torch.Tensor: + r"""Compute the essential matrix using the 5-point algorithm from Nister. + + The linear system is solved by Nister's 5-point algorithm [@nister2004efficient], + and the solver implemented referred to [@barath2020magsac++][@wei2023generalized]. + + Args: + points1: A set of carlibrated points in the first image with a tensor shape :math:`(B, N, 2), N>=8`. + points2: A set of points in the second image with a tensor shape :math:`(B, N, 2), N>=8`. + weights: Tensor containing the weights per point correspondence with a shape of :math:`(B, N)`. + + Returns: + the computed essential matrix with shape :math:`(B, 3, 3)`. + """ + with autocast(dtype=torch.double): + KORNIA_CHECK_SHAPE(points1, ["B", "N", "2"]) + KORNIA_CHECK_SAME_SHAPE(points1, points2) + KORNIA_CHECK(points1.shape[1] >= 5, "Number of points should be >=5") + + if masks is None: + masks = ones_like(points1[..., 0]) + + if weights is not None: + KORNIA_CHECK_SAME_SHAPE(points1[:, :, 0], weights) + + batch_size, _, _ = points1.shape + x1, y1 = torch.chunk(points1, dim=-1, chunks=2) # Bx1xN + x2, y2 = torch.chunk(points2, dim=-1, chunks=2) # Bx1xN + ones = ones_like(x1) + + # build equations system and find null space. + # [x * x', x * y', x, y * x', y * y', y, x', y', 1] + # BxNx9 + X = torch.cat([x1 * x2, x1 * y2, x1, y1 * x2, y1 * y2, y1, x2, y2, ones], dim=-1) + + # if masks is not valid, force the cooresponding rows (points) to all zeros + if masks is not None: + X = X * masks.unsqueeze(-1) + + # apply the weights to the linear system + if weights is None: + X = X.transpose(-2, -1) @ X + else: + w_diag = torch.diag_embed(weights) + X = X.transpose(-2, -1) @ w_diag @ X + + # compute eigenvectors and retrieve the one with the smallest eigenvalue, using SVD + # turn off the grad check due to the unstable gradients from SVD. + # several close to zero values of eigenvalues. + _, _, V = _torch_svd_cast(X) # torch.svd + + # use Nister's method to solve essential matrix + + E_Nister = null_to_Nister_solution(V, batch_size) + return E_Nister + + +def fun_select(null_mat, i: int, j: int, ratio=3) -> torch.Tensor: + return null_mat[:, ratio * j + i] + + +def null_to_Nister_solution(V, batch_size): + null_ = V[:, :, -4:] # the last four rows + nullSpace = V.transpose(-1, -2)[:, -4:, :] + + coeffs = zeros(batch_size, 10, 20, device=null_.device, dtype=null_.dtype) + d = zeros(batch_size, 60, device=null_.device, dtype=null_.dtype) + + coeffs[:, 9] = ( + solvers.multiply_deg_two_one_poly( + solvers.multiply_deg_one_poly(fun_select(null_, 0, 1), fun_select(null_, 1, 2)) + - solvers.multiply_deg_one_poly(fun_select(null_, 0, 2), fun_select(null_, 1, 1)), + fun_select(null_, 2, 0), + ) + + solvers.multiply_deg_two_one_poly( + solvers.multiply_deg_one_poly(fun_select(null_, 0, 2), fun_select(null_, 1, 0)) + - solvers.multiply_deg_one_poly(fun_select(null_, 0, 0), fun_select(null_, 1, 2)), + fun_select(null_, 2, 1), + ) + + solvers.multiply_deg_two_one_poly( + solvers.multiply_deg_one_poly(fun_select(null_, 0, 0), fun_select(null_, 1, 1)) + - solvers.multiply_deg_one_poly(fun_select(null_, 0, 1), fun_select(null_, 1, 0)), + fun_select(null_, 2, 2), + ) + ) + + indices = torch.tensor([[0, 10, 20], [10, 40, 30], [20, 30, 50]]) + + # Compute EE^T (Eqn. 20 in the paper) + for i in range(3): + for j in range(3): + d[:, indices[i, j] : indices[i, j] + 10] = ( + solvers.multiply_deg_one_poly(fun_select(null_, i, 0), fun_select(null_, j, 0)) + + solvers.multiply_deg_one_poly(fun_select(null_, i, 1), fun_select(null_, j, 1)) + + solvers.multiply_deg_one_poly(fun_select(null_, i, 2), fun_select(null_, j, 2)) + ) + + for i in range(10): + t = 0.5 * (d[:, indices[0, 0] + i] + d[:, indices[1, 1] + i] + d[:, indices[2, 2] + i]) + d[:, indices[0, 0] + i] -= t + d[:, indices[1, 1] + i] -= t + d[:, indices[2, 2] + i] -= t + + cnt = 0 + for i in range(3): + for j in range(3): + row = ( + solvers.multiply_deg_two_one_poly(d[:, indices[i, 0] : indices[i, 0] + 10], fun_select(null_, 0, j)) + + solvers.multiply_deg_two_one_poly(d[:, indices[i, 1] : indices[i, 1] + 10], fun_select(null_, 1, j)) + + solvers.multiply_deg_two_one_poly(d[:, indices[i, 2] : indices[i, 2] + 10], fun_select(null_, 2, j)) + ) + coeffs[:, cnt] = row + cnt += 1 + + b = coeffs[:, :, 10:] + + # NOTE Some operations are filtered here + singular_filter = torch.linalg.matrix_rank(coeffs[:, :, :10]) >= torch.max( + torch.linalg.matrix_rank(coeffs), ones_like(torch.linalg.matrix_rank(coeffs[:, :, :10])) * 10 + ) + + if len(singular_filter) == 0: + return torch.eye(3, dtype=coeffs.dtype, device=coeffs.device)[None].expand(batch_size, 10, -1, -1).clone() + + eliminated_mat = torch.linalg.solve(coeffs[singular_filter, :, :10], b[singular_filter]) + + coeffs_ = torch.cat((coeffs[singular_filter, :, :10], eliminated_mat), dim=-1) + + batch_size_filtered = coeffs_.shape[0] + A = zeros(batch_size_filtered, 3, 13, device=coeffs_.device, dtype=coeffs_.dtype) + + for i in range(3): + A[:, i, 0] = 0.0 + A[:, i : i + 1, 1:4] = coeffs_[:, 4 + 2 * i : 5 + 2 * i, 10:13] + A[:, i : i + 1, 0:3] -= coeffs_[:, 5 + 2 * i : 6 + 2 * i, 10:13] + A[:, i, 4] = 0.0 + A[:, i : i + 1, 5:8] = coeffs_[:, 4 + 2 * i : 5 + 2 * i, 13:16] + A[:, i : i + 1, 4:7] -= coeffs_[:, 5 + 2 * i : 6 + 2 * i, 13:16] + A[:, i, 8] = 0.0 + A[:, i : i + 1, 9:13] = coeffs_[:, 4 + 2 * i : 5 + 2 * i, 16:20] + A[:, i : i + 1, 8:12] -= coeffs_[:, 5 + 2 * i : 6 + 2 * i, 16:20] + + # Bx11 + cs = solvers.determinant_to_polynomial(A) + E_models = [] + + # A: Bx3x13 + # nullSpace: Bx4x9 + + C = zeros((batch_size_filtered, 10, 10), device=cs.device, dtype=cs.dtype) + eye_mat = eye(C[0, 0:-1, 0:-1].shape[0], device=cs.device, dtype=cs.dtype) + C[:, 0:-1, 1:] = eye_mat + + cs_de = cs[:, -1].unsqueeze(-1) + cs_de = torch.where(cs_de == 0, 1e-8, cs_de) + C[:, -1, :] = -cs[:, :-1] / cs_de + + roots = torch.real(torch.linalg.eigvals(C)) + + roots_unsqu = roots.unsqueeze(1) + Bs = stack( + ( + A[:, :3, :1] * (roots_unsqu**3) + + A[:, :3, 1:2] * roots_unsqu.square() + + A[:, 0:3, 2:3] * roots_unsqu + + A[:, 0:3, 3:4], + A[:, 0:3, 4:5] * (roots_unsqu**3) + + A[:, 0:3, 5:6] * roots_unsqu.square() + + A[:, 0:3, 6:7] * roots_unsqu + + A[:, 0:3, 7:8], + ), + dim=1, + ) + Bs = Bs.transpose(1, -1) + + bs = ( + ( + A[:, 0:3, 8:9] * (roots_unsqu**4) + + A[:, 0:3, 9:10] * (roots_unsqu**3) + + A[:, 0:3, 10:11] * roots_unsqu.square() + + A[:, 0:3, 11:12] * roots_unsqu + + A[:, 0:3, 12:13] + ) + .transpose(1, 2) + .unsqueeze(-1) + ) + + xzs = torch.matmul(torch.inverse(Bs[:, :, 0:2, 0:2]), bs[:, :, 0:2]) + + mask = (abs(Bs[:, 2].unsqueeze(1) @ xzs - bs[:, 2].unsqueeze(1)) > 1e-3).flatten() + + # mask: bx10x1x1 + mask = ( + abs(torch.matmul(Bs[:, :, 2, :].unsqueeze(2), xzs) - bs[:, :, 2, :].unsqueeze(2)) > 1e-3 + ) # .flatten(start_dim=1) + # bx10 + mask = mask.squeeze(3).squeeze(2) + + if torch.any(mask): + q_batch, r_batch = torch.linalg.qr(Bs[mask]) + xyz_to_feed = torch.linalg.solve(r_batch, torch.matmul(q_batch.transpose(-1, -2), bs[mask])) + xzs[mask] = xyz_to_feed + + nullSpace_filtered = nullSpace[singular_filter] + + Es = ( + nullSpace_filtered[:, 0:1] * (-xzs[:, :, 0]) + + nullSpace_filtered[:, 1:2] * (-xzs[:, :, 1]) + + nullSpace_filtered[:, 2:3] * roots.unsqueeze(-1) + + nullSpace_filtered[:, 3:4] + ) + + inv = 1.0 / torch.sqrt((-xzs[:, :, 0]) ** 2 + (-xzs[:, :, 1]) ** 2 + roots.unsqueeze(-1) ** 2 + 1.0) + Es *= inv + + Es = Es.view(batch_size_filtered, -1, 3, 3).transpose(-1, -2) + E_return = torch.eye(3, dtype=Es.dtype, device=Es.device)[None].expand(batch_size, 10, -1, -1).clone() + E_return[singular_filter] = Es + + return E_return diff --git a/vggsfm/vggsfm/two_view_geo/estimate_preliminary.py b/vggsfm/vggsfm/two_view_geo/estimate_preliminary.py new file mode 100644 index 0000000000000000000000000000000000000000..4b5cc0b91bfd76ca64350d5e0e8c725166d5091a --- /dev/null +++ b/vggsfm/vggsfm/two_view_geo/estimate_preliminary.py @@ -0,0 +1,237 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from minipytorch3d.cameras import CamerasBase, PerspectiveCameras + +# from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras +# from pytorch3d.transforms import se3_exp_map, se3_log_map, Transform3d, so3_relative_angle + + +from torch.cuda.amp import autocast + +from .fundamental import estimate_fundamental, essential_from_fundamental +from .homography import estimate_homography, decompose_homography_matrix +from .essential import estimate_essential, decompose_essential_matrix +from .utils import get_default_intri, remove_cheirality + +# TODO remove the .. and . that may look confusing +from ..utils.metric import closed_form_inverse + +try: + import poselib + + print("Poselib is available") +except: + print("Poselib is not installed. Please disable use_poselib") + + +def estimate_preliminary_cameras_poselib( + tracks, + tracks_vis, + width, + height, + tracks_score=None, + max_error=0.5, + max_ransac_iters=20000, + predict_essential=False, + lo_num=None, + predict_homo=False, + loopresidual=False, +): + B, S, N, _ = tracks.shape + + query_points = tracks[:, 0:1].reshape(B, N, 2) + reference_points = tracks[:, 1:].reshape(B * (S - 1), N, 2) + + valid_mask = (tracks_vis >= 0.05)[:, 1:].reshape(B * (S - 1), N) + + fmat = [] + inlier_mask = [] + for idx in range(len(reference_points)): + kps_left = query_points[0].cpu().numpy() + kps_right = reference_points[idx].cpu().numpy() + + cur_inlier_mask = valid_mask[idx].cpu().numpy() + + kps_left = kps_left[cur_inlier_mask] + kps_right = kps_right[cur_inlier_mask] + + cur_fmat, info = poselib.estimate_fundamental( + kps_left, + kps_right, + { + "max_epipolar_error": max_error, + "max_iterations": max_ransac_iters, + "min_iterations": 1000, + "real_focal_check": True, + "progressive_sampling": False, + }, + ) + + cur_inlier_mask[cur_inlier_mask] = np.array(info["inliers"]) + + fmat.append(cur_fmat) + inlier_mask.append(cur_inlier_mask) + + fmat = torch.from_numpy(np.array(fmat)).to(query_points.device) + inlier_mask = torch.from_numpy(np.array(inlier_mask)).to(query_points.device) + + preliminary_dict = {"fmat": fmat[None], "fmat_inlier_mask": inlier_mask[None]} + + return None, preliminary_dict + + +def estimate_preliminary_cameras( + tracks, + tracks_vis, + width, + height, + tracks_score=None, + max_error=0.5, + lo_num=300, + max_ransac_iters=4096, + predict_essential=False, + predict_homo=False, + loopresidual=False, +): + # TODO: also clean the code for predict_essential and predict_homo + + with autocast(dtype=torch.double): + # batch_num, frame_num, point_num + B, S, N, _ = tracks.shape + + # We have S-1 reference frame per batch + query_points = tracks[:, 0:1].expand(-1, S - 1, -1, -1).reshape(B * (S - 1), N, 2) + reference_points = tracks[:, 1:].reshape(B * (S - 1), N, 2) + + # Filter out some matches based on track vis and score + + valid_mask = (tracks_vis >= 0.05)[:, 1:].reshape(B * (S - 1), N) + + if tracks_score is not None: + valid_tracks_score_mask = (tracks_score >= 0.5)[:, 1:].reshape(B * (S - 1), N) + valid_mask = torch.logical_and(valid_mask, valid_tracks_score_mask) + + # Estimate Fundamental Matrix by Batch + # fmat: (B*(S-1))x3x3 + fmat, fmat_inlier_num, fmat_inlier_mask, fmat_residuals = estimate_fundamental( + query_points, + reference_points, + max_error=max_error, + lo_num=lo_num, + max_ransac_iters=max_ransac_iters, + valid_mask=valid_mask, + loopresidual=loopresidual, + return_residuals=True, + ) + + # kmat1, kmat2: (B*(S-1))x3x3 + kmat1, kmat2, fl, pp = build_default_kmat( + width, height, B, S, N, device=query_points.device, dtype=query_points.dtype + ) + + emat_fromf, _, _ = essential_from_fundamental(fmat, kmat1, kmat2) + + R_emat_fromf, t_emat_fromf = decompose_essential_matrix(emat_fromf) + R_emat_fromf, t_emat_fromf = remove_cheirality( + R_emat_fromf, t_emat_fromf, query_points, reference_points, fl, pp + ) + + # TODO: clean the code for R_hmat, t_hmat, R_emat, t_emat and add them here + R_preliminary = R_emat_fromf + t_preliminary = t_emat_fromf + + R_preliminary = R_preliminary.reshape(B, S - 1, 3, 3) + t_preliminary = t_preliminary.reshape(B, S - 1, 3) + + # pad for the first camera + R_pad = torch.eye(3, device=tracks.device, dtype=tracks.dtype)[None].repeat(B, 1, 1).unsqueeze(1) + t_pad = torch.zeros(3, device=tracks.device, dtype=tracks.dtype)[None].repeat(B, 1).unsqueeze(1) + + R_preliminary = torch.cat([R_pad, R_preliminary], dim=1).reshape(B * S, 3, 3) + t_preliminary = torch.cat([t_pad, t_preliminary], dim=1).reshape(B * S, 3) + + R_opencv = R_preliminary.clone() + t_opencv = t_preliminary.clone() + + # From OpenCV/COLMAP camera convention to PyTorch3D + # TODO: Remove the usage of PyTorch3D convention in all the codebase + # So that we don't need to do such conventions any more + R_preliminary = R_preliminary.clone().permute(0, 2, 1) + t_preliminary = t_preliminary.clone() + t_preliminary[:, :2] *= -1 + R_preliminary[:, :, :2] *= -1 + + pred_cameras = PerspectiveCameras(R=R_preliminary, T=t_preliminary, device=R_preliminary.device) + + with autocast(dtype=torch.double): + # Optional in the future + # make all the cameras relative to the first one + pred_se3 = pred_cameras.get_world_to_view_transform().get_matrix() + + rel_transform = closed_form_inverse(pred_se3[0:1, :, :]) + rel_transform = rel_transform.expand(B * S, -1, -1) + + pred_se3_rel = torch.bmm(rel_transform, pred_se3) + pred_se3_rel[..., :3, 3] = 0.0 + pred_se3_rel[..., 3, 3] = 1.0 + + pred_cameras.R = pred_se3_rel[:, :3, :3].clone() + pred_cameras.T = pred_se3_rel[:, 3, :3].clone() + + # Record them in case we may need later + fmat = fmat.reshape(B, S - 1, 3, 3) + fmat_inlier_mask = fmat_inlier_mask.reshape(B, S - 1, -1) + kmat1 = kmat1.reshape(B, S - 1, 3, 3) + R_opencv = R_opencv.reshape(B, S, 3, 3) + t_opencv = t_opencv.reshape(B, S, 3) + + fmat_residuals = fmat_residuals.reshape(B, S - 1, -1) + + preliminary_dict = { + "fmat": fmat, + "fmat_inlier_mask": fmat_inlier_mask, + "R_opencv": R_opencv, + "t_opencv": t_opencv, + "default_intri": kmat1, + "emat_fromf": emat_fromf, + "fmat_residuals": fmat_residuals, + } + + return pred_cameras, preliminary_dict + + +def build_default_kmat(width, height, B, S, N, device=None, dtype=None): + # focal length is set as max(width, height) + # principal point is set as (width//2, height//2,) + + fl, pp, _ = get_default_intri(width, height, device, dtype) + + # :2 for left frame, 2: for right frame + fl = torch.ones((B, S - 1, 4), device=device, dtype=dtype) * fl + pp = torch.cat([pp, pp])[None][None].expand(B, S - 1, -1) + + fl = fl.reshape(B * (S - 1), 4) + pp = pp.reshape(B * (S - 1), 4) + + # build kmat + kmat1 = torch.eye(3, device=device, dtype=dtype)[None].repeat(B * (S - 1), 1, 1) + kmat2 = torch.eye(3, device=device, dtype=dtype)[None].repeat(B * (S - 1), 1, 1) + + # assign them to the corresponding locations of kmats + kmat1[:, [0, 1], [0, 1]] = fl[:, :2] + kmat1[:, [0, 1], 2] = pp[:, :2] + + kmat2[:, [0, 1], [0, 1]] = fl[:, 2:] + kmat2[:, [0, 1], 2] = pp[:, 2:] + + return kmat1, kmat2, fl, pp diff --git a/vggsfm/vggsfm/two_view_geo/fundamental.py b/vggsfm/vggsfm/two_view_geo/fundamental.py new file mode 100644 index 0000000000000000000000000000000000000000..796a36b524005e36a5f8fb2307122afd40f905d9 --- /dev/null +++ b/vggsfm/vggsfm/two_view_geo/fundamental.py @@ -0,0 +1,410 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Literal, Optional, Tuple + +import torch + + +import math +from torch.cuda.amp import autocast + +from kornia.core.check import KORNIA_CHECK_SHAPE +from kornia.geometry.solvers import solve_cubic + +from kornia.geometry.epipolar.fundamental import normalize_points, normalize_transformation +from kornia.core import Tensor, concatenate, ones_like, stack, where, zeros + +import numpy as np + +from .utils import ( + generate_samples, + sampson_epipolar_distance_batched, + calculate_residual_indicator, + normalize_points_masked, + local_refinement, + _torch_svd_cast, + sampson_epipolar_distance_forloop_wrapper, +) + + +# The code structure learned from https://github.com/kornia/kornia +# Some funtions adapted from https://github.com/kornia/kornia +# The minimal solvers learned from https://github.com/colmap/colmap + + +def estimate_fundamental( + points1, + points2, + max_ransac_iters=4096, + max_error=1, + lo_num=300, + valid_mask=None, + squared=True, + second_refine=True, + loopresidual=False, + return_residuals=False, +): + """ + Given 2D correspondences, + this function estimate fundamental matrix by 7pt/8pt algo + LORANSAC. + + points1, points2: Pytorch Tensor, BxNx2 + + best_fmat: Bx3x3 + """ + max_thres = max_error**2 if squared else max_error + + # points1, points2: BxNx2 + B, N, _ = points1.shape + point_per_sample = 7 # 7p algorithm require 7 pairs + + # randomly sample 7 point set by max_ransac_iters times + # ransac_idx: torch matirx Nx7 + ransac_idx = generate_samples(N, max_ransac_iters, point_per_sample) + left = points1[:, ransac_idx].view(B * max_ransac_iters, point_per_sample, 2) + right = points2[:, ransac_idx].view(B * max_ransac_iters, point_per_sample, 2) + + # Note that, we have (B*max_ransac_iters) 7-point sets + # Each 7-point set will lead to 3 potential answers by 7p algorithm (check run_7point for details) + # Therefore the number of 3x3 matrix fmat_ransac is (B*max_ransac_iters*3) + # We reshape it to B x (max_ransac_iters*3) x 3 x3 + fmat_ransac = run_7point(left, right) + fmat_ransac = fmat_ransac.reshape(B, max_ransac_iters, 3, 3, 3).reshape(B, max_ransac_iters * 3, 3, 3) + + # Not sure why but the computation of sampson errors takes a lot of GPU memory + # Since it is very fast, users can use a for loop to reduce the peak GPU usage + # if necessary + if loopresidual: + sampson_fn = sampson_epipolar_distance_forloop_wrapper + else: + sampson_fn = sampson_epipolar_distance_batched + + residuals = sampson_fn(points1, points2, fmat_ransac, squared=squared) + if loopresidual: + torch.cuda.empty_cache() + + # If we know some matches are invalid, + # we can simply force its corresponding errors as a huge value + if valid_mask is not None: + valid_mask_tmp = valid_mask[:, None].expand(-1, residuals.shape[1], -1) + residuals[~valid_mask_tmp] = 1e6 + + # Compute the number of inliers + # and sort the candidate fmats based on it + inlier_mask = residuals <= max_thres + inlier_num = inlier_mask.sum(dim=-1) + sorted_values, sorted_indices = torch.sort(inlier_num, dim=1, descending=True) + + # Conduct local refinement by 8p algorithm + # Basically, for a well-conditioned candidate fmat from 7p algorithm + # we can compute all of its inliers + # and then feed these inliers to 8p algorithm + fmat_lo = local_refinement(run_8point, points1, points2, inlier_mask, sorted_indices, lo_num=lo_num) + if loopresidual: + torch.cuda.empty_cache() + residuals_lo = sampson_fn(points1, points2, fmat_lo, squared=squared) + if loopresidual: + torch.cuda.empty_cache() + + if second_refine: + # We can do this again to the predictd fmats from last run of 8p algorithm + # Usually it is not necessary but let's put it here + lo_more = lo_num // 2 + inlier_mask_lo = residuals_lo <= max_thres + inlier_num_lo = inlier_mask_lo.sum(dim=-1) + sorted_values_lo, sorted_indices_lo = torch.sort(inlier_num_lo, dim=1, descending=True) + fmat_lo_second = local_refinement( + run_8point, points1, points2, inlier_mask_lo, sorted_indices_lo, lo_num=lo_more + ) + if loopresidual: + torch.cuda.empty_cache() + residuals_lo_second = sampson_fn(points1, points2, fmat_lo_second, squared=squared) + if loopresidual: + torch.cuda.empty_cache() + fmat_lo = torch.cat([fmat_lo, fmat_lo_second], dim=1) + residuals_lo = torch.cat([residuals_lo, residuals_lo_second], dim=1) + lo_num += lo_more + + if valid_mask is not None: + valid_mask_tmp = valid_mask[:, None].expand(-1, residuals_lo.shape[1], -1) + residuals_lo[~valid_mask_tmp] = 1e6 + + # Get all the predicted fmats + # choose the one with the highest inlier number and smallest (valid) residual + + all_fmat = torch.cat([fmat_ransac, fmat_lo], dim=1) + residuals_all = torch.cat([residuals, residuals_lo], dim=1) + + residual_indicator, inlier_num_all, inlier_mask_all = calculate_residual_indicator( + residuals_all, max_thres, debug=True + ) + + batch_index = torch.arange(B).unsqueeze(-1).expand(-1, lo_num) + + # Find the index of the best fmat + best_f_indices = torch.argmax(residual_indicator, dim=1) + best_fmat = all_fmat[batch_index[:, 0], best_f_indices] + best_inlier_num = inlier_num_all[batch_index[:, 0], best_f_indices] + best_inlier_mask = inlier_mask_all[batch_index[:, 0], best_f_indices] + + if return_residuals: + best_residuals = residuals_all[batch_index[:, 0], best_f_indices] + return best_fmat, best_inlier_num, best_inlier_mask, best_residuals + + return best_fmat, best_inlier_num, best_inlier_mask + + +def essential_from_fundamental( + fmat, + kmat1, + kmat2, + points1=None, + points2=None, + focal_length=None, + principal_point=None, + max_error=4, + squared=True, + compute_residual=False, +): + """Get Essential matrix from Fundamental and Camera matrices. + + Uses the method from Hartley/Zisserman 9.6 pag 257 (formula 9.12). + + Args: + F_mat: The fundamental matrix with shape of :math:`(*, 3, 3)`. + K1: The camera matrix from first camera with shape :math:`(*, 3, 3)`. + K2: The camera matrix from second camera with shape :math:`(*, 3, 3)`. + + Returns: + The essential matrix with shape :math:`(*, 3, 3)`. + """ + + with autocast(dtype=torch.float32): + emat_from_fmat = kmat2.transpose(-2, -1) @ fmat @ kmat1 + + if compute_residual: + principal_point = principal_point.unsqueeze(1) + focal_length = focal_length.unsqueeze(1) + + points1 = (points1 - principal_point[..., :2]) / focal_length[..., :2] + points2 = (points2 - principal_point[..., 2:]) / focal_length[..., 2:] + + max_error = max_error / focal_length.mean(dim=-1, keepdim=True) + + max_thres = max_error**2 if squared else max_error + + B, N, _ = points1.shape + + if kmat1 is None: + raise NotImplementedError + + residuals = sampson_epipolar_distance_batched( + points1, points2, emat_from_fmat.unsqueeze(1), squared=squared + ) + + inlier_mask = residuals <= max_thres + + inlier_num = inlier_mask.sum(dim=-1).squeeze(1) + inlier_mask = inlier_mask.squeeze(1) + else: + inlier_num = None + inlier_mask = None + + return emat_from_fmat, inlier_num, inlier_mask + + +################################################################## +# 8P # +################################################################## + + +def run_8point( + points1: Tensor, points2: Tensor, masks: Optional[Tensor] = None, weights: Optional[Tensor] = None +) -> Tensor: + r"""Compute the fundamental matrix using the DLT formulation. + + The linear system is solved by using the Weighted Least Squares Solution for the 8 Points algorithm. + + Args: + points1: A set of points in the first image with a tensor shape :math:`(B, N, 2), N>=8`. + points2: A set of points in the second image with a tensor shape :math:`(B, N, 2), N>=8`. + weights: Tensor containing the weights per point correspondence with a shape of :math:`(B, N)`. + + Returns: + the computed fundamental matrix with shape :math:`(B, 3, 3)`. + + Adapted from https://github.com/kornia/kornia/blob/b0995bdce3b04a11d39e86853bb1de9a2a438ca2/kornia/geometry/epipolar/fundamental.py#L169 + + Refer to Hartley and Zisserman, Multiple View Geometry, algorithm 11.1, page 282 for more details + + """ + with autocast(dtype=torch.float32): + # NOTE: DO NOT use bf16 when related to SVD + if points1.shape != points2.shape: + raise AssertionError(points1.shape, points2.shape) + if points1.shape[1] < 8: + raise AssertionError(points1.shape) + if weights is not None: + if not (len(weights.shape) == 2 and weights.shape[1] == points1.shape[1]): + raise AssertionError(weights.shape) + + if masks is None: + masks = ones_like(points1[..., 0]) + + points1_norm, transform1 = normalize_points_masked(points1, masks=masks) + points2_norm, transform2 = normalize_points_masked(points2, masks=masks) + + x1, y1 = torch.chunk(points1_norm, dim=-1, chunks=2) # Bx1xN + x2, y2 = torch.chunk(points2_norm, dim=-1, chunks=2) # Bx1xN + + ones = ones_like(x1) + + # build equations system and solve DLT + # [x * x', x * y', x, y * x', y * y', y, x', y', 1] + + X = torch.cat([x2 * x1, x2 * y1, x2, y2 * x1, y2 * y1, y2, x1, y1, ones], dim=-1) # BxNx9 + + # if masks is not valid, force the cooresponding rows (points) to all zeros + if masks is not None: + X = X * masks.unsqueeze(-1) + + # apply the weights to the linear system + if weights is None: + X = X.transpose(-2, -1) @ X + else: + w_diag = torch.diag_embed(weights) + X = X.transpose(-2, -1) @ w_diag @ X + + # compute eigevectors and retrieve the one with the smallest eigenvalue + _, _, V = _torch_svd_cast(X) + F_mat = V[..., -1].view(-1, 3, 3) + + # reconstruct and force the matrix to have rank2 + U, S, V = _torch_svd_cast(F_mat) + rank_mask = torch.tensor([1.0, 1.0, 0.0], device=F_mat.device, dtype=F_mat.dtype) + + F_projected = U @ (torch.diag_embed(S * rank_mask) @ V.transpose(-2, -1)) + F_est = transform2.transpose(-2, -1) @ (F_projected @ transform1) + + return normalize_transformation(F_est) # , points1_norm, points2_norm + + +################################################################## +# 7P # +################################################################## + + +def run_7point(points1: Tensor, points2: Tensor) -> Tensor: + # with autocast(dtype=torch.double): + with autocast(dtype=torch.float32): + # NOTE: DO NOT use bf16 when related to SVD + r"""Compute the fundamental matrix using the 7-point algorithm. + + Args: + points1: A set of points in the first image with a tensor shape :math:`(B, N, 2)`. + points2: A set of points in the second image with a tensor shape :math:`(B, N, 2)`. + + Returns: + the computed fundamental matrix with shape :math:`(B, 3*m, 3), Valid values of m are 1, 2 or 3` + + Adapted from: + https://github.com/kornia/kornia/blob/b0995bdce3b04a11d39e86853bb1de9a2a438ca2/kornia/geometry/epipolar/fundamental.py#L76 + + which is based on the following paper: + Zhengyou Zhang and T. Kanade, Determining the Epipolar Geometry and its + Uncertainty: A Review, International Journal of Computer Vision, 1998. + http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.33.4540 + + """ + KORNIA_CHECK_SHAPE(points1, ["B", "7", "2"]) + KORNIA_CHECK_SHAPE(points2, ["B", "7", "2"]) + + batch_size = points1.shape[0] + + points1_norm, transform1 = normalize_points(points1) + points2_norm, transform2 = normalize_points(points2) + + x1, y1 = torch.chunk(points1_norm, dim=-1, chunks=2) # Bx1xN + x2, y2 = torch.chunk(points2_norm, dim=-1, chunks=2) # Bx1xN + + ones = ones_like(x1) + # form a linear system: which represents + # the equation (x2[i], 1)*F*(x1[i], 1) = 0 + X = concatenate([x2 * x1, x2 * y1, x2, y2 * x1, y2 * y1, y2, x1, y1, ones], -1) # BxNx9 + + # X * fmat = 0 is singular (7 equations for 9 variables) + # solving for nullspace of X to get two F + _, _, v = _torch_svd_cast(X) + + # last two singular vector as a basic of the space + f1 = v[..., 7].view(-1, 3, 3) + f2 = v[..., 8].view(-1, 3, 3) + + # lambda*f1 + mu*f2 is an arbitrary fundamental matrix + # f ~ lambda*f1 + (1 - lambda)*f2 + # det(f) = det(lambda*f1 + (1-lambda)*f2), find lambda + # form a cubic equation + # finding the coefficients of cubic polynomial (coeffs) + + coeffs = zeros((batch_size, 4), device=v.device, dtype=v.dtype) + + f1_det = torch.linalg.det(f1) + f2_det = torch.linalg.det(f2) + + f1_det_invalid = f1_det == 0 + f2_det_invalid = f2_det == 0 + + # ignore the samples that failed for det checking + if f1_det_invalid.any(): + f1[f1_det_invalid] = torch.eye(3).to(f1.device).to(f1.dtype) + + if f2_det_invalid.any(): + f2[f2_det_invalid] = torch.eye(3).to(f2.device).to(f2.dtype) + + coeffs[:, 0] = f1_det + coeffs[:, 1] = torch.einsum("bii->b", f2 @ torch.inverse(f1)) * f1_det + coeffs[:, 2] = torch.einsum("bii->b", f1 @ torch.inverse(f2)) * f2_det + coeffs[:, 3] = f2_det + + # solve the cubic equation, there can be 1 to 3 roots + roots = solve_cubic(coeffs) + + fmatrix = zeros((batch_size, 3, 3, 3), device=v.device, dtype=v.dtype) + valid_root_mask = (torch.count_nonzero(roots, dim=1) < 3) | (torch.count_nonzero(roots, dim=1) > 1) + + _lambda = roots + _mu = torch.ones_like(_lambda) + + _s = f1[valid_root_mask, 2, 2].unsqueeze(dim=1) * roots[valid_root_mask] + f2[valid_root_mask, 2, 2].unsqueeze( + dim=1 + ) + _s_non_zero_mask = ~torch.isclose(_s, torch.tensor(0.0, device=v.device, dtype=v.dtype)) + + _mu[_s_non_zero_mask] = 1.0 / _s[_s_non_zero_mask] + _lambda[_s_non_zero_mask] = _lambda[_s_non_zero_mask] * _mu[_s_non_zero_mask] + + f1_expanded = f1.unsqueeze(1).expand(batch_size, 3, 3, 3) + f2_expanded = f2.unsqueeze(1).expand(batch_size, 3, 3, 3) + + fmatrix[valid_root_mask] = ( + f1_expanded[valid_root_mask] * _lambda[valid_root_mask, :, None, None] + + f2_expanded[valid_root_mask] * _mu[valid_root_mask, :, None, None] + ) + + mat_ind = zeros(3, 3, dtype=torch.bool) + mat_ind[2, 2] = True + fmatrix[_s_non_zero_mask, mat_ind] = 1.0 + fmatrix[~_s_non_zero_mask, mat_ind] = 0.0 + + trans1_exp = transform1[valid_root_mask].unsqueeze(1).expand(-1, fmatrix.shape[2], -1, -1) + trans2_exp = transform2[valid_root_mask].unsqueeze(1).expand(-1, fmatrix.shape[2], -1, -1) + + bf16_happy = torch.matmul(trans2_exp.transpose(-2, -1), torch.matmul(fmatrix[valid_root_mask], trans1_exp)) + fmatrix[valid_root_mask] = bf16_happy.float() + + return normalize_transformation(fmatrix) diff --git a/vggsfm/vggsfm/two_view_geo/homography.py b/vggsfm/vggsfm/two_view_geo/homography.py new file mode 100644 index 0000000000000000000000000000000000000000..ae8a8150b1715d7c9f725b53d3ba43949cbdbbeb --- /dev/null +++ b/vggsfm/vggsfm/two_view_geo/homography.py @@ -0,0 +1,326 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Adapted from https://github.com/kornia +from typing import Literal, Optional, Tuple + +import torch + +from kornia.core import Tensor, concatenate, ones_like, stack, where, zeros +from kornia.core.check import KORNIA_CHECK_SHAPE +from kornia.geometry.conversions import convert_points_from_homogeneous, convert_points_to_homogeneous +from kornia.geometry.linalg import transform_points +from kornia.geometry.solvers import solve_cubic +from kornia.utils._compat import torch_version_ge + +import math +from torch.cuda.amp import autocast + +from .utils import ( + generate_samples, + calculate_residual_indicator, + normalize_points_masked, + local_refinement, + _torch_svd_cast, + oneway_transfer_error_batched, +) + +from kornia.core.check import KORNIA_CHECK_IS_TENSOR + +import warnings + +from kornia.utils import _extract_device_dtype, safe_inverse_with_mask, safe_solve_with_mask + +from kornia.geometry.homography import oneway_transfer_error + + +# The code structure learned from https://github.com/kornia/kornia +# Some funtions adapted from https://github.com/kornia/kornia +# The minimal solvers learned from https://github.com/colmap/colmap + + +def estimate_homography(points1, points2, max_ransac_iters=1024, max_error=4, lo_num=50): + max_thres = max_error**2 + # points1, points2: BxNx2 + B, N, _ = points1.shape + point_per_sample = 4 # 4p algorithm + + ransac_idx = generate_samples(N, max_ransac_iters, point_per_sample) + left = points1[:, ransac_idx].view(B * max_ransac_iters, point_per_sample, 2) + right = points2[:, ransac_idx].view(B * max_ransac_iters, point_per_sample, 2) + + hmat_ransac = run_homography_dlt(left.float(), right.float()) + hmat_ransac = hmat_ransac.reshape(B, max_ransac_iters, 3, 3) + + residuals = oneway_transfer_error_batched(points1, points2, hmat_ransac, squared=True) + + inlier_mask = residuals <= max_thres + + inlier_num = inlier_mask.sum(dim=-1) + + sorted_values, sorted_indices = torch.sort(inlier_num, dim=1, descending=True) + + hmat_lo = local_refinement(run_homography_dlt, points1, points2, inlier_mask, sorted_indices, lo_num=lo_num) + + # choose the one with the higher inlier number and smallest (valid) residual + all_hmat = torch.cat([hmat_ransac, hmat_lo], dim=1) + residuals_all = oneway_transfer_error_batched(points1, points2, all_hmat, squared=True) + residual_indicator, inlier_num_all, inlier_mask_all = calculate_residual_indicator(residuals_all, max_thres) + + batch_index = torch.arange(B).unsqueeze(-1).expand(-1, lo_num) + best_indices = torch.argmax(residual_indicator, dim=1) + + best_hmat = all_hmat[batch_index[:, 0], best_indices] + best_inlier_num = inlier_num_all[batch_index[:, 0], best_indices] + best_inlier_mask = inlier_mask_all[batch_index[:, 0], best_indices] + + return best_hmat, best_inlier_num, best_inlier_mask + + +def run_homography_dlt( + points1: torch.Tensor, + points2: torch.Tensor, + masks=None, + weights: Optional[torch.Tensor] = None, + solver: str = "svd", + colmap_style=False, +) -> torch.Tensor: + r"""Compute the homography matrix using the DLT formulation. + + The linear system is solved by using the Weighted Least Squares Solution for the 4 Points algorithm. + + Args: + points1: A set of points in the first image with a tensor shape :math:`(B, N, 2)`. + points2: A set of points in the second image with a tensor shape :math:`(B, N, 2)`. + weights: Tensor containing the weights per point correspondence with a shape of :math:`(B, N)`. + solver: variants: svd, lu. + + + Returns: + the computed homography matrix with shape :math:`(B, 3, 3)`. + """ + # with autocast(dtype=torch.double): + with autocast(dtype=torch.float32): + if points1.shape != points2.shape: + raise AssertionError(points1.shape) + if points1.shape[1] < 4: + raise AssertionError(points1.shape) + KORNIA_CHECK_SHAPE(points1, ["B", "N", "2"]) + KORNIA_CHECK_SHAPE(points2, ["B", "N", "2"]) + + device, dtype = _extract_device_dtype([points1, points2]) + + eps: float = 1e-8 + + if masks is None: + masks = ones_like(points1[..., 0]) + + points1_norm, transform1 = normalize_points_masked(points1, masks=masks, colmap_style=colmap_style) + points2_norm, transform2 = normalize_points_masked(points2, masks=masks, colmap_style=colmap_style) + + x1, y1 = torch.chunk(points1_norm, dim=-1, chunks=2) # BxNx1 + x2, y2 = torch.chunk(points2_norm, dim=-1, chunks=2) # BxNx1 + ones, zeros = torch.ones_like(x1), torch.zeros_like(x1) + + # DIAPO 11: https://www.uio.no/studier/emner/matnat/its/nedlagte-emner/UNIK4690/v16/forelesninger/lecture_4_3-estimating-homographies-from-feature-correspondences.pdf # noqa: E501 + + if colmap_style: + # should be the same + ax = torch.cat([-x1, -y1, -ones, zeros, zeros, zeros, x1 * x2, y1 * x2, x2], dim=-1) + ay = torch.cat([zeros, zeros, zeros, -x1, -y1, -ones, x1 * y2, y1 * y2, y2], dim=-1) + else: + ax = torch.cat([zeros, zeros, zeros, -x1, -y1, -ones, y2 * x1, y2 * y1, y2], dim=-1) + ay = torch.cat([x1, y1, ones, zeros, zeros, zeros, -x2 * x1, -x2 * y1, -x2], dim=-1) + + # if masks is not valid, force the cooresponding rows (points) to all zeros + if masks is not None: + masks = masks.unsqueeze(-1) + ax = ax * masks + ay = ay * masks + + A = torch.cat((ax, ay), dim=-1).reshape(ax.shape[0], -1, ax.shape[-1]) + + if weights is None: + # All points are equally important + A = A.transpose(-2, -1) @ A + else: + # We should use provided weights + if not (len(weights.shape) == 2 and weights.shape == points1.shape[:2]): + raise AssertionError(weights.shape) + w_diag = torch.diag_embed(weights.unsqueeze(dim=-1).repeat(1, 1, 2).reshape(weights.shape[0], -1)) + A = A.transpose(-2, -1) @ w_diag @ A + + if solver == "svd": + try: + _, _, V = _torch_svd_cast(A) + except RuntimeError: + warnings.warn("SVD did not converge", RuntimeWarning) + return torch.empty((points1_norm.size(0), 3, 3), device=device, dtype=dtype) + H = V[..., -1].view(-1, 3, 3) + else: + raise NotImplementedError + + H = transform2.inverse() @ (H @ transform1) + H_norm = H / (H[..., -1:, -1:] + eps) + return H_norm + + +def normalize_to_unit(M: Tensor, eps: float = 1e-8) -> Tensor: + r"""Normalize a given transformation matrix. + + The function trakes the transformation matrix and normalize so that the value in + the last row and column is one. + + Args: + M: The transformation to be normalized of any shape with a minimum size of 2x2. + eps: small value to avoid unstabilities during the backpropagation. + + Returns: + the normalized transformation matrix with same shape as the input. + """ + if len(M.shape) < 2: + raise AssertionError(M.shape) + norm_val = M.norm(dim=-1, keepdim=True) + return where(norm_val.abs() > eps, M / (norm_val + eps), M) + + +############ decompose + + +def decompose_homography_matrix(H, left, right, K1, K2): + # WE FORCE FLOAT64 here to avoid the problem in SVD + B, _, _ = H.shape # Assuming H is Bx3x3 + H = H.double() + K1 = K1.double() + K2 = K2.double() + + with autocast(dtype=torch.double): + # Adjust calibration removal for batched input + K2_inv = torch.linalg.inv(K2) # Assuming K2 is Bx3x3 + H_normalized = torch.matmul(torch.matmul(K2_inv, H), K1) + + # Adjust scale removal for batched input + _, s, _ = torch.linalg.svd(H_normalized) + s_mid = s[:, 1].unsqueeze(1).unsqueeze(2) + H_normalized /= s_mid + + # Ensure that we always return rotations, and never reflections + det_H = torch.linalg.det(H_normalized) + H_normalized[det_H < 0] *= -1.0 + + I_3 = torch.eye(3, device=H.device).unsqueeze(0).expand(B, 3, 3) + S = torch.matmul(H_normalized.transpose(-2, -1), H_normalized) - I_3 + + kMinInfinityNorm = 1e-3 + rotation_only_mask = torch.linalg.norm(S, ord=float("inf"), dim=(-2, -1)) < kMinInfinityNorm + + M00 = compute_opposite_of_minor(S, 0, 0) + M11 = compute_opposite_of_minor(S, 1, 1) + M22 = compute_opposite_of_minor(S, 2, 2) + + rtM00 = torch.sqrt(M00) + rtM11 = torch.sqrt(M11) + rtM22 = torch.sqrt(M22) + + M01 = compute_opposite_of_minor(S, 0, 1) + M12 = compute_opposite_of_minor(S, 1, 2) + M02 = compute_opposite_of_minor(S, 0, 2) + + e12 = torch.sign(M12) + e02 = torch.sign(M02) + e01 = torch.sign(M01) + + nS = torch.stack([S[:, 0, 0].abs(), S[:, 1, 1].abs(), S[:, 2, 2].abs()], dim=1) + idx = torch.argmax(nS, dim=1) + + np1, np2 = compute_np1_np2(idx, S, rtM22, rtM11, rtM00, e12, e02, e01) + + traceS = torch.einsum("bii->b", S) # Batched trace + v = 2.0 * torch.sqrt(1.0 + traceS - M00 - M11 - M22) + + ESii = torch.sign(torch.stack([S[i, idx[i], idx[i]] for i in range(B)])) + + r_2 = 2 + traceS + v + nt_2 = 2 + traceS - v + + r = torch.sqrt(r_2) + n_t = torch.sqrt(nt_2) + + # normalize there + np1_valid_mask = torch.linalg.norm(np1, dim=-1) != 0 + np1_valid_scale = torch.linalg.norm(np1[np1_valid_mask], dim=-1) + np1[np1_valid_mask] = np1[np1_valid_mask] / np1_valid_scale.unsqueeze(1) + + np2_valid_mask = torch.linalg.norm(np2, dim=-1) != 0 + np2_valid_scale = torch.linalg.norm(np2[np2_valid_mask], dim=-1) + np2[np2_valid_mask] = np2[np2_valid_mask] / np2_valid_scale.unsqueeze(1) + + half_nt = 0.5 * n_t + esii_t_r = ESii * r + t1_star = half_nt.unsqueeze(-1) * (esii_t_r.unsqueeze(-1) * np2 - n_t.unsqueeze(-1) * np1) + t2_star = half_nt.unsqueeze(-1) * (esii_t_r.unsqueeze(-1) * np1 - n_t.unsqueeze(-1) * np2) + + R1 = compute_homography_rotation(H_normalized, t1_star, np1, v) + t1 = torch.bmm(R1, t1_star.unsqueeze(-1)).squeeze(-1) + + R2 = compute_homography_rotation(H_normalized, t2_star, np2, v) + t2 = torch.bmm(R2, t2_star.unsqueeze(-1)).squeeze(-1) + + # normalize to norm-1 vector + t1 = normalize_to_unit(t1) + t2 = normalize_to_unit(t2) + + R_return = torch.cat([R1[:, None], R1[:, None], R2[:, None], R2[:, None]], dim=1) + t_return = torch.cat([t1[:, None], -t1[:, None], t2[:, None], -t2[:, None]], dim=1) + + np_return = torch.cat([-np1[:, None], np1[:, None], -np2[:, None], np2[:, None]], dim=1) + + return R_return, t_return, np_return + + +def compute_homography_rotation(H_normalized, tstar, n, v): + B, _, _ = H_normalized.shape + identity_matrix = torch.eye(3, device=H_normalized.device).unsqueeze(0).repeat(B, 1, 1) + outer_product = tstar.unsqueeze(2) * n.unsqueeze(1) + R = H_normalized @ (identity_matrix - (2.0 / v.unsqueeze(-1).unsqueeze(-1)) * outer_product) + return R + + +def compute_np1_np2(idx, S, rtM22, rtM11, rtM00, e12, e02, e01): + B = S.shape[0] + np1 = torch.zeros(B, 3, dtype=S.dtype, device=S.device) + np2 = torch.zeros(B, 3, dtype=S.dtype, device=S.device) + + # Masks for selecting indices + idx0 = idx == 0 + idx1 = idx == 1 + idx2 = idx == 2 + + # Compute np1 and np2 for idx == 0 + np1[idx0, 0], np2[idx0, 0] = S[idx0, 0, 0], S[idx0, 0, 0] + np1[idx0, 1], np2[idx0, 1] = S[idx0, 0, 1] + rtM22[idx0], S[idx0, 0, 1] - rtM22[idx0] + np1[idx0, 2], np2[idx0, 2] = S[idx0, 0, 2] + e12[idx0] * rtM11[idx0], S[idx0, 0, 2] - e12[idx0] * rtM11[idx0] + + # Compute np1 and np2 for idx == 1 + np1[idx1, 0], np2[idx1, 0] = S[idx1, 0, 1] + rtM22[idx1], S[idx1, 0, 1] - rtM22[idx1] + np1[idx1, 1], np2[idx1, 1] = S[idx1, 1, 1], S[idx1, 1, 1] + np1[idx1, 2], np2[idx1, 2] = S[idx1, 1, 2] - e02[idx1] * rtM00[idx1], S[idx1, 1, 2] + e02[idx1] * rtM00[idx1] + + # Compute np1 and np2 for idx == 2 + np1[idx2, 0], np2[idx2, 0] = S[idx2, 0, 2] + e01[idx2] * rtM11[idx2], S[idx2, 0, 2] - e01[idx2] * rtM11[idx2] + np1[idx2, 1], np2[idx2, 1] = S[idx2, 1, 2] + rtM00[idx2], S[idx2, 1, 2] - rtM00[idx2] + np1[idx2, 2], np2[idx2, 2] = S[idx2, 2, 2], S[idx2, 2, 2] + + return np1, np2 + + +def compute_opposite_of_minor(matrix, row, col): + col1 = 1 if col == 0 else 0 + col2 = 1 if col == 2 else 2 + row1 = 1 if row == 0 else 0 + row2 = 1 if row == 2 else 2 + return matrix[:, row1, col2] * matrix[:, row2, col1] - matrix[:, row1, col1] * matrix[:, row2, col2] diff --git a/vggsfm/vggsfm/two_view_geo/perspective_n_points.py b/vggsfm/vggsfm/two_view_geo/perspective_n_points.py new file mode 100644 index 0000000000000000000000000000000000000000..7634b30ea4f3d2995a97f64ab70aa0db2844fde8 --- /dev/null +++ b/vggsfm/vggsfm/two_view_geo/perspective_n_points.py @@ -0,0 +1,393 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +This file contains Efficient PnP algorithm for Perspective-n-Points problem. +It finds a camera position (defined by rotation `R` and translation `T`) that +minimizes re-projection error between the given 3D points `x` and +the corresponding uncalibrated 2D points `y`. +""" + +import warnings +from typing import NamedTuple, Optional + +import torch +import torch.nn.functional as F + +try: + from pytorch3d.ops import points_alignment, utils as oputil +except: + print("pytorch3d pnp not available") + + +class EpnpSolution(NamedTuple): + x_cam: torch.Tensor + R: torch.Tensor + T: torch.Tensor + err_2d: torch.Tensor + err_3d: torch.Tensor + + +def _define_control_points(x, weight, storage_opts=None): + """ + Returns control points that define barycentric coordinates + Args: + x: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`. + weight: Batch of non-negative weights of + shape `(minibatch, num_point)`. `None` means equal weights. + storage_opts: dict of keyword arguments to the tensor constructor. + """ + storage_opts = storage_opts or {} + x_mean = oputil.wmean(x, weight) + c_world = F.pad(torch.eye(3, **storage_opts), (0, 0, 0, 1), value=0.0).expand_as(x[:, :4, :]) + return c_world + x_mean + + +def _compute_alphas(x, c_world): + """ + Computes barycentric coordinates of x in the frame c_world. + Args: + x: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`. + c_world: control points in world coordinates. + """ + x = F.pad(x, (0, 1), value=1.0) + c = F.pad(c_world, (0, 1), value=1.0) + return torch.matmul(x, torch.inverse(c)) # B x N x 4 + + +def _build_M(y, alphas, weight): + """Returns the matrix defining the reprojection equations. + Args: + y: projected points in camera coordinates of size B x N x 2 + alphas: barycentric coordinates of size B x N x 4 + weight: Batch of non-negative weights of + shape `(minibatch, num_point)`. `None` means equal weights. + """ + bs, n, _ = y.size() + + # prepend t with the column of v's + def prepad(t, v): + return F.pad(t, (1, 0), value=v) + + if weight is not None: + # weight the alphas in order to get a correctly weighted version of M + alphas = alphas * weight[:, :, None] + + # outer left-multiply by alphas + def lm_alphas(t): + return torch.matmul(alphas[..., None], t).reshape(bs, n, 12) + + M = torch.cat( + ( + lm_alphas(prepad(prepad(-y[:, :, 0, None, None], 0.0), 1.0)), # u constraints + lm_alphas(prepad(prepad(-y[:, :, 1, None, None], 1.0), 0.0)), # v constraints + ), + dim=-1, + ).reshape(bs, -1, 12) + + return M + + +def _null_space(m, kernel_dim): + """Finds the null space (kernel) basis of the matrix + Args: + m: the batch of input matrices, B x N x 12 + kernel_dim: number of dimensions to approximate the kernel + Returns: + * a batch of null space basis vectors + of size B x 4 x 3 x kernel_dim + * a batch of spectral values where near-0s correspond to actual + kernel vectors, of size B x kernel_dim + """ + mTm = torch.bmm(m.transpose(1, 2), m) + s, v = torch.linalg.eigh(mTm) + return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim] + + +def _reproj_error(y_hat, y, weight, eps=1e-9): + """Projects estimated 3D points and computes the reprojection error + Args: + y_hat: a batch of predicted 2D points in homogeneous coordinates + y: a batch of ground-truth 2D points + weight: Batch of non-negative weights of + shape `(minibatch, num_point)`. `None` means equal weights. + Returns: + Optionally weighted RMSE of difference between y and y_hat. + """ + y_hat = y_hat / torch.clamp(y_hat[..., 2:], eps) + dist = ((y - y_hat[..., :2]) ** 2).sum(dim=-1, keepdim=True) ** 0.5 + return oputil.wmean(dist, weight)[:, 0, 0] + + +def _algebraic_error(x_w_rotated, x_cam, weight): + """Computes the residual of Umeyama in 3D. + Args: + x_w_rotated: The given 3D points rotated with the predicted camera. + x_cam: the lifted 2D points y + weight: Batch of non-negative weights of + shape `(minibatch, num_point)`. `None` means equal weights. + Returns: + Optionally weighted MSE of difference between x_w_rotated and x_cam. + """ + dist = ((x_w_rotated - x_cam) ** 2).sum(dim=-1, keepdim=True) + return oputil.wmean(dist, weight)[:, 0, 0] + + +def _compute_norm_sign_scaling_factor(c_cam, alphas, x_world, y, weight, eps=1e-9): + """Given a solution, adjusts the scale and flip + Args: + c_cam: control points in camera coordinates + alphas: barycentric coordinates of the points + x_world: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`. + y: Batch of 2-dimensional points of shape `(minibatch, num_points, 2)`. + weights: Batch of non-negative weights of + shape `(minibatch, num_point)`. `None` means equal weights. + eps: epsilon to threshold negative `z` values + """ + # position of reference points in camera coordinates + x_cam = torch.matmul(alphas, c_cam) + + x_cam = x_cam * (1.0 - 2.0 * (oputil.wmean(x_cam[..., 2:], weight) < 0).float()) + if torch.any(x_cam[..., 2:] < -eps): + neg_rate = oputil.wmean((x_cam[..., 2:] < 0).float(), weight, dim=(0, 1)).item() + warnings.warn("\nEPnP: %2.2f%% points have z<0." % (neg_rate * 100.0)) + + R, T, s = points_alignment.corresponding_points_alignment(x_world, x_cam, weight, estimate_scale=True) + s = s.clamp(eps) + x_cam = x_cam / s[:, None, None] + T = T / s[:, None] + x_w_rotated = torch.matmul(x_world, R) + T[:, None, :] + err_2d = _reproj_error(x_w_rotated, y, weight) + err_3d = _algebraic_error(x_w_rotated, x_cam, weight) + + return EpnpSolution(x_cam, R, T, err_2d, err_3d) + + +def _gen_pairs(input, dim=-2, reducer=lambda a, b: ((a - b) ** 2).sum(dim=-1)): + """Generates all pairs of different rows and then applies the reducer + Args: + input: a tensor + dim: a dimension to generate pairs across + reducer: a function of generated pair of rows to apply (beyond just concat) + Returns: + for default args, for A x B x C input, will output A x (B choose 2) + """ + n = input.size()[dim] + range = torch.arange(n) + idx = torch.combinations(range).to(input).long() + left = input.index_select(dim, idx[:, 0]) + right = input.index_select(dim, idx[:, 1]) + return reducer(left, right) + + +def _kernel_vec_distances(v): + """Computes the coefficients for linearization of the quadratic system + to match all pairwise distances between 4 control points (dim=1). + The last dimension corresponds to the coefficients for quadratic terms + Bij = Bi * Bj, where Bi and Bj correspond to kernel vectors. + Arg: + v: tensor of B x 4 x 3 x D, where D is dim(kernel), usually 4 + Returns: + a tensor of B x 6 x [(D choose 2) + D]; + for D=4, the last dim means [B11 B22 B33 B44 B12 B13 B14 B23 B24 B34]. + """ + dv = _gen_pairs(v, dim=-3, reducer=lambda a, b: a - b) # B x 6 x 3 x D + + # we should take dot-product of all (i,j), i < j, with coeff 2 + rows_2ij = 2.0 * _gen_pairs(dv, dim=-1, reducer=lambda a, b: (a * b).sum(dim=-2)) + # this should produce B x 6 x (D choose 2) tensor + + # we should take dot-product of all (i,i) + rows_ii = (dv**2).sum(dim=-2) + # this should produce B x 6 x D tensor + + return torch.cat((rows_ii, rows_2ij), dim=-1) + + +def _solve_lstsq_subcols(rhs, lhs, lhs_col_idx): + """Solves an over-determined linear system for selected LHS columns. + A batched version of `torch.lstsq`. + Args: + rhs: right-hand side vectors + lhs: left-hand side matrices + lhs_col_idx: a slice of columns in lhs + Returns: + a least-squares solution for lhs * X = rhs + """ + lhs = lhs.index_select(-1, torch.tensor(lhs_col_idx, device=lhs.device).long()) + return torch.matmul(torch.pinverse(lhs), rhs[:, :, None]) + + +def _binary_sign(t): + return (t >= 0).to(t) * 2.0 - 1.0 + + +def _find_null_space_coords_1(kernel_dsts, cw_dst, eps=1e-9): + """Solves case 1 from the paper [1]; solve for 4 coefficients: + [B11 B22 B33 B44 B12 B13 B14 B23 B24 B34] + ^ ^ ^ ^ + Args: + kernel_dsts: distances between kernel vectors + cw_dst: distances between control points + Returns: + coefficients to weight kernel vectors + [1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009). + EPnP: An Accurate O(n) solution to the PnP problem. + International Journal of Computer Vision. + https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/ + """ + beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 5, 6]) + + beta = beta * _binary_sign(beta[:, :1, :]) + return beta / torch.clamp(beta[:, :1, :] ** 0.5, eps) + + +def _find_null_space_coords_2(kernel_dsts, cw_dst): + """Solves case 2 from the paper; solve for 3 coefficients: + [B11 B22 B33 B44 B12 B13 B14 B23 B24 B34] + ^ ^ ^ + Args: + kernel_dsts: distances between kernel vectors + cw_dst: distances between control points + Returns: + coefficients to weight kernel vectors + [1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009). + EPnP: An Accurate O(n) solution to the PnP problem. + International Journal of Computer Vision. + https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/ + """ + beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 1]) + + coord_0 = (beta[:, :1, :].abs() ** 0.5) * _binary_sign(beta[:, 1:2, :]) + coord_1 = (beta[:, 2:3, :].abs() ** 0.5) * ((beta[:, :1, :] >= 0) == (beta[:, 2:3, :] >= 0)).float() + + return torch.cat((coord_0, coord_1, torch.zeros_like(beta[:, :2, :])), dim=1) + + +def _find_null_space_coords_3(kernel_dsts, cw_dst, eps=1e-9): + """Solves case 3 from the paper; solve for 5 coefficients: + [B11 B22 B33 B44 B12 B13 B14 B23 B24 B34] + ^ ^ ^ ^ ^ + Args: + kernel_dsts: distances between kernel vectors + cw_dst: distances between control points + Returns: + coefficients to weight kernel vectors + [1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009). + EPnP: An Accurate O(n) solution to the PnP problem. + International Journal of Computer Vision. + https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/ + """ + beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 1, 5, 7]) + + coord_0 = (beta[:, :1, :].abs() ** 0.5) * _binary_sign(beta[:, 1:2, :]) + coord_1 = (beta[:, 2:3, :].abs() ** 0.5) * ((beta[:, :1, :] >= 0) == (beta[:, 2:3, :] >= 0)).float() + coord_2 = beta[:, 3:4, :] / torch.clamp(coord_0[:, :1, :], eps) + + return torch.cat((coord_0, coord_1, coord_2, torch.zeros_like(beta[:, :1, :])), dim=1) + + +def efficient_pnp( + x: torch.Tensor, + y: torch.Tensor, + masks: Optional[torch.Tensor] = None, + weights: Optional[torch.Tensor] = None, + skip_quadratic_eq: bool = False, +) -> EpnpSolution: + """ + Implements Efficient PnP algorithm [1] for Perspective-n-Points problem: + finds a camera position (defined by rotation `R` and translation `T`) that + minimizes re-projection error between the given 3D points `x` and + the corresponding uncalibrated 2D points `y`, i.e. solves + + `y[i] = Proj(x[i] R[i] + T[i])` + + in the least-squares sense, where `i` are indices within the batch, and + `Proj` is the perspective projection operator: `Proj([x y z]) = [x/z y/z]`. + In the noise-less case, 4 points are enough to find the solution as long + as they are not co-planar. + + Args: + x: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`. + y: Batch of 2-dimensional points of shape `(minibatch, num_points, 2)`. + weights: Batch of non-negative weights of + shape `(minibatch, num_point)`. `None` means equal weights. + skip_quadratic_eq: If True, assumes the solution space for the + linear system is one-dimensional, i.e. takes the scaled eigenvector + that corresponds to the smallest eigenvalue as a solution. + If False, finds the candidate coordinates in the potentially + 4D null space by approximately solving the systems of quadratic + equations. The best candidate is chosen by examining the 2D + re-projection error. While this option finds a better solution, + especially when the number of points is small or perspective + distortions are low (the points are far away), it may be more + difficult to back-propagate through. + + Returns: + `EpnpSolution` namedtuple containing elements: + **x_cam**: Batch of transformed points `x` that is used to find + the camera parameters, of shape `(minibatch, num_points, 3)`. + In the general (noisy) case, they are not exactly equal to + `x[i] R[i] + T[i]` but are some affine transform of `x[i]`s. + **R**: Batch of rotation matrices of shape `(minibatch, 3, 3)`. + **T**: Batch of translation vectors of shape `(minibatch, 3)`. + **err_2d**: Batch of mean 2D re-projection errors of shape + `(minibatch,)`. Specifically, if `yhat` is the re-projection for + the `i`-th batch element, it returns `sum_j norm(yhat_j - y_j)` + where `j` iterates over points and `norm` denotes the L2 norm. + **err_3d**: Batch of mean algebraic errors of shape `(minibatch,)`. + Specifically, those are squared distances between `x_world` and + estimated points on the rays defined by `y`. + + [1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009). + EPnP: An Accurate O(n) solution to the PnP problem. + International Journal of Computer Vision. + https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/ + """ + # define control points in a world coordinate system (centered on the 3d + # points centroid); 4 x 3 + # TODO: more stable when initialised with the center and eigenvectors! + weights = masks + + c_world = _define_control_points(x.detach(), weights, storage_opts={"dtype": x.dtype, "device": x.device}) + + # find the linear combination of the control points to represent the 3d points + alphas = _compute_alphas(x, c_world) + + M = _build_M(y, alphas, weights) + + # import pdb;pdb.set_trace() + + # Compute kernel M + kernel, spectrum = _null_space(M, 4) + + c_world_distances = _gen_pairs(c_world) + kernel_dsts = _kernel_vec_distances(kernel) + + betas = ( + [] + if skip_quadratic_eq + else [ + fnsc(kernel_dsts, c_world_distances) + for fnsc in [_find_null_space_coords_1, _find_null_space_coords_2, _find_null_space_coords_3] + ] + ) + + c_cam_variants = [kernel] + [torch.matmul(kernel, beta[:, None, :, :]) for beta in betas] + + solutions = [_compute_norm_sign_scaling_factor(c_cam[..., 0], alphas, x, y, weights) for c_cam in c_cam_variants] + + sol_zipped = EpnpSolution(*(torch.stack(list(col)) for col in zip(*solutions))) + best = torch.argmin(sol_zipped.err_2d, dim=0) + + def gather1d(source, idx): + # reduces the dim=1 by picking the slices in a 1D tensor idx + # in other words, it is batched index_select. + return source.gather(0, idx.reshape(1, -1, *([1] * (len(source.shape) - 2))).expand_as(source[:1]))[0] + + return EpnpSolution(*[gather1d(sol_col, best) for sol_col in sol_zipped]) diff --git a/vggsfm/vggsfm/two_view_geo/pnp.py b/vggsfm/vggsfm/two_view_geo/pnp.py new file mode 100644 index 0000000000000000000000000000000000000000..714938f980070ca787fa74f831abc586ccb3ea75 --- /dev/null +++ b/vggsfm/vggsfm/two_view_geo/pnp.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Literal, Optional, Tuple + +import torch + +from kornia.core import Tensor, concatenate, ones_like, stack, where, zeros +from kornia.core.check import KORNIA_CHECK_SHAPE +from kornia.geometry.linalg import transform_points +from kornia.geometry.solvers import solve_cubic +from kornia.utils._compat import torch_version_ge + +import math +from torch.cuda.amp import autocast + + +import numpy as np +import kornia +from .perspective_n_points import efficient_pnp +from kornia.geometry.calibration.pnp import solve_pnp_dlt + + +from .utils import ( + generate_samples, + sampson_epipolar_distance_batched, + calculate_residual_indicator, + normalize_points_masked, + local_refinement, + _torch_svd_cast, +) + + +def conduct_pnp(points3D, points2D, intrinsics, max_ransac_iters=1024, max_error=8, lo_num=50, f_trials=51): + """ + Solve PnP problem by 6p algorithm + ePnP + LORANSAC + points2D and intrinsics is defined in pixel + """ + + max_thres = max_error**2 + + oriB = points3D.shape[0] + if f_trials > 0: + # Search for potential focal lengths + # Refer to + # https://github.com/colmap/colmap/blob/0ea2d5ceee1360bba427b2ef61f1351e59a46f91/src/colmap/estimators/pose.cc#L87 + # for more details + + B, P, _ = points3D.shape + f_factors = generate_focal_factors(f_trials - 1) + f_factors = torch.FloatTensor(f_factors).to(points2D.device) + + points3D = points3D[:, None].expand(-1, f_trials, -1, -1) + points2D = points2D[:, None].expand(-1, f_trials, -1, -1) + intrinsics = intrinsics[:, None].expand(-1, f_trials, -1, -1).clone() + intrinsics[:, :, 0, 0] = intrinsics[:, :, 0, 0] * f_factors[None, :] + intrinsics[:, :, 1, 1] = intrinsics[:, :, 1, 1] * f_factors[None, :] + + points3D = points3D.reshape(B * f_trials, P, 3) + points2D = points2D.reshape(B * f_trials, P, 2) + intrinsics = intrinsics.reshape(B * f_trials, 3, 3) + else: + f_trials = 1 + + # update B + B, P, _ = points3D.shape + + point_per_sample = 6 # 7p algorithm + ransac_idx = generate_samples(P, max_ransac_iters, point_per_sample) + + points3D_ransac = points3D[:, ransac_idx].view(B * max_ransac_iters, point_per_sample, 3) + points2D_ransac = points2D[:, ransac_idx].view(B * max_ransac_iters, point_per_sample, 2) + intrinsics_ransac = intrinsics[:, None].expand(-1, max_ransac_iters, -1, -1).reshape(B * max_ransac_iters, 3, 3) + pred_world_to_cam = solve_pnp_dlt(points3D_ransac, points2D_ransac, intrinsics_ransac) + + pred_world_to_cam_4x4 = kornia.eye_like(4, pred_world_to_cam) + pred_world_to_cam_4x4[:, :3, :] = pred_world_to_cam + + points3D_expand = points3D[:, None].expand(-1, max_ransac_iters, -1, -1).reshape(B * max_ransac_iters, P, 3) + points2D_expand = points2D[:, None].expand(-1, max_ransac_iters, -1, -1).reshape(B * max_ransac_iters, P, 2) + cam_points = kornia.geometry.transform_points(pred_world_to_cam_4x4, points3D_expand) + + img_points = kornia.geometry.project_points(cam_points, intrinsics_ransac[:, None]) + + che_invalid = cam_points[..., -1] <= 0 + residuals = (img_points - points2D_expand).norm(dim=-1) ** 2 + residuals[che_invalid] = 1e6 # fails for che Chirality + + inlier_mask = residuals <= max_thres + + inlier_mask = inlier_mask.reshape(B, max_ransac_iters, P) + inlier_num = inlier_mask.sum(dim=-1) + + sorted_values, sorted_indices = torch.sort(inlier_num, dim=1, descending=True) + + focal_length = intrinsics[:, [0, 1], [0, 1]] + principal_point = intrinsics[:, [0, 1], [2, 2]] + points2D_normalized = (points2D - principal_point[:, None]) / focal_length[:, None] + + # LORANSAC refinement + transform_lo = local_refinement( + efficient_pnp, points3D, points2D_normalized, inlier_mask, sorted_indices, lo_num=lo_num, skip_resize=True + ) + + pred_world_to_cam_4x4_lo = kornia.eye_like(4, transform_lo.R) + pred_world_to_cam_4x4_lo[:, :3, :3] = transform_lo.R.permute(0, 2, 1) + pred_world_to_cam_4x4_lo[:, :3, 3] = transform_lo.T + + all_pmat = pred_world_to_cam_4x4_lo.reshape(B, lo_num, 4, 4) + + all_pmat_num = all_pmat.shape[1] + # all + points3D_expand = points3D[:, None].expand(-1, all_pmat_num, -1, -1).reshape(B * all_pmat_num, P, 3) + points2D_expand = points2D[:, None].expand(-1, all_pmat_num, -1, -1).reshape(B * all_pmat_num, P, 2) + intrinsics_all = intrinsics[:, None].expand(-1, all_pmat_num, -1, -1).reshape(B * all_pmat_num, 3, 3) + + cam_points = kornia.geometry.transform_points(all_pmat.reshape(B * all_pmat_num, 4, 4), points3D_expand) + img_points = kornia.geometry.project_points(cam_points, intrinsics_all[:, None]) + + residuals_all = (img_points - points2D_expand).norm(dim=-1) ** 2 + + che_invalid_all = cam_points[..., -1] <= 0 + residuals_all[che_invalid_all] = 1e6 # fails for che Chirality + + residuals_all = residuals_all.reshape(B, all_pmat_num, P) + residuals_all = residuals_all.reshape(oriB, f_trials, all_pmat_num, P).reshape(oriB, f_trials * all_pmat_num, P) + + residual_indicator, inlier_num_all, inlier_mask_all = calculate_residual_indicator( + residuals_all, max_thres, debug=True + ) + + # update B back to original B + B = residual_indicator.shape[0] + batch_index = torch.arange(B).unsqueeze(-1).expand(-1, lo_num) + + best_p_indices = torch.argmax(residual_indicator, dim=1) + + all_pmat = all_pmat.reshape(B, f_trials, all_pmat_num, 4, 4).reshape(B, f_trials * all_pmat_num, 4, 4) + all_intri = intrinsics_all.reshape(B, f_trials, all_pmat_num, 3, 3).reshape(B, f_trials * all_pmat_num, 3, 3) + + best_pmat = all_pmat[batch_index[:, 0], best_p_indices] + best_intri = all_intri[batch_index[:, 0], best_p_indices] + + best_inlier_num = inlier_num_all[batch_index[:, 0], best_p_indices] + best_inlier_mask = inlier_mask_all[batch_index[:, 0], best_p_indices] + + return best_pmat, best_intri, best_inlier_num, best_inlier_mask + + +def generate_focal_factors(num_focal_length_samples=10, max_focal_length_ratio=5, min_focal_length_ratio=0.2): + focal_length_factors = [] + fstep = 1.0 / num_focal_length_samples + fscale = max_focal_length_ratio - min_focal_length_ratio + focal = 0.0 + for i in range(num_focal_length_samples): + focal_length_factors.append(min_focal_length_ratio + fscale * focal * focal) + focal += fstep + focal_length_factors.append(1.0) + return focal_length_factors diff --git a/vggsfm/vggsfm/two_view_geo/utils.py b/vggsfm/vggsfm/two_view_geo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..709d3208e19c4a36d9506800505f19866baa3acc --- /dev/null +++ b/vggsfm/vggsfm/two_view_geo/utils.py @@ -0,0 +1,462 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Adapted from https://github.com/kornia + +from typing import Literal, Optional, Tuple +import numpy as np +import torch +import cv2 +import math + +# Importing Kornia core functionalities +from kornia.core import Tensor, concatenate, ones_like, stack, where, zeros +from kornia.core.check import KORNIA_CHECK_SHAPE, KORNIA_CHECK_IS_TENSOR + +# Importing Kornia geometry functionalities +from kornia.geometry.conversions import convert_points_from_homogeneous, convert_points_to_homogeneous +from kornia.geometry.linalg import transform_points +from kornia.geometry.solvers import solve_cubic +from kornia.geometry.epipolar.fundamental import normalize_points, normalize_transformation + +# Importing PyTorch functionalities +from torch.cuda.amp import autocast + +# Importing Kornia utils +from kornia.utils._compat import torch_version_ge + + +def generate_samples(N, target_num, sample_num, expand_ratio=2): + """ + This function generates random samples of indices without duplicates. + + Parameters: + N (int): The upper limit for generating random integers. + max_num_trials (int): The maximum number of trials for generating samples. + sample_num (int): The number of samples to generate. + + Returns: + np.array: An array of indices without duplicates. + """ + sample_idx = np.random.randint(0, N, size=(target_num * expand_ratio, sample_num)) + sorted_array = np.sort(sample_idx, axis=1) + diffs = np.diff(sorted_array, axis=1) + has_duplicates = (diffs == 0).any(axis=1) + indices_wo_duplicates = np.where(~has_duplicates)[0] + sample_idx_safe = sample_idx[indices_wo_duplicates][:target_num] + + return sample_idx_safe + + +def calculate_residual_indicator(residuals, max_residual, debug=False, check=False, nanvalue=1e6): + with autocast(dtype=torch.double): + B, S, N = residuals.shape + inlier_mask = residuals <= max_residual + + inlier_num = inlier_mask.sum(dim=-1) + + # only consider the residuals of inliers, BxSxN + residual_indicator = inlier_mask.float() * residuals + # the average residual for inliers + residual_indicator = residual_indicator.sum(dim=-1) / inlier_num + # remove zero dividing + residual_indicator = torch.nan_to_num(residual_indicator, nan=nanvalue, posinf=nanvalue, neginf=nanvalue) + # we want the min average one, but don't want it to change the choice of inlier num + thres = residual_indicator.max() + 1e-6 + + residual_indicator = (thres - residual_indicator) / thres + # choose the one with the higher inlier number and smallest (valid) residual + residual_indicator = residual_indicator.double() + inlier_num.double() + + return residual_indicator, inlier_num, inlier_mask + + +def sampson_epipolar_distance_batched( + pts1: Tensor, pts2: Tensor, Fm: Tensor, squared: bool = True, eps: float = 1e-8, debug=False, evaluation=False +) -> Tensor: + """Return Sampson distance for correspondences given the fundamental matrices. + + Args: + pts1: correspondences from the left images with shape :math:`(B, N, (2|3))`. + pts2: correspondences from the right images with shape :math:`(B, N, (2|3))`. + Fm: Batch of fundamental matrices with shape :math:`(B, K, 3, 3)`. + squared: if True (default), the squared distance is returned. + eps: Small constant for safe sqrt. + + Returns: + the computed Sampson distance with shape :math:`(B, K, N)`. + """ + # TODO: check why this would take a high GPU memory + + if not isinstance(Fm, Tensor): + raise TypeError(f"Fm type is not a torch.Tensor. Got {type(Fm)}") + + if Fm.shape[-2:] != (3, 3): + raise ValueError(f"Fm must be a (B, K, 3, 3) tensor. Got {Fm.shape}") + + dtype = pts1.dtype + efficient_dtype = torch.float32 + + with autocast(dtype=efficient_dtype): + if pts1.shape[-1] == 2: + pts1 = convert_points_to_homogeneous(pts1) + + if pts2.shape[-1] == 2: + pts2 = convert_points_to_homogeneous(pts2) + + # Expand pts1 and pts2 to match Fm's batch and K dimensions for broadcasting + B, K, _, _ = Fm.shape + N = pts1.shape[1] + + pts1_expanded = pts1[:, None, :, :].expand(-1, K, -1, -1) # Shape: (B, K, N, 3) + pts2_expanded = pts2[:, None, :, :].expand(-1, K, -1, -1) # Shape: (B, K, N, 3) + + Fm = Fm.to(efficient_dtype) + F_t = Fm.transpose(-2, -1) # Shape: (B, K, 3, 3) + + # pts1_expanded @ F_t + line1_in_2 = torch.einsum("bkij,bkjn->bkin", pts1_expanded, F_t) # Shape: (B, K, N, 3) + if evaluation: + torch.cuda.empty_cache() + line2_in_1 = torch.einsum("bkij,bkjn->bkin", pts2_expanded, Fm) # Shape: (B, K, N, 3) + if evaluation: + torch.cuda.empty_cache() + + numerator = (pts2_expanded * line1_in_2).sum(dim=-1).pow(2) # Shape: (B, K, N) + denominator = line1_in_2[..., :2].norm(2, dim=-1).pow(2) + line2_in_1[..., :2].norm(2, dim=-1).pow( + 2 + ) # Shape: (B, K, N) + + out = numerator / denominator + + out = out.to(dtype) + if debug: + return numerator, denominator, out, line1_in_2, line2_in_1 + + if squared: + return out + return (out + eps).sqrt() + + +def normalize_points_masked( + points: Tensor, masks: Tensor, eps: float = 1e-8, colmap_style=False +) -> Tuple[Tensor, Tensor]: + """ + Normalizes points using a boolean mask to exclude certain points. + + Args: + points: Tensor containing the points to be normalized with shape (B, N, 2). + masks: Bool tensor indicating which points to include with shape (B, N). + eps: epsilon value to avoid numerical instabilities. + + Returns: + Tuple containing the normalized points in the shape (B, N, 2) and the transformation matrix in the shape (B, 3, 3). + """ + if len(points.shape) != 3 or points.shape[-1] != 2: + raise ValueError(f"Expected points with shape (B, N, 2), got {points.shape}") + + if masks is None: + masks = ones_like(points[..., 0]) + + if masks.shape != points.shape[:-1]: + raise ValueError(f"Expected masks with shape {points.shape[:-1]}, got {masks.shape}") + + # Convert masks to float and apply it + mask_f = masks.float().unsqueeze(-1) # BxNx1 + masked_points = points * mask_f + + # Compute mean only over masked (non-zero) points + num_valid_points = mask_f.sum(dim=1, keepdim=True) # Bx1x1 + x_mean = masked_points.sum(dim=1, keepdim=True) / (num_valid_points + eps) # Bx1x2 + + diffs = masked_points - x_mean # BxNx2, Apply mask before subtraction to zero-out invalid points + + if colmap_style: + sum_squared_diffs = (diffs**2).sum(dim=-1).sum(dim=-1) # Shape: (B, N) + mean_squared_diffs = sum_squared_diffs / (num_valid_points[:, 0, 0] + eps) # Shape: (B,) + rms_mean_dist = torch.sqrt(mean_squared_diffs) # Shape: (B,) + rms_mean_dist = torch.clamp(rms_mean_dist, min=eps) + scale = torch.sqrt(torch.tensor(2.0)) / rms_mean_dist # Shape: (B,) + else: + # Compute scale only over masked points + scale = (diffs.norm(dim=-1, p=2) * masks).sum(dim=-1) / (num_valid_points[:, 0, 0] + eps) # B + scale = torch.sqrt(torch.tensor(2.0)) / (scale + eps) # B + + # Prepare transformation matrix components + ones = torch.ones_like(scale) + zeros = torch.zeros_like(scale) + + transform = stack( + [scale, zeros, -scale * x_mean[..., 0, 0], zeros, scale, -scale * x_mean[..., 0, 1], zeros, zeros, ones], dim=-1 + ) # Bx3x3 + + transform = transform.view(-1, 3, 3) # Bx3x3 + points_norm = transform_points(transform, points) # BxNx2 + + return points_norm, transform + + +def local_refinement( + local_estimator, points1, points2, inlier_mask, sorted_indices, lo_num=50, essential=False, skip_resize=False +): + # Running local refinement by local_estimator based on inlier_mask + # as in LORANSAC + + B, N, _ = points1.shape + batch_index = torch.arange(B).unsqueeze(-1).expand(-1, lo_num) + + points1_expand = points1.unsqueeze(1).expand(-1, lo_num, -1, -1) + points2_expand = points2.unsqueeze(1).expand(-1, lo_num, -1, -1) + + # The sets selected for local refinement + lo_indices = sorted_indices[:, :lo_num] + + # Find the points that would be used for local_estimator + lo_mask = inlier_mask[batch_index, lo_indices] + lo_points1 = torch.zeros_like(points1_expand) + lo_points1[lo_mask] = points1_expand[lo_mask] + lo_points2 = torch.zeros_like(points2_expand) + lo_points2[lo_mask] = points2_expand[lo_mask] + + lo_points1 = lo_points1.reshape(B * lo_num, N, -1) + lo_points2 = lo_points2.reshape(B * lo_num, N, -1) + lo_mask = lo_mask.reshape(B * lo_num, N) + + pred_mat = local_estimator(lo_points1, lo_points2, masks=lo_mask) + + if skip_resize: + return pred_mat + + if essential: + return pred_mat.reshape(B, lo_num, 10, 3, 3) + + return pred_mat.reshape(B, lo_num, 3, 3) + + +def inlier_by_fundamental(fmat, tracks, max_error=0.5): + """ + Given tracks and fundamental matrix, compute the inlier mask for each 2D match + """ + + B, S, N, _ = tracks.shape + left = tracks[:, 0:1].expand(-1, S - 1, -1, -1).reshape(B * (S - 1), N, 2) + right = tracks[:, 1:].reshape(B * (S - 1), N, 2) + + fmat = fmat.reshape(B * (S - 1), 3, 3) + + max_thres = max_error**2 + + residuals = sampson_epipolar_distance_batched(left, right, fmat[:, None], squared=True) + + residuals = residuals[:, 0] + + inlier_mask = residuals <= max_thres + + inlier_mask = inlier_mask.reshape(B, S - 1, -1) + return inlier_mask + + +def remove_cheirality(R, t, points1, points2, focal_length=None, principal_point=None): + # TODO: merge this function with triangulation utils + with autocast(dtype=torch.double): + if focal_length is not None: + principal_point = principal_point.unsqueeze(1) + focal_length = focal_length.unsqueeze(1) + + points1 = (points1 - principal_point[..., :2]) / focal_length[..., :2] + points2 = (points2 - principal_point[..., 2:]) / focal_length[..., 2:] + + B, cheirality_dim, _, _ = R.shape + Bche = B * cheirality_dim + _, N, _ = points1.shape + points1_expand = points1[:, None].expand(-1, cheirality_dim, -1, -1) + points2_expand = points2[:, None].expand(-1, cheirality_dim, -1, -1) + points1_expand = points1_expand.reshape(Bche, N, 2) + points2_expand = points2_expand.reshape(Bche, N, 2) + + cheirality_num, points3D = check_cheirality_batch( + R.reshape(Bche, 3, 3), t.reshape(Bche, 3), points1_expand, points2_expand + ) + cheirality_num = cheirality_num.reshape(B, cheirality_dim) + + cheirality_idx = torch.argmax(cheirality_num, dim=1) + + batch_idx = torch.arange(B) + R_cheirality = R[batch_idx, cheirality_idx] + t_cheirality = t[batch_idx, cheirality_idx] + + return R_cheirality, t_cheirality + + +def triangulate_point_batch(cam1_from_world, cam2_from_world, points1, points2): + # TODO: merge this function with triangulation utils + + B, N, _ = points1.shape + A = torch.zeros(B, N, 4, 4, dtype=points1.dtype, device=points1.device) + + A[:, :, 0, :] = points1[:, :, 0, None] * cam1_from_world[:, None, 2, :] - cam1_from_world[:, None, 0, :] + A[:, :, 1, :] = points1[:, :, 1, None] * cam1_from_world[:, None, 2, :] - cam1_from_world[:, None, 1, :] + A[:, :, 2, :] = points2[:, :, 0, None] * cam2_from_world[:, None, 2, :] - cam2_from_world[:, None, 0, :] + A[:, :, 3, :] = points2[:, :, 1, None] * cam2_from_world[:, None, 2, :] - cam2_from_world[:, None, 1, :] + + # Perform SVD on A + _, _, Vh = torch.linalg.svd(A.view(-1, 4, 4), full_matrices=True) + V = Vh.transpose(-2, -1) # Transpose Vh to get V + + # Extract the last column of V for each batch and point, then reshape to the original batch and points shape + X = V[..., -1].view(B, N, 4) + return X[..., :3] / X[..., 3, None] + + +def calculate_depth_batch(proj_matrices, points3D): + # TODO: merge this function with triangulation utils + + # proj_matrices: Bx3x4 + # points3D: BxNx3 + B, N, _ = points3D.shape + points3D_homo = torch.cat((points3D, torch.ones(B, N, 1, dtype=points3D.dtype, device=points3D.device)), dim=-1) + points2D_homo = torch.einsum("bij,bkj->bki", proj_matrices, points3D_homo) + return points2D_homo[..., 2] + + +def check_cheirality_batch(R, t, points1, points2): + # TODO: merge this function with triangulation utils + + B, N, _ = points1.shape + assert points1.shape == points2.shape + + proj_matrix1 = torch.eye(3, 4, dtype=R.dtype, device=R.device).expand(B, -1, -1) + proj_matrix2 = torch.zeros(B, 3, 4, dtype=R.dtype, device=R.device) + proj_matrix2[:, :, :3] = R + proj_matrix2[:, :, 3] = t + + kMinDepth = torch.finfo(R.dtype).eps + max_depth = 1000.0 * torch.linalg.norm(R.transpose(-2, -1) @ t[:, :, None], dim=1) + + points3D = triangulate_point_batch(proj_matrix1, proj_matrix2, points1, points2) + + depths1 = calculate_depth_batch(proj_matrix1, points3D) + depths2 = calculate_depth_batch(proj_matrix2, points3D) + + valid_depths = (depths1 > kMinDepth) & (depths1 < max_depth) & (depths2 > kMinDepth) & (depths2 < max_depth) + + valid_nums = valid_depths.sum(dim=-1) + return valid_nums, points3D + + +###################################################################################################### + + +def sampson_epipolar_distance_forloop_wrapper( + pts1: Tensor, pts2: Tensor, Fm: Tensor, squared: bool = True, eps: float = 1e-8, debug=False +) -> Tensor: + """Wrapper function for sampson_epipolar_distance_batched to loop over B dimension. + + Args: + pts1: correspondences from the left images with shape :math:`(B, N, (2|3))`. + pts2: correspondences from the right images with shape :math:`(B, N, (2|3))`. + Fm: Batch of fundamental matrices with shape :math:`(B, K, 3, 3)`. + squared: if True (default), the squared distance is returned. + eps: Small constant for safe sqrt. + + Returns: + the computed Sampson distance with shape :math:`(B, K, N)`. + """ + B = Fm.shape[0] + output_list = [] + + for b in range(B): + output = sampson_epipolar_distance_batched( + pts1[b].unsqueeze(0), + pts2[b].unsqueeze(0), + Fm[b].unsqueeze(0), + squared=squared, + eps=eps, + debug=debug, + evaluation=True, + ) + output_list.append(output) + + return torch.cat(output_list, dim=0) + + +def get_default_intri(width, height, device, dtype, ratio=1.0): + # assume same focal length for hw + max_size = max(width, height) + focal_length = max_size * ratio + + principal_point = [width / 2, height / 2] + + K = torch.tensor( + [[focal_length, 0, width / 2], [0, focal_length, height / 2], [0, 0, 1]], device=device, dtype=dtype + ) + + return ( + torch.tensor(focal_length, device=device, dtype=dtype), + torch.tensor(principal_point, device=device, dtype=dtype), + K, + ) + + +def _torch_svd_cast(input: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """Helper function to make torch.svd work with other than fp32/64. + + The function torch.svd is only implemented for fp32/64 which makes + impossible to be used by fp16 or others. What this function does, is cast + input data type to fp32, apply torch.svd, and cast back to the input dtype. + + NOTE: in torch 1.8.1 this function is recommended to use as torch.linalg.svd + """ + out1, out2, out3H = torch.linalg.svd(input) + if torch_version_ge(1, 11): + out3 = out3H.mH + else: + out3 = out3H.transpose(-1, -2) + return (out1, out2, out3) + + +def oneway_transfer_error_batched( + pts1: Tensor, pts2: Tensor, H: Tensor, squared: bool = True, eps: float = 1e-8 +) -> Tensor: + r"""Return transfer error in image 2 for correspondences given the homography matrix. + + Args: + pts1: correspondences from the left images with shape + (B, N, 2 or 3). If they are homogeneous, converted automatically. + pts2: correspondences from the right images with shape + (B, N, 2 or 3). If they are homogeneous, converted automatically. + H: Homographies with shape :math:`(B, K, 3, 3)`. + squared: if True (default), the squared distance is returned. + eps: Small constant for safe sqrt. + + Returns: + the computed distance with shape :math:`(B, K, N)`. + """ + + # From Hartley and Zisserman, Error in one image (4.6) + # dist = \sum_{i} ( d(x', Hx)**2) + + if pts1.shape[-1] == 2: + pts1 = convert_points_to_homogeneous(pts1) + + B, K, _, _ = H.shape + N = pts1.shape[1] + + pts1_expanded = pts1[:, None, :, :].expand(-1, K, -1, -1) # Shape: (B, K, N, 3) + + H_transpose = H.permute(0, 1, 3, 2) + + pts1_in_2_h = torch.einsum("bkij,bkjn->bkin", pts1_expanded, H_transpose) + + pts1_in_2 = convert_points_from_homogeneous(pts1_in_2_h) + pts2_expanded = pts2[:, None, :, :].expand(-1, K, -1, -1) # Shape: (B, K, N, 2) + + error_squared = (pts1_in_2 - pts2_expanded).pow(2).sum(dim=-1) + + if squared: + return error_squared + return (error_squared + eps).sqrt() diff --git a/vggsfm/vggsfm/utils/metric.py b/vggsfm/vggsfm/utils/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..ad41aff7f9333234f040db89cd60d87784b2f10e --- /dev/null +++ b/vggsfm/vggsfm/utils/metric.py @@ -0,0 +1,325 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import random +import numpy as np +import torch + +from minipytorch3d.rotation_conversions import matrix_to_quaternion, quaternion_to_matrix + + +def quaternion_from_matrix(matrix, isprecise=False): + """Return quaternion from rotation matrix. + + If isprecise is True, the input matrix is assumed to be a precise rotation + matrix and a faster algorithm is used. + + >>> q = quaternion_from_matrix(numpy.identity(4), True) + >>> numpy.allclose(q, [1, 0, 0, 0]) + True + >>> q = quaternion_from_matrix(numpy.diag([1, -1, -1, 1])) + >>> numpy.allclose(q, [0, 1, 0, 0]) or numpy.allclose(q, [0, -1, 0, 0]) + True + >>> R = rotation_matrix(0.123, (1, 2, 3)) + >>> q = quaternion_from_matrix(R, True) + >>> numpy.allclose(q, [0.9981095, 0.0164262, 0.0328524, 0.0492786]) + True + >>> R = [[-0.545, 0.797, 0.260, 0], [0.733, 0.603, -0.313, 0], + ... [-0.407, 0.021, -0.913, 0], [0, 0, 0, 1]] + >>> q = quaternion_from_matrix(R) + >>> numpy.allclose(q, [0.19069, 0.43736, 0.87485, -0.083611]) + True + >>> R = [[0.395, 0.362, 0.843, 0], [-0.626, 0.796, -0.056, 0], + ... [-0.677, -0.498, 0.529, 0], [0, 0, 0, 1]] + >>> q = quaternion_from_matrix(R) + >>> numpy.allclose(q, [0.82336615, -0.13610694, 0.46344705, -0.29792603]) + True + >>> R = random_rotation_matrix() + >>> q = quaternion_from_matrix(R) + >>> is_same_transform(R, quaternion_matrix(q)) + True + >>> R = euler_matrix(0.0, 0.0, numpy.pi/2.0) + >>> numpy.allclose(quaternion_from_matrix(R, isprecise=False), + ... quaternion_from_matrix(R, isprecise=True)) + True + + """ + + M = np.array(matrix, dtype=np.float64, copy=False)[:4, :4] + if isprecise: + q = np.empty((4,)) + t = np.trace(M) + if t > M[3, 3]: + q[0] = t + q[3] = M[1, 0] - M[0, 1] + q[2] = M[0, 2] - M[2, 0] + q[1] = M[2, 1] - M[1, 2] + else: + i, j, k = 1, 2, 3 + if M[1, 1] > M[0, 0]: + i, j, k = 2, 3, 1 + if M[2, 2] > M[i, i]: + i, j, k = 3, 1, 2 + t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3] + q[i] = t + q[j] = M[i, j] + M[j, i] + q[k] = M[k, i] + M[i, k] + q[3] = M[k, j] - M[j, k] + q *= 0.5 / math.sqrt(t * M[3, 3]) + else: + m00 = M[0, 0] + m01 = M[0, 1] + m02 = M[0, 2] + m10 = M[1, 0] + m11 = M[1, 1] + m12 = M[1, 2] + m20 = M[2, 0] + m21 = M[2, 1] + m22 = M[2, 2] + + # symmetric matrix K + K = np.array( + [ + [m00 - m11 - m22, 0.0, 0.0, 0.0], + [m01 + m10, m11 - m00 - m22, 0.0, 0.0], + [m02 + m20, m12 + m21, m22 - m00 - m11, 0.0], + [m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22], + ] + ) + K /= 3.0 + + # quaternion is eigenvector of K that corresponds to largest eigenvalue + w, V = np.linalg.eigh(K) + q = V[[3, 0, 1, 2], np.argmax(w)] + + if q[0] < 0.0: + np.negative(q, q) + + return q + + +def camera_to_rel_deg(pred_cameras, gt_cameras, device, batch_size): + """ + Calculate relative rotation and translation angles between predicted and ground truth cameras. + + Args: + - pred_cameras: Predicted camera. + - gt_cameras: Ground truth camera. + - accelerator: The device for moving tensors to GPU or others. + - batch_size: Number of data samples in one batch. + + Returns: + - rel_rotation_angle_deg, rel_translation_angle_deg: Relative rotation and translation angles in degrees. + """ + + with torch.no_grad(): + # Convert cameras to 4x4 SE3 transformation matrices + gt_se3 = gt_cameras.get_world_to_view_transform().get_matrix() + pred_se3 = pred_cameras.get_world_to_view_transform().get_matrix() + + # Generate pairwise indices to compute relative poses + pair_idx_i1, pair_idx_i2 = batched_all_pairs(batch_size, gt_se3.shape[0] // batch_size) + pair_idx_i1 = pair_idx_i1.to(device) + + # Compute relative camera poses between pairs + # We use closed_form_inverse to avoid potential numerical loss by torch.inverse() + # This is possible because of SE3 + relative_pose_gt = closed_form_inverse(gt_se3[pair_idx_i1]).bmm(gt_se3[pair_idx_i2]) + relative_pose_pred = closed_form_inverse(pred_se3[pair_idx_i1]).bmm(pred_se3[pair_idx_i2]) + + # Compute the difference in rotation and translation + # between the ground truth and predicted relative camera poses + rel_rangle_deg = rotation_angle(relative_pose_gt[:, :3, :3], relative_pose_pred[:, :3, :3]) + rel_tangle_deg = translation_angle(relative_pose_gt[:, 3, :3], relative_pose_pred[:, 3, :3]) + + return rel_rangle_deg, rel_tangle_deg + + +def calculate_auc_np(r_error, t_error, max_threshold=30): + """ + Calculate the Area Under the Curve (AUC) for the given error arrays. + + :param r_error: numpy array representing R error values (Degree). + :param t_error: numpy array representing T error values (Degree). + :param max_threshold: maximum threshold value for binning the histogram. + :return: cumulative sum of normalized histogram of maximum error values. + """ + + # Concatenate the error arrays along a new axis + error_matrix = np.concatenate((r_error[:, None], t_error[:, None]), axis=1) + + # Compute the maximum error value for each pair + max_errors = np.max(error_matrix, axis=1) + + # Define histogram bins + bins = np.arange(max_threshold + 1) + + # Calculate histogram of maximum error values + histogram, _ = np.histogram(max_errors, bins=bins) + + # Normalize the histogram + num_pairs = float(len(max_errors)) + normalized_histogram = histogram.astype(float) / num_pairs + + # Compute and return the cumulative sum of the normalized histogram + return np.mean(np.cumsum(normalized_histogram)), normalized_histogram + + +def calculate_auc(r_error, t_error, max_threshold=30, return_list=False): + """ + Calculate the Area Under the Curve (AUC) for the given error arrays using PyTorch. + + :param r_error: torch.Tensor representing R error values (Degree). + :param t_error: torch.Tensor representing T error values (Degree). + :param max_threshold: maximum threshold value for binning the histogram. + :return: cumulative sum of normalized histogram of maximum error values. + """ + + # Concatenate the error tensors along a new axis + error_matrix = torch.stack((r_error, t_error), dim=1) + + # Compute the maximum error value for each pair + max_errors, _ = torch.max(error_matrix, dim=1) + + # Define histogram bins + bins = torch.arange(max_threshold + 1) + + # Calculate histogram of maximum error values + histogram = torch.histc(max_errors, bins=max_threshold + 1, min=0, max=max_threshold) + + # Normalize the histogram + num_pairs = float(max_errors.size(0)) + normalized_histogram = histogram / num_pairs + + if return_list: + return torch.cumsum(normalized_histogram, dim=0).mean(), normalized_histogram + # Compute and return the cumulative sum of the normalized histogram + return torch.cumsum(normalized_histogram, dim=0).mean() + + +def batched_all_pairs(B, N): + # B, N = se3.shape[:2] + i1_, i2_ = torch.combinations(torch.arange(N), 2, with_replacement=False).unbind(-1) + i1, i2 = [(i[None] + torch.arange(B)[:, None] * N).reshape(-1) for i in [i1_, i2_]] + + return i1, i2 + + +def closed_form_inverse_OpenCV(se3, R=None, T=None): + """ + Computes the inverse of each 4x4 SE3 matrix in the batch. + + Args: + - se3 (Tensor): Nx4x4 tensor of SE3 matrices. + + Returns: + - Tensor: Nx4x4 tensor of inverted SE3 matrices. + + + | R t | + | 0 1 | + --> + | R^T -R^T t| + | 0 1 | + """ + if R is None: + R = se3[:, :3, :3] + + if T is None: + T = se3[:, :3, 3:] + + # Compute the transpose of the rotation + R_transposed = R.transpose(1, 2) + + # -R^T t + top_right = -R_transposed.bmm(T) + + inverted_matrix = torch.eye(4, 4)[None].repeat(len(se3), 1, 1) + inverted_matrix = inverted_matrix.to(R.dtype).to(R.device) + + inverted_matrix[:, :3, :3] = R_transposed + inverted_matrix[:, :3, 3:] = top_right + + return inverted_matrix + + +def closed_form_inverse(se3, R=None, T=None): + """ + Computes the inverse of each 4x4 SE3 matrix in the batch. + This function assumes PyTorch3D coordinate. + + + Args: + - se3 (Tensor): Nx4x4 tensor of SE3 matrices. + + Returns: + - Tensor: Nx4x4 tensor of inverted SE3 matrices. + """ + if R is None: + R = se3[:, :3, :3] + + if T is None: + T = se3[:, 3:, :3] + + # NOTE THIS ASSUMES PYTORCH3D CAMERA COORDINATE + + # Compute the transpose of the rotation + R_transposed = R.transpose(1, 2) + + # Compute the left part of the inverse transformation + left_bottom = -T.bmm(R_transposed) + left_combined = torch.cat((R_transposed, left_bottom), dim=1) + + # Keep the right-most column as it is + right_col = se3[:, :, 3:].detach().clone() + inverted_matrix = torch.cat((left_combined, right_col), dim=-1) + + return inverted_matrix + + +def rotation_angle(rot_gt, rot_pred, batch_size=None, eps=1e-15): + ######### + q_pred = matrix_to_quaternion(rot_pred) + q_gt = matrix_to_quaternion(rot_gt) + + loss_q = (1 - (q_pred * q_gt).sum(dim=1) ** 2).clamp(min=eps) + err_q = torch.arccos(1 - 2 * loss_q) + + rel_rangle_deg = err_q * 180 / np.pi + + if batch_size is not None: + rel_rangle_deg = rel_rangle_deg.reshape(batch_size, -1) + + return rel_rangle_deg + + +def translation_angle(tvec_gt, tvec_pred, batch_size=None, ambiguity=True): + # tvec_gt, tvec_pred (B, 3,) + rel_tangle_deg = compare_translation_by_angle(tvec_gt, tvec_pred) + rel_tangle_deg = rel_tangle_deg * 180.0 / np.pi + + if ambiguity: + rel_tangle_deg = torch.min(rel_tangle_deg, (180 - rel_tangle_deg).abs()) + + if batch_size is not None: + rel_tangle_deg = rel_tangle_deg.reshape(batch_size, -1) + + return rel_tangle_deg + + +def compare_translation_by_angle(t_gt, t, eps=1e-15, default_err=1e6): + """Normalize the translation vectors and compute the angle between them.""" + t_norm = torch.norm(t, dim=1, keepdim=True) + t = t / (t_norm + eps) + + t_gt_norm = torch.norm(t_gt, dim=1, keepdim=True) + t_gt = t_gt / (t_gt_norm + eps) + + loss_t = torch.clamp_min(1.0 - torch.sum(t * t_gt, dim=1) ** 2, eps) + err_t = torch.acos(torch.sqrt(1 - loss_t)) + + err_t[torch.isnan(err_t) | torch.isinf(err_t)] = default_err + return err_t diff --git a/vggsfm/vggsfm/utils/tensor_to_pycolmap.py b/vggsfm/vggsfm/utils/tensor_to_pycolmap.py new file mode 100644 index 0000000000000000000000000000000000000000..91b26c34489ab641263adc95692059e627f0e3e4 --- /dev/null +++ b/vggsfm/vggsfm/utils/tensor_to_pycolmap.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np +import pycolmap + + +def batch_matrix_to_pycolmap( + points3d, extrinsics, intrinsics, tracks, masks, image_size, max_points3D_val=300, camera_type="simple_pinhole" +): + """ + Convert Batched Pytorch Tensors to PyCOLMAP + + Check https://github.com/colmap/pycolmap for more details about its format + """ + + # points3d: Px2 + # extrinsics: Nx3x4 + # intrinsics: Nx3x3 + # tracks: NxPx2 + # masks: NxP + # image_size: 2, assume all the frames have been padded to the same size + # where N is the number of frames and P is the number of tracks + + N, P, _ = tracks.shape + assert len(extrinsics) == N + assert len(intrinsics) == N + assert len(points3d) == P + assert image_size.shape[0] == 2 + + extrinsics = extrinsics.cpu().numpy() + intrinsics = intrinsics.cpu().numpy() + tracks = tracks.cpu().numpy() + masks = masks.cpu().numpy() + points3d = points3d.cpu().numpy() + image_size = image_size.cpu().numpy() + + # Reconstruction object, following the format of PyCOLMAP/COLMAP + reconstruction = pycolmap.Reconstruction() + + inlier_num = masks.sum(0) + valid_mask = inlier_num >= 2 # a track is invalid if without two inliers + valid_idx = np.nonzero(valid_mask)[0] + + # Only add 3D points that have sufficient 2D points + for vidx in valid_idx: + reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), np.zeros(3)) + + num_points3D = len(valid_idx) + + # frame idx + for fidx in range(N): + # set camera + if camera_type == "simple_radial": + pycolmap_intri_radial = np.array( + [intrinsics[fidx][0, 0], intrinsics[fidx][0, 2], intrinsics[fidx][1, 2], 0] + ) + camera = pycolmap.Camera( + model="SIMPLE_RADIAL", + width=image_size[0], + height=image_size[1], + params=pycolmap_intri_radial, + camera_id=fidx, + ) + else: + pycolmap_intri_pinhole = np.array([intrinsics[fidx][0, 0], intrinsics[fidx][0, 2], intrinsics[fidx][1, 2]]) + camera = pycolmap.Camera( + model="SIMPLE_PINHOLE", + width=image_size[0], + height=image_size[1], + params=pycolmap_intri_pinhole, + camera_id=fidx, + ) + + # add camera + reconstruction.add_camera(camera) + + # set image + cam_from_world = pycolmap.Rigid3d( + pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3] + ) # Rot and Trans + image = pycolmap.Image(id=fidx, name=f"image_{fidx}", camera_id=camera.camera_id, cam_from_world=cam_from_world) + image.registered = True + + points2D_list = [] + + # NOTE point3D_id start by 1 + for point3D_id in range(1, num_points3D + 1): + original_track_idx = valid_idx[point3D_id - 1] + + if (reconstruction.points3D[point3D_id].xyz < max_points3D_val).all(): + if masks[fidx][original_track_idx]: + # It seems we don't need +0.5 for BA + point2D_xy = tracks[fidx][original_track_idx] + # Please note when adding the Point2D object + # It not only requires the 2D xy location, but also the id to 3D point + points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id)) + + # add element + track = reconstruction.points3D[point3D_id].track + point2D_idx = point3D_id - 1 + track.add_element(fidx, point2D_idx) + + try: + image.points2D = pycolmap.ListPoint2D(points2D_list) + except: + print(f"frame {fidx} is out of BA") + + # add image + reconstruction.add_image(image) + + return reconstruction + + +def pycolmap_to_batch_matrix(reconstruction, device="cuda"): + """ + Inversion to batch_matrix_to_pycolmap, nothing but picking them back + """ + + num_images = len(reconstruction.images) + max_points3D_id = max(reconstruction.point3D_ids()) + points3D = np.zeros((max_points3D_id, 3)) + + for point3D_id in reconstruction.points3D: + points3D[point3D_id - 1] = reconstruction.points3D[point3D_id].xyz + points3D = torch.from_numpy(points3D).to(device) + + extrinsics = torch.from_numpy( + np.stack([reconstruction.images[i].cam_from_world.matrix() for i in range(num_images)]) + ) + extrinsics = extrinsics.to(device) + + intrinsics = torch.from_numpy(np.stack([reconstruction.cameras[i].calibration_matrix() for i in range(num_images)])) + intrinsics = intrinsics.to(device) + + return points3D, extrinsics, intrinsics diff --git a/vggsfm/vggsfm/utils/triangulation.py b/vggsfm/vggsfm/utils/triangulation.py new file mode 100644 index 0000000000000000000000000000000000000000..bde875d68c3f809de8ace82273b00b81d088b1ec --- /dev/null +++ b/vggsfm/vggsfm/utils/triangulation.py @@ -0,0 +1,813 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import numpy as np +import pycolmap + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +from torch.cuda.amp import autocast + +from .tensor_to_pycolmap import batch_matrix_to_pycolmap, pycolmap_to_batch_matrix + +from .triangulation_helpers import ( + triangulate_multi_view_point_batched, + filter_all_points3D, + project_3D_points, + calculate_normalized_angular_error_batched, + calculate_triangulation_angle_batched, + calculate_triangulation_angle_exhaustive, + calculate_triangulation_angle, + create_intri_matrix, + prepare_ba_options, + generate_combinations, + local_refinement_tri, +) + +from ..two_view_geo.utils import calculate_depth_batch, calculate_residual_indicator + + +def triangulate_by_pair(extrinsics, tracks_normalized, eps=1e-12): + """ + Given B x S x 3 x 4 extrinsics and B x S x N x 2 tracks_normalized, + triangulate point clouds for B*(S-1) query-reference pairs + + Return: + + points_3d_pair: B*(S-1) x N x 3 + cheirality_mask: B*(S-1) x N + triangles: B*(S-1) x N + """ + B, S, N, _ = tracks_normalized.shape + + # build pair wise extrinsics and matches + extrinsics_left = extrinsics[:, 0:1].expand(-1, S - 1, -1, -1) + extrinsics_right = extrinsics[:, 1:] + extrinsics_pair = torch.cat([extrinsics_left.unsqueeze(2), extrinsics_right.unsqueeze(2)], dim=2) + + tracks_normalized_left = tracks_normalized[:, 0:1].expand(-1, S - 1, -1, -1) + tracks_normalized_right = tracks_normalized[:, 1:] + tracks_normalized_pair = torch.cat( + [tracks_normalized_left.unsqueeze(2), tracks_normalized_right.unsqueeze(2)], dim=2 + ) + + extrinsics_pair = extrinsics_pair.reshape(B * (S - 1), 2, 3, 4) + tracks_normalized_pair = tracks_normalized_pair.reshape(B * (S - 1), 2, N, 2) + + # triangulate + points_3d_pair, cheirality_mask = triangulate_multi_view_point_from_tracks(extrinsics_pair, tracks_normalized_pair) + + # check triangulation angles + # B*(S-1)x3x1 + # Learned from + # https://github.com/colmap/colmap/blob/c0d8926841cf6325eb031c873eaedb95204a1845/src/colmap/geometry/triangulation.cc#L155 + rot_left = extrinsics_pair[:, 0, :3, :3] + t_left = extrinsics_pair[:, 0, :3, 3:4] + project_center1 = torch.bmm(-rot_left.transpose(-2, -1), t_left) + + rot_right = extrinsics_pair[:, 1, :3, :3] + t_right = extrinsics_pair[:, 1, :3, 3:4] + project_center2 = torch.bmm(-rot_right.transpose(-2, -1), t_right) + + baseline_length_squared = (project_center2 - project_center1).norm(dim=1) ** 2 # B*(S-1)x1 + ray_length_squared1 = (points_3d_pair - project_center1.transpose(-2, -1)).norm(dim=-1) ** 2 # BxN + ray_length_squared2 = (points_3d_pair - project_center2.transpose(-2, -1)).norm(dim=-1) ** 2 # BxN + + denominator = 2.0 * torch.sqrt(ray_length_squared1 * ray_length_squared2) + nominator = ray_length_squared1 + ray_length_squared2 - baseline_length_squared + + # if denominator is zero, angle is zero + # so we set nominator and denominator as one acos(1) = 0 + nonvalid = denominator <= eps + nominator = torch.where(nonvalid, torch.ones_like(nominator), nominator) + denominator = torch.where(nonvalid, torch.ones_like(denominator), denominator) + + cos_angle = nominator / denominator + cos_angle = torch.clamp(cos_angle, -1.0, 1.0) + + # rad to deg + triangles = torch.abs(torch.acos(cos_angle)) + + # take the min of (angle, pi-angle) + # to avoid the effect of acute angles (far away points) and obtuse angles (close points) + triangles = torch.min(triangles, torch.pi - triangles) + triangles = triangles * (180.0 / torch.pi) + + return points_3d_pair, cheirality_mask, triangles + + +def init_BA(extrinsics, intrinsics, tracks, points_3d_pair, inlier, image_size, init_max_reproj_error=0.5): + """ + This function first optimizes the init point cloud + and the cameras of its corresponding frames by BA, + + Input: + extrinsics: Sx3x4 + intrinsics: Sx3x3 + tracks: SxPx2 + points_3d_pair: (S-1)xPx3 + inlier: (S-1)xP + """ + + # Find the frame that has the highest inlier (inlier for triangulation angle and cheirality check) + # Please note the init_idx was defined in 0 to S-1 + init_idx = torch.argmax(inlier.sum(dim=-1)).item() + + # init_indices include the query frame and the frame with highest inlier number (i.e., the init pair) + # init_idx+1 means shifting to the range of 0 to S + # TODO: consider initializing by not only a pair, but more frames + init_indices = [0, init_idx + 1] + + # Pick the camera parameters for the init pair + toBA_extrinsics = extrinsics[init_indices] + toBA_intrinsics = intrinsics[init_indices] + toBA_tracks = tracks[init_indices] + + # points_3d_pair and inlier has a shape of (S-1), *, *, ... + toBA_points3D = points_3d_pair[init_idx] + toBA_masks = inlier[init_idx].unsqueeze(0) + # all the points are assumed valid at query + # TODO: remove this assumption in the next version + toBA_masks = torch.cat([torch.ones_like(toBA_masks), toBA_masks], dim=0) + + # Only if a track has more than 2 inliers, + # it is viewed as valid + toBA_valid_track_mask = toBA_masks.sum(dim=0) >= 2 + toBA_masks = toBA_masks[:, toBA_valid_track_mask] + toBA_points3D = toBA_points3D[toBA_valid_track_mask] + toBA_tracks = toBA_tracks[:, toBA_valid_track_mask] + + # Convert PyTorch tensors to the format of Pycolmap + # Prepare for the Bundle Adjustment Optimization + # NOTE although we use pycolmap for BA here, but any BA library should be able to achieve the same result + reconstruction = batch_matrix_to_pycolmap( + toBA_points3D, toBA_extrinsics, toBA_intrinsics, toBA_tracks, toBA_masks, image_size + ) + + # Prepare BA options + ba_options = prepare_ba_options() + + # Conduct BA + pycolmap.bundle_adjustment(reconstruction, ba_options) + + reconstruction.normalize(5.0, 0.1, 0.9, True) + + # Get the optimized 3D points, extrinsics, and intrinsics + points3D_opt, extrinsics_opt, intrinsics_opt = pycolmap_to_batch_matrix( + reconstruction, device=toBA_extrinsics.device + ) + + # Filter those invalid 3D points + valid_poins3D_mask = filter_all_points3D( + points3D_opt, + toBA_tracks, + extrinsics_opt, + intrinsics_opt, + check_triangle=False, + max_reproj_error=init_max_reproj_error, + ) + points3D_opt = points3D_opt[valid_poins3D_mask] + + # If a 3D point is invalid, all of its 2D matches are invalid + filtered_valid_track_mask = toBA_valid_track_mask.clone() + filtered_valid_track_mask[toBA_valid_track_mask] = valid_poins3D_mask + + # Replace the original cameras by the optimized ones + extrinsics[init_indices] = extrinsics_opt.to(extrinsics.dtype) + intrinsics[init_indices] = intrinsics_opt.to(intrinsics.dtype) + + # NOTE: filtered_valid_track_mask or toBA_valid_track_mask? + return points3D_opt, extrinsics, intrinsics, filtered_valid_track_mask, reconstruction, init_idx + + +def refine_pose( + extrinsics, + intrinsics, + inlier, + points3D, + tracks, + valid_track_mask, + image_size, + max_reproj_error=12, + camera_type="simple_pinhole", + force_estimate=False, +): + # extrinsics: Sx3x4 + # intrinsics: Sx3x3 + # inlier: SxP + # points3D: P' x 3 + # tracks: SxPx2 + # valid_track_mask: P + + S, _, _ = extrinsics.shape + _, P, _ = tracks.shape + + assert len(intrinsics) == S + assert inlier.shape[0] == S + assert inlier.shape[1] == P + assert len(valid_track_mask) == P + + empty_mask = points3D.abs().sum(-1) <= 0 + if empty_mask.sum() > 0: + non_empty_mask = ~empty_mask + tmp_mask = valid_track_mask.clone() + tmp_mask[valid_track_mask] = non_empty_mask + valid_track_mask = tmp_mask.clone() + points3D = points3D[non_empty_mask] + + tracks2D = tracks[:, valid_track_mask] + + # compute reprojection error + projected_points2D, projected_points_cam = project_3D_points( + points3D, extrinsics, intrinsics, return_points_cam=True + ) + + reproj_error = (projected_points2D - tracks2D).norm(dim=-1) ** 2 # sqaure + # ensure all the points stay in front of the cameras + reproj_error[projected_points_cam[:, -1] <= 0] = 1e9 + + reproj_inlier = reproj_error <= (max_reproj_error**2) + + inlier_nongeo = inlier[:, valid_track_mask] + inlier_absrefine = torch.logical_and(inlier_nongeo, reproj_inlier) + + inlier_nongeo = inlier_nongeo.cpu().numpy() + inlier_absrefine = inlier_absrefine.cpu().numpy() + # P' x 3 + points3D = points3D.cpu().numpy() + # S x P' x 2 + tracks2D = tracks2D.cpu().numpy() + + estoptions = pycolmap.AbsolutePoseEstimationOptions() + estoptions.estimate_focal_length = True + estoptions.ransac.max_error = max_reproj_error + + refoptions = pycolmap.AbsolutePoseRefinementOptions() + refoptions.refine_focal_length = True + refoptions.refine_extra_params = True + refoptions.print_summary = False + + refined_extrinsics = [] + refined_intrinsics = [] + + scale = image_size.max() + + for ridx in range(S): + if camera_type == "simple_radial": + pycolmap_intri_radial = np.array( + [intrinsics[ridx][0, 0].cpu(), intrinsics[ridx][0, 2].cpu(), intrinsics[ridx][1, 2].cpu(), 0] + ) + pycamera = pycolmap.Camera( + model="SIMPLE_RADIAL", + width=image_size[0], + height=image_size[1], + params=pycolmap_intri_radial, + camera_id=ridx, + ) + else: + pycolmap_intri_pinhole = np.array( + [intrinsics[ridx][0, 0].cpu(), intrinsics[ridx][0, 2].cpu(), intrinsics[ridx][1, 2].cpu()] + ) + pycamera = pycolmap.Camera( + model="SIMPLE_PINHOLE", + width=image_size[0], + height=image_size[1], + params=pycolmap_intri_pinhole, + camera_id=ridx, + ) + + cam_from_world = pycolmap.Rigid3d( + pycolmap.Rotation3d(extrinsics[ridx][:3, :3].cpu()), extrinsics[ridx][:3, 3].cpu() + ) # Rot and Trans + points2D = tracks2D[ridx] + inlier_mask = inlier_absrefine[ridx] + + estimate_abs_pose = False + + if inlier_mask.sum() > 100: + answer = pycolmap.pose_refinement(cam_from_world, points2D, points3D, inlier_mask, pycamera, refoptions) + cam_from_world = answer["cam_from_world"] + + intri_mat = pycamera.calibration_matrix() + focal = intri_mat[0, 0] + if (focal < 0.1 * scale) or (focal > 30 * scale): + # invalid focal length + estimate_abs_pose = True + else: + estimate_abs_pose = True + print(f"Frame {ridx} only has {inlier_mask.sum()} geo_vis inliers") + + if estimate_abs_pose and force_estimate: + inlier_mask = inlier_nongeo[ridx] + if inlier_mask.sum() > 50: + print(f"Estimating absolute poses by visible matches for frame {ridx}") + estanswer = pycolmap.absolute_pose_estimation( + points2D[inlier_mask], points3D[inlier_mask], pycamera, estoptions, refoptions + ) + if estanswer is None: + estanswer = pycolmap.absolute_pose_estimation(points2D, points3D, pycamera, estoptions, refoptions) + else: + print(f"Warning! Estimating absolute poses by non visible matches for frame {ridx}") + estanswer = pycolmap.absolute_pose_estimation(points2D, points3D, pycamera, estoptions, refoptions) + + if estanswer is not None: + cam_from_world = estanswer["cam_from_world"] + + extri_mat = cam_from_world.matrix() + intri_mat = pycamera.calibration_matrix() + + refined_extrinsics.append(extri_mat) + refined_intrinsics.append(intri_mat) + + # get the optimized cameras + refined_extrinsics = torch.from_numpy(np.stack(refined_extrinsics)).to(tracks.device) + refined_intrinsics = torch.from_numpy(np.stack(refined_intrinsics)).to(tracks.device) + + valid_intri_mask = torch.logical_and( + refined_intrinsics[:, 0, 0] >= 0.1 * scale, refined_intrinsics[:, 0, 0] <= 30 * scale + ) + valid_trans_mask = (refined_extrinsics[:, :, 3].abs() <= 30).all(-1) + + valid_frame_mask = torch.logical_and(valid_intri_mask, valid_trans_mask) + + if (~valid_frame_mask).sum() > 0: + print("some frames are invalid after BA refinement") + refined_extrinsics[~valid_frame_mask] = extrinsics[~valid_frame_mask].to(refined_extrinsics.dtype) + refined_intrinsics[~valid_frame_mask] = intrinsics[~valid_frame_mask].to(refined_extrinsics.dtype) + + return refined_extrinsics, refined_intrinsics, valid_frame_mask + + +def init_refine_pose( + extrinsics, + intrinsics, + inlier, + points3D, + tracks, + valid_track_mask_init, + image_size, + init_idx, + max_reproj_error=12, + second_refine=False, + camera_type="simple_pinhole", +): + """ + Refine the extrinsics and intrinsics by points3D and tracks, + which conducts bundle adjustment but does not modify points3D + """ + # extrinsics: Sx3x4 + # intrinsics: Sx3x3 + # inlier: (S-1)xP + # points3D: P' x 3 + # tracks: SxPx2 + # valid_track_mask_init: P + + S, _, _ = extrinsics.shape + _, P, _ = tracks.shape + + assert len(intrinsics) == S + assert inlier.shape[0] == (S - 1) + assert inlier.shape[1] == P + assert len(valid_track_mask_init) == P + + # TODO check this + # remove all zeros points3D + # non_empty_mask = points3D.abs().sum(-1) >0 + # valid_track_mask_tmp = valid_track_mask_init.clone() + # valid_track_mask_tmp[valid_track_mask_init] = non_empty_mask + # valid_track_mask_init = valid_track_mask_tmp.clone() + + # Prepare the inlier mask + inlier_absrefine = torch.cat([torch.ones_like(inlier[0:1]), inlier], dim=0) + inlier_absrefine = inlier_absrefine[:, valid_track_mask_init] + inlier_absrefine = inlier_absrefine.cpu().numpy() + + # P' x 3 + points3D = points3D.cpu().numpy() + # S x P' x 2 + tracks2D = tracks[:, valid_track_mask_init].cpu().numpy() + + refoptions = pycolmap.AbsolutePoseRefinementOptions() + refoptions.refine_focal_length = True + refoptions.refine_extra_params = True + refoptions.print_summary = False + + refined_extrinsics = [] + refined_intrinsics = [] + + for ridx in range(S): + if camera_type == "simple_radial": + pycolmap_intri_radial = np.array( + [intrinsics[ridx][0, 0].cpu(), intrinsics[ridx][0, 2].cpu(), intrinsics[ridx][1, 2].cpu(), 0] + ) + pycamera = pycolmap.Camera( + model="SIMPLE_RADIAL", + width=image_size[0], + height=image_size[1], + params=pycolmap_intri_radial, + camera_id=ridx, + ) + else: + pycolmap_intri_pinhole = np.array( + [intrinsics[ridx][0, 0].cpu(), intrinsics[ridx][0, 2].cpu(), intrinsics[ridx][1, 2].cpu()] + ) + pycamera = pycolmap.Camera( + model="SIMPLE_PINHOLE", + width=image_size[0], + height=image_size[1], + params=pycolmap_intri_pinhole, + camera_id=ridx, + ) + + cam_from_world = pycolmap.Rigid3d( + pycolmap.Rotation3d(extrinsics[ridx][:3, :3].cpu()), extrinsics[ridx][:3, 3].cpu() + ) # Rot and Trans + points2D = tracks2D[ridx] + inlier_mask = inlier_absrefine[ridx] + + if ridx != 0 and ridx != (init_idx + 1): + # ridx==0 or ==(init_idx+1) means they are init pair, no need to optimize again + if inlier_mask.sum() > 50: + # If too few inliers, ignore it + # Bundle adjustment without optimizing 3D point + answer = pycolmap.pose_refinement(cam_from_world, points2D, points3D, inlier_mask, pycamera, refoptions) + cam_from_world = answer["cam_from_world"] + else: + print("This frame only has inliers:", inlier_mask.sum()) + + if second_refine: + # refine a second time by filtering out some points with a high reprojection error + extri_mat = cam_from_world.matrix() + intri_mat = pycamera.calibration_matrix() + homo_points3D = np.hstack((points3D, np.ones((points3D.shape[0], 1)))) + + projection = extri_mat @ homo_points3D.transpose(-1, -2) + projection_2D = intri_mat @ projection + projection_2D = projection_2D[:2] / projection_2D[-1] + + residual = projection_2D.transpose(-1, -2) - points2D + sqrt_error_per_point = np.sqrt(np.sum(residual**2, axis=-1)) + + inlier_mask_reproj = sqrt_error_per_point <= 1 + inlier_mask_refine = inlier_mask & inlier_mask_reproj + + refoptions.gradient_tolerance = 10 + answer = pycolmap.pose_refinement( + cam_from_world, points2D, points3D, inlier_mask_refine, pycamera, refoptions + ) + refoptions.gradient_tolerance = 1 + cam_from_world = answer["cam_from_world"] + + extri_mat = cam_from_world.matrix() + intri_mat = pycamera.calibration_matrix() + refined_extrinsics.append(extri_mat) + refined_intrinsics.append(intri_mat) + + # get the optimized cameras + refined_extrinsics = torch.from_numpy(np.stack(refined_extrinsics)).to(tracks.device) + refined_intrinsics = torch.from_numpy(np.stack(refined_intrinsics)).to(tracks.device) + + scale = image_size.max() + + valid_intri_mask = torch.logical_and( + refined_intrinsics[:, 0, 0] >= 0.1 * scale, refined_intrinsics[:, 0, 0] <= 30 * scale + ) + valid_trans_mask = (refined_extrinsics[:, :, 3].abs() <= 30).all(-1) + + valid_frame_mask = torch.logical_and(valid_intri_mask, valid_trans_mask) + + if (~valid_frame_mask).sum() > 0: + print("some frames are invalid after BA refinement") + refined_extrinsics[~valid_frame_mask] = extrinsics[~valid_frame_mask].to(refined_extrinsics.dtype) + refined_intrinsics[~valid_frame_mask] = intrinsics[~valid_frame_mask].to(refined_extrinsics.dtype) + + return refined_extrinsics, refined_intrinsics, valid_frame_mask + + +def triangulate_multi_view_point_from_tracks(cams_from_world, tracks, mask=None): + with autocast(dtype=torch.float32): + B, S, _, _ = cams_from_world.shape + _, _, N, _ = tracks.shape # B S N 2 + tracks = tracks.permute(0, 2, 1, 3) + + tracks = tracks.reshape(B * N, S, 2) + if mask is not None: + mask = mask.permute(0, 2, 1).reshape(B * N, S) + + cams_from_world = cams_from_world[:, None].expand(-1, N, -1, -1, -1) + cams_from_world = cams_from_world.reshape(B * N, S, 3, 4) + + points3d, invalid_cheirality_mask = triangulate_multi_view_point_batched( + cams_from_world, tracks, mask, check_cheirality=True + ) + + points3d = points3d.reshape(B, N, 3) + invalid_cheirality_mask = invalid_cheirality_mask.reshape(B, N) + cheirality_mask = ~invalid_cheirality_mask + return points3d, cheirality_mask + + +def triangulate_tracks( + extrinsics, + tracks_normalized, + max_ransac_iters=256, + lo_num=50, + max_angular_error=2, + min_tri_angle=1.5, + track_vis=None, + track_score=None, +): + """ + This function conduct triangulation over all the input frames + + It adopts LORANSAC, which means + (1) first triangulate 3d points by random 2-view pairs + (2) compute the inliers of these triangulated points + (3) do re-triangulation using the inliers + (4) check the ones with most inliers + """ + max_rad_error = max_angular_error * (torch.pi / 180) + + with autocast(dtype=torch.float32): + tracks_normalized = tracks_normalized.transpose(0, 1) + B, S, _ = tracks_normalized.shape + extrinsics_expand = extrinsics[None].expand(B, -1, -1, -1) + + point_per_sample = 2 # first triangulate points by 2 points + + ransac_idx = generate_combinations(S) + ransac_idx = torch.from_numpy(ransac_idx).to(extrinsics.device) + + # Prevent max_ransac_iters from being unnecessarily high + if max_ransac_iters > len(ransac_idx): + max_ransac_iters = len(ransac_idx) + else: + ransac_idx = ransac_idx[torch.randperm(len(ransac_idx))[:max_ransac_iters]] + lo_num = lo_num if max_ransac_iters >= lo_num else max_ransac_iters + + # Prepare the input + points_ransac = tracks_normalized[:, ransac_idx].view(B * max_ransac_iters, point_per_sample, 2) + extrinsics_ransac = extrinsics_expand[:, ransac_idx].view(B * max_ransac_iters, point_per_sample, 3, 4) + + # triangulated_points: (B * max_ransac_iters) x 3 + # tri_angles: (B * max_ransac_iters) x (point_per_sample * point_per_sample) + # invalid_che_mask: (B * max_ransac_iters) + triangulated_points, tri_angles, invalid_che_mask = triangulate_multi_view_point_batched( + extrinsics_ransac, points_ransac, compute_tri_angle=True, check_cheirality=True + ) + + triangulated_points = triangulated_points.reshape(B, max_ransac_iters, 3) + invalid_che_mask = invalid_che_mask.reshape(B, max_ransac_iters) + + # if any of the pair fits the minimum triangulation angle, we view it as valid + tri_masks = (tri_angles >= min_tri_angle).any(dim=-1) + invalid_tri_mask = (~tri_masks).reshape(B, max_ransac_iters) + + # a point is invalid if it does not meet the minimum triangulation angle or fails the cheirality test + # B x max_ransac_iters + + invalid_mask = torch.logical_or(invalid_tri_mask, invalid_che_mask) + + # Please note angular error is not triangulation angle + # For a quick understanding, + # angular error: lower the better + # triangulation angle: higher the better (within a reasonable range) + angular_error, _ = calculate_normalized_angular_error_batched( + tracks_normalized.transpose(0, 1), triangulated_points.permute(1, 0, 2), extrinsics + ) + # max_ransac_iters x S x B -> B x max_ransac_iters x S + angular_error = angular_error.permute(2, 0, 1) + + # If some tracks are invalid, give them a very high error + angular_error[invalid_mask] = angular_error[invalid_mask] + torch.pi + + # Also, we hope the tracks also meet the visibility and score requirement + # logical_or: invalid if does not meet any requirement + if track_score is not None: + invalid_vis_conf_mask = torch.logical_or(track_vis <= 0.05, track_score <= 0.5) + else: + invalid_vis_conf_mask = track_vis <= 0.05 + + invalid_vis_conf_mask = invalid_vis_conf_mask.permute(1, 0) + angular_error[invalid_vis_conf_mask[:, None].expand(-1, max_ransac_iters, -1)] += torch.pi + + # wow, finally, inlier + inlier_mask = (angular_error) <= (max_rad_error) + + ############################################################################# + # LOCAL REFINEMENT + + # Triangulate based on the inliers + # and compute the errors + lo_triangulated_points, lo_tri_angles, lo_angular_error = local_refine_and_compute_error( + tracks_normalized, + extrinsics, + extrinsics_expand, + inlier_mask, + lo_num, + min_tri_angle, + invalid_vis_conf_mask, + max_rad_error=max_rad_error, + ) + + # Refine it again + # if you want to, you can repeat the local refine more and more + lo_num_sec = 10 + if lo_num <= lo_num_sec: + lo_num_sec = lo_num + + lo_inlier_mask = (lo_angular_error) <= (max_rad_error) + lo_triangulated_points_2, lo_tri_angles_2, lo_angular_error_2 = local_refine_and_compute_error( + tracks_normalized, + extrinsics, + extrinsics_expand, + lo_inlier_mask, + lo_num_sec, + min_tri_angle, + invalid_vis_conf_mask, + max_rad_error=max_rad_error, + ) + + # combine the first and second local refinement results + lo_num += lo_num_sec + lo_triangulated_points = torch.cat([lo_triangulated_points, lo_triangulated_points_2], dim=1) + lo_angular_error = torch.cat([lo_angular_error, lo_angular_error_2], dim=1) + lo_tri_angles = torch.cat([lo_tri_angles, lo_tri_angles_2], dim=1) + ############################################################################# + + all_triangulated_points = torch.cat([triangulated_points, lo_triangulated_points], dim=1) + all_angular_error = torch.cat([angular_error, lo_angular_error], dim=1) + + residual_indicator, inlier_num_all, inlier_mask_all = calculate_residual_indicator( + all_angular_error, max_rad_error, check=True, nanvalue=2 * torch.pi + ) + + batch_index = torch.arange(B).unsqueeze(-1).expand(-1, lo_num) + + best_indices = torch.argmax(residual_indicator, dim=1) + + # Pick the triangulated points with most inliers + best_triangulated_points = all_triangulated_points[batch_index[:, 0], best_indices] + best_inlier_num = inlier_num_all[batch_index[:, 0], best_indices] + best_inlier_mask = inlier_mask_all[batch_index[:, 0], best_indices] + + return best_triangulated_points, best_inlier_num, best_inlier_mask + + +def local_refine_and_compute_error( + tracks_normalized, + extrinsics, + extrinsics_expand, + inlier_mask, + lo_num, + min_tri_angle, + invalid_vis_conf_mask, + max_rad_error, +): + B, S, _ = tracks_normalized.shape + + inlier_num = inlier_mask.sum(dim=-1) + sorted_values, sorted_indices = torch.sort(inlier_num, dim=1, descending=True) + + # local refinement + lo_triangulated_points, lo_tri_angles, lo_invalid_che_mask = local_refinement_tri( + tracks_normalized, extrinsics_expand, inlier_mask, sorted_indices, lo_num=lo_num + ) + + lo_tri_masks = (lo_tri_angles >= min_tri_angle).any(dim=-1) + lo_invalid_tri_mask = (~lo_tri_masks).reshape(B, lo_num) + + lo_invalid_mask = torch.logical_or(lo_invalid_tri_mask, lo_invalid_che_mask) + + lo_angular_error, _ = calculate_normalized_angular_error_batched( + tracks_normalized.transpose(0, 1), lo_triangulated_points.permute(1, 0, 2), extrinsics + ) + lo_angular_error = lo_angular_error.permute(2, 0, 1) + + # avoid nan and inf + lo_angular_error = torch.nan_to_num( + lo_angular_error, nan=100 * torch.pi, posinf=100 * torch.pi, neginf=100 * torch.pi + ) + + # penalty to invalid points + lo_angular_error[lo_invalid_mask] = lo_angular_error[lo_invalid_mask] + torch.pi + + lo_angular_error[invalid_vis_conf_mask[:, None].expand(-1, lo_num, -1)] += torch.pi + + return lo_triangulated_points, lo_tri_angles, lo_angular_error + + +def global_BA( + triangulated_points, + valid_tracks, + pred_tracks, + inlier_mask, + extrinsics, + intrinsics, + image_size, + device, + camera_type="simple_pinhole", +): + ba_options = prepare_ba_options() + + # triangulated_points + BA_points = triangulated_points[valid_tracks] + BA_tracks = pred_tracks[:, valid_tracks] + BA_inlier_masks = inlier_mask[valid_tracks].transpose(0, 1) + reconstruction = batch_matrix_to_pycolmap( + BA_points, extrinsics, intrinsics, BA_tracks, BA_inlier_masks, image_size, camera_type=camera_type + ) + pycolmap.bundle_adjustment(reconstruction, ba_options) + + reconstruction.normalize(5.0, 0.1, 0.9, True) + + points3D_opt, extrinsics, intrinsics = pycolmap_to_batch_matrix(reconstruction, device=device) + + return points3D_opt, extrinsics, intrinsics, reconstruction + + +def iterative_global_BA( + pred_tracks, + intrinsics, + extrinsics, + pred_vis, + pred_score, + valid_tracks, + points3D_opt, + image_size, + lastBA=False, + min_valid_track_length=2, + max_reproj_error=1, + ba_options=None, +): + # normalize points from pixel + principal_point_refined = intrinsics[:, [0, 1], [2, 2]].unsqueeze(-2) + focal_length_refined = intrinsics[:, [0, 1], [0, 1]].unsqueeze(-2) + tracks_normalized_refined = (pred_tracks - principal_point_refined) / focal_length_refined + + # triangulate tracks by LORANSAC + best_triangulated_points, best_inlier_num, best_inlier_mask = triangulate_tracks( + extrinsics, tracks_normalized_refined, track_vis=pred_vis, track_score=pred_score, max_ransac_iters=128 + ) + + best_triangulated_points[valid_tracks] = points3D_opt + + # well do we need this? best_inlier_mask may be enough already + valid_poins3D_mask, filtered_inlier_mask = filter_all_points3D( + best_triangulated_points, + pred_tracks, + extrinsics, + intrinsics, + max_reproj_error=max_reproj_error, + return_detail=True, + ) + + valid_tracks = filtered_inlier_mask.sum(dim=0) >= min_valid_track_length + BA_points = best_triangulated_points[valid_tracks] + BA_tracks = pred_tracks[:, valid_tracks] + BA_inlier_masks = filtered_inlier_mask[:, valid_tracks] + + if ba_options is None: + ba_options = pycolmap.BundleAdjustmentOptions() + + reconstruction = batch_matrix_to_pycolmap( + BA_points, extrinsics, intrinsics, BA_tracks, BA_inlier_masks, image_size, camera_type="simple_pinhole" + ) + pycolmap.bundle_adjustment(reconstruction, ba_options) + + reconstruction.normalize(5.0, 0.1, 0.9, True) + + points3D_opt, extrinsics, intrinsics = pycolmap_to_batch_matrix(reconstruction, device=pred_tracks.device) + + valid_poins3D_mask, filtered_inlier_mask = filter_all_points3D( + points3D_opt, + pred_tracks[:, valid_tracks], + extrinsics, + intrinsics, + max_reproj_error=max_reproj_error, + return_detail=True, + ) + + valid_tracks_afterBA = filtered_inlier_mask.sum(dim=0) >= min_valid_track_length + valid_tracks_tmp = valid_tracks.clone() + valid_tracks_tmp[valid_tracks] = valid_tracks_afterBA + valid_tracks = valid_tracks_tmp.clone() + points3D_opt = points3D_opt[valid_tracks_afterBA] + BA_inlier_masks = filtered_inlier_mask[:, valid_tracks_afterBA] + + if lastBA: + print("Saving in a colmap format") + BA_tracks = pred_tracks[:, valid_tracks] + reconstruction = batch_matrix_to_pycolmap( + points3D_opt, extrinsics, intrinsics, BA_tracks, BA_inlier_masks, image_size, camera_type="simple_pinhole" + ) + + return points3D_opt, extrinsics, intrinsics, valid_tracks, BA_inlier_masks, reconstruction diff --git a/vggsfm/vggsfm/utils/triangulation_helpers.py b/vggsfm/vggsfm/utils/triangulation_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..faff8748b236f4f948cd47a93f535dc8b00e0bbc --- /dev/null +++ b/vggsfm/vggsfm/utils/triangulation_helpers.py @@ -0,0 +1,419 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np +import pycolmap + +from torch.cuda.amp import autocast +from itertools import combinations + + +def triangulate_multi_view_point_batched( + cams_from_world, points, mask=None, compute_tri_angle=False, check_cheirality=False +): + # cams_from_world: BxNx3x4 + # points: BxNx2 + + B, N, _ = points.shape + assert ( + cams_from_world.shape[0] == B and cams_from_world.shape[1] == N + ), "The number of cameras and points must be equal for each batch." + + # Convert points to homogeneous coordinates and normalize + points_homo = torch.cat( + (points, torch.ones(B, N, 1, dtype=cams_from_world.dtype, device=cams_from_world.device)), dim=-1 + ) + points_norm = points_homo / torch.norm(points_homo, dim=-1, keepdim=True) + + # Compute the outer product of each point with itself + outer_products = torch.einsum("bni,bnj->bnij", points_norm, points_norm) + + # Compute the term for each camera-point pair + terms = cams_from_world - torch.einsum("bnij,bnik->bnjk", outer_products, cams_from_world) + + if mask is not None: + terms = terms * mask[:, :, None, None] + + A = torch.einsum("bnij,bnik->bjk", terms, terms) + + # Compute eigenvalues and eigenvectors + try: + _, eigenvectors = torch.linalg.eigh(A) + except: + print("Meet CUSOLVER_STATUS_INVALID_VALUE ERROR during torch.linalg.eigh()") + print("SWITCH TO torch.linalg.eig()") + _, eigenvectors = torch.linalg.eig(A) + eigenvectors = torch.real(eigenvectors) + + # Select the first eigenvector + first_eigenvector = eigenvectors[:, :, 0] + + # Perform homogeneous normalization: divide by the last component + first_eigenvector_hnormalized = first_eigenvector / first_eigenvector[..., -1:] + + # Return the first eigenvector normalized to make its first component 1 + triangulated_points = first_eigenvector_hnormalized[..., :-1] + + if check_cheirality: + points3D_homogeneous = torch.cat( + [triangulated_points, torch.ones_like(triangulated_points[..., 0:1])], dim=1 + ) # Nx4 + + points3D_homogeneous = points3D_homogeneous.unsqueeze(1).unsqueeze(-1) + points_cam = torch.matmul(cams_from_world, points3D_homogeneous).squeeze(-1) + + invalid_cheirality_mask = points_cam[..., -1] <= 0 + invalid_cheirality_mask = invalid_cheirality_mask.any(dim=1) + + if compute_tri_angle: + triangles = calculate_triangulation_angle_batched(cams_from_world, triangulated_points) + + if check_cheirality and compute_tri_angle: + return triangulated_points, triangles, invalid_cheirality_mask + + if compute_tri_angle: + return triangulated_points, triangles + + if check_cheirality: + return triangulated_points, invalid_cheirality_mask + + return triangulated_points + + +def filter_all_points3D( + points3D, + points2D, + extrinsics, + intrinsics, + max_reproj_error=4, + min_tri_angle=1.5, + check_triangle=True, + return_detail=False, + hard_max=100, +): + """ + Filter 3D points based on reprojection error and triangulation angle error. + + Adapted from https://github.com/colmap/colmap/blob/0ea2d5ceee1360bba427b2ef61f1351e59a46f91/src/colmap/sfm/incremental_mapper.cc#L828 + + """ + # points3D Px3 + # points2D BxPx2 + # extrinsics Bx3x4 + # intrinsics Bx3x3 + + # compute reprojection error + projected_points2D, projected_points_cam = project_3D_points( + points3D, extrinsics, intrinsics, return_points_cam=True + ) + + reproj_error = (projected_points2D - points2D).norm(dim=-1) ** 2 # sqaure + # ensure all the points stay in front of the cameras + reproj_error[projected_points_cam[:, -1] <= 0] = 1e6 + + inlier = reproj_error <= (max_reproj_error**2) + valid_track_length = inlier.sum(dim=0) + + valid_track_mask = valid_track_length >= 2 # at least two frames to form a track + + if hard_max > 0: + valid_value_mask = (points3D.abs() <= hard_max).all(-1) + valid_track_mask = torch.logical_and(valid_track_mask, valid_value_mask) + + if check_triangle: + # update points3D + points3D = points3D[valid_track_mask] + inlier = inlier[:, valid_track_mask] + # https://github.com/colmap/colmap/blob/0ea2d5ceee1360bba427b2ef61f1351e59a46f91/src/colmap/geometry/triangulation.cc#L130 + + B = len(extrinsics) + + triangles = calculate_triangulation_angle_exhaustive(extrinsics, points3D) + + # only when both the pair are within reporjection thres, + # the triangles can be counted + inlier_row = inlier[:, None].expand(-1, B, -1).reshape(B * B, -1) + inlier_col = inlier[None].expand(B, -1, -1).reshape(B * B, -1) + inlier_grid = torch.logical_and(inlier_row, inlier_col) + + triangles_valid_mask = torch.logical_and((triangles >= min_tri_angle), inlier_grid) + + # if any pair meets the standard, it is okay + triangles_valid_any = triangles_valid_mask.sum(dim=0) > 0 + + triangles_valid_any_full_size = torch.zeros_like(valid_track_mask) + triangles_valid_any_full_size[valid_track_mask] = triangles_valid_any + + return_mask = torch.logical_and(triangles_valid_any_full_size, valid_track_mask) + else: + return_mask = valid_track_mask + + if check_triangle and return_detail: + inlier_detail = reproj_error <= (max_reproj_error**2) + inlier_detail = triangles_valid_any_full_size[None] * inlier_detail + return return_mask, inlier_detail + + return return_mask + + +def project_3D_points(points3D, extrinsics, intrinsics=None, return_points_cam=False, default=0, only_points_cam=False): + """ + Transforms 3D points to 2D using extrinsic and intrinsic parameters. + Args: + points3D (torch.Tensor): 3D points of shape Px3. + extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4. + intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3. + Returns: + torch.Tensor: Transformed 2D points of shape BxNx2. + """ + with autocast(dtype=torch.double): + N = points3D.shape[0] # Number of points + B = extrinsics.shape[0] # Batch size, i.e., number of cameras + points3D_homogeneous = torch.cat([points3D, torch.ones_like(points3D[..., 0:1])], dim=1) # Nx4 + # Reshape for batch processing + points3D_homogeneous = points3D_homogeneous.unsqueeze(0).expand(B, -1, -1) # BxNx4 + # Step 1: Apply extrinsic parameters + # Transform 3D points to camera coordinate system for all cameras + points_cam = torch.bmm(extrinsics, points3D_homogeneous.transpose(-1, -2)) + + if only_points_cam: + return points_cam + + # Step 2: Apply intrinsic parameters + # Intrinsic multiplication requires a transpose to match dimensions (Bx3x3 * Bx3xN -> Bx3xN) + points2D_homogeneous = torch.bmm(intrinsics, points_cam) # Still Bx3xN + points2D_homogeneous = points2D_homogeneous.transpose(1, 2) # BxNx3 + points2D = points2D_homogeneous[..., :2] / points2D_homogeneous[..., 2:3] # BxNx2 + # Performs safe division, replacing NaNs with a default value + points2D[torch.isnan(points2D)] = default + if return_points_cam: + return points2D, points_cam + return points2D + + +def calculate_normalized_angular_error_batched(point2D, point3D, cam_from_world, to_degree=False): + """ + Please note the normalized angular error is different from triangulation angle + """ + # point2D: BxNx2 + # point3D: PxNx3 + # cam_from_world: Bx3x4 + + B, N, _ = point2D.shape + P, _, _ = point3D.shape + assert len(cam_from_world) == B + + # homogeneous + point2D_homo = torch.cat([point2D, torch.ones_like(point2D[..., 0:1])], dim=-1) + point3D_homo = torch.cat([point3D, torch.ones_like(point3D[..., 0:1])], dim=-1) + + point3D_homo_tran = point3D_homo.transpose(-1, -2) + + ray1 = point2D_homo + ray2 = cam_from_world[None].expand(P, -1, -1, -1) @ point3D_homo_tran[:, None].expand(-1, B, -1, -1) + + ray1 = F.normalize(ray1, dim=-1) + # PxBxNx3 + ray2 = F.normalize(ray2.transpose(-1, -2), dim=-1) + + ray1 = ray1[None].expand(P, -1, -1, -1) + cos_angle = (ray1 * ray2).sum(dim=-1) + cos_angle = torch.clamp(cos_angle, -1.0, 1.0) + triangles = torch.acos(cos_angle) + + if to_degree: + triangles = triangles * (180.0 / torch.pi) + + return triangles, cos_angle + + +def calculate_triangulation_angle_batched(extrinsics, points3D, eps=1e-12): + # points3D: Bx3 + # extrinsics: BxSx3x4 + + B, S, _, _ = extrinsics.shape + assert len(points3D) == B + + R = extrinsics[:, :, :, :3] # B x S x 3 x 3 + t = extrinsics[:, :, :, 3] # B x S x 3 + + proj_centers = -(R.transpose(-1, -2) @ t.unsqueeze(-1)).squeeze(-1) + # tmp = -R.transpose(-1, -2)[0].bmm(t.unsqueeze(-1)[0]) + + proj_center1 = proj_centers[:, :, None].expand(-1, -1, S, -1) + proj_center2 = proj_centers[:, None].expand(-1, S, -1, -1) + + # Bx(S*S)x3 + # TODO not using S*S any more, instead of C( ) + proj_center1 = proj_center1.reshape(B, S * S, 3) + proj_center2 = proj_center2.reshape(B, S * S, 3) + + # Bx(S*S) + baseline_length_squared = (proj_center1 - proj_center2).norm(dim=-1) ** 2 + + # Bx(S*S) + ray_length_squared1 = (points3D[:, None] - proj_center1).norm(dim=-1) ** 2 + ray_length_squared2 = (points3D[:, None] - proj_center2).norm(dim=-1) ** 2 + + denominator = 2.0 * torch.sqrt(ray_length_squared1 * ray_length_squared2) + nominator = ray_length_squared1 + ray_length_squared2 - baseline_length_squared + # if denominator is zero, angle is zero + # so we set nominator and denominator as one + # acos(1) = 0 + nonvalid = denominator <= eps + nominator = torch.where(nonvalid, torch.ones_like(nominator), nominator) + denominator = torch.where(nonvalid, torch.ones_like(denominator), denominator) + cos_angle = nominator / denominator + cos_angle = torch.clamp(cos_angle, -1.0, 1.0) + triangles = torch.abs(torch.acos(cos_angle)) + triangles = torch.min(triangles, torch.pi - triangles) + triangles = triangles * (180.0 / torch.pi) + + return triangles + + +def calculate_triangulation_angle_exhaustive(extrinsics, points3D): + # points3D: Px3 + # extrinsics: Bx3x4 + + R = extrinsics[:, :, :3] # B x 3 x 3 + t = extrinsics[:, :, 3] # B x 3 + # Compute projection centers + proj_centers = -torch.bmm(R.transpose(1, 2), t.unsqueeze(-1)).squeeze(-1) + B = len(proj_centers) + + # baseline_length_squared = (proj_centers[:,None] - proj_centers[None]).norm(dim=-1) ** 2 + proj_center1 = proj_centers[:, None].expand(-1, B, -1) + proj_center2 = proj_centers[None].expand(B, -1, -1) + proj_center1 = proj_center1.reshape(B * B, 3) + proj_center2 = proj_center2.reshape(B * B, 3) + + triangles = calculate_triangulation_angle(proj_center1, proj_center2, points3D) + + return triangles + + +def calculate_triangulation_angle(proj_center1, proj_center2, point3D, eps=1e-12): + # proj_center1: Bx3 + # proj_center2: Bx3 + # point3D: Px3 + # returned: (B*B)xP, in degree + + # B + baseline_length_squared = (proj_center1 - proj_center2).norm(dim=-1) ** 2 # B*(S-1)x1 + + # BxP + ray_length_squared1 = (point3D[None] - proj_center1[:, None]).norm(dim=-1) ** 2 + ray_length_squared2 = (point3D[None] - proj_center2[:, None]).norm(dim=-1) ** 2 + + denominator = 2.0 * torch.sqrt(ray_length_squared1 * ray_length_squared2) + nominator = ray_length_squared1 + ray_length_squared2 - baseline_length_squared.unsqueeze(-1) + # if denominator is zero, angle is zero + # so we set nominator and denominator as one + # acos(1) = 0 + nonvalid = denominator <= eps + nominator = torch.where(nonvalid, torch.ones_like(nominator), nominator) + denominator = torch.where(nonvalid, torch.ones_like(denominator), denominator) + cos_angle = nominator / denominator + cos_angle = torch.clamp(cos_angle, -1.0, 1.0) + triangles = torch.abs(torch.acos(cos_angle)) + triangles = torch.min(triangles, torch.pi - triangles) + triangles = triangles * (180.0 / torch.pi) + return triangles + + +def create_intri_matrix(focal_length, principal_point): + """ + Creates a intri matrix from focal length and principal point. + + Args: + focal_length (torch.Tensor): A Bx2 or BxSx2 tensor containing the focal lengths (fx, fy) for each image. + principal_point (torch.Tensor): A Bx2 or BxSx2 tensor containing the principal point coordinates (cx, cy) for each image. + + Returns: + torch.Tensor: A Bx3x3 or BxSx3x3 tensor containing the camera matrix for each image. + """ + + if len(focal_length.shape) == 2: + B = focal_length.shape[0] + intri_matrix = torch.zeros(B, 3, 3, dtype=focal_length.dtype, device=focal_length.device) + intri_matrix[:, 0, 0] = focal_length[:, 0] + intri_matrix[:, 1, 1] = focal_length[:, 1] + intri_matrix[:, 2, 2] = 1.0 + intri_matrix[:, 0, 2] = principal_point[:, 0] + intri_matrix[:, 1, 2] = principal_point[:, 1] + else: + B, S = focal_length.shape[0], focal_length.shape[1] + intri_matrix = torch.zeros(B, S, 3, 3, dtype=focal_length.dtype, device=focal_length.device) + intri_matrix[:, :, 0, 0] = focal_length[:, :, 0] + intri_matrix[:, :, 1, 1] = focal_length[:, :, 1] + intri_matrix[:, :, 2, 2] = 1.0 + intri_matrix[:, :, 0, 2] = principal_point[:, :, 0] + intri_matrix[:, :, 1, 2] = principal_point[:, :, 1] + + return intri_matrix + + +def prepare_ba_options(): + ba_options_tmp = pycolmap.BundleAdjustmentOptions() + ba_options_tmp.solver_options.function_tolerance *= 10 + ba_options_tmp.solver_options.gradient_tolerance *= 10 + ba_options_tmp.solver_options.parameter_tolerance *= 10 + + ba_options_tmp.solver_options.max_num_iterations = 50 + ba_options_tmp.solver_options.max_linear_solver_iterations = 200 + ba_options_tmp.print_summary = False + return ba_options_tmp + + +def generate_combinations(N): + # Create an array of numbers from 0 to N-1 + indices = np.arange(N) + # Generate all C(N, 2) combinations + comb = list(combinations(indices, 2)) + # Convert list of tuples into a NumPy array + comb_array = np.array(comb) + return comb_array + + +def local_refinement_tri(points1, extrinsics, inlier_mask, sorted_indices, lo_num=50): + """ + Local Refinement for triangulation + """ + B, N, _ = points1.shape + batch_index = torch.arange(B).unsqueeze(-1).expand(-1, lo_num) + + points1_expand = points1.unsqueeze(1).expand(-1, lo_num, -1, -1) + extrinsics_expand = extrinsics.unsqueeze(1).expand(-1, lo_num, -1, -1, -1) + + # The sets selected for local refinement + lo_indices = sorted_indices[:, :lo_num] + + # Find the points that would be used for local_estimator + lo_mask = inlier_mask[batch_index, lo_indices] + lo_points1 = torch.zeros_like(points1_expand) + lo_points1[lo_mask] = points1_expand[lo_mask] + + lo_points1 = lo_points1.reshape(B * lo_num, N, -1) + lo_mask = lo_mask.reshape(B * lo_num, N) + lo_extrinsics = extrinsics_expand.reshape(B * lo_num, N, 3, 4) + + # triangulate the inliers + triangulated_points, tri_angles, invalid_che_mask = triangulate_multi_view_point_batched( + lo_extrinsics, lo_points1, mask=lo_mask, compute_tri_angle=True, check_cheirality=True + ) + + triangulated_points = triangulated_points.reshape(B, lo_num, 3) + tri_angles = tri_angles.reshape(B, lo_num, -1) + + invalid_che_mask = invalid_che_mask.reshape(B, lo_num) + + return triangulated_points, tri_angles, invalid_che_mask diff --git a/vggsfm/vggsfm/utils/utils.py b/vggsfm/vggsfm/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..23ff7637e4d17d5aca75f3a88a39dd75f6c8a45f --- /dev/null +++ b/vggsfm/vggsfm/utils/utils.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import json +import warnings + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np + +from accelerate.utils import set_seed as accelerate_set_seed, PrecisionType + + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +from .metric import closed_form_inverse + + +def calculate_index_mappings(query_index, S, device=None): + """ + Construct an order that we can switch [query_index] and [0] + so that the content of query_index would be placed at [0] + """ + new_order = torch.arange(S) + new_order[0] = query_index + new_order[query_index] = 0 + if device is not None: + new_order = new_order.to(device) + return new_order + + +def switch_tensor_order(tensors, order, dim=1): + """ + Switch the tensor among the specific dimension + """ + return [torch.index_select(tensor, dim, order) for tensor in tensors] + + +def set_seed_and_print(seed): + accelerate_set_seed(seed, device_specific=True) + print(f"----------Seed is set to {np.random.get_state()[1][0]} now----------") + + +def transform_camera_relative_to_first(pred_cameras, batch_size): + pred_se3 = pred_cameras.get_world_to_view_transform().get_matrix() + rel_transform = closed_form_inverse(pred_se3[0:1, :, :]) + rel_transform = rel_transform.expand(batch_size, -1, -1) + + pred_se3_rel = torch.bmm(rel_transform, pred_se3) + pred_se3_rel[..., :3, 3] = 0.0 + pred_se3_rel[..., 3, 3] = 1.0 + + pred_cameras.R = pred_se3_rel[:, :3, :3].clone() + pred_cameras.T = pred_se3_rel[:, 3, :3].clone() + return pred_cameras + + +def farthest_point_sampling(distance_matrix, num_samples, most_common_frame_index=0): + # Number of points + distance_matrix = distance_matrix.clamp(min=0) + + N = distance_matrix.size(0) + + # Initialize + # Start from the first point (arbitrary choice) + selected_indices = [most_common_frame_index] + # Track the minimum distances to the selected set + check_distances = distance_matrix[selected_indices] + + while len(selected_indices) < num_samples: + # Find the farthest point from the current set of selected points + farthest_point = torch.argmax(check_distances) + selected_indices.append(farthest_point.item()) + + check_distances = distance_matrix[farthest_point] + # the ones already selected would not selected any more + check_distances[selected_indices] = 0 + + # Break the loop if all points have been selected + if len(selected_indices) == N: + break + + return selected_indices + + +def visual_query_points(images, query_index, query_points): + """ + Processes an image by converting it to BGR color space, drawing circles at specified points, + and saving the image to a file. + Args: + images (torch.Tensor): A batch of images in the shape (N, C, H, W). + query_index (int): The index of the image in the batch to process. + query_points (list of tuples): List of (x, y) tuples where circles should be drawn. + Returns: + None + """ + # Convert the image from RGB to BGR + image_cv2 = cv2.cvtColor( + (images[:, query_index].squeeze().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8), cv2.COLOR_RGB2BGR + ) + + # Draw circles at the specified query points + for x, y in query_points[0]: + image_cv2 = cv2.circle(image_cv2, (int(x), int(y)), 4, (0, 255, 0), -1) + + # Save the processed image to a file + cv2.imwrite("image_cv2.png", image_cv2)