UP
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .buddy/docker-build.fixed.yml +6 -0
- .circleci/config.yml +37 -0
- .flake8 +3 -0
- .gitattributes +6 -0
- .gitattributes.bak +3 -0
- .github/ISSUE_TEMPLATE/bug_report.md +37 -0
- .github/workflows/stale.yml +29 -0
- .gitignore +15 -0
- LICENSE +661 -0
- README.md +156 -0
- clip/__init__.py +1 -0
- clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- clip/clip.py +241 -0
- clip/clipseg.py +538 -0
- clip/model.py +436 -0
- clip/simple_tokenizer.py +132 -0
- clip/vitseg.py +286 -0
- config_colab.yaml +14 -0
- docs/screenshot.png +3 -0
- installer/installer.py +87 -0
- installer/windows_run.bat +99 -0
- mypy.ini +7 -0
- requirements.txt +19 -0
- roop-unleashed.ipynb +208 -0
- roop/FaceSet.py +20 -0
- roop/ProcessEntry.py +7 -0
- roop/ProcessMgr.py +702 -0
- roop/ProcessOptions.py +13 -0
- roop/__init__.py +0 -0
- roop/capturer.py +30 -0
- roop/core.py +378 -0
- roop/face_util.py +306 -0
- roop/ffmpeg_writer.py +218 -0
- roop/globals.py +53 -0
- roop/metadata.py +2 -0
- roop/processors/Enhance_CodeFormer.py +75 -0
- roop/processors/Enhance_DMDNet.py +898 -0
- roop/processors/Enhance_GFPGAN.py +77 -0
- roop/processors/Enhance_GPEN.py +63 -0
- roop/processors/Enhance_RestoreFormerPPlus.py +64 -0
- roop/processors/FaceSwapInsightFace.py +69 -0
- roop/processors/Frame_Colorizer.py +70 -0
- roop/processors/Frame_Filter.py +105 -0
- roop/processors/Frame_Masking.py +71 -0
- roop/processors/Frame_Upscale.py +131 -0
- roop/processors/Mask_Clip2Seg.py +94 -0
- roop/processors/Mask_XSeg.py +60 -0
- roop/processors/__init__.py +0 -0
- roop/template_parser.py +23 -0
- roop/typing.py +9 -0
.buddy/docker-build.fixed.yml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
- pipeline: "docker-build"
|
2 |
+
events:
|
3 |
+
- type: "PUSH"
|
4 |
+
refs:
|
5 |
+
- "refs/heads/main"
|
6 |
+
fail_on_prepare_env_warning: true
|
.circleci/config.yml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This config was automatically generated from your source code
|
2 |
+
# Stacks detected: cicd:github-actions:.github/workflows,deps:python:.
|
3 |
+
version: 2.1
|
4 |
+
orbs:
|
5 |
+
python: circleci/python@2
|
6 |
+
jobs:
|
7 |
+
test-python:
|
8 |
+
# Install dependencies and run tests
|
9 |
+
docker:
|
10 |
+
- image: cimg/python:3.8-node
|
11 |
+
steps:
|
12 |
+
- checkout
|
13 |
+
- python/install-packages
|
14 |
+
- run:
|
15 |
+
name: Run tests
|
16 |
+
command: pytest --junitxml=junit.xml || ((($? == 5)) && echo 'Did not find any tests to run.')
|
17 |
+
- store_test_results:
|
18 |
+
path: junit.xml
|
19 |
+
deploy:
|
20 |
+
# This is an example deploy job, not actually used by the workflow
|
21 |
+
docker:
|
22 |
+
- image: cimg/base:stable
|
23 |
+
steps:
|
24 |
+
# Replace this with steps to deploy to users
|
25 |
+
- run:
|
26 |
+
name: deploy
|
27 |
+
command: '#e.g. ./deploy.sh'
|
28 |
+
- run:
|
29 |
+
name: found github actions config
|
30 |
+
command: ':'
|
31 |
+
workflows:
|
32 |
+
build-and-test:
|
33 |
+
jobs:
|
34 |
+
- test-python
|
35 |
+
# - deploy:
|
36 |
+
# requires:
|
37 |
+
# - test-python
|
.flake8
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[flake8]
|
2 |
+
select = E3, E4, F
|
3 |
+
per-file-ignores = roop/core.py:E402
|
.gitattributes
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.github/examples/snapshot.mp4 filter=lfs diff=lfs merge=lfs -text
|
2 |
+
examples/snapshot.mp4 filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
6 |
+
|
.gitattributes.bak
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.github/examples/snapshot.mp4 filter=lfs diff=lfs merge=lfs -text
|
2 |
+
examples/snapshot.mp4 filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
.github/ISSUE_TEMPLATE/bug_report.md
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: Bug report
|
3 |
+
about: Create a report to help us improve
|
4 |
+
title: ''
|
5 |
+
labels: ''
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
**Describe the bug**
|
11 |
+
A clear and concise description of what the bug is.
|
12 |
+
|
13 |
+
**To Reproduce**
|
14 |
+
Steps to reproduce the behavior:
|
15 |
+
1. Go to '...'
|
16 |
+
2. Click on '....'
|
17 |
+
3. Scroll down to '....'
|
18 |
+
4. See error
|
19 |
+
|
20 |
+
**Details**
|
21 |
+
What OS are you using?
|
22 |
+
- [ ] Linux
|
23 |
+
- [ ] Linux in WSL
|
24 |
+
- [ ] Windows
|
25 |
+
- [ ] Mac
|
26 |
+
|
27 |
+
Are you using a GPU?
|
28 |
+
- [ ] No. CPU FTW
|
29 |
+
- [ ] NVIDIA
|
30 |
+
- [ ] AMD
|
31 |
+
- [ ] Intel
|
32 |
+
- [ ] Mac
|
33 |
+
|
34 |
+
**Which version of roop unleashed are you using?**
|
35 |
+
|
36 |
+
**Screenshots**
|
37 |
+
If applicable, add screenshots to help explain your problem.
|
.github/workflows/stale.yml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.
|
2 |
+
#
|
3 |
+
# You can adjust the behavior by modifying this file.
|
4 |
+
# For more information, see:
|
5 |
+
# https://github.com/actions/stale
|
6 |
+
name: Mark stale issues and pull requests
|
7 |
+
|
8 |
+
on:
|
9 |
+
schedule:
|
10 |
+
- cron: '32 0 * * *'
|
11 |
+
|
12 |
+
jobs:
|
13 |
+
stale:
|
14 |
+
|
15 |
+
runs-on: ubuntu-latest
|
16 |
+
permissions:
|
17 |
+
issues: write
|
18 |
+
pull-requests: write
|
19 |
+
|
20 |
+
steps:
|
21 |
+
- uses: actions/stale@v5
|
22 |
+
with:
|
23 |
+
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
24 |
+
stale-issue-message: 'This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.'
|
25 |
+
stale-pr-message: 'This PR is stale because it has been open 45 days with no activity. Remove stale label or comment or this will be closed in 10 days.'
|
26 |
+
close-issue-message: 'This issue was closed because it has been stalled for 5 days with no activity.'
|
27 |
+
days-before-stale: 30
|
28 |
+
days-before-close: 5
|
29 |
+
days-before-pr-close: -1
|
.gitignore
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.vs
|
2 |
+
.idea
|
3 |
+
models
|
4 |
+
temp
|
5 |
+
__pycache__
|
6 |
+
*.pth
|
7 |
+
/start.bat
|
8 |
+
/env
|
9 |
+
.vscode
|
10 |
+
output
|
11 |
+
temp
|
12 |
+
config.yaml
|
13 |
+
run.bat
|
14 |
+
venv
|
15 |
+
start.sh
|
LICENSE
ADDED
@@ -0,0 +1,661 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GNU AFFERO GENERAL PUBLIC LICENSE
|
2 |
+
Version 3, 19 November 2007
|
3 |
+
|
4 |
+
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
5 |
+
Everyone is permitted to copy and distribute verbatim copies
|
6 |
+
of this license document, but changing it is not allowed.
|
7 |
+
|
8 |
+
Preamble
|
9 |
+
|
10 |
+
The GNU Affero General Public License is a free, copyleft license for
|
11 |
+
software and other kinds of works, specifically designed to ensure
|
12 |
+
cooperation with the community in the case of network server software.
|
13 |
+
|
14 |
+
The licenses for most software and other practical works are designed
|
15 |
+
to take away your freedom to share and change the works. By contrast,
|
16 |
+
our General Public Licenses are intended to guarantee your freedom to
|
17 |
+
share and change all versions of a program--to make sure it remains free
|
18 |
+
software for all its users.
|
19 |
+
|
20 |
+
When we speak of free software, we are referring to freedom, not
|
21 |
+
price. Our General Public Licenses are designed to make sure that you
|
22 |
+
have the freedom to distribute copies of free software (and charge for
|
23 |
+
them if you wish), that you receive source code or can get it if you
|
24 |
+
want it, that you can change the software or use pieces of it in new
|
25 |
+
free programs, and that you know you can do these things.
|
26 |
+
|
27 |
+
Developers that use our General Public Licenses protect your rights
|
28 |
+
with two steps: (1) assert copyright on the software, and (2) offer
|
29 |
+
you this License which gives you legal permission to copy, distribute
|
30 |
+
and/or modify the software.
|
31 |
+
|
32 |
+
A secondary benefit of defending all users' freedom is that
|
33 |
+
improvements made in alternate versions of the program, if they
|
34 |
+
receive widespread use, become available for other developers to
|
35 |
+
incorporate. Many developers of free software are heartened and
|
36 |
+
encouraged by the resulting cooperation. However, in the case of
|
37 |
+
software used on network servers, this result may fail to come about.
|
38 |
+
The GNU General Public License permits making a modified version and
|
39 |
+
letting the public access it on a server without ever releasing its
|
40 |
+
source code to the public.
|
41 |
+
|
42 |
+
The GNU Affero General Public License is designed specifically to
|
43 |
+
ensure that, in such cases, the modified source code becomes available
|
44 |
+
to the community. It requires the operator of a network server to
|
45 |
+
provide the source code of the modified version running there to the
|
46 |
+
users of that server. Therefore, public use of a modified version, on
|
47 |
+
a publicly accessible server, gives the public access to the source
|
48 |
+
code of the modified version.
|
49 |
+
|
50 |
+
An older license, called the Affero General Public License and
|
51 |
+
published by Affero, was designed to accomplish similar goals. This is
|
52 |
+
a different license, not a version of the Affero GPL, but Affero has
|
53 |
+
released a new version of the Affero GPL which permits relicensing under
|
54 |
+
this license.
|
55 |
+
|
56 |
+
The precise terms and conditions for copying, distribution and
|
57 |
+
modification follow.
|
58 |
+
|
59 |
+
TERMS AND CONDITIONS
|
60 |
+
|
61 |
+
0. Definitions.
|
62 |
+
|
63 |
+
"This License" refers to version 3 of the GNU Affero General Public License.
|
64 |
+
|
65 |
+
"Copyright" also means copyright-like laws that apply to other kinds of
|
66 |
+
works, such as semiconductor masks.
|
67 |
+
|
68 |
+
"The Program" refers to any copyrightable work licensed under this
|
69 |
+
License. Each licensee is addressed as "you". "Licensees" and
|
70 |
+
"recipients" may be individuals or organizations.
|
71 |
+
|
72 |
+
To "modify" a work means to copy from or adapt all or part of the work
|
73 |
+
in a fashion requiring copyright permission, other than the making of an
|
74 |
+
exact copy. The resulting work is called a "modified version" of the
|
75 |
+
earlier work or a work "based on" the earlier work.
|
76 |
+
|
77 |
+
A "covered work" means either the unmodified Program or a work based
|
78 |
+
on the Program.
|
79 |
+
|
80 |
+
To "propagate" a work means to do anything with it that, without
|
81 |
+
permission, would make you directly or secondarily liable for
|
82 |
+
infringement under applicable copyright law, except executing it on a
|
83 |
+
computer or modifying a private copy. Propagation includes copying,
|
84 |
+
distribution (with or without modification), making available to the
|
85 |
+
public, and in some countries other activities as well.
|
86 |
+
|
87 |
+
To "convey" a work means any kind of propagation that enables other
|
88 |
+
parties to make or receive copies. Mere interaction with a user through
|
89 |
+
a computer network, with no transfer of a copy, is not conveying.
|
90 |
+
|
91 |
+
An interactive user interface displays "Appropriate Legal Notices"
|
92 |
+
to the extent that it includes a convenient and prominently visible
|
93 |
+
feature that (1) displays an appropriate copyright notice, and (2)
|
94 |
+
tells the user that there is no warranty for the work (except to the
|
95 |
+
extent that warranties are provided), that licensees may convey the
|
96 |
+
work under this License, and how to view a copy of this License. If
|
97 |
+
the interface presents a list of user commands or options, such as a
|
98 |
+
menu, a prominent item in the list meets this criterion.
|
99 |
+
|
100 |
+
1. Source Code.
|
101 |
+
|
102 |
+
The "source code" for a work means the preferred form of the work
|
103 |
+
for making modifications to it. "Object code" means any non-source
|
104 |
+
form of a work.
|
105 |
+
|
106 |
+
A "Standard Interface" means an interface that either is an official
|
107 |
+
standard defined by a recognized standards body, or, in the case of
|
108 |
+
interfaces specified for a particular programming language, one that
|
109 |
+
is widely used among developers working in that language.
|
110 |
+
|
111 |
+
The "System Libraries" of an executable work include anything, other
|
112 |
+
than the work as a whole, that (a) is included in the normal form of
|
113 |
+
packaging a Major Component, but which is not part of that Major
|
114 |
+
Component, and (b) serves only to enable use of the work with that
|
115 |
+
Major Component, or to implement a Standard Interface for which an
|
116 |
+
implementation is available to the public in source code form. A
|
117 |
+
"Major Component", in this context, means a major essential component
|
118 |
+
(kernel, window system, and so on) of the specific operating system
|
119 |
+
(if any) on which the executable work runs, or a compiler used to
|
120 |
+
produce the work, or an object code interpreter used to run it.
|
121 |
+
|
122 |
+
The "Corresponding Source" for a work in object code form means all
|
123 |
+
the source code needed to generate, install, and (for an executable
|
124 |
+
work) run the object code and to modify the work, including scripts to
|
125 |
+
control those activities. However, it does not include the work's
|
126 |
+
System Libraries, or general-purpose tools or generally available free
|
127 |
+
programs which are used unmodified in performing those activities but
|
128 |
+
which are not part of the work. For example, Corresponding Source
|
129 |
+
includes interface definition files associated with source files for
|
130 |
+
the work, and the source code for shared libraries and dynamically
|
131 |
+
linked subprograms that the work is specifically designed to require,
|
132 |
+
such as by intimate data communication or control flow between those
|
133 |
+
subprograms and other parts of the work.
|
134 |
+
|
135 |
+
The Corresponding Source need not include anything that users
|
136 |
+
can regenerate automatically from other parts of the Corresponding
|
137 |
+
Source.
|
138 |
+
|
139 |
+
The Corresponding Source for a work in source code form is that
|
140 |
+
same work.
|
141 |
+
|
142 |
+
2. Basic Permissions.
|
143 |
+
|
144 |
+
All rights granted under this License are granted for the term of
|
145 |
+
copyright on the Program, and are irrevocable provided the stated
|
146 |
+
conditions are met. This License explicitly affirms your unlimited
|
147 |
+
permission to run the unmodified Program. The output from running a
|
148 |
+
covered work is covered by this License only if the output, given its
|
149 |
+
content, constitutes a covered work. This License acknowledges your
|
150 |
+
rights of fair use or other equivalent, as provided by copyright law.
|
151 |
+
|
152 |
+
You may make, run and propagate covered works that you do not
|
153 |
+
convey, without conditions so long as your license otherwise remains
|
154 |
+
in force. You may convey covered works to others for the sole purpose
|
155 |
+
of having them make modifications exclusively for you, or provide you
|
156 |
+
with facilities for running those works, provided that you comply with
|
157 |
+
the terms of this License in conveying all material for which you do
|
158 |
+
not control copyright. Those thus making or running the covered works
|
159 |
+
for you must do so exclusively on your behalf, under your direction
|
160 |
+
and control, on terms that prohibit them from making any copies of
|
161 |
+
your copyrighted material outside their relationship with you.
|
162 |
+
|
163 |
+
Conveying under any other circumstances is permitted solely under
|
164 |
+
the conditions stated below. Sublicensing is not allowed; section 10
|
165 |
+
makes it unnecessary.
|
166 |
+
|
167 |
+
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
168 |
+
|
169 |
+
No covered work shall be deemed part of an effective technological
|
170 |
+
measure under any applicable law fulfilling obligations under article
|
171 |
+
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
172 |
+
similar laws prohibiting or restricting circumvention of such
|
173 |
+
measures.
|
174 |
+
|
175 |
+
When you convey a covered work, you waive any legal power to forbid
|
176 |
+
circumvention of technological measures to the extent such circumvention
|
177 |
+
is effected by exercising rights under this License with respect to
|
178 |
+
the covered work, and you disclaim any intention to limit operation or
|
179 |
+
modification of the work as a means of enforcing, against the work's
|
180 |
+
users, your or third parties' legal rights to forbid circumvention of
|
181 |
+
technological measures.
|
182 |
+
|
183 |
+
4. Conveying Verbatim Copies.
|
184 |
+
|
185 |
+
You may convey verbatim copies of the Program's source code as you
|
186 |
+
receive it, in any medium, provided that you conspicuously and
|
187 |
+
appropriately publish on each copy an appropriate copyright notice;
|
188 |
+
keep intact all notices stating that this License and any
|
189 |
+
non-permissive terms added in accord with section 7 apply to the code;
|
190 |
+
keep intact all notices of the absence of any warranty; and give all
|
191 |
+
recipients a copy of this License along with the Program.
|
192 |
+
|
193 |
+
You may charge any price or no price for each copy that you convey,
|
194 |
+
and you may offer support or warranty protection for a fee.
|
195 |
+
|
196 |
+
5. Conveying Modified Source Versions.
|
197 |
+
|
198 |
+
You may convey a work based on the Program, or the modifications to
|
199 |
+
produce it from the Program, in the form of source code under the
|
200 |
+
terms of section 4, provided that you also meet all of these conditions:
|
201 |
+
|
202 |
+
a) The work must carry prominent notices stating that you modified
|
203 |
+
it, and giving a relevant date.
|
204 |
+
|
205 |
+
b) The work must carry prominent notices stating that it is
|
206 |
+
released under this License and any conditions added under section
|
207 |
+
7. This requirement modifies the requirement in section 4 to
|
208 |
+
"keep intact all notices".
|
209 |
+
|
210 |
+
c) You must license the entire work, as a whole, under this
|
211 |
+
License to anyone who comes into possession of a copy. This
|
212 |
+
License will therefore apply, along with any applicable section 7
|
213 |
+
additional terms, to the whole of the work, and all its parts,
|
214 |
+
regardless of how they are packaged. This License gives no
|
215 |
+
permission to license the work in any other way, but it does not
|
216 |
+
invalidate such permission if you have separately received it.
|
217 |
+
|
218 |
+
d) If the work has interactive user interfaces, each must display
|
219 |
+
Appropriate Legal Notices; however, if the Program has interactive
|
220 |
+
interfaces that do not display Appropriate Legal Notices, your
|
221 |
+
work need not make them do so.
|
222 |
+
|
223 |
+
A compilation of a covered work with other separate and independent
|
224 |
+
works, which are not by their nature extensions of the covered work,
|
225 |
+
and which are not combined with it such as to form a larger program,
|
226 |
+
in or on a volume of a storage or distribution medium, is called an
|
227 |
+
"aggregate" if the compilation and its resulting copyright are not
|
228 |
+
used to limit the access or legal rights of the compilation's users
|
229 |
+
beyond what the individual works permit. Inclusion of a covered work
|
230 |
+
in an aggregate does not cause this License to apply to the other
|
231 |
+
parts of the aggregate.
|
232 |
+
|
233 |
+
6. Conveying Non-Source Forms.
|
234 |
+
|
235 |
+
You may convey a covered work in object code form under the terms
|
236 |
+
of sections 4 and 5, provided that you also convey the
|
237 |
+
machine-readable Corresponding Source under the terms of this License,
|
238 |
+
in one of these ways:
|
239 |
+
|
240 |
+
a) Convey the object code in, or embodied in, a physical product
|
241 |
+
(including a physical distribution medium), accompanied by the
|
242 |
+
Corresponding Source fixed on a durable physical medium
|
243 |
+
customarily used for software interchange.
|
244 |
+
|
245 |
+
b) Convey the object code in, or embodied in, a physical product
|
246 |
+
(including a physical distribution medium), accompanied by a
|
247 |
+
written offer, valid for at least three years and valid for as
|
248 |
+
long as you offer spare parts or customer support for that product
|
249 |
+
model, to give anyone who possesses the object code either (1) a
|
250 |
+
copy of the Corresponding Source for all the software in the
|
251 |
+
product that is covered by this License, on a durable physical
|
252 |
+
medium customarily used for software interchange, for a price no
|
253 |
+
more than your reasonable cost of physically performing this
|
254 |
+
conveying of source, or (2) access to copy the
|
255 |
+
Corresponding Source from a network server at no charge.
|
256 |
+
|
257 |
+
c) Convey individual copies of the object code with a copy of the
|
258 |
+
written offer to provide the Corresponding Source. This
|
259 |
+
alternative is allowed only occasionally and noncommercially, and
|
260 |
+
only if you received the object code with such an offer, in accord
|
261 |
+
with subsection 6b.
|
262 |
+
|
263 |
+
d) Convey the object code by offering access from a designated
|
264 |
+
place (gratis or for a charge), and offer equivalent access to the
|
265 |
+
Corresponding Source in the same way through the same place at no
|
266 |
+
further charge. You need not require recipients to copy the
|
267 |
+
Corresponding Source along with the object code. If the place to
|
268 |
+
copy the object code is a network server, the Corresponding Source
|
269 |
+
may be on a different server (operated by you or a third party)
|
270 |
+
that supports equivalent copying facilities, provided you maintain
|
271 |
+
clear directions next to the object code saying where to find the
|
272 |
+
Corresponding Source. Regardless of what server hosts the
|
273 |
+
Corresponding Source, you remain obligated to ensure that it is
|
274 |
+
available for as long as needed to satisfy these requirements.
|
275 |
+
|
276 |
+
e) Convey the object code using peer-to-peer transmission, provided
|
277 |
+
you inform other peers where the object code and Corresponding
|
278 |
+
Source of the work are being offered to the general public at no
|
279 |
+
charge under subsection 6d.
|
280 |
+
|
281 |
+
A separable portion of the object code, whose source code is excluded
|
282 |
+
from the Corresponding Source as a System Library, need not be
|
283 |
+
included in conveying the object code work.
|
284 |
+
|
285 |
+
A "User Product" is either (1) a "consumer product", which means any
|
286 |
+
tangible personal property which is normally used for personal, family,
|
287 |
+
or household purposes, or (2) anything designed or sold for incorporation
|
288 |
+
into a dwelling. In determining whether a product is a consumer product,
|
289 |
+
doubtful cases shall be resolved in favor of coverage. For a particular
|
290 |
+
product received by a particular user, "normally used" refers to a
|
291 |
+
typical or common use of that class of product, regardless of the status
|
292 |
+
of the particular user or of the way in which the particular user
|
293 |
+
actually uses, or expects or is expected to use, the product. A product
|
294 |
+
is a consumer product regardless of whether the product has substantial
|
295 |
+
commercial, industrial or non-consumer uses, unless such uses represent
|
296 |
+
the only significant mode of use of the product.
|
297 |
+
|
298 |
+
"Installation Information" for a User Product means any methods,
|
299 |
+
procedures, authorization keys, or other information required to install
|
300 |
+
and execute modified versions of a covered work in that User Product from
|
301 |
+
a modified version of its Corresponding Source. The information must
|
302 |
+
suffice to ensure that the continued functioning of the modified object
|
303 |
+
code is in no case prevented or interfered with solely because
|
304 |
+
modification has been made.
|
305 |
+
|
306 |
+
If you convey an object code work under this section in, or with, or
|
307 |
+
specifically for use in, a User Product, and the conveying occurs as
|
308 |
+
part of a transaction in which the right of possession and use of the
|
309 |
+
User Product is transferred to the recipient in perpetuity or for a
|
310 |
+
fixed term (regardless of how the transaction is characterized), the
|
311 |
+
Corresponding Source conveyed under this section must be accompanied
|
312 |
+
by the Installation Information. But this requirement does not apply
|
313 |
+
if neither you nor any third party retains the ability to install
|
314 |
+
modified object code on the User Product (for example, the work has
|
315 |
+
been installed in ROM).
|
316 |
+
|
317 |
+
The requirement to provide Installation Information does not include a
|
318 |
+
requirement to continue to provide support service, warranty, or updates
|
319 |
+
for a work that has been modified or installed by the recipient, or for
|
320 |
+
the User Product in which it has been modified or installed. Access to a
|
321 |
+
network may be denied when the modification itself materially and
|
322 |
+
adversely affects the operation of the network or violates the rules and
|
323 |
+
protocols for communication across the network.
|
324 |
+
|
325 |
+
Corresponding Source conveyed, and Installation Information provided,
|
326 |
+
in accord with this section must be in a format that is publicly
|
327 |
+
documented (and with an implementation available to the public in
|
328 |
+
source code form), and must require no special password or key for
|
329 |
+
unpacking, reading or copying.
|
330 |
+
|
331 |
+
7. Additional Terms.
|
332 |
+
|
333 |
+
"Additional permissions" are terms that supplement the terms of this
|
334 |
+
License by making exceptions from one or more of its conditions.
|
335 |
+
Additional permissions that are applicable to the entire Program shall
|
336 |
+
be treated as though they were included in this License, to the extent
|
337 |
+
that they are valid under applicable law. If additional permissions
|
338 |
+
apply only to part of the Program, that part may be used separately
|
339 |
+
under those permissions, but the entire Program remains governed by
|
340 |
+
this License without regard to the additional permissions.
|
341 |
+
|
342 |
+
When you convey a copy of a covered work, you may at your option
|
343 |
+
remove any additional permissions from that copy, or from any part of
|
344 |
+
it. (Additional permissions may be written to require their own
|
345 |
+
removal in certain cases when you modify the work.) You may place
|
346 |
+
additional permissions on material, added by you to a covered work,
|
347 |
+
for which you have or can give appropriate copyright permission.
|
348 |
+
|
349 |
+
Notwithstanding any other provision of this License, for material you
|
350 |
+
add to a covered work, you may (if authorized by the copyright holders of
|
351 |
+
that material) supplement the terms of this License with terms:
|
352 |
+
|
353 |
+
a) Disclaiming warranty or limiting liability differently from the
|
354 |
+
terms of sections 15 and 16 of this License; or
|
355 |
+
|
356 |
+
b) Requiring preservation of specified reasonable legal notices or
|
357 |
+
author attributions in that material or in the Appropriate Legal
|
358 |
+
Notices displayed by works containing it; or
|
359 |
+
|
360 |
+
c) Prohibiting misrepresentation of the origin of that material, or
|
361 |
+
requiring that modified versions of such material be marked in
|
362 |
+
reasonable ways as different from the original version; or
|
363 |
+
|
364 |
+
d) Limiting the use for publicity purposes of names of licensors or
|
365 |
+
authors of the material; or
|
366 |
+
|
367 |
+
e) Declining to grant rights under trademark law for use of some
|
368 |
+
trade names, trademarks, or service marks; or
|
369 |
+
|
370 |
+
f) Requiring indemnification of licensors and authors of that
|
371 |
+
material by anyone who conveys the material (or modified versions of
|
372 |
+
it) with contractual assumptions of liability to the recipient, for
|
373 |
+
any liability that these contractual assumptions directly impose on
|
374 |
+
those licensors and authors.
|
375 |
+
|
376 |
+
All other non-permissive additional terms are considered "further
|
377 |
+
restrictions" within the meaning of section 10. If the Program as you
|
378 |
+
received it, or any part of it, contains a notice stating that it is
|
379 |
+
governed by this License along with a term that is a further
|
380 |
+
restriction, you may remove that term. If a license document contains
|
381 |
+
a further restriction but permits relicensing or conveying under this
|
382 |
+
License, you may add to a covered work material governed by the terms
|
383 |
+
of that license document, provided that the further restriction does
|
384 |
+
not survive such relicensing or conveying.
|
385 |
+
|
386 |
+
If you add terms to a covered work in accord with this section, you
|
387 |
+
must place, in the relevant source files, a statement of the
|
388 |
+
additional terms that apply to those files, or a notice indicating
|
389 |
+
where to find the applicable terms.
|
390 |
+
|
391 |
+
Additional terms, permissive or non-permissive, may be stated in the
|
392 |
+
form of a separately written license, or stated as exceptions;
|
393 |
+
the above requirements apply either way.
|
394 |
+
|
395 |
+
8. Termination.
|
396 |
+
|
397 |
+
You may not propagate or modify a covered work except as expressly
|
398 |
+
provided under this License. Any attempt otherwise to propagate or
|
399 |
+
modify it is void, and will automatically terminate your rights under
|
400 |
+
this License (including any patent licenses granted under the third
|
401 |
+
paragraph of section 11).
|
402 |
+
|
403 |
+
However, if you cease all violation of this License, then your
|
404 |
+
license from a particular copyright holder is reinstated (a)
|
405 |
+
provisionally, unless and until the copyright holder explicitly and
|
406 |
+
finally terminates your license, and (b) permanently, if the copyright
|
407 |
+
holder fails to notify you of the violation by some reasonable means
|
408 |
+
prior to 60 days after the cessation.
|
409 |
+
|
410 |
+
Moreover, your license from a particular copyright holder is
|
411 |
+
reinstated permanently if the copyright holder notifies you of the
|
412 |
+
violation by some reasonable means, this is the first time you have
|
413 |
+
received notice of violation of this License (for any work) from that
|
414 |
+
copyright holder, and you cure the violation prior to 30 days after
|
415 |
+
your receipt of the notice.
|
416 |
+
|
417 |
+
Termination of your rights under this section does not terminate the
|
418 |
+
licenses of parties who have received copies or rights from you under
|
419 |
+
this License. If your rights have been terminated and not permanently
|
420 |
+
reinstated, you do not qualify to receive new licenses for the same
|
421 |
+
material under section 10.
|
422 |
+
|
423 |
+
9. Acceptance Not Required for Having Copies.
|
424 |
+
|
425 |
+
You are not required to accept this License in order to receive or
|
426 |
+
run a copy of the Program. Ancillary propagation of a covered work
|
427 |
+
occurring solely as a consequence of using peer-to-peer transmission
|
428 |
+
to receive a copy likewise does not require acceptance. However,
|
429 |
+
nothing other than this License grants you permission to propagate or
|
430 |
+
modify any covered work. These actions infringe copyright if you do
|
431 |
+
not accept this License. Therefore, by modifying or propagating a
|
432 |
+
covered work, you indicate your acceptance of this License to do so.
|
433 |
+
|
434 |
+
10. Automatic Licensing of Downstream Recipients.
|
435 |
+
|
436 |
+
Each time you convey a covered work, the recipient automatically
|
437 |
+
receives a license from the original licensors, to run, modify and
|
438 |
+
propagate that work, subject to this License. You are not responsible
|
439 |
+
for enforcing compliance by third parties with this License.
|
440 |
+
|
441 |
+
An "entity transaction" is a transaction transferring control of an
|
442 |
+
organization, or substantially all assets of one, or subdividing an
|
443 |
+
organization, or merging organizations. If propagation of a covered
|
444 |
+
work results from an entity transaction, each party to that
|
445 |
+
transaction who receives a copy of the work also receives whatever
|
446 |
+
licenses to the work the party's predecessor in interest had or could
|
447 |
+
give under the previous paragraph, plus a right to possession of the
|
448 |
+
Corresponding Source of the work from the predecessor in interest, if
|
449 |
+
the predecessor has it or can get it with reasonable efforts.
|
450 |
+
|
451 |
+
You may not impose any further restrictions on the exercise of the
|
452 |
+
rights granted or affirmed under this License. For example, you may
|
453 |
+
not impose a license fee, royalty, or other charge for exercise of
|
454 |
+
rights granted under this License, and you may not initiate litigation
|
455 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
456 |
+
any patent claim is infringed by making, using, selling, offering for
|
457 |
+
sale, or importing the Program or any portion of it.
|
458 |
+
|
459 |
+
11. Patents.
|
460 |
+
|
461 |
+
A "contributor" is a copyright holder who authorizes use under this
|
462 |
+
License of the Program or a work on which the Program is based. The
|
463 |
+
work thus licensed is called the contributor's "contributor version".
|
464 |
+
|
465 |
+
A contributor's "essential patent claims" are all patent claims
|
466 |
+
owned or controlled by the contributor, whether already acquired or
|
467 |
+
hereafter acquired, that would be infringed by some manner, permitted
|
468 |
+
by this License, of making, using, or selling its contributor version,
|
469 |
+
but do not include claims that would be infringed only as a
|
470 |
+
consequence of further modification of the contributor version. For
|
471 |
+
purposes of this definition, "control" includes the right to grant
|
472 |
+
patent sublicenses in a manner consistent with the requirements of
|
473 |
+
this License.
|
474 |
+
|
475 |
+
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
476 |
+
patent license under the contributor's essential patent claims, to
|
477 |
+
make, use, sell, offer for sale, import and otherwise run, modify and
|
478 |
+
propagate the contents of its contributor version.
|
479 |
+
|
480 |
+
In the following three paragraphs, a "patent license" is any express
|
481 |
+
agreement or commitment, however denominated, not to enforce a patent
|
482 |
+
(such as an express permission to practice a patent or covenant not to
|
483 |
+
sue for patent infringement). To "grant" such a patent license to a
|
484 |
+
party means to make such an agreement or commitment not to enforce a
|
485 |
+
patent against the party.
|
486 |
+
|
487 |
+
If you convey a covered work, knowingly relying on a patent license,
|
488 |
+
and the Corresponding Source of the work is not available for anyone
|
489 |
+
to copy, free of charge and under the terms of this License, through a
|
490 |
+
publicly available network server or other readily accessible means,
|
491 |
+
then you must either (1) cause the Corresponding Source to be so
|
492 |
+
available, or (2) arrange to deprive yourself of the benefit of the
|
493 |
+
patent license for this particular work, or (3) arrange, in a manner
|
494 |
+
consistent with the requirements of this License, to extend the patent
|
495 |
+
license to downstream recipients. "Knowingly relying" means you have
|
496 |
+
actual knowledge that, but for the patent license, your conveying the
|
497 |
+
covered work in a country, or your recipient's use of the covered work
|
498 |
+
in a country, would infringe one or more identifiable patents in that
|
499 |
+
country that you have reason to believe are valid.
|
500 |
+
|
501 |
+
If, pursuant to or in connection with a single transaction or
|
502 |
+
arrangement, you convey, or propagate by procuring conveyance of, a
|
503 |
+
covered work, and grant a patent license to some of the parties
|
504 |
+
receiving the covered work authorizing them to use, propagate, modify
|
505 |
+
or convey a specific copy of the covered work, then the patent license
|
506 |
+
you grant is automatically extended to all recipients of the covered
|
507 |
+
work and works based on it.
|
508 |
+
|
509 |
+
A patent license is "discriminatory" if it does not include within
|
510 |
+
the scope of its coverage, prohibits the exercise of, or is
|
511 |
+
conditioned on the non-exercise of one or more of the rights that are
|
512 |
+
specifically granted under this License. You may not convey a covered
|
513 |
+
work if you are a party to an arrangement with a third party that is
|
514 |
+
in the business of distributing software, under which you make payment
|
515 |
+
to the third party based on the extent of your activity of conveying
|
516 |
+
the work, and under which the third party grants, to any of the
|
517 |
+
parties who would receive the covered work from you, a discriminatory
|
518 |
+
patent license (a) in connection with copies of the covered work
|
519 |
+
conveyed by you (or copies made from those copies), or (b) primarily
|
520 |
+
for and in connection with specific products or compilations that
|
521 |
+
contain the covered work, unless you entered into that arrangement,
|
522 |
+
or that patent license was granted, prior to 28 March 2007.
|
523 |
+
|
524 |
+
Nothing in this License shall be construed as excluding or limiting
|
525 |
+
any implied license or other defenses to infringement that may
|
526 |
+
otherwise be available to you under applicable patent law.
|
527 |
+
|
528 |
+
12. No Surrender of Others' Freedom.
|
529 |
+
|
530 |
+
If conditions are imposed on you (whether by court order, agreement or
|
531 |
+
otherwise) that contradict the conditions of this License, they do not
|
532 |
+
excuse you from the conditions of this License. If you cannot convey a
|
533 |
+
covered work so as to satisfy simultaneously your obligations under this
|
534 |
+
License and any other pertinent obligations, then as a consequence you may
|
535 |
+
not convey it at all. For example, if you agree to terms that obligate you
|
536 |
+
to collect a royalty for further conveying from those to whom you convey
|
537 |
+
the Program, the only way you could satisfy both those terms and this
|
538 |
+
License would be to refrain entirely from conveying the Program.
|
539 |
+
|
540 |
+
13. Remote Network Interaction; Use with the GNU General Public License.
|
541 |
+
|
542 |
+
Notwithstanding any other provision of this License, if you modify the
|
543 |
+
Program, your modified version must prominently offer all users
|
544 |
+
interacting with it remotely through a computer network (if your version
|
545 |
+
supports such interaction) an opportunity to receive the Corresponding
|
546 |
+
Source of your version by providing access to the Corresponding Source
|
547 |
+
from a network server at no charge, through some standard or customary
|
548 |
+
means of facilitating copying of software. This Corresponding Source
|
549 |
+
shall include the Corresponding Source for any work covered by version 3
|
550 |
+
of the GNU General Public License that is incorporated pursuant to the
|
551 |
+
following paragraph.
|
552 |
+
|
553 |
+
Notwithstanding any other provision of this License, you have
|
554 |
+
permission to link or combine any covered work with a work licensed
|
555 |
+
under version 3 of the GNU General Public License into a single
|
556 |
+
combined work, and to convey the resulting work. The terms of this
|
557 |
+
License will continue to apply to the part which is the covered work,
|
558 |
+
but the work with which it is combined will remain governed by version
|
559 |
+
3 of the GNU General Public License.
|
560 |
+
|
561 |
+
14. Revised Versions of this License.
|
562 |
+
|
563 |
+
The Free Software Foundation may publish revised and/or new versions of
|
564 |
+
the GNU Affero General Public License from time to time. Such new versions
|
565 |
+
will be similar in spirit to the present version, but may differ in detail to
|
566 |
+
address new problems or concerns.
|
567 |
+
|
568 |
+
Each version is given a distinguishing version number. If the
|
569 |
+
Program specifies that a certain numbered version of the GNU Affero General
|
570 |
+
Public License "or any later version" applies to it, you have the
|
571 |
+
option of following the terms and conditions either of that numbered
|
572 |
+
version or of any later version published by the Free Software
|
573 |
+
Foundation. If the Program does not specify a version number of the
|
574 |
+
GNU Affero General Public License, you may choose any version ever published
|
575 |
+
by the Free Software Foundation.
|
576 |
+
|
577 |
+
If the Program specifies that a proxy can decide which future
|
578 |
+
versions of the GNU Affero General Public License can be used, that proxy's
|
579 |
+
public statement of acceptance of a version permanently authorizes you
|
580 |
+
to choose that version for the Program.
|
581 |
+
|
582 |
+
Later license versions may give you additional or different
|
583 |
+
permissions. However, no additional obligations are imposed on any
|
584 |
+
author or copyright holder as a result of your choosing to follow a
|
585 |
+
later version.
|
586 |
+
|
587 |
+
15. Disclaimer of Warranty.
|
588 |
+
|
589 |
+
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
590 |
+
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
591 |
+
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
592 |
+
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
593 |
+
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
594 |
+
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
595 |
+
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
596 |
+
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
597 |
+
|
598 |
+
16. Limitation of Liability.
|
599 |
+
|
600 |
+
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
601 |
+
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
602 |
+
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
603 |
+
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
604 |
+
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
605 |
+
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
606 |
+
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
607 |
+
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
608 |
+
SUCH DAMAGES.
|
609 |
+
|
610 |
+
17. Interpretation of Sections 15 and 16.
|
611 |
+
|
612 |
+
If the disclaimer of warranty and limitation of liability provided
|
613 |
+
above cannot be given local legal effect according to their terms,
|
614 |
+
reviewing courts shall apply local law that most closely approximates
|
615 |
+
an absolute waiver of all civil liability in connection with the
|
616 |
+
Program, unless a warranty or assumption of liability accompanies a
|
617 |
+
copy of the Program in return for a fee.
|
618 |
+
|
619 |
+
END OF TERMS AND CONDITIONS
|
620 |
+
|
621 |
+
How to Apply These Terms to Your New Programs
|
622 |
+
|
623 |
+
If you develop a new program, and you want it to be of the greatest
|
624 |
+
possible use to the public, the best way to achieve this is to make it
|
625 |
+
free software which everyone can redistribute and change under these terms.
|
626 |
+
|
627 |
+
To do so, attach the following notices to the program. It is safest
|
628 |
+
to attach them to the start of each source file to most effectively
|
629 |
+
state the exclusion of warranty; and each file should have at least
|
630 |
+
the "copyright" line and a pointer to where the full notice is found.
|
631 |
+
|
632 |
+
<one line to give the program's name and a brief idea of what it does.>
|
633 |
+
Copyright (C) <year> <name of author>
|
634 |
+
|
635 |
+
This program is free software: you can redistribute it and/or modify
|
636 |
+
it under the terms of the GNU Affero General Public License as published
|
637 |
+
by the Free Software Foundation, either version 3 of the License, or
|
638 |
+
(at your option) any later version.
|
639 |
+
|
640 |
+
This program is distributed in the hope that it will be useful,
|
641 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
642 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
643 |
+
GNU Affero General Public License for more details.
|
644 |
+
|
645 |
+
You should have received a copy of the GNU Affero General Public License
|
646 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
647 |
+
|
648 |
+
Also add information on how to contact you by electronic and paper mail.
|
649 |
+
|
650 |
+
If your software can interact with users remotely through a computer
|
651 |
+
network, you should also make sure that it provides a way for users to
|
652 |
+
get its source. For example, if your program is a web application, its
|
653 |
+
interface could display a "Source" link that leads users to an archive
|
654 |
+
of the code. There are many ways you could offer source, and different
|
655 |
+
solutions will be better for different programs; see section 13 for the
|
656 |
+
specific requirements.
|
657 |
+
|
658 |
+
You should also get your employer (if you work as a programmer) or school,
|
659 |
+
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
660 |
+
For more information on this, and how to apply and follow the GNU AGPL, see
|
661 |
+
<https://www.gnu.org/licenses/>.
|
README.md
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# roop-unleashed
|
2 |
+
|
3 |
+
[Changelog](#changelog) • [Usage](#usage) • [Wiki](https://github.com/C0untFloyd/roop-unleashed/wiki)
|
4 |
+
|
5 |
+
|
6 |
+
Uncensored Deepfakes for images and videos without training and an easy-to-use GUI.
|
7 |
+
|
8 |
+
|
9 |
+
![Screen](https://github.com/C0untFloyd/roop-unleashed/assets/131583554/6ee6860d-efbe-4337-8c62-a67598863637)
|
10 |
+
|
11 |
+
### Features
|
12 |
+
|
13 |
+
- Platform-independant Browser GUI
|
14 |
+
- Selection of multiple input/output faces in one go
|
15 |
+
- Many different swapping modes, first detected, face selections, by gender
|
16 |
+
- Batch processing of images/videos
|
17 |
+
- Masking of face occluders using text prompts or automatically
|
18 |
+
- Optional Face Upscaler/Restoration using different enhancers
|
19 |
+
- Preview swapping from different video frames
|
20 |
+
- Live Fake Cam using your webcam
|
21 |
+
- Extras Tab for cutting videos etc.
|
22 |
+
- Settings - storing configuration for next session
|
23 |
+
- Theme Support
|
24 |
+
|
25 |
+
and lots more...
|
26 |
+
|
27 |
+
|
28 |
+
## Disclaimer
|
29 |
+
|
30 |
+
This project is for technical and academic use only.
|
31 |
+
Users of this software are expected to use this software responsibly while abiding the local law. If a face of a real person is being used, users are suggested to get consent from the concerned person and clearly mention that it is a deepfake when posting content online. Developers of this software will not be responsible for actions of end-users.
|
32 |
+
**Please do not apply it to illegal and unethical scenarios.**
|
33 |
+
|
34 |
+
In the event of violation of the legal and ethical requirements of the user's country or region, this code repository is exempt from liability
|
35 |
+
|
36 |
+
### Installation
|
37 |
+
|
38 |
+
Please refer to the [wiki](https://github.com/C0untFloyd/roop-unleashed/wiki).
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
### Usage
|
44 |
+
|
45 |
+
- Windows: run the `windows_run.bat` from the Installer.
|
46 |
+
- Linux: `python run.py`
|
47 |
+
|
48 |
+
<a target="_blank" href="https://colab.research.google.com/github/C0untFloyd/roop-unleashed/blob/main/roop-unleashed.ipynb">
|
49 |
+
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
50 |
+
</a>
|
51 |
+
|
52 |
+
|
53 |
+
Additional commandline arguments are currently unsupported and settings should be done via the UI.
|
54 |
+
|
55 |
+
> Note: When you run this program for the first time, it will download some models roughly ~2Gb in size.
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
### Changelog
|
61 |
+
|
62 |
+
**22.04.2024** v3.9.0
|
63 |
+
|
64 |
+
- Bugfix: Face detection bounding box corrupt values at weird angles
|
65 |
+
- Rewrote mask previewing to work with every model
|
66 |
+
- Switching mask engines toggles text interactivity
|
67 |
+
- Clearing target files, resets face selection dropdown
|
68 |
+
- Massive rewrite of swapping architecture, needed for xseg implementation
|
69 |
+
- Added DFL Xseg Support for partial face occlusion
|
70 |
+
- Face masking only runs when there is a face detected
|
71 |
+
- Removed unnecessary toggle checkbox for text masking
|
72 |
+
|
73 |
+
|
74 |
+
**22.03.2024** v3.6.5
|
75 |
+
|
76 |
+
- Bugfix: Installer pulling latest update on first installation
|
77 |
+
- Bugfix: Regression issue, blurring/erosion missing from face swap
|
78 |
+
- Exposed erosion and blur amounts to UI
|
79 |
+
- Using same values for manual masking too
|
80 |
+
|
81 |
+
|
82 |
+
**20.03.2024** v3.6.3
|
83 |
+
|
84 |
+
- Bugfix: Workaround for Gradio Slider Change Bug
|
85 |
+
- Bugfix: CSS Styling to fix Gradio Image Height Bug
|
86 |
+
- Made face swapping mask offsets resolution independant
|
87 |
+
- Show offset mask as overlay
|
88 |
+
- Changed layout for masking
|
89 |
+
|
90 |
+
|
91 |
+
**18.03.2024** v3.6.0
|
92 |
+
|
93 |
+
- Updated to Gradio 4.21.0 - requiring many changes under the hood
|
94 |
+
- New manual masking (draw the mask yourself)
|
95 |
+
- Extras Tab, streamlined cutting/joining videos
|
96 |
+
- Re-added face selection by gender (on-demand loading, default turned off)
|
97 |
+
- Removed unnecessary activate live-cam option
|
98 |
+
- Added time info to preview frame and changed frame slider event to allow faster changes
|
99 |
+
|
100 |
+
|
101 |
+
**10.03.2024** v3.5.5
|
102 |
+
|
103 |
+
- Bugfix: Installer Path Env
|
104 |
+
- Bugfix: file attributes
|
105 |
+
- Video processing checks for presence of ffmpeg and displays warning if not found
|
106 |
+
- Removed gender + age detection to speed up processing. Option removed from UI
|
107 |
+
- Replaced restoreformer with restoreformer++
|
108 |
+
- Live Cam recoded to run separate from virtual cam and without blocking controls
|
109 |
+
- Swapping with only 1 target face allows selecting from several input faces
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
**08.01.2024** v3.5.0
|
114 |
+
|
115 |
+
- Bugfix: wrong access options when creating folders
|
116 |
+
- New auto rotation of horizontal faces, fixing bad landmark positions (expanded on ![PR 364](https://github.com/C0untFloyd/roop-unleashed/pull/364))
|
117 |
+
- Simple VR Option for stereo Images/Movies, best used in selected face mode
|
118 |
+
- Added RestoreFormer Enhancer - https://github.com/wzhouxiff/RestoreFormer
|
119 |
+
- Bumped up package versions for onnx/Torch etc.
|
120 |
+
|
121 |
+
|
122 |
+
**16.10.2023** v3.3.4
|
123 |
+
|
124 |
+
**11.8.2023** v2.7.0
|
125 |
+
|
126 |
+
Initial Gradio Version - old TkInter Version now deprecated
|
127 |
+
|
128 |
+
- Re-added unified padding to face enhancers
|
129 |
+
- Fixed DMDNet for all resolutions
|
130 |
+
- Selecting target face now automatically switches swapping mode to selected
|
131 |
+
- GPU providers are correctly set using the GUI (needs restart currently)
|
132 |
+
- Local output folder can be opened from page
|
133 |
+
- Unfinished extras functions disabled for now
|
134 |
+
- Installer checks out specific commit, allowing to go back to first install
|
135 |
+
- Updated readme for new gradio version
|
136 |
+
- Updated Colab
|
137 |
+
|
138 |
+
|
139 |
+
# Acknowledgements
|
140 |
+
|
141 |
+
Lots of ideas, code or pre-trained models borrowed from the following projects:
|
142 |
+
|
143 |
+
https://github.com/deepinsight/insightface<br />
|
144 |
+
https://github.com/s0md3v/roop<br />
|
145 |
+
https://github.com/AUTOMATIC1111/stable-diffusion-webui<br />
|
146 |
+
https://github.com/Hillobar/Rope<br />
|
147 |
+
https://github.com/TencentARC/GFPGAN<br />
|
148 |
+
https://github.com/kadirnar/codeformer-pip<br />
|
149 |
+
https://github.com/csxmli2016/DMDNet<br />
|
150 |
+
https://github.com/glucauze/sd-webui-faceswaplab<br />
|
151 |
+
https://github.com/ykk648/face_power<br />
|
152 |
+
|
153 |
+
<br />
|
154 |
+
<br />
|
155 |
+
Thanks to all developers!
|
156 |
+
|
clip/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .clip import *
|
clip/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
clip/clip.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import urllib
|
4 |
+
import warnings
|
5 |
+
from typing import Any, Union, List
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from .model import build_model
|
13 |
+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
14 |
+
|
15 |
+
try:
|
16 |
+
from torchvision.transforms import InterpolationMode
|
17 |
+
BICUBIC = InterpolationMode.BICUBIC
|
18 |
+
except ImportError:
|
19 |
+
BICUBIC = Image.BICUBIC
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
__all__ = ["available_models", "load", "tokenize"]
|
24 |
+
_tokenizer = _Tokenizer()
|
25 |
+
|
26 |
+
_MODELS = {
|
27 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
28 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
29 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
30 |
+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
31 |
+
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
|
32 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
33 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
34 |
+
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
35 |
+
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
def _download(url: str, root: str):
|
40 |
+
os.makedirs(root, exist_ok=True)
|
41 |
+
filename = os.path.basename(url)
|
42 |
+
|
43 |
+
expected_sha256 = url.split("/")[-2]
|
44 |
+
download_target = os.path.join(root, filename)
|
45 |
+
|
46 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
47 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
48 |
+
|
49 |
+
if os.path.isfile(download_target):
|
50 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
51 |
+
return download_target
|
52 |
+
else:
|
53 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
54 |
+
|
55 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
56 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
57 |
+
while True:
|
58 |
+
buffer = source.read(8192)
|
59 |
+
if not buffer:
|
60 |
+
break
|
61 |
+
|
62 |
+
output.write(buffer)
|
63 |
+
loop.update(len(buffer))
|
64 |
+
|
65 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
66 |
+
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
|
67 |
+
|
68 |
+
return download_target
|
69 |
+
|
70 |
+
|
71 |
+
def _convert_image_to_rgb(image):
|
72 |
+
return image.convert("RGB")
|
73 |
+
|
74 |
+
|
75 |
+
def _transform(n_px):
|
76 |
+
return Compose([
|
77 |
+
Resize(n_px, interpolation=BICUBIC),
|
78 |
+
CenterCrop(n_px),
|
79 |
+
_convert_image_to_rgb,
|
80 |
+
ToTensor(),
|
81 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
82 |
+
])
|
83 |
+
|
84 |
+
|
85 |
+
def available_models() -> List[str]:
|
86 |
+
"""Returns the names of available CLIP models"""
|
87 |
+
return list(_MODELS.keys())
|
88 |
+
|
89 |
+
|
90 |
+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
|
91 |
+
"""Load a CLIP model
|
92 |
+
|
93 |
+
Parameters
|
94 |
+
----------
|
95 |
+
name : str
|
96 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
97 |
+
|
98 |
+
device : Union[str, torch.device]
|
99 |
+
The device to put the loaded model
|
100 |
+
|
101 |
+
jit : bool
|
102 |
+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
103 |
+
|
104 |
+
download_root: str
|
105 |
+
path to download the model files; by default, it uses "~/.cache/clip"
|
106 |
+
|
107 |
+
Returns
|
108 |
+
-------
|
109 |
+
model : torch.nn.Module
|
110 |
+
The CLIP model
|
111 |
+
|
112 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
113 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
114 |
+
"""
|
115 |
+
if name in _MODELS:
|
116 |
+
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
117 |
+
elif os.path.isfile(name):
|
118 |
+
model_path = name
|
119 |
+
else:
|
120 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
121 |
+
|
122 |
+
with open(model_path, 'rb') as opened_file:
|
123 |
+
try:
|
124 |
+
# loading JIT archive
|
125 |
+
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
126 |
+
state_dict = None
|
127 |
+
except RuntimeError:
|
128 |
+
# loading saved state dict
|
129 |
+
if jit:
|
130 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
131 |
+
jit = False
|
132 |
+
state_dict = torch.load(opened_file, map_location="cpu")
|
133 |
+
|
134 |
+
if not jit:
|
135 |
+
model = build_model(state_dict or model.state_dict()).to(device)
|
136 |
+
if str(device) == "cpu":
|
137 |
+
model.float()
|
138 |
+
return model, _transform(model.visual.input_resolution)
|
139 |
+
|
140 |
+
# patch the device names
|
141 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
142 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
143 |
+
|
144 |
+
def _node_get(node: torch._C.Node, key: str):
|
145 |
+
"""Gets attributes of a node which is polymorphic over return type.
|
146 |
+
|
147 |
+
From https://github.com/pytorch/pytorch/pull/82628
|
148 |
+
"""
|
149 |
+
sel = node.kindOf(key)
|
150 |
+
return getattr(node, sel)(key)
|
151 |
+
|
152 |
+
def patch_device(module):
|
153 |
+
try:
|
154 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
155 |
+
except RuntimeError:
|
156 |
+
graphs = []
|
157 |
+
|
158 |
+
if hasattr(module, "forward1"):
|
159 |
+
graphs.append(module.forward1.graph)
|
160 |
+
|
161 |
+
for graph in graphs:
|
162 |
+
for node in graph.findAllNodes("prim::Constant"):
|
163 |
+
if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
|
164 |
+
node.copyAttributes(device_node)
|
165 |
+
|
166 |
+
model.apply(patch_device)
|
167 |
+
patch_device(model.encode_image)
|
168 |
+
patch_device(model.encode_text)
|
169 |
+
|
170 |
+
# patch dtype to float32 on CPU
|
171 |
+
if str(device) == "cpu":
|
172 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
173 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
174 |
+
float_node = float_input.node()
|
175 |
+
|
176 |
+
def patch_float(module):
|
177 |
+
try:
|
178 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
179 |
+
except RuntimeError:
|
180 |
+
graphs = []
|
181 |
+
|
182 |
+
if hasattr(module, "forward1"):
|
183 |
+
graphs.append(module.forward1.graph)
|
184 |
+
|
185 |
+
for graph in graphs:
|
186 |
+
for node in graph.findAllNodes("aten::to"):
|
187 |
+
inputs = list(node.inputs())
|
188 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
189 |
+
if _node_get(inputs[i].node(), "value") == 5:
|
190 |
+
inputs[i].node().copyAttributes(float_node)
|
191 |
+
|
192 |
+
model.apply(patch_float)
|
193 |
+
patch_float(model.encode_image)
|
194 |
+
patch_float(model.encode_text)
|
195 |
+
|
196 |
+
model.float()
|
197 |
+
|
198 |
+
return model, _transform(model.input_resolution.item())
|
199 |
+
|
200 |
+
|
201 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
|
202 |
+
"""
|
203 |
+
Returns the tokenized representation of given input string(s)
|
204 |
+
|
205 |
+
Parameters
|
206 |
+
----------
|
207 |
+
texts : Union[str, List[str]]
|
208 |
+
An input string or a list of input strings to tokenize
|
209 |
+
|
210 |
+
context_length : int
|
211 |
+
The context length to use; all CLIP models use 77 as the context length
|
212 |
+
|
213 |
+
truncate: bool
|
214 |
+
Whether to truncate the text in case its encoding is longer than the context length
|
215 |
+
|
216 |
+
Returns
|
217 |
+
-------
|
218 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
|
219 |
+
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
|
220 |
+
"""
|
221 |
+
if isinstance(texts, str):
|
222 |
+
texts = [texts]
|
223 |
+
|
224 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
225 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
226 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
227 |
+
#if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
|
228 |
+
# result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
229 |
+
#else:
|
230 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
|
231 |
+
|
232 |
+
for i, tokens in enumerate(all_tokens):
|
233 |
+
if len(tokens) > context_length:
|
234 |
+
if truncate:
|
235 |
+
tokens = tokens[:context_length]
|
236 |
+
tokens[-1] = eot_token
|
237 |
+
else:
|
238 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
239 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
240 |
+
|
241 |
+
return result
|
clip/clipseg.py
ADDED
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from os.path import basename, dirname, join, isfile
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as nnf
|
6 |
+
from torch.nn.modules.activation import ReLU
|
7 |
+
|
8 |
+
|
9 |
+
def get_prompt_list(prompt):
|
10 |
+
if prompt == 'plain':
|
11 |
+
return ['{}']
|
12 |
+
elif prompt == 'fixed':
|
13 |
+
return ['a photo of a {}.']
|
14 |
+
elif prompt == 'shuffle':
|
15 |
+
return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
|
16 |
+
elif prompt == 'shuffle+':
|
17 |
+
return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.',
|
18 |
+
'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.',
|
19 |
+
'a bad photo of a {}.', 'a photo of the {}.']
|
20 |
+
else:
|
21 |
+
raise ValueError('Invalid value for prompt')
|
22 |
+
|
23 |
+
|
24 |
+
def forward_multihead_attention(x, b, with_aff=False, attn_mask=None):
|
25 |
+
"""
|
26 |
+
Simplified version of multihead attention (taken from torch source code but without tons of if clauses).
|
27 |
+
The mlp and layer norm come from CLIP.
|
28 |
+
x: input.
|
29 |
+
b: multihead attention module.
|
30 |
+
"""
|
31 |
+
|
32 |
+
x_ = b.ln_1(x)
|
33 |
+
q, k, v = nnf.linear(x_, b.attn.in_proj_weight, b.attn.in_proj_bias).chunk(3, dim=-1)
|
34 |
+
tgt_len, bsz, embed_dim = q.size()
|
35 |
+
|
36 |
+
head_dim = embed_dim // b.attn.num_heads
|
37 |
+
scaling = float(head_dim) ** -0.5
|
38 |
+
|
39 |
+
q = q.contiguous().view(tgt_len, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
|
40 |
+
k = k.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
|
41 |
+
v = v.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1)
|
42 |
+
|
43 |
+
q = q * scaling
|
44 |
+
|
45 |
+
attn_output_weights = torch.bmm(q, k.transpose(1, 2)) # n_heads * batch_size, tokens^2, tokens^2
|
46 |
+
if attn_mask is not None:
|
47 |
+
|
48 |
+
|
49 |
+
attn_mask_type, attn_mask = attn_mask
|
50 |
+
n_heads = attn_output_weights.size(0) // attn_mask.size(0)
|
51 |
+
attn_mask = attn_mask.repeat(n_heads, 1)
|
52 |
+
|
53 |
+
if attn_mask_type == 'cls_token':
|
54 |
+
# the mask only affects similarities compared to the readout-token.
|
55 |
+
attn_output_weights[:, 0, 1:] = attn_output_weights[:, 0, 1:] * attn_mask[None,...]
|
56 |
+
# attn_output_weights[:, 0, 0] = 0*attn_output_weights[:, 0, 0]
|
57 |
+
|
58 |
+
if attn_mask_type == 'all':
|
59 |
+
# print(attn_output_weights.shape, attn_mask[:, None].shape)
|
60 |
+
attn_output_weights[:, 1:, 1:] = attn_output_weights[:, 1:, 1:] * attn_mask[:, None]
|
61 |
+
|
62 |
+
|
63 |
+
attn_output_weights = torch.softmax(attn_output_weights, dim=-1)
|
64 |
+
|
65 |
+
attn_output = torch.bmm(attn_output_weights, v)
|
66 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
67 |
+
attn_output = b.attn.out_proj(attn_output)
|
68 |
+
|
69 |
+
x = x + attn_output
|
70 |
+
x = x + b.mlp(b.ln_2(x))
|
71 |
+
|
72 |
+
if with_aff:
|
73 |
+
return x, attn_output_weights
|
74 |
+
else:
|
75 |
+
return x
|
76 |
+
|
77 |
+
|
78 |
+
class CLIPDenseBase(nn.Module):
|
79 |
+
|
80 |
+
def __init__(self, version, reduce_cond, reduce_dim, prompt, n_tokens):
|
81 |
+
super().__init__()
|
82 |
+
|
83 |
+
import clip
|
84 |
+
|
85 |
+
# prec = torch.FloatTensor
|
86 |
+
self.clip_model, _ = clip.load(version, device='cpu', jit=False)
|
87 |
+
self.model = self.clip_model.visual
|
88 |
+
|
89 |
+
# if not None, scale conv weights such that we obtain n_tokens.
|
90 |
+
self.n_tokens = n_tokens
|
91 |
+
|
92 |
+
for p in self.clip_model.parameters():
|
93 |
+
p.requires_grad_(False)
|
94 |
+
|
95 |
+
# conditional
|
96 |
+
if reduce_cond is not None:
|
97 |
+
self.reduce_cond = nn.Linear(512, reduce_cond)
|
98 |
+
for p in self.reduce_cond.parameters():
|
99 |
+
p.requires_grad_(False)
|
100 |
+
else:
|
101 |
+
self.reduce_cond = None
|
102 |
+
|
103 |
+
self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
|
104 |
+
self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
|
105 |
+
|
106 |
+
self.reduce = nn.Linear(768, reduce_dim)
|
107 |
+
|
108 |
+
self.prompt_list = get_prompt_list(prompt)
|
109 |
+
|
110 |
+
# precomputed prompts
|
111 |
+
import pickle
|
112 |
+
if isfile('precomputed_prompt_vectors.pickle'):
|
113 |
+
precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb'))
|
114 |
+
self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()}
|
115 |
+
else:
|
116 |
+
self.precomputed_prompts = dict()
|
117 |
+
|
118 |
+
def rescaled_pos_emb(self, new_size):
|
119 |
+
assert len(new_size) == 2
|
120 |
+
|
121 |
+
a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape)
|
122 |
+
b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T
|
123 |
+
return torch.cat([self.model.positional_embedding[:1], b])
|
124 |
+
|
125 |
+
def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None):
|
126 |
+
|
127 |
+
|
128 |
+
with torch.no_grad():
|
129 |
+
|
130 |
+
inp_size = x_inp.shape[2:]
|
131 |
+
|
132 |
+
if self.n_tokens is not None:
|
133 |
+
stride2 = x_inp.shape[2] // self.n_tokens
|
134 |
+
conv_weight2 = nnf.interpolate(self.model.conv1.weight, (stride2, stride2), mode='bilinear', align_corners=True)
|
135 |
+
x = nnf.conv2d(x_inp, conv_weight2, bias=self.model.conv1.bias, stride=stride2, dilation=self.model.conv1.dilation)
|
136 |
+
else:
|
137 |
+
x = self.model.conv1(x_inp) # shape = [*, width, grid, grid]
|
138 |
+
|
139 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
140 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
141 |
+
|
142 |
+
x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
143 |
+
|
144 |
+
standard_n_tokens = 50 if self.model.conv1.kernel_size[0] == 32 else 197
|
145 |
+
|
146 |
+
if x.shape[1] != standard_n_tokens:
|
147 |
+
new_shape = int(math.sqrt(x.shape[1]-1))
|
148 |
+
x = x + self.rescaled_pos_emb((new_shape, new_shape)).to(x.dtype)[None,:,:]
|
149 |
+
else:
|
150 |
+
x = x + self.model.positional_embedding.to(x.dtype)
|
151 |
+
|
152 |
+
x = self.model.ln_pre(x)
|
153 |
+
|
154 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
155 |
+
|
156 |
+
activations, affinities = [], []
|
157 |
+
for i, res_block in enumerate(self.model.transformer.resblocks):
|
158 |
+
|
159 |
+
if mask is not None:
|
160 |
+
mask_layer, mask_type, mask_tensor = mask
|
161 |
+
if mask_layer == i or mask_layer == 'all':
|
162 |
+
# import ipdb; ipdb.set_trace()
|
163 |
+
size = int(math.sqrt(x.shape[0] - 1))
|
164 |
+
|
165 |
+
attn_mask = (mask_type, nnf.interpolate(mask_tensor.unsqueeze(1).float(), (size, size)).view(mask_tensor.shape[0], size * size))
|
166 |
+
|
167 |
+
else:
|
168 |
+
attn_mask = None
|
169 |
+
else:
|
170 |
+
attn_mask = None
|
171 |
+
|
172 |
+
x, aff_per_head = forward_multihead_attention(x, res_block, with_aff=True, attn_mask=attn_mask)
|
173 |
+
|
174 |
+
if i in extract_layers:
|
175 |
+
affinities += [aff_per_head]
|
176 |
+
|
177 |
+
#if self.n_tokens is not None:
|
178 |
+
# activations += [nnf.interpolate(x, inp_size, mode='bilinear', align_corners=True)]
|
179 |
+
#else:
|
180 |
+
activations += [x]
|
181 |
+
|
182 |
+
if len(extract_layers) > 0 and i == max(extract_layers) and skip:
|
183 |
+
print('early skip')
|
184 |
+
break
|
185 |
+
|
186 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
187 |
+
x = self.model.ln_post(x[:, 0, :])
|
188 |
+
|
189 |
+
if self.model.proj is not None:
|
190 |
+
x = x @ self.model.proj
|
191 |
+
|
192 |
+
return x, activations, affinities
|
193 |
+
|
194 |
+
def sample_prompts(self, words, prompt_list=None):
|
195 |
+
|
196 |
+
prompt_list = prompt_list if prompt_list is not None else self.prompt_list
|
197 |
+
|
198 |
+
prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
|
199 |
+
prompts = [prompt_list[i] for i in prompt_indices]
|
200 |
+
return [promt.format(w) for promt, w in zip(prompts, words)]
|
201 |
+
|
202 |
+
def get_cond_vec(self, conditional, batch_size):
|
203 |
+
# compute conditional from a single string
|
204 |
+
if conditional is not None and type(conditional) == str:
|
205 |
+
cond = self.compute_conditional(conditional)
|
206 |
+
cond = cond.repeat(batch_size, 1)
|
207 |
+
|
208 |
+
# compute conditional from string list/tuple
|
209 |
+
elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str:
|
210 |
+
assert len(conditional) == batch_size
|
211 |
+
cond = self.compute_conditional(conditional)
|
212 |
+
|
213 |
+
# use conditional directly
|
214 |
+
elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2:
|
215 |
+
cond = conditional
|
216 |
+
|
217 |
+
# compute conditional from image
|
218 |
+
elif conditional is not None and type(conditional) == torch.Tensor:
|
219 |
+
with torch.no_grad():
|
220 |
+
cond, _, _ = self.visual_forward(conditional)
|
221 |
+
else:
|
222 |
+
raise ValueError('invalid conditional')
|
223 |
+
return cond
|
224 |
+
|
225 |
+
def compute_conditional(self, conditional):
|
226 |
+
import clip
|
227 |
+
|
228 |
+
dev = next(self.parameters()).device
|
229 |
+
|
230 |
+
if type(conditional) in {list, tuple}:
|
231 |
+
text_tokens = clip.tokenize(conditional).to(dev)
|
232 |
+
cond = self.clip_model.encode_text(text_tokens)
|
233 |
+
else:
|
234 |
+
if conditional in self.precomputed_prompts:
|
235 |
+
cond = self.precomputed_prompts[conditional].float().to(dev)
|
236 |
+
else:
|
237 |
+
text_tokens = clip.tokenize([conditional]).to(dev)
|
238 |
+
cond = self.clip_model.encode_text(text_tokens)[0]
|
239 |
+
|
240 |
+
if self.shift_vector is not None:
|
241 |
+
return cond + self.shift_vector
|
242 |
+
else:
|
243 |
+
return cond
|
244 |
+
|
245 |
+
|
246 |
+
def clip_load_untrained(version):
|
247 |
+
assert version == 'ViT-B/16'
|
248 |
+
from clip.model import CLIP
|
249 |
+
from clip.clip import _MODELS, _download
|
250 |
+
model = torch.jit.load(_download(_MODELS['ViT-B/16'])).eval()
|
251 |
+
state_dict = model.state_dict()
|
252 |
+
|
253 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
254 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
255 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
256 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
257 |
+
image_resolution = vision_patch_size * grid_size
|
258 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
259 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
260 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
261 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
262 |
+
transformer_heads = transformer_width // 64
|
263 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
264 |
+
|
265 |
+
return CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size,
|
266 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers)
|
267 |
+
|
268 |
+
|
269 |
+
class CLIPDensePredT(CLIPDenseBase):
|
270 |
+
|
271 |
+
def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed',
|
272 |
+
extra_blocks=0, reduce_cond=None, fix_shift=False,
|
273 |
+
learn_trans_conv_only=False, limit_to_clip_only=False, upsample=False,
|
274 |
+
add_calibration=False, rev_activations=False, trans_conv=None, n_tokens=None, complex_trans_conv=False):
|
275 |
+
|
276 |
+
super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens)
|
277 |
+
# device = 'cpu'
|
278 |
+
|
279 |
+
self.extract_layers = extract_layers
|
280 |
+
self.cond_layer = cond_layer
|
281 |
+
self.limit_to_clip_only = limit_to_clip_only
|
282 |
+
self.process_cond = None
|
283 |
+
self.rev_activations = rev_activations
|
284 |
+
|
285 |
+
depth = len(extract_layers)
|
286 |
+
|
287 |
+
if add_calibration:
|
288 |
+
self.calibration_conds = 1
|
289 |
+
|
290 |
+
self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None
|
291 |
+
|
292 |
+
self.add_activation1 = True
|
293 |
+
|
294 |
+
self.version = version
|
295 |
+
|
296 |
+
self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version]
|
297 |
+
|
298 |
+
if fix_shift:
|
299 |
+
# self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'clip_text_shift_vector.pth')), requires_grad=False)
|
300 |
+
self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'shift_text_to_vis.pth')), requires_grad=False)
|
301 |
+
# self.shift_vector = nn.Parameter(-1*torch.load(join(dirname(basename(__file__)), 'shift2.pth')), requires_grad=False)
|
302 |
+
else:
|
303 |
+
self.shift_vector = None
|
304 |
+
|
305 |
+
if trans_conv is None:
|
306 |
+
trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version]
|
307 |
+
else:
|
308 |
+
# explicitly define transposed conv kernel size
|
309 |
+
trans_conv_ks = (trans_conv, trans_conv)
|
310 |
+
|
311 |
+
if not complex_trans_conv:
|
312 |
+
self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
|
313 |
+
else:
|
314 |
+
assert trans_conv_ks[0] == trans_conv_ks[1]
|
315 |
+
|
316 |
+
tp_kernels = (trans_conv_ks[0] // 4, trans_conv_ks[0] // 4)
|
317 |
+
|
318 |
+
self.trans_conv = nn.Sequential(
|
319 |
+
nn.Conv2d(reduce_dim, reduce_dim, kernel_size=3, padding=1),
|
320 |
+
nn.ReLU(),
|
321 |
+
nn.ConvTranspose2d(reduce_dim, reduce_dim // 2, kernel_size=tp_kernels[0], stride=tp_kernels[0]),
|
322 |
+
nn.ReLU(),
|
323 |
+
nn.ConvTranspose2d(reduce_dim // 2, 1, kernel_size=tp_kernels[1], stride=tp_kernels[1]),
|
324 |
+
)
|
325 |
+
|
326 |
+
# self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
|
327 |
+
|
328 |
+
assert len(self.extract_layers) == depth
|
329 |
+
|
330 |
+
self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)])
|
331 |
+
self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))])
|
332 |
+
self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)])
|
333 |
+
|
334 |
+
# refinement and trans conv
|
335 |
+
|
336 |
+
if learn_trans_conv_only:
|
337 |
+
for p in self.parameters():
|
338 |
+
p.requires_grad_(False)
|
339 |
+
|
340 |
+
for p in self.trans_conv.parameters():
|
341 |
+
p.requires_grad_(True)
|
342 |
+
|
343 |
+
self.prompt_list = get_prompt_list(prompt)
|
344 |
+
|
345 |
+
|
346 |
+
def forward(self, inp_image, conditional=None, return_features=False, mask=None):
|
347 |
+
|
348 |
+
assert type(return_features) == bool
|
349 |
+
|
350 |
+
inp_image = inp_image.to(self.model.positional_embedding.device)
|
351 |
+
|
352 |
+
if mask is not None:
|
353 |
+
raise ValueError('mask not supported')
|
354 |
+
|
355 |
+
# x_inp = normalize(inp_image)
|
356 |
+
x_inp = inp_image
|
357 |
+
|
358 |
+
bs, dev = inp_image.shape[0], x_inp.device
|
359 |
+
|
360 |
+
cond = self.get_cond_vec(conditional, bs)
|
361 |
+
|
362 |
+
visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers))
|
363 |
+
|
364 |
+
activation1 = activations[0]
|
365 |
+
activations = activations[1:]
|
366 |
+
|
367 |
+
_activations = activations[::-1] if not self.rev_activations else activations
|
368 |
+
|
369 |
+
a = None
|
370 |
+
for i, (activation, block, reduce) in enumerate(zip(_activations, self.blocks, self.reduces)):
|
371 |
+
|
372 |
+
if a is not None:
|
373 |
+
a = reduce(activation) + a
|
374 |
+
else:
|
375 |
+
a = reduce(activation)
|
376 |
+
|
377 |
+
if i == self.cond_layer:
|
378 |
+
if self.reduce_cond is not None:
|
379 |
+
cond = self.reduce_cond(cond)
|
380 |
+
|
381 |
+
a = self.film_mul(cond) * a + self.film_add(cond)
|
382 |
+
|
383 |
+
a = block(a)
|
384 |
+
|
385 |
+
for block in self.extra_blocks:
|
386 |
+
a = a + block(a)
|
387 |
+
|
388 |
+
a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
|
389 |
+
|
390 |
+
size = int(math.sqrt(a.shape[2]))
|
391 |
+
|
392 |
+
a = a.view(bs, a.shape[1], size, size)
|
393 |
+
|
394 |
+
a = self.trans_conv(a)
|
395 |
+
|
396 |
+
if self.n_tokens is not None:
|
397 |
+
a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear', align_corners=True)
|
398 |
+
|
399 |
+
if self.upsample_proj is not None:
|
400 |
+
a = self.upsample_proj(a)
|
401 |
+
a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear')
|
402 |
+
|
403 |
+
if return_features:
|
404 |
+
return a, visual_q, cond, [activation1] + activations
|
405 |
+
else:
|
406 |
+
return a,
|
407 |
+
|
408 |
+
|
409 |
+
|
410 |
+
class CLIPDensePredTMasked(CLIPDensePredT):
|
411 |
+
|
412 |
+
def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4,
|
413 |
+
prompt='fixed', extra_blocks=0, reduce_cond=None, fix_shift=False, learn_trans_conv_only=False,
|
414 |
+
refine=None, limit_to_clip_only=False, upsample=False, add_calibration=False, n_tokens=None):
|
415 |
+
|
416 |
+
super().__init__(version=version, extract_layers=extract_layers, cond_layer=cond_layer, reduce_dim=reduce_dim,
|
417 |
+
n_heads=n_heads, prompt=prompt, extra_blocks=extra_blocks, reduce_cond=reduce_cond,
|
418 |
+
fix_shift=fix_shift, learn_trans_conv_only=learn_trans_conv_only,
|
419 |
+
limit_to_clip_only=limit_to_clip_only, upsample=upsample, add_calibration=add_calibration,
|
420 |
+
n_tokens=n_tokens)
|
421 |
+
|
422 |
+
def visual_forward_masked(self, img_s, seg_s):
|
423 |
+
return super().visual_forward(img_s, mask=('all', 'cls_token', seg_s))
|
424 |
+
|
425 |
+
def forward(self, img_q, cond_or_img_s, seg_s=None, return_features=False):
|
426 |
+
|
427 |
+
if seg_s is None:
|
428 |
+
cond = cond_or_img_s
|
429 |
+
else:
|
430 |
+
img_s = cond_or_img_s
|
431 |
+
|
432 |
+
with torch.no_grad():
|
433 |
+
cond, _, _ = self.visual_forward_masked(img_s, seg_s)
|
434 |
+
|
435 |
+
return super().forward(img_q, cond, return_features=return_features)
|
436 |
+
|
437 |
+
|
438 |
+
|
439 |
+
class CLIPDenseBaseline(CLIPDenseBase):
|
440 |
+
|
441 |
+
def __init__(self, version='ViT-B/32', cond_layer=0,
|
442 |
+
extract_layer=9, reduce_dim=128, reduce2_dim=None, prompt='fixed',
|
443 |
+
reduce_cond=None, limit_to_clip_only=False, n_tokens=None):
|
444 |
+
|
445 |
+
super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens)
|
446 |
+
device = 'cpu'
|
447 |
+
|
448 |
+
# self.cond_layer = cond_layer
|
449 |
+
self.extract_layer = extract_layer
|
450 |
+
self.limit_to_clip_only = limit_to_clip_only
|
451 |
+
self.shift_vector = None
|
452 |
+
|
453 |
+
self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version]
|
454 |
+
|
455 |
+
assert reduce2_dim is not None
|
456 |
+
|
457 |
+
self.reduce2 = nn.Sequential(
|
458 |
+
nn.Linear(reduce_dim, reduce2_dim),
|
459 |
+
nn.ReLU(),
|
460 |
+
nn.Linear(reduce2_dim, reduce_dim)
|
461 |
+
)
|
462 |
+
|
463 |
+
trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version]
|
464 |
+
self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
|
465 |
+
|
466 |
+
|
467 |
+
def forward(self, inp_image, conditional=None, return_features=False):
|
468 |
+
|
469 |
+
inp_image = inp_image.to(self.model.positional_embedding.device)
|
470 |
+
|
471 |
+
# x_inp = normalize(inp_image)
|
472 |
+
x_inp = inp_image
|
473 |
+
|
474 |
+
bs, dev = inp_image.shape[0], x_inp.device
|
475 |
+
|
476 |
+
cond = self.get_cond_vec(conditional, bs)
|
477 |
+
|
478 |
+
visual_q, activations, affinities = self.visual_forward(x_inp, extract_layers=[self.extract_layer])
|
479 |
+
|
480 |
+
a = activations[0]
|
481 |
+
a = self.reduce(a)
|
482 |
+
a = self.film_mul(cond) * a + self.film_add(cond)
|
483 |
+
|
484 |
+
if self.reduce2 is not None:
|
485 |
+
a = self.reduce2(a)
|
486 |
+
|
487 |
+
# the original model would execute a transformer block here
|
488 |
+
|
489 |
+
a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
|
490 |
+
|
491 |
+
size = int(math.sqrt(a.shape[2]))
|
492 |
+
|
493 |
+
a = a.view(bs, a.shape[1], size, size)
|
494 |
+
a = self.trans_conv(a)
|
495 |
+
|
496 |
+
if return_features:
|
497 |
+
return a, visual_q, cond, activations
|
498 |
+
else:
|
499 |
+
return a,
|
500 |
+
|
501 |
+
|
502 |
+
class CLIPSegMultiLabel(nn.Module):
|
503 |
+
|
504 |
+
def __init__(self, model) -> None:
|
505 |
+
super().__init__()
|
506 |
+
|
507 |
+
from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC
|
508 |
+
|
509 |
+
self.pascal_classes = VOC
|
510 |
+
|
511 |
+
from clip.clipseg import CLIPDensePredT
|
512 |
+
from general_utils import load_model
|
513 |
+
# self.clipseg = load_model('rd64-vit16-neg0.2-phrasecut', strict=False)
|
514 |
+
self.clipseg = load_model(model, strict=False)
|
515 |
+
|
516 |
+
self.clipseg.eval()
|
517 |
+
|
518 |
+
def forward(self, x):
|
519 |
+
|
520 |
+
bs = x.shape[0]
|
521 |
+
out = torch.ones(21, bs, 352, 352).to(x.device) * -10
|
522 |
+
|
523 |
+
for class_id, class_name in enumerate(self.pascal_classes):
|
524 |
+
|
525 |
+
fac = 3 if class_name == 'background' else 1
|
526 |
+
|
527 |
+
with torch.no_grad():
|
528 |
+
pred = torch.sigmoid(self.clipseg(x, class_name)[0][:,0]) * fac
|
529 |
+
|
530 |
+
out[class_id] += pred
|
531 |
+
|
532 |
+
|
533 |
+
out = out.permute(1, 0, 2, 3)
|
534 |
+
|
535 |
+
return out
|
536 |
+
|
537 |
+
# construct output tensor
|
538 |
+
|
clip/model.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Tuple, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
class Bottleneck(nn.Module):
|
11 |
+
expansion = 4
|
12 |
+
|
13 |
+
def __init__(self, inplanes, planes, stride=1):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
17 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
19 |
+
self.relu1 = nn.ReLU(inplace=True)
|
20 |
+
|
21 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
22 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
23 |
+
self.relu2 = nn.ReLU(inplace=True)
|
24 |
+
|
25 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
26 |
+
|
27 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
28 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
29 |
+
self.relu3 = nn.ReLU(inplace=True)
|
30 |
+
|
31 |
+
self.downsample = None
|
32 |
+
self.stride = stride
|
33 |
+
|
34 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
35 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
36 |
+
self.downsample = nn.Sequential(OrderedDict([
|
37 |
+
("-1", nn.AvgPool2d(stride)),
|
38 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
39 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
40 |
+
]))
|
41 |
+
|
42 |
+
def forward(self, x: torch.Tensor):
|
43 |
+
identity = x
|
44 |
+
|
45 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
46 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
47 |
+
out = self.avgpool(out)
|
48 |
+
out = self.bn3(self.conv3(out))
|
49 |
+
|
50 |
+
if self.downsample is not None:
|
51 |
+
identity = self.downsample(x)
|
52 |
+
|
53 |
+
out += identity
|
54 |
+
out = self.relu3(out)
|
55 |
+
return out
|
56 |
+
|
57 |
+
|
58 |
+
class AttentionPool2d(nn.Module):
|
59 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
60 |
+
super().__init__()
|
61 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
62 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
63 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
64 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
65 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
66 |
+
self.num_heads = num_heads
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
70 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
71 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
72 |
+
x, _ = F.multi_head_attention_forward(
|
73 |
+
query=x[:1], key=x, value=x,
|
74 |
+
embed_dim_to_check=x.shape[-1],
|
75 |
+
num_heads=self.num_heads,
|
76 |
+
q_proj_weight=self.q_proj.weight,
|
77 |
+
k_proj_weight=self.k_proj.weight,
|
78 |
+
v_proj_weight=self.v_proj.weight,
|
79 |
+
in_proj_weight=None,
|
80 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
81 |
+
bias_k=None,
|
82 |
+
bias_v=None,
|
83 |
+
add_zero_attn=False,
|
84 |
+
dropout_p=0,
|
85 |
+
out_proj_weight=self.c_proj.weight,
|
86 |
+
out_proj_bias=self.c_proj.bias,
|
87 |
+
use_separate_proj_weight=True,
|
88 |
+
training=self.training,
|
89 |
+
need_weights=False
|
90 |
+
)
|
91 |
+
return x.squeeze(0)
|
92 |
+
|
93 |
+
|
94 |
+
class ModifiedResNet(nn.Module):
|
95 |
+
"""
|
96 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
97 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
98 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
99 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
103 |
+
super().__init__()
|
104 |
+
self.output_dim = output_dim
|
105 |
+
self.input_resolution = input_resolution
|
106 |
+
|
107 |
+
# the 3-layer stem
|
108 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
109 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
110 |
+
self.relu1 = nn.ReLU(inplace=True)
|
111 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
112 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
113 |
+
self.relu2 = nn.ReLU(inplace=True)
|
114 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
115 |
+
self.bn3 = nn.BatchNorm2d(width)
|
116 |
+
self.relu3 = nn.ReLU(inplace=True)
|
117 |
+
self.avgpool = nn.AvgPool2d(2)
|
118 |
+
|
119 |
+
# residual layers
|
120 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
121 |
+
self.layer1 = self._make_layer(width, layers[0])
|
122 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
123 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
124 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
125 |
+
|
126 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
127 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
128 |
+
|
129 |
+
def _make_layer(self, planes, blocks, stride=1):
|
130 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
131 |
+
|
132 |
+
self._inplanes = planes * Bottleneck.expansion
|
133 |
+
for _ in range(1, blocks):
|
134 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
135 |
+
|
136 |
+
return nn.Sequential(*layers)
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
def stem(x):
|
140 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
141 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
142 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
143 |
+
x = self.avgpool(x)
|
144 |
+
return x
|
145 |
+
|
146 |
+
x = x.type(self.conv1.weight.dtype)
|
147 |
+
x = stem(x)
|
148 |
+
x = self.layer1(x)
|
149 |
+
x = self.layer2(x)
|
150 |
+
x = self.layer3(x)
|
151 |
+
x = self.layer4(x)
|
152 |
+
x = self.attnpool(x)
|
153 |
+
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
class LayerNorm(nn.LayerNorm):
|
158 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
159 |
+
|
160 |
+
def forward(self, x: torch.Tensor):
|
161 |
+
orig_type = x.dtype
|
162 |
+
ret = super().forward(x.type(torch.float32))
|
163 |
+
return ret.type(orig_type)
|
164 |
+
|
165 |
+
|
166 |
+
class QuickGELU(nn.Module):
|
167 |
+
def forward(self, x: torch.Tensor):
|
168 |
+
return x * torch.sigmoid(1.702 * x)
|
169 |
+
|
170 |
+
|
171 |
+
class ResidualAttentionBlock(nn.Module):
|
172 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
173 |
+
super().__init__()
|
174 |
+
|
175 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
176 |
+
self.ln_1 = LayerNorm(d_model)
|
177 |
+
self.mlp = nn.Sequential(OrderedDict([
|
178 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
179 |
+
("gelu", QuickGELU()),
|
180 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
181 |
+
]))
|
182 |
+
self.ln_2 = LayerNorm(d_model)
|
183 |
+
self.attn_mask = attn_mask
|
184 |
+
|
185 |
+
def attention(self, x: torch.Tensor):
|
186 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
187 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
188 |
+
|
189 |
+
def forward(self, x: torch.Tensor):
|
190 |
+
x = x + self.attention(self.ln_1(x))
|
191 |
+
x = x + self.mlp(self.ln_2(x))
|
192 |
+
return x
|
193 |
+
|
194 |
+
|
195 |
+
class Transformer(nn.Module):
|
196 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
197 |
+
super().__init__()
|
198 |
+
self.width = width
|
199 |
+
self.layers = layers
|
200 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
201 |
+
|
202 |
+
def forward(self, x: torch.Tensor):
|
203 |
+
return self.resblocks(x)
|
204 |
+
|
205 |
+
|
206 |
+
class VisionTransformer(nn.Module):
|
207 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
208 |
+
super().__init__()
|
209 |
+
self.input_resolution = input_resolution
|
210 |
+
self.output_dim = output_dim
|
211 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
212 |
+
|
213 |
+
scale = width ** -0.5
|
214 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
215 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
216 |
+
self.ln_pre = LayerNorm(width)
|
217 |
+
|
218 |
+
self.transformer = Transformer(width, layers, heads)
|
219 |
+
|
220 |
+
self.ln_post = LayerNorm(width)
|
221 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
222 |
+
|
223 |
+
def forward(self, x: torch.Tensor):
|
224 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
225 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
226 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
227 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
228 |
+
x = x + self.positional_embedding.to(x.dtype)
|
229 |
+
x = self.ln_pre(x)
|
230 |
+
|
231 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
232 |
+
x = self.transformer(x)
|
233 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
234 |
+
|
235 |
+
x = self.ln_post(x[:, 0, :])
|
236 |
+
|
237 |
+
if self.proj is not None:
|
238 |
+
x = x @ self.proj
|
239 |
+
|
240 |
+
return x
|
241 |
+
|
242 |
+
|
243 |
+
class CLIP(nn.Module):
|
244 |
+
def __init__(self,
|
245 |
+
embed_dim: int,
|
246 |
+
# vision
|
247 |
+
image_resolution: int,
|
248 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
249 |
+
vision_width: int,
|
250 |
+
vision_patch_size: int,
|
251 |
+
# text
|
252 |
+
context_length: int,
|
253 |
+
vocab_size: int,
|
254 |
+
transformer_width: int,
|
255 |
+
transformer_heads: int,
|
256 |
+
transformer_layers: int
|
257 |
+
):
|
258 |
+
super().__init__()
|
259 |
+
|
260 |
+
self.context_length = context_length
|
261 |
+
|
262 |
+
if isinstance(vision_layers, (tuple, list)):
|
263 |
+
vision_heads = vision_width * 32 // 64
|
264 |
+
self.visual = ModifiedResNet(
|
265 |
+
layers=vision_layers,
|
266 |
+
output_dim=embed_dim,
|
267 |
+
heads=vision_heads,
|
268 |
+
input_resolution=image_resolution,
|
269 |
+
width=vision_width
|
270 |
+
)
|
271 |
+
else:
|
272 |
+
vision_heads = vision_width // 64
|
273 |
+
self.visual = VisionTransformer(
|
274 |
+
input_resolution=image_resolution,
|
275 |
+
patch_size=vision_patch_size,
|
276 |
+
width=vision_width,
|
277 |
+
layers=vision_layers,
|
278 |
+
heads=vision_heads,
|
279 |
+
output_dim=embed_dim
|
280 |
+
)
|
281 |
+
|
282 |
+
self.transformer = Transformer(
|
283 |
+
width=transformer_width,
|
284 |
+
layers=transformer_layers,
|
285 |
+
heads=transformer_heads,
|
286 |
+
attn_mask=self.build_attention_mask()
|
287 |
+
)
|
288 |
+
|
289 |
+
self.vocab_size = vocab_size
|
290 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
291 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
292 |
+
self.ln_final = LayerNorm(transformer_width)
|
293 |
+
|
294 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
295 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
296 |
+
|
297 |
+
self.initialize_parameters()
|
298 |
+
|
299 |
+
def initialize_parameters(self):
|
300 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
301 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
302 |
+
|
303 |
+
if isinstance(self.visual, ModifiedResNet):
|
304 |
+
if self.visual.attnpool is not None:
|
305 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
306 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
307 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
308 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
309 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
310 |
+
|
311 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
312 |
+
for name, param in resnet_block.named_parameters():
|
313 |
+
if name.endswith("bn3.weight"):
|
314 |
+
nn.init.zeros_(param)
|
315 |
+
|
316 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
317 |
+
attn_std = self.transformer.width ** -0.5
|
318 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
319 |
+
for block in self.transformer.resblocks:
|
320 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
321 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
322 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
323 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
324 |
+
|
325 |
+
if self.text_projection is not None:
|
326 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
327 |
+
|
328 |
+
def build_attention_mask(self):
|
329 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
330 |
+
# pytorch uses additive attention mask; fill with -inf
|
331 |
+
mask = torch.empty(self.context_length, self.context_length)
|
332 |
+
mask.fill_(float("-inf"))
|
333 |
+
mask.triu_(1) # zero out the lower diagonal
|
334 |
+
return mask
|
335 |
+
|
336 |
+
@property
|
337 |
+
def dtype(self):
|
338 |
+
return self.visual.conv1.weight.dtype
|
339 |
+
|
340 |
+
def encode_image(self, image):
|
341 |
+
return self.visual(image.type(self.dtype))
|
342 |
+
|
343 |
+
def encode_text(self, text):
|
344 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
345 |
+
|
346 |
+
x = x + self.positional_embedding.type(self.dtype)
|
347 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
348 |
+
x = self.transformer(x)
|
349 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
350 |
+
x = self.ln_final(x).type(self.dtype)
|
351 |
+
|
352 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
353 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
354 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
355 |
+
|
356 |
+
return x
|
357 |
+
|
358 |
+
def forward(self, image, text):
|
359 |
+
image_features = self.encode_image(image)
|
360 |
+
text_features = self.encode_text(text)
|
361 |
+
|
362 |
+
# normalized features
|
363 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
364 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
365 |
+
|
366 |
+
# cosine similarity as logits
|
367 |
+
logit_scale = self.logit_scale.exp()
|
368 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
369 |
+
logits_per_text = logits_per_image.t()
|
370 |
+
|
371 |
+
# shape = [global_batch_size, global_batch_size]
|
372 |
+
return logits_per_image, logits_per_text
|
373 |
+
|
374 |
+
|
375 |
+
def convert_weights(model: nn.Module):
|
376 |
+
"""Convert applicable model parameters to fp16"""
|
377 |
+
|
378 |
+
def _convert_weights_to_fp16(l):
|
379 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
380 |
+
l.weight.data = l.weight.data.half()
|
381 |
+
if l.bias is not None:
|
382 |
+
l.bias.data = l.bias.data.half()
|
383 |
+
|
384 |
+
if isinstance(l, nn.MultiheadAttention):
|
385 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
386 |
+
tensor = getattr(l, attr)
|
387 |
+
if tensor is not None:
|
388 |
+
tensor.data = tensor.data.half()
|
389 |
+
|
390 |
+
for name in ["text_projection", "proj"]:
|
391 |
+
if hasattr(l, name):
|
392 |
+
attr = getattr(l, name)
|
393 |
+
if attr is not None:
|
394 |
+
attr.data = attr.data.half()
|
395 |
+
|
396 |
+
model.apply(_convert_weights_to_fp16)
|
397 |
+
|
398 |
+
|
399 |
+
def build_model(state_dict: dict):
|
400 |
+
vit = "visual.proj" in state_dict
|
401 |
+
|
402 |
+
if vit:
|
403 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
404 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
405 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
406 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
407 |
+
image_resolution = vision_patch_size * grid_size
|
408 |
+
else:
|
409 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
410 |
+
vision_layers = tuple(counts)
|
411 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
412 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
413 |
+
vision_patch_size = None
|
414 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
415 |
+
image_resolution = output_width * 32
|
416 |
+
|
417 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
418 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
419 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
420 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
421 |
+
transformer_heads = transformer_width // 64
|
422 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
|
423 |
+
|
424 |
+
model = CLIP(
|
425 |
+
embed_dim,
|
426 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
427 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
428 |
+
)
|
429 |
+
|
430 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
431 |
+
if key in state_dict:
|
432 |
+
del state_dict[key]
|
433 |
+
|
434 |
+
convert_weights(model)
|
435 |
+
model.load_state_dict(state_dict)
|
436 |
+
return model.eval()
|
clip/simple_tokenizer.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import html
|
3 |
+
import os
|
4 |
+
from functools import lru_cache
|
5 |
+
|
6 |
+
import ftfy
|
7 |
+
import regex as re
|
8 |
+
|
9 |
+
|
10 |
+
@lru_cache()
|
11 |
+
def default_bpe():
|
12 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
13 |
+
|
14 |
+
|
15 |
+
@lru_cache()
|
16 |
+
def bytes_to_unicode():
|
17 |
+
"""
|
18 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
19 |
+
The reversible bpe codes work on unicode strings.
|
20 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
21 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
22 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
23 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
24 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
25 |
+
"""
|
26 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
27 |
+
cs = bs[:]
|
28 |
+
n = 0
|
29 |
+
for b in range(2**8):
|
30 |
+
if b not in bs:
|
31 |
+
bs.append(b)
|
32 |
+
cs.append(2**8+n)
|
33 |
+
n += 1
|
34 |
+
cs = [chr(n) for n in cs]
|
35 |
+
return dict(zip(bs, cs))
|
36 |
+
|
37 |
+
|
38 |
+
def get_pairs(word):
|
39 |
+
"""Return set of symbol pairs in a word.
|
40 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
41 |
+
"""
|
42 |
+
pairs = set()
|
43 |
+
prev_char = word[0]
|
44 |
+
for char in word[1:]:
|
45 |
+
pairs.add((prev_char, char))
|
46 |
+
prev_char = char
|
47 |
+
return pairs
|
48 |
+
|
49 |
+
|
50 |
+
def basic_clean(text):
|
51 |
+
text = ftfy.fix_text(text)
|
52 |
+
text = html.unescape(html.unescape(text))
|
53 |
+
return text.strip()
|
54 |
+
|
55 |
+
|
56 |
+
def whitespace_clean(text):
|
57 |
+
text = re.sub(r'\s+', ' ', text)
|
58 |
+
text = text.strip()
|
59 |
+
return text
|
60 |
+
|
61 |
+
|
62 |
+
class SimpleTokenizer(object):
|
63 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
64 |
+
self.byte_encoder = bytes_to_unicode()
|
65 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
66 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
67 |
+
merges = merges[1:49152-256-2+1]
|
68 |
+
merges = [tuple(merge.split()) for merge in merges]
|
69 |
+
vocab = list(bytes_to_unicode().values())
|
70 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
71 |
+
for merge in merges:
|
72 |
+
vocab.append(''.join(merge))
|
73 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
74 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
75 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
76 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
77 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
78 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
79 |
+
|
80 |
+
def bpe(self, token):
|
81 |
+
if token in self.cache:
|
82 |
+
return self.cache[token]
|
83 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
84 |
+
pairs = get_pairs(word)
|
85 |
+
|
86 |
+
if not pairs:
|
87 |
+
return token+'</w>'
|
88 |
+
|
89 |
+
while True:
|
90 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
91 |
+
if bigram not in self.bpe_ranks:
|
92 |
+
break
|
93 |
+
first, second = bigram
|
94 |
+
new_word = []
|
95 |
+
i = 0
|
96 |
+
while i < len(word):
|
97 |
+
try:
|
98 |
+
j = word.index(first, i)
|
99 |
+
new_word.extend(word[i:j])
|
100 |
+
i = j
|
101 |
+
except:
|
102 |
+
new_word.extend(word[i:])
|
103 |
+
break
|
104 |
+
|
105 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
106 |
+
new_word.append(first+second)
|
107 |
+
i += 2
|
108 |
+
else:
|
109 |
+
new_word.append(word[i])
|
110 |
+
i += 1
|
111 |
+
new_word = tuple(new_word)
|
112 |
+
word = new_word
|
113 |
+
if len(word) == 1:
|
114 |
+
break
|
115 |
+
else:
|
116 |
+
pairs = get_pairs(word)
|
117 |
+
word = ' '.join(word)
|
118 |
+
self.cache[token] = word
|
119 |
+
return word
|
120 |
+
|
121 |
+
def encode(self, text):
|
122 |
+
bpe_tokens = []
|
123 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
124 |
+
for token in re.findall(self.pat, text):
|
125 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
126 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
127 |
+
return bpe_tokens
|
128 |
+
|
129 |
+
def decode(self, tokens):
|
130 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
131 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
132 |
+
return text
|
clip/vitseg.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from posixpath import basename, dirname, join
|
3 |
+
# import clip
|
4 |
+
from clip.model import convert_weights
|
5 |
+
import torch
|
6 |
+
import json
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as nnf
|
9 |
+
from torch.nn.modules import activation
|
10 |
+
from torch.nn.modules.activation import ReLU
|
11 |
+
from torchvision import transforms
|
12 |
+
|
13 |
+
normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
|
14 |
+
|
15 |
+
from torchvision.models import ResNet
|
16 |
+
|
17 |
+
|
18 |
+
def process_prompts(conditional, prompt_list, conditional_map):
|
19 |
+
# DEPRECATED
|
20 |
+
|
21 |
+
# randomly sample a synonym
|
22 |
+
words = [conditional_map[int(i)] for i in conditional]
|
23 |
+
words = [syns[torch.multinomial(torch.ones(len(syns)), 1, replacement=True).item()] for syns in words]
|
24 |
+
words = [w.replace('_', ' ') for w in words]
|
25 |
+
|
26 |
+
if prompt_list is not None:
|
27 |
+
prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
|
28 |
+
prompts = [prompt_list[i] for i in prompt_indices]
|
29 |
+
else:
|
30 |
+
prompts = ['a photo of {}'] * (len(words))
|
31 |
+
|
32 |
+
return [promt.format(w) for promt, w in zip(prompts, words)]
|
33 |
+
|
34 |
+
|
35 |
+
class VITDenseBase(nn.Module):
|
36 |
+
|
37 |
+
def rescaled_pos_emb(self, new_size):
|
38 |
+
assert len(new_size) == 2
|
39 |
+
|
40 |
+
a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape)
|
41 |
+
b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T
|
42 |
+
return torch.cat([self.model.positional_embedding[:1], b])
|
43 |
+
|
44 |
+
def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None):
|
45 |
+
|
46 |
+
with torch.no_grad():
|
47 |
+
|
48 |
+
x_inp = nnf.interpolate(x_inp, (384, 384))
|
49 |
+
|
50 |
+
x = self.model.patch_embed(x_inp)
|
51 |
+
cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
52 |
+
if self.model.dist_token is None:
|
53 |
+
x = torch.cat((cls_token, x), dim=1)
|
54 |
+
else:
|
55 |
+
x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
56 |
+
x = self.model.pos_drop(x + self.model.pos_embed)
|
57 |
+
|
58 |
+
activations = []
|
59 |
+
for i, block in enumerate(self.model.blocks):
|
60 |
+
x = block(x)
|
61 |
+
|
62 |
+
if i in extract_layers:
|
63 |
+
# permute to be compatible with CLIP
|
64 |
+
activations += [x.permute(1,0,2)]
|
65 |
+
|
66 |
+
x = self.model.norm(x)
|
67 |
+
x = self.model.head(self.model.pre_logits(x[:, 0]))
|
68 |
+
|
69 |
+
# again for CLIP compatibility
|
70 |
+
# x = x.permute(1, 0, 2)
|
71 |
+
|
72 |
+
return x, activations, None
|
73 |
+
|
74 |
+
def sample_prompts(self, words, prompt_list=None):
|
75 |
+
|
76 |
+
prompt_list = prompt_list if prompt_list is not None else self.prompt_list
|
77 |
+
|
78 |
+
prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
|
79 |
+
prompts = [prompt_list[i] for i in prompt_indices]
|
80 |
+
return [promt.format(w) for promt, w in zip(prompts, words)]
|
81 |
+
|
82 |
+
def get_cond_vec(self, conditional, batch_size):
|
83 |
+
# compute conditional from a single string
|
84 |
+
if conditional is not None and type(conditional) == str:
|
85 |
+
cond = self.compute_conditional(conditional)
|
86 |
+
cond = cond.repeat(batch_size, 1)
|
87 |
+
|
88 |
+
# compute conditional from string list/tuple
|
89 |
+
elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str:
|
90 |
+
assert len(conditional) == batch_size
|
91 |
+
cond = self.compute_conditional(conditional)
|
92 |
+
|
93 |
+
# use conditional directly
|
94 |
+
elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2:
|
95 |
+
cond = conditional
|
96 |
+
|
97 |
+
# compute conditional from image
|
98 |
+
elif conditional is not None and type(conditional) == torch.Tensor:
|
99 |
+
with torch.no_grad():
|
100 |
+
cond, _, _ = self.visual_forward(conditional)
|
101 |
+
else:
|
102 |
+
raise ValueError('invalid conditional')
|
103 |
+
return cond
|
104 |
+
|
105 |
+
def compute_conditional(self, conditional):
|
106 |
+
import clip
|
107 |
+
|
108 |
+
dev = next(self.parameters()).device
|
109 |
+
|
110 |
+
if type(conditional) in {list, tuple}:
|
111 |
+
text_tokens = clip.tokenize(conditional).to(dev)
|
112 |
+
cond = self.clip_model.encode_text(text_tokens)
|
113 |
+
else:
|
114 |
+
if conditional in self.precomputed_prompts:
|
115 |
+
cond = self.precomputed_prompts[conditional].float().to(dev)
|
116 |
+
else:
|
117 |
+
text_tokens = clip.tokenize([conditional]).to(dev)
|
118 |
+
cond = self.clip_model.encode_text(text_tokens)[0]
|
119 |
+
|
120 |
+
return cond
|
121 |
+
|
122 |
+
|
123 |
+
class VITDensePredT(VITDenseBase):
|
124 |
+
|
125 |
+
def __init__(self, extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed',
|
126 |
+
depth=3, extra_blocks=0, reduce_cond=None, fix_shift=False,
|
127 |
+
learn_trans_conv_only=False, refine=None, limit_to_clip_only=False, upsample=False,
|
128 |
+
add_calibration=False, process_cond=None, not_pretrained=False):
|
129 |
+
super().__init__()
|
130 |
+
# device = 'cpu'
|
131 |
+
|
132 |
+
self.extract_layers = extract_layers
|
133 |
+
self.cond_layer = cond_layer
|
134 |
+
self.limit_to_clip_only = limit_to_clip_only
|
135 |
+
self.process_cond = None
|
136 |
+
|
137 |
+
if add_calibration:
|
138 |
+
self.calibration_conds = 1
|
139 |
+
|
140 |
+
self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None
|
141 |
+
|
142 |
+
self.add_activation1 = True
|
143 |
+
|
144 |
+
import timm
|
145 |
+
self.model = timm.create_model('vit_base_patch16_384', pretrained=True)
|
146 |
+
self.model.head = nn.Linear(768, 512 if reduce_cond is None else reduce_cond)
|
147 |
+
|
148 |
+
for p in self.model.parameters():
|
149 |
+
p.requires_grad_(False)
|
150 |
+
|
151 |
+
import clip
|
152 |
+
self.clip_model, _ = clip.load('ViT-B/16', device='cpu', jit=False)
|
153 |
+
# del self.clip_model.visual
|
154 |
+
|
155 |
+
|
156 |
+
self.token_shape = (14, 14)
|
157 |
+
|
158 |
+
# conditional
|
159 |
+
if reduce_cond is not None:
|
160 |
+
self.reduce_cond = nn.Linear(512, reduce_cond)
|
161 |
+
for p in self.reduce_cond.parameters():
|
162 |
+
p.requires_grad_(False)
|
163 |
+
else:
|
164 |
+
self.reduce_cond = None
|
165 |
+
|
166 |
+
# self.film = AVAILABLE_BLOCKS['film'](512, 128)
|
167 |
+
self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
|
168 |
+
self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
|
169 |
+
|
170 |
+
# DEPRECATED
|
171 |
+
# self.conditional_map = {c['id']: c['synonyms'] for c in json.load(open(cond_map))}
|
172 |
+
|
173 |
+
assert len(self.extract_layers) == depth
|
174 |
+
|
175 |
+
self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)])
|
176 |
+
self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))])
|
177 |
+
self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)])
|
178 |
+
|
179 |
+
trans_conv_ks = (16, 16)
|
180 |
+
self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
|
181 |
+
|
182 |
+
# refinement and trans conv
|
183 |
+
|
184 |
+
if learn_trans_conv_only:
|
185 |
+
for p in self.parameters():
|
186 |
+
p.requires_grad_(False)
|
187 |
+
|
188 |
+
for p in self.trans_conv.parameters():
|
189 |
+
p.requires_grad_(True)
|
190 |
+
|
191 |
+
if prompt == 'fixed':
|
192 |
+
self.prompt_list = ['a photo of a {}.']
|
193 |
+
elif prompt == 'shuffle':
|
194 |
+
self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
|
195 |
+
elif prompt == 'shuffle+':
|
196 |
+
self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.',
|
197 |
+
'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.',
|
198 |
+
'a bad photo of a {}.', 'a photo of the {}.']
|
199 |
+
elif prompt == 'shuffle_clip':
|
200 |
+
from models.clip_prompts import imagenet_templates
|
201 |
+
self.prompt_list = imagenet_templates
|
202 |
+
|
203 |
+
if process_cond is not None:
|
204 |
+
if process_cond == 'clamp' or process_cond[0] == 'clamp':
|
205 |
+
|
206 |
+
val = process_cond[1] if type(process_cond) in {list, tuple} else 0.2
|
207 |
+
|
208 |
+
def clamp_vec(x):
|
209 |
+
return torch.clamp(x, -val, val)
|
210 |
+
|
211 |
+
self.process_cond = clamp_vec
|
212 |
+
|
213 |
+
elif process_cond.endswith('.pth'):
|
214 |
+
|
215 |
+
shift = torch.load(process_cond)
|
216 |
+
def add_shift(x):
|
217 |
+
return x + shift.to(x.device)
|
218 |
+
|
219 |
+
self.process_cond = add_shift
|
220 |
+
|
221 |
+
import pickle
|
222 |
+
precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb'))
|
223 |
+
self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()}
|
224 |
+
|
225 |
+
|
226 |
+
def forward(self, inp_image, conditional=None, return_features=False, mask=None):
|
227 |
+
|
228 |
+
assert type(return_features) == bool
|
229 |
+
|
230 |
+
# inp_image = inp_image.to(self.model.positional_embedding.device)
|
231 |
+
|
232 |
+
if mask is not None:
|
233 |
+
raise ValueError('mask not supported')
|
234 |
+
|
235 |
+
# x_inp = normalize(inp_image)
|
236 |
+
x_inp = inp_image
|
237 |
+
|
238 |
+
bs, dev = inp_image.shape[0], x_inp.device
|
239 |
+
|
240 |
+
inp_image_size = inp_image.shape[2:]
|
241 |
+
|
242 |
+
cond = self.get_cond_vec(conditional, bs)
|
243 |
+
|
244 |
+
visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers))
|
245 |
+
|
246 |
+
activation1 = activations[0]
|
247 |
+
activations = activations[1:]
|
248 |
+
|
249 |
+
a = None
|
250 |
+
for i, (activation, block, reduce) in enumerate(zip(activations[::-1], self.blocks, self.reduces)):
|
251 |
+
|
252 |
+
if a is not None:
|
253 |
+
a = reduce(activation) + a
|
254 |
+
else:
|
255 |
+
a = reduce(activation)
|
256 |
+
|
257 |
+
if i == self.cond_layer:
|
258 |
+
if self.reduce_cond is not None:
|
259 |
+
cond = self.reduce_cond(cond)
|
260 |
+
|
261 |
+
a = self.film_mul(cond) * a + self.film_add(cond)
|
262 |
+
|
263 |
+
a = block(a)
|
264 |
+
|
265 |
+
for block in self.extra_blocks:
|
266 |
+
a = a + block(a)
|
267 |
+
|
268 |
+
a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
|
269 |
+
|
270 |
+
size = int(math.sqrt(a.shape[2]))
|
271 |
+
|
272 |
+
a = a.view(bs, a.shape[1], size, size)
|
273 |
+
|
274 |
+
if self.trans_conv is not None:
|
275 |
+
a = self.trans_conv(a)
|
276 |
+
|
277 |
+
if self.upsample_proj is not None:
|
278 |
+
a = self.upsample_proj(a)
|
279 |
+
a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear')
|
280 |
+
|
281 |
+
a = nnf.interpolate(a, inp_image_size)
|
282 |
+
|
283 |
+
if return_features:
|
284 |
+
return a, visual_q, cond, [activation1] + activations
|
285 |
+
else:
|
286 |
+
return a,
|
config_colab.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
clear_output: true
|
2 |
+
force_cpu: false
|
3 |
+
max_threads: 3
|
4 |
+
memory_limit: 0
|
5 |
+
output_image_format: png
|
6 |
+
output_template: '{file}_{time}'
|
7 |
+
output_video_codec: libx264
|
8 |
+
output_video_format: mp4
|
9 |
+
provider: cuda
|
10 |
+
selected_theme: Default
|
11 |
+
server_name: ''
|
12 |
+
server_port: 0
|
13 |
+
server_share: true
|
14 |
+
video_quality: 14
|
docs/screenshot.png
ADDED
Git LFS Details
|
installer/installer.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
import site
|
6 |
+
import subprocess
|
7 |
+
import sys
|
8 |
+
|
9 |
+
|
10 |
+
script_dir = os.getcwd()
|
11 |
+
|
12 |
+
|
13 |
+
def run_cmd(cmd, capture_output=False, env=None):
|
14 |
+
# Run shell commands
|
15 |
+
return subprocess.run(cmd, shell=True, capture_output=capture_output, env=env)
|
16 |
+
|
17 |
+
|
18 |
+
def check_env():
|
19 |
+
# If we have access to conda, we are probably in an environment
|
20 |
+
conda_not_exist = run_cmd("conda", capture_output=True).returncode
|
21 |
+
if conda_not_exist:
|
22 |
+
print("Conda is not installed. Exiting...")
|
23 |
+
sys.exit()
|
24 |
+
|
25 |
+
# Ensure this is a new environment and not the base environment
|
26 |
+
if os.environ["CONDA_DEFAULT_ENV"] == "base":
|
27 |
+
print("Create an environment for this project and activate it. Exiting...")
|
28 |
+
sys.exit()
|
29 |
+
|
30 |
+
|
31 |
+
def install_dependencies():
|
32 |
+
global MY_PATH
|
33 |
+
|
34 |
+
# Install Git and clone repo
|
35 |
+
run_cmd("conda install -y -k git")
|
36 |
+
run_cmd("git clone https://github.com/C0untFloyd/roop-unleashed.git")
|
37 |
+
os.chdir(MY_PATH)
|
38 |
+
run_cmd("git checkout ebf163acdb66de17abf408a86a72d00ddf49480c")
|
39 |
+
# Installs dependencies from requirements.txt
|
40 |
+
run_cmd("python -m pip install -r requirements.txt")
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
def update_dependencies():
|
45 |
+
global MY_PATH
|
46 |
+
|
47 |
+
os.chdir(MY_PATH)
|
48 |
+
# do a hard reset for to update even if there are local changes
|
49 |
+
run_cmd("git fetch --all")
|
50 |
+
run_cmd("git reset --hard origin/main")
|
51 |
+
run_cmd("git pull")
|
52 |
+
# Installs/Updates dependencies from all requirements.txt
|
53 |
+
run_cmd("python -m pip install -r requirements.txt")
|
54 |
+
|
55 |
+
|
56 |
+
def start_app():
|
57 |
+
global MY_PATH
|
58 |
+
|
59 |
+
os.chdir(MY_PATH)
|
60 |
+
# forward commandline arguments
|
61 |
+
sys.argv.pop(0)
|
62 |
+
args = ' '.join(sys.argv)
|
63 |
+
print("Launching App")
|
64 |
+
run_cmd(f'python run.py {args}')
|
65 |
+
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
global MY_PATH
|
69 |
+
|
70 |
+
MY_PATH = "roop-unleashed"
|
71 |
+
|
72 |
+
|
73 |
+
# Verifies we are in a conda environment
|
74 |
+
check_env()
|
75 |
+
|
76 |
+
# If webui has already been installed, skip and run
|
77 |
+
if not os.path.exists(MY_PATH):
|
78 |
+
install_dependencies()
|
79 |
+
else:
|
80 |
+
# moved update from batch to here, because of batch limitations
|
81 |
+
updatechoice = input("Check for Updates? [y/n]").lower()
|
82 |
+
if updatechoice == "y":
|
83 |
+
update_dependencies()
|
84 |
+
|
85 |
+
# Run the model with webui
|
86 |
+
os.chdir(script_dir)
|
87 |
+
start_app()
|
installer/windows_run.bat
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@echo off
|
2 |
+
|
3 |
+
REM No CLI arguments supported anymore
|
4 |
+
set COMMANDLINE_ARGS=
|
5 |
+
|
6 |
+
cd /D "%~dp0"
|
7 |
+
|
8 |
+
echo "%CD%"| findstr /C:" " >nul && echo This script relies on Miniconda which can not be silently installed under a path with spaces. && goto end
|
9 |
+
|
10 |
+
set PATH=%PATH%;%SystemRoot%\system32
|
11 |
+
|
12 |
+
@rem config
|
13 |
+
set INSTALL_DIR=%cd%\installer_files
|
14 |
+
set CONDA_ROOT_PREFIX=%cd%\installer_files\conda
|
15 |
+
set INSTALL_ENV_DIR=%cd%\installer_files\env
|
16 |
+
set MINICONDA_DOWNLOAD_URL=https://repo.anaconda.com/miniconda/Miniconda3-latest-Windows-x86_64.exe
|
17 |
+
set FFMPEG_DOWNLOAD_URL=https://github.com/GyanD/codexffmpeg/releases/download/2023-06-21-git-1bcb8a7338/ffmpeg-2023-06-21-git-1bcb8a7338-essentials_build.zip
|
18 |
+
set INSTALL_FFMPEG_DIR=%cd%\installer_files\ffmpeg
|
19 |
+
set INSIGHTFACE_PACKAGE_URL=https://github.com/C0untFloyd/roop-unleashed/releases/download/3.6.6/insightface-0.7.3-cp310-cp310-win_amd64.whl
|
20 |
+
set INSIGHTFACE_PACKAGE_PATH=%INSTALL_DIR%\insightface-0.7.3-cp310-cp310-win_amd64.whl
|
21 |
+
|
22 |
+
set conda_exists=F
|
23 |
+
set ffmpeg_exists=F
|
24 |
+
|
25 |
+
@rem figure out whether git and conda needs to be installed
|
26 |
+
call "%CONDA_ROOT_PREFIX%\_conda.exe" --version >nul 2>&1
|
27 |
+
if "%ERRORLEVEL%" EQU "0" set conda_exists=T
|
28 |
+
|
29 |
+
@rem Check if FFmpeg is already in PATH
|
30 |
+
where ffmpeg >nul 2>&1
|
31 |
+
if "%ERRORLEVEL%" EQU "0" (
|
32 |
+
echo FFmpeg is already installed.
|
33 |
+
set ffmpeg_exists=T
|
34 |
+
)
|
35 |
+
|
36 |
+
@rem (if necessary) install git and conda into a contained environment
|
37 |
+
|
38 |
+
@rem download conda
|
39 |
+
if "%conda_exists%" == "F" (
|
40 |
+
echo Downloading Miniconda from %MINICONDA_DOWNLOAD_URL% to %INSTALL_DIR%\miniconda_installer.exe
|
41 |
+
mkdir "%INSTALL_DIR%"
|
42 |
+
call curl -Lk "%MINICONDA_DOWNLOAD_URL%" > "%INSTALL_DIR%\miniconda_installer.exe" || ( echo. && echo Miniconda failed to download. && goto end )
|
43 |
+
echo Installing Miniconda to %CONDA_ROOT_PREFIX%
|
44 |
+
start /wait "" "%INSTALL_DIR%\miniconda_installer.exe" /InstallationType=JustMe /NoShortcuts=1 /AddToPath=0 /RegisterPython=0 /NoRegistry=1 /S /D=%CONDA_ROOT_PREFIX%
|
45 |
+
|
46 |
+
@rem test the conda binary
|
47 |
+
echo Miniconda version:
|
48 |
+
call "%CONDA_ROOT_PREFIX%\_conda.exe" --version || ( echo. && echo Miniconda not found. && goto end )
|
49 |
+
)
|
50 |
+
|
51 |
+
@rem create the installer env
|
52 |
+
if not exist "%INSTALL_ENV_DIR%" (
|
53 |
+
echo Creating Conda Environment
|
54 |
+
call "%CONDA_ROOT_PREFIX%\_conda.exe" create --no-shortcuts -y -k --prefix "%INSTALL_ENV_DIR%" python=3.10 || ( echo. && echo ERROR: Conda environment creation failed. && goto end )
|
55 |
+
@rem check if conda environment was actually created
|
56 |
+
if not exist "%INSTALL_ENV_DIR%\python.exe" ( echo. && echo ERROR: Conda environment is empty. && goto end )
|
57 |
+
@rem activate installer env
|
58 |
+
call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%" || ( echo. && echo ERROR: Miniconda hook not found. && goto end )
|
59 |
+
@rem Download insightface package
|
60 |
+
echo Downloading insightface package from %INSIGHTFACE_PACKAGE_URL% to %INSIGHTFACE_PACKAGE_PATH%
|
61 |
+
call curl -Lk "%INSIGHTFACE_PACKAGE_URL%" > "%INSIGHTFACE_PACKAGE_PATH%" || ( echo. && echo ERROR: Insightface package failed to download. && goto end )
|
62 |
+
@rem install insightface package using pip
|
63 |
+
echo Installing insightface package
|
64 |
+
call pip install "%INSIGHTFACE_PACKAGE_PATH%" || ( echo. && echo ERROR: Insightface package installation failed. && goto end )
|
65 |
+
)
|
66 |
+
|
67 |
+
@rem Download and install FFmpeg if not already installed
|
68 |
+
if "%ffmpeg_exists%" == "F" (
|
69 |
+
if not exist "%INSTALL_FFMPEG_DIR%" (
|
70 |
+
echo Downloading ffmpeg from %FFMPEG_DOWNLOAD_URL% to %INSTALL_DIR%
|
71 |
+
call curl -Lk "%FFMPEG_DOWNLOAD_URL%" > "%INSTALL_DIR%\ffmpeg.zip" || ( echo. && echo ffmpeg failed to download. && goto end )
|
72 |
+
call powershell -command "Expand-Archive -Force '%INSTALL_DIR%\ffmpeg.zip' '%INSTALL_DIR%\'"
|
73 |
+
cd "installer_files"
|
74 |
+
setlocal EnableExtensions EnableDelayedExpansion
|
75 |
+
for /f "tokens=*" %%f in ('dir /s /b /ad "ffmpeg\*"') do (
|
76 |
+
ren "%%f" "ffmpeg"
|
77 |
+
)
|
78 |
+
endlocal
|
79 |
+
setx PATH "%INSTALL_FFMPEG_DIR%\bin\;%PATH%"
|
80 |
+
echo To use videos, you need to restart roop after this installation.
|
81 |
+
cd ..
|
82 |
+
)
|
83 |
+
) else (
|
84 |
+
echo Skipping FFmpeg installation as it is already available.
|
85 |
+
)
|
86 |
+
|
87 |
+
@rem setup installer env
|
88 |
+
@rem check if conda environment was actually created
|
89 |
+
if not exist "%INSTALL_ENV_DIR%\python.exe" ( echo. && echo ERROR: Conda environment is empty. && goto end )
|
90 |
+
@rem activate installer env
|
91 |
+
call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%" || ( echo. && echo ERROR: Miniconda hook not found. && goto end )
|
92 |
+
echo Launching roop unleashed
|
93 |
+
call python installer.py %COMMANDLINE_ARGS%
|
94 |
+
|
95 |
+
echo.
|
96 |
+
echo Done!
|
97 |
+
|
98 |
+
:end
|
99 |
+
pause
|
mypy.ini
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[mypy]
|
2 |
+
check_untyped_defs = True
|
3 |
+
disallow_any_generics = True
|
4 |
+
disallow_untyped_calls = True
|
5 |
+
disallow_untyped_defs = True
|
6 |
+
ignore_missing_imports = True
|
7 |
+
strict_optional = False
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu118
|
2 |
+
|
3 |
+
numpy==1.26.4
|
4 |
+
gradio==4.32.1
|
5 |
+
opencv-python==4.9.0.80
|
6 |
+
onnx==1.16.0
|
7 |
+
insightface==0.7.3
|
8 |
+
psutil==5.9.6
|
9 |
+
torch==2.1.2+cu118; sys_platform != 'darwin'
|
10 |
+
torch==2.1.2; sys_platform == 'darwin'
|
11 |
+
torchvision==0.16.2+cu118; sys_platform != 'darwin'
|
12 |
+
torchvision==0.16.2; sys_platform == 'darwin'
|
13 |
+
onnxruntime==1.17.1; sys_platform == 'darwin' and platform_machine != 'arm64'
|
14 |
+
onnxruntime-silicon==1.16.3; sys_platform == 'darwin' and platform_machine == 'arm64'
|
15 |
+
onnxruntime-gpu==1.17.1; sys_platform != 'darwin'
|
16 |
+
tqdm==4.66.4
|
17 |
+
ftfy
|
18 |
+
regex
|
19 |
+
pyvirtualcam
|
roop-unleashed.ipynb
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nbformat": 4,
|
3 |
+
"nbformat_minor": 0,
|
4 |
+
"metadata": {
|
5 |
+
"colab": {
|
6 |
+
"provenance": [],
|
7 |
+
"gpuType": "T4",
|
8 |
+
"collapsed_sections": [
|
9 |
+
"UdQ1VHdI8lCf"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
"kernelspec": {
|
13 |
+
"name": "python3",
|
14 |
+
"display_name": "Python 3"
|
15 |
+
},
|
16 |
+
"language_info": {
|
17 |
+
"name": "python"
|
18 |
+
},
|
19 |
+
"accelerator": "GPU"
|
20 |
+
},
|
21 |
+
"cells": [
|
22 |
+
{
|
23 |
+
"cell_type": "markdown",
|
24 |
+
"source": [
|
25 |
+
"# Colab for roop-unleashed - Gradio version\n",
|
26 |
+
"https://github.com/C0untFloyd/roop-unleashed\n"
|
27 |
+
],
|
28 |
+
"metadata": {
|
29 |
+
"id": "G9BdiCppV6AS"
|
30 |
+
}
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "markdown",
|
34 |
+
"source": [
|
35 |
+
"Install CUDA V11.8 on Google Cloud Compute"
|
36 |
+
],
|
37 |
+
"metadata": {
|
38 |
+
"id": "CanIXgLJgaOj"
|
39 |
+
}
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"cell_type": "code",
|
43 |
+
"source": [
|
44 |
+
"!apt-get -y update\n",
|
45 |
+
"!apt-get -y install cuda-toolkit-11-8\n",
|
46 |
+
"import os\n",
|
47 |
+
"os.environ[\"LD_LIBRARY_PATH\"] += \":\" + \"/usr/local/cuda-11/lib64\"\n",
|
48 |
+
"os.environ[\"LD_LIBRARY_PATH\"] += \":\" + \"/usr/local/cuda-11.8/lib64\""
|
49 |
+
],
|
50 |
+
"metadata": {
|
51 |
+
"id": "96GE4UgYg3Ej"
|
52 |
+
},
|
53 |
+
"execution_count": null,
|
54 |
+
"outputs": []
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"cell_type": "markdown",
|
58 |
+
"source": [
|
59 |
+
"Installing & preparing requirements"
|
60 |
+
],
|
61 |
+
"metadata": {
|
62 |
+
"id": "0ZYRNb0AWLLW"
|
63 |
+
}
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "code",
|
67 |
+
"execution_count": null,
|
68 |
+
"metadata": {
|
69 |
+
"id": "t1yPuhdySqCq"
|
70 |
+
},
|
71 |
+
"outputs": [],
|
72 |
+
"source": [
|
73 |
+
"!git clone https://github.com/C0untFloyd/roop-unleashed.git\n",
|
74 |
+
"%cd roop-unleashed\n",
|
75 |
+
"!mv config_colab.yaml config.yaml\n",
|
76 |
+
"!pip install pip install -r requirements.txt"
|
77 |
+
]
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"cell_type": "markdown",
|
81 |
+
"source": [
|
82 |
+
"Running roop-unleashed with default config"
|
83 |
+
],
|
84 |
+
"metadata": {
|
85 |
+
"id": "u_4JQiSlV9Fi"
|
86 |
+
}
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"cell_type": "code",
|
90 |
+
"source": [
|
91 |
+
"!python run.py"
|
92 |
+
],
|
93 |
+
"metadata": {
|
94 |
+
"id": "Is6U2huqSzLE"
|
95 |
+
},
|
96 |
+
"execution_count": null,
|
97 |
+
"outputs": []
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "markdown",
|
101 |
+
"source": [
|
102 |
+
"### Download generated images folder\n",
|
103 |
+
"(only needed if you want to zip the generated output)"
|
104 |
+
],
|
105 |
+
"metadata": {
|
106 |
+
"id": "UdQ1VHdI8lCf"
|
107 |
+
}
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"cell_type": "code",
|
111 |
+
"source": [
|
112 |
+
"import shutil\n",
|
113 |
+
"import os\n",
|
114 |
+
"from google.colab import files\n",
|
115 |
+
"\n",
|
116 |
+
"def zip_directory(directory_path, zip_path):\n",
|
117 |
+
" shutil.make_archive(zip_path, 'zip', directory_path)\n",
|
118 |
+
"\n",
|
119 |
+
"# Set the directory path you want to download\n",
|
120 |
+
"directory_path = '/content/roop-unleashed/output'\n",
|
121 |
+
"\n",
|
122 |
+
"# Set the zip file name\n",
|
123 |
+
"zip_filename = 'fake_output.zip'\n",
|
124 |
+
"\n",
|
125 |
+
"# Zip the directory\n",
|
126 |
+
"zip_directory(directory_path, zip_filename)\n",
|
127 |
+
"\n",
|
128 |
+
"# Download the zip file\n",
|
129 |
+
"files.download(zip_filename+'.zip')\n"
|
130 |
+
],
|
131 |
+
"metadata": {
|
132 |
+
"colab": {
|
133 |
+
"base_uri": "https://localhost:8080/",
|
134 |
+
"height": 17
|
135 |
+
},
|
136 |
+
"id": "oYjWveAmw10X",
|
137 |
+
"outputId": "5b4c3650-f951-434a-c650-5525a8a70c1e"
|
138 |
+
},
|
139 |
+
"execution_count": null,
|
140 |
+
"outputs": [
|
141 |
+
{
|
142 |
+
"output_type": "display_data",
|
143 |
+
"data": {
|
144 |
+
"text/plain": [
|
145 |
+
"<IPython.core.display.Javascript object>"
|
146 |
+
],
|
147 |
+
"application/javascript": [
|
148 |
+
"\n",
|
149 |
+
" async function download(id, filename, size) {\n",
|
150 |
+
" if (!google.colab.kernel.accessAllowed) {\n",
|
151 |
+
" return;\n",
|
152 |
+
" }\n",
|
153 |
+
" const div = document.createElement('div');\n",
|
154 |
+
" const label = document.createElement('label');\n",
|
155 |
+
" label.textContent = `Downloading \"${filename}\": `;\n",
|
156 |
+
" div.appendChild(label);\n",
|
157 |
+
" const progress = document.createElement('progress');\n",
|
158 |
+
" progress.max = size;\n",
|
159 |
+
" div.appendChild(progress);\n",
|
160 |
+
" document.body.appendChild(div);\n",
|
161 |
+
"\n",
|
162 |
+
" const buffers = [];\n",
|
163 |
+
" let downloaded = 0;\n",
|
164 |
+
"\n",
|
165 |
+
" const channel = await google.colab.kernel.comms.open(id);\n",
|
166 |
+
" // Send a message to notify the kernel that we're ready.\n",
|
167 |
+
" channel.send({})\n",
|
168 |
+
"\n",
|
169 |
+
" for await (const message of channel.messages) {\n",
|
170 |
+
" // Send a message to notify the kernel that we're ready.\n",
|
171 |
+
" channel.send({})\n",
|
172 |
+
" if (message.buffers) {\n",
|
173 |
+
" for (const buffer of message.buffers) {\n",
|
174 |
+
" buffers.push(buffer);\n",
|
175 |
+
" downloaded += buffer.byteLength;\n",
|
176 |
+
" progress.value = downloaded;\n",
|
177 |
+
" }\n",
|
178 |
+
" }\n",
|
179 |
+
" }\n",
|
180 |
+
" const blob = new Blob(buffers, {type: 'application/binary'});\n",
|
181 |
+
" const a = document.createElement('a');\n",
|
182 |
+
" a.href = window.URL.createObjectURL(blob);\n",
|
183 |
+
" a.download = filename;\n",
|
184 |
+
" div.appendChild(a);\n",
|
185 |
+
" a.click();\n",
|
186 |
+
" div.remove();\n",
|
187 |
+
" }\n",
|
188 |
+
" "
|
189 |
+
]
|
190 |
+
},
|
191 |
+
"metadata": {}
|
192 |
+
},
|
193 |
+
{
|
194 |
+
"output_type": "display_data",
|
195 |
+
"data": {
|
196 |
+
"text/plain": [
|
197 |
+
"<IPython.core.display.Javascript object>"
|
198 |
+
],
|
199 |
+
"application/javascript": [
|
200 |
+
"download(\"download_789eab11-93d2-4880-adf3-6aceee0cc5f9\", \"fake_output.zip.zip\", 80125)"
|
201 |
+
]
|
202 |
+
},
|
203 |
+
"metadata": {}
|
204 |
+
}
|
205 |
+
]
|
206 |
+
}
|
207 |
+
]
|
208 |
+
}
|
roop/FaceSet.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
class FaceSet:
|
4 |
+
faces = []
|
5 |
+
ref_images = []
|
6 |
+
embedding_average = 'None'
|
7 |
+
embeddings_backup = None
|
8 |
+
|
9 |
+
def __init__(self):
|
10 |
+
self.faces = []
|
11 |
+
self.ref_images = []
|
12 |
+
self.embeddings_backup = None
|
13 |
+
|
14 |
+
def AverageEmbeddings(self):
|
15 |
+
if len(self.faces) > 1 and self.embeddings_backup is None:
|
16 |
+
self.embeddings_backup = self.faces[0]['embedding']
|
17 |
+
embeddings = [face.embedding for face in self.faces]
|
18 |
+
|
19 |
+
self.faces[0]['embedding'] = np.mean(embeddings, axis=0)
|
20 |
+
# try median too?
|
roop/ProcessEntry.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class ProcessEntry:
|
2 |
+
def __init__(self, filename: str, start: int, end: int, fps: float):
|
3 |
+
self.filename = filename
|
4 |
+
self.finalname = None
|
5 |
+
self.startframe = start
|
6 |
+
self.endframe = end
|
7 |
+
self.fps = fps
|
roop/ProcessMgr.py
ADDED
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import psutil
|
5 |
+
|
6 |
+
from enum import Enum
|
7 |
+
from roop.ProcessOptions import ProcessOptions
|
8 |
+
|
9 |
+
from roop.face_util import get_first_face, get_all_faces, rotate_image_180, rotate_anticlockwise, rotate_clockwise, clamp_cut_values
|
10 |
+
from roop.utilities import compute_cosine_distance, get_device, str_to_class
|
11 |
+
import roop.vr_util as vr
|
12 |
+
|
13 |
+
from typing import Any, List, Callable
|
14 |
+
from roop.typing import Frame, Face
|
15 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
16 |
+
from threading import Thread, Lock
|
17 |
+
from queue import Queue
|
18 |
+
from tqdm import tqdm
|
19 |
+
from roop.ffmpeg_writer import FFMPEG_VideoWriter
|
20 |
+
import roop.globals
|
21 |
+
|
22 |
+
|
23 |
+
# Poor man's enum to be able to compare to int
|
24 |
+
class eNoFaceAction():
|
25 |
+
USE_ORIGINAL_FRAME = 0
|
26 |
+
RETRY_ROTATED = 1
|
27 |
+
SKIP_FRAME = 2
|
28 |
+
SKIP_FRAME_IF_DISSIMILAR = 3
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
def create_queue(temp_frame_paths: List[str]) -> Queue[str]:
|
33 |
+
queue: Queue[str] = Queue()
|
34 |
+
for frame_path in temp_frame_paths:
|
35 |
+
queue.put(frame_path)
|
36 |
+
return queue
|
37 |
+
|
38 |
+
|
39 |
+
def pick_queue(queue: Queue[str], queue_per_future: int) -> List[str]:
|
40 |
+
queues = []
|
41 |
+
for _ in range(queue_per_future):
|
42 |
+
if not queue.empty():
|
43 |
+
queues.append(queue.get())
|
44 |
+
return queues
|
45 |
+
|
46 |
+
|
47 |
+
class ProcessMgr():
|
48 |
+
input_face_datas = []
|
49 |
+
target_face_datas = []
|
50 |
+
|
51 |
+
imagemask = None
|
52 |
+
|
53 |
+
processors = []
|
54 |
+
options : ProcessOptions = None
|
55 |
+
|
56 |
+
num_threads = 1
|
57 |
+
current_index = 0
|
58 |
+
processing_threads = 1
|
59 |
+
buffer_wait_time = 0.1
|
60 |
+
|
61 |
+
lock = Lock()
|
62 |
+
|
63 |
+
frames_queue = None
|
64 |
+
processed_queue = None
|
65 |
+
|
66 |
+
videowriter= None
|
67 |
+
|
68 |
+
progress_gradio = None
|
69 |
+
total_frames = 0
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
plugins = {
|
75 |
+
'faceswap' : 'FaceSwapInsightFace',
|
76 |
+
'mask_clip2seg' : 'Mask_Clip2Seg',
|
77 |
+
'mask_xseg' : 'Mask_XSeg',
|
78 |
+
'codeformer' : 'Enhance_CodeFormer',
|
79 |
+
'gfpgan' : 'Enhance_GFPGAN',
|
80 |
+
'dmdnet' : 'Enhance_DMDNet',
|
81 |
+
'gpen' : 'Enhance_GPEN',
|
82 |
+
'restoreformer++' : 'Enhance_RestoreFormerPPlus',
|
83 |
+
'colorizer' : 'Frame_Colorizer',
|
84 |
+
'filter_generic' : 'Frame_Filter',
|
85 |
+
'removebg' : 'Frame_Masking',
|
86 |
+
'upscale' : 'Frame_Upscale'
|
87 |
+
}
|
88 |
+
|
89 |
+
def __init__(self, progress):
|
90 |
+
if progress is not None:
|
91 |
+
self.progress_gradio = progress
|
92 |
+
|
93 |
+
def reuseOldProcessor(self, name:str):
|
94 |
+
for p in self.processors:
|
95 |
+
if p.processorname == name:
|
96 |
+
return p
|
97 |
+
|
98 |
+
return None
|
99 |
+
|
100 |
+
|
101 |
+
def initialize(self, input_faces, target_faces, options):
|
102 |
+
self.input_face_datas = input_faces
|
103 |
+
self.target_face_datas = target_faces
|
104 |
+
self.options = options
|
105 |
+
devicename = get_device()
|
106 |
+
|
107 |
+
roop.globals.g_desired_face_analysis=["landmark_3d_68", "landmark_2d_106","detection","recognition"]
|
108 |
+
if options.swap_mode == "all_female" or options.swap_mode == "all_male":
|
109 |
+
roop.globals.g_desired_face_analysis.append("genderage")
|
110 |
+
|
111 |
+
for p in self.processors:
|
112 |
+
newp = next((x for x in options.processors.keys() if x == p.processorname), None)
|
113 |
+
if newp is None:
|
114 |
+
p.Release()
|
115 |
+
del p
|
116 |
+
|
117 |
+
newprocessors = []
|
118 |
+
for key, extoption in options.processors.items():
|
119 |
+
p = self.reuseOldProcessor(key)
|
120 |
+
if p is None:
|
121 |
+
classname = self.plugins[key]
|
122 |
+
module = 'roop.processors.' + classname
|
123 |
+
p = str_to_class(module, classname)
|
124 |
+
if p is not None:
|
125 |
+
extoption.update({"devicename": devicename})
|
126 |
+
p.Initialize(extoption)
|
127 |
+
newprocessors.append(p)
|
128 |
+
else:
|
129 |
+
print(f"Not using {module}")
|
130 |
+
self.processors = newprocessors
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
if isinstance(self.options.imagemask, dict) and self.options.imagemask.get("layers") and len(self.options.imagemask["layers"]) > 0:
|
135 |
+
self.options.imagemask = self.options.imagemask.get("layers")[0]
|
136 |
+
# Get rid of alpha
|
137 |
+
self.options.imagemask = cv2.cvtColor(self.options.imagemask, cv2.COLOR_RGBA2GRAY)
|
138 |
+
if np.any(self.options.imagemask):
|
139 |
+
mo = self.input_face_datas[0].faces[0].mask_offsets
|
140 |
+
self.options.imagemask = self.blur_area(self.options.imagemask, mo[4], mo[5])
|
141 |
+
self.options.imagemask = self.options.imagemask.astype(np.float32) / 255
|
142 |
+
self.options.imagemask = cv2.cvtColor(self.options.imagemask, cv2.COLOR_GRAY2RGB)
|
143 |
+
else:
|
144 |
+
self.options.imagemask = None
|
145 |
+
|
146 |
+
self.options.frame_processing = False
|
147 |
+
for p in self.processors:
|
148 |
+
if p.type.startswith("frame_"):
|
149 |
+
self.options.frame_processing = True
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
def run_batch(self, source_files, target_files, threads:int = 1):
|
157 |
+
progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
|
158 |
+
self.total_frames = len(source_files)
|
159 |
+
self.num_threads = threads
|
160 |
+
with tqdm(total=self.total_frames, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format) as progress:
|
161 |
+
with ThreadPoolExecutor(max_workers=threads) as executor:
|
162 |
+
futures = []
|
163 |
+
queue = create_queue(source_files)
|
164 |
+
queue_per_future = max(len(source_files) // threads, 1)
|
165 |
+
while not queue.empty():
|
166 |
+
future = executor.submit(self.process_frames, source_files, target_files, pick_queue(queue, queue_per_future), lambda: self.update_progress(progress))
|
167 |
+
futures.append(future)
|
168 |
+
for future in as_completed(futures):
|
169 |
+
future.result()
|
170 |
+
|
171 |
+
|
172 |
+
def process_frames(self, source_files: List[str], target_files: List[str], current_files, update: Callable[[], None]) -> None:
|
173 |
+
for f in current_files:
|
174 |
+
if not roop.globals.processing:
|
175 |
+
return
|
176 |
+
|
177 |
+
# Decode the byte array into an OpenCV image
|
178 |
+
temp_frame = cv2.imdecode(np.fromfile(f, dtype=np.uint8), cv2.IMREAD_COLOR)
|
179 |
+
if temp_frame is not None:
|
180 |
+
if self.options.frame_processing:
|
181 |
+
for p in self.processors:
|
182 |
+
frame = p.Run(temp_frame)
|
183 |
+
resimg = frame
|
184 |
+
else:
|
185 |
+
resimg = self.process_frame(temp_frame)
|
186 |
+
if resimg is not None:
|
187 |
+
i = source_files.index(f)
|
188 |
+
cv2.imwrite(target_files[i], resimg)
|
189 |
+
if update:
|
190 |
+
update()
|
191 |
+
|
192 |
+
|
193 |
+
|
194 |
+
def read_frames_thread(self, cap, frame_start, frame_end, num_threads):
|
195 |
+
num_frame = 0
|
196 |
+
total_num = frame_end - frame_start
|
197 |
+
if frame_start > 0:
|
198 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES,frame_start)
|
199 |
+
|
200 |
+
while True and roop.globals.processing:
|
201 |
+
ret, frame = cap.read()
|
202 |
+
if not ret:
|
203 |
+
break
|
204 |
+
|
205 |
+
self.frames_queue[num_frame % num_threads].put(frame, block=True)
|
206 |
+
num_frame += 1
|
207 |
+
if num_frame == total_num:
|
208 |
+
break
|
209 |
+
|
210 |
+
for i in range(num_threads):
|
211 |
+
self.frames_queue[i].put(None)
|
212 |
+
|
213 |
+
|
214 |
+
|
215 |
+
def process_videoframes(self, threadindex, progress) -> None:
|
216 |
+
while True:
|
217 |
+
frame = self.frames_queue[threadindex].get()
|
218 |
+
if frame is None:
|
219 |
+
self.processing_threads -= 1
|
220 |
+
self.processed_queue[threadindex].put((False, None))
|
221 |
+
return
|
222 |
+
else:
|
223 |
+
if self.options.frame_processing:
|
224 |
+
for p in self.processors:
|
225 |
+
frame = p.Run(frame)
|
226 |
+
resimg = frame
|
227 |
+
else:
|
228 |
+
resimg = self.process_frame(frame)
|
229 |
+
self.processed_queue[threadindex].put((True, resimg))
|
230 |
+
del frame
|
231 |
+
progress()
|
232 |
+
|
233 |
+
|
234 |
+
def write_frames_thread(self):
|
235 |
+
nextindex = 0
|
236 |
+
num_producers = self.num_threads
|
237 |
+
|
238 |
+
while True:
|
239 |
+
process, frame = self.processed_queue[nextindex % self.num_threads].get()
|
240 |
+
nextindex += 1
|
241 |
+
if frame is not None:
|
242 |
+
self.videowriter.write_frame(frame)
|
243 |
+
del frame
|
244 |
+
elif process == False:
|
245 |
+
num_producers -= 1
|
246 |
+
if num_producers < 1:
|
247 |
+
return
|
248 |
+
|
249 |
+
|
250 |
+
|
251 |
+
def run_batch_inmem(self, source_video, target_video, frame_start, frame_end, fps, threads:int = 1, skip_audio=False):
|
252 |
+
cap = cv2.VideoCapture(source_video)
|
253 |
+
# frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
254 |
+
frame_count = (frame_end - frame_start) + 1
|
255 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
256 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
257 |
+
|
258 |
+
processed_resolution = None
|
259 |
+
for p in self.processors:
|
260 |
+
if hasattr(p, 'getProcessedResolution'):
|
261 |
+
processed_resolution = p.getProcessedResolution(width, height)
|
262 |
+
print(f"Processed resolution: {processed_resolution}")
|
263 |
+
if processed_resolution is not None:
|
264 |
+
width = processed_resolution[0]
|
265 |
+
height = processed_resolution[1]
|
266 |
+
|
267 |
+
|
268 |
+
self.total_frames = frame_count
|
269 |
+
self.num_threads = threads
|
270 |
+
|
271 |
+
self.processing_threads = self.num_threads
|
272 |
+
self.frames_queue = []
|
273 |
+
self.processed_queue = []
|
274 |
+
for _ in range(threads):
|
275 |
+
self.frames_queue.append(Queue(1))
|
276 |
+
self.processed_queue.append(Queue(1))
|
277 |
+
|
278 |
+
self.videowriter = FFMPEG_VideoWriter(target_video, (width, height), fps, codec=roop.globals.video_encoder, crf=roop.globals.video_quality, audiofile=None)
|
279 |
+
|
280 |
+
readthread = Thread(target=self.read_frames_thread, args=(cap, frame_start, frame_end, threads))
|
281 |
+
readthread.start()
|
282 |
+
|
283 |
+
writethread = Thread(target=self.write_frames_thread)
|
284 |
+
writethread.start()
|
285 |
+
|
286 |
+
progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
|
287 |
+
with tqdm(total=self.total_frames, desc='Processing', unit='frames', dynamic_ncols=True, bar_format=progress_bar_format) as progress:
|
288 |
+
with ThreadPoolExecutor(thread_name_prefix='swap_proc', max_workers=self.num_threads) as executor:
|
289 |
+
futures = []
|
290 |
+
|
291 |
+
for threadindex in range(threads):
|
292 |
+
future = executor.submit(self.process_videoframes, threadindex, lambda: self.update_progress(progress))
|
293 |
+
futures.append(future)
|
294 |
+
|
295 |
+
for future in as_completed(futures):
|
296 |
+
future.result()
|
297 |
+
# wait for the task to complete
|
298 |
+
readthread.join()
|
299 |
+
writethread.join()
|
300 |
+
cap.release()
|
301 |
+
self.videowriter.close()
|
302 |
+
self.frames_queue.clear()
|
303 |
+
self.processed_queue.clear()
|
304 |
+
|
305 |
+
|
306 |
+
|
307 |
+
|
308 |
+
def update_progress(self, progress: Any = None) -> None:
|
309 |
+
process = psutil.Process(os.getpid())
|
310 |
+
memory_usage = process.memory_info().rss / 1024 / 1024 / 1024
|
311 |
+
progress.set_postfix({
|
312 |
+
'memory_usage': '{:.2f}'.format(memory_usage).zfill(5) + 'GB',
|
313 |
+
'execution_threads': self.num_threads
|
314 |
+
})
|
315 |
+
progress.update(1)
|
316 |
+
if self.progress_gradio is not None:
|
317 |
+
self.progress_gradio((progress.n, self.total_frames), desc='Processing', total=self.total_frames, unit='frames')
|
318 |
+
|
319 |
+
|
320 |
+
# https://github.com/deepinsight/insightface#third-party-re-implementation-of-arcface
|
321 |
+
# https://github.com/deepinsight/insightface/blob/master/alignment/coordinate_reg/image_infer.py
|
322 |
+
# https://github.com/deepinsight/insightface/issues/1350
|
323 |
+
# https://github.com/linghu8812/tensorrt_inference
|
324 |
+
|
325 |
+
|
326 |
+
def process_frame(self, frame:Frame):
|
327 |
+
if len(self.input_face_datas) < 1 and not self.options.show_face_masking:
|
328 |
+
return frame
|
329 |
+
temp_frame = frame.copy()
|
330 |
+
num_swapped, temp_frame = self.swap_faces(frame, temp_frame)
|
331 |
+
if num_swapped > 0:
|
332 |
+
if roop.globals.no_face_action == eNoFaceAction.SKIP_FRAME_IF_DISSIMILAR:
|
333 |
+
if len(self.input_face_datas) > num_swapped:
|
334 |
+
return None
|
335 |
+
return temp_frame
|
336 |
+
if roop.globals.no_face_action == eNoFaceAction.USE_ORIGINAL_FRAME:
|
337 |
+
return frame
|
338 |
+
if roop.globals.no_face_action == eNoFaceAction.SKIP_FRAME:
|
339 |
+
#This only works with in-mem processing, as it simply skips the frame.
|
340 |
+
#For 'extract frames' it simply leaves the unprocessed frame unprocessed and it gets used in the final output by ffmpeg.
|
341 |
+
#If we could delete that frame here, that'd work but that might cause ffmpeg to fail unless the frames are renamed, and I don't think we have the info on what frame it actually is?????
|
342 |
+
#alternatively, it could mark all the necessary frames for deletion, delete them at the end, then rename the remaining frames that might work?
|
343 |
+
return None
|
344 |
+
else:
|
345 |
+
return self.retry_rotated(frame)
|
346 |
+
|
347 |
+
def retry_rotated(self, frame):
|
348 |
+
copyframe = frame.copy()
|
349 |
+
copyframe = rotate_clockwise(copyframe)
|
350 |
+
temp_frame = copyframe.copy()
|
351 |
+
num_swapped, temp_frame = self.swap_faces(copyframe, temp_frame)
|
352 |
+
if num_swapped > 0:
|
353 |
+
return rotate_anticlockwise(temp_frame)
|
354 |
+
|
355 |
+
copyframe = frame.copy()
|
356 |
+
copyframe = rotate_anticlockwise(copyframe)
|
357 |
+
temp_frame = copyframe.copy()
|
358 |
+
num_swapped, temp_frame = self.swap_faces(copyframe, temp_frame)
|
359 |
+
if num_swapped > 0:
|
360 |
+
return rotate_clockwise(temp_frame)
|
361 |
+
del copyframe
|
362 |
+
return frame
|
363 |
+
|
364 |
+
|
365 |
+
|
366 |
+
def swap_faces(self, frame, temp_frame):
|
367 |
+
num_faces_found = 0
|
368 |
+
|
369 |
+
if self.options.swap_mode == "first":
|
370 |
+
face = get_first_face(frame)
|
371 |
+
|
372 |
+
if face is None:
|
373 |
+
return num_faces_found, frame
|
374 |
+
|
375 |
+
num_faces_found += 1
|
376 |
+
temp_frame = self.process_face(self.options.selected_index, face, temp_frame)
|
377 |
+
else:
|
378 |
+
faces = get_all_faces(frame)
|
379 |
+
if faces is None:
|
380 |
+
return num_faces_found, frame
|
381 |
+
|
382 |
+
if self.options.swap_mode == "all":
|
383 |
+
for face in faces:
|
384 |
+
num_faces_found += 1
|
385 |
+
temp_frame = self.process_face(self.options.selected_index, face, temp_frame)
|
386 |
+
del face
|
387 |
+
|
388 |
+
elif self.options.swap_mode == "selected":
|
389 |
+
num_targetfaces = len(self.target_face_datas)
|
390 |
+
use_index = num_targetfaces == 1
|
391 |
+
for i,tf in enumerate(self.target_face_datas):
|
392 |
+
for face in faces:
|
393 |
+
if compute_cosine_distance(tf.embedding, face.embedding) <= self.options.face_distance_threshold:
|
394 |
+
if i < len(self.input_face_datas):
|
395 |
+
if use_index:
|
396 |
+
temp_frame = self.process_face(self.options.selected_index, face, temp_frame)
|
397 |
+
else:
|
398 |
+
temp_frame = self.process_face(i, face, temp_frame)
|
399 |
+
num_faces_found += 1
|
400 |
+
del face
|
401 |
+
if not roop.globals.vr_mode and num_faces_found == num_targetfaces:
|
402 |
+
break
|
403 |
+
elif self.options.swap_mode == "all_female" or self.options.swap_mode == "all_male":
|
404 |
+
gender = 'F' if self.options.swap_mode == "all_female" else 'M'
|
405 |
+
for face in faces:
|
406 |
+
if face.sex == gender:
|
407 |
+
num_faces_found += 1
|
408 |
+
temp_frame = self.process_face(self.options.selected_index, face, temp_frame)
|
409 |
+
del face
|
410 |
+
|
411 |
+
if roop.globals.vr_mode and num_faces_found % 2 > 0:
|
412 |
+
# stereo image, there has to be an even number of faces
|
413 |
+
num_faces_found = 0
|
414 |
+
return num_faces_found, frame
|
415 |
+
if num_faces_found == 0:
|
416 |
+
return num_faces_found, frame
|
417 |
+
|
418 |
+
#maskprocessor = next((x for x in self.processors if x.type == 'mask'), None)
|
419 |
+
|
420 |
+
if self.options.imagemask is not None and self.options.imagemask.shape == frame.shape:
|
421 |
+
temp_frame = self.simple_blend_with_mask(temp_frame, frame, self.options.imagemask)
|
422 |
+
return num_faces_found, temp_frame
|
423 |
+
|
424 |
+
|
425 |
+
def rotation_action(self, original_face:Face, frame:Frame):
|
426 |
+
(height, width) = frame.shape[:2]
|
427 |
+
|
428 |
+
bounding_box_width = original_face.bbox[2] - original_face.bbox[0]
|
429 |
+
bounding_box_height = original_face.bbox[3] - original_face.bbox[1]
|
430 |
+
horizontal_face = bounding_box_width > bounding_box_height
|
431 |
+
|
432 |
+
center_x = width // 2.0
|
433 |
+
start_x = original_face.bbox[0]
|
434 |
+
end_x = original_face.bbox[2]
|
435 |
+
bbox_center_x = start_x + (bounding_box_width // 2.0)
|
436 |
+
|
437 |
+
# need to leverage the array of landmarks as decribed here:
|
438 |
+
# https://github.com/deepinsight/insightface/tree/master/alignment/coordinate_reg
|
439 |
+
# basically, we should be able to check for the relative position of eyes and nose
|
440 |
+
# then use that to determine which way the face is actually facing when in a horizontal position
|
441 |
+
# and use that to determine the correct rotation_action
|
442 |
+
|
443 |
+
forehead_x = original_face.landmark_2d_106[72][0]
|
444 |
+
chin_x = original_face.landmark_2d_106[0][0]
|
445 |
+
|
446 |
+
if horizontal_face:
|
447 |
+
if chin_x < forehead_x:
|
448 |
+
# this is someone lying down with their face like this (:
|
449 |
+
return "rotate_anticlockwise"
|
450 |
+
elif forehead_x < chin_x:
|
451 |
+
# this is someone lying down with their face like this :)
|
452 |
+
return "rotate_clockwise"
|
453 |
+
if bbox_center_x >= center_x:
|
454 |
+
# this is someone lying down with their face in the right hand side of the frame
|
455 |
+
return "rotate_anticlockwise"
|
456 |
+
if bbox_center_x < center_x:
|
457 |
+
# this is someone lying down with their face in the left hand side of the frame
|
458 |
+
return "rotate_clockwise"
|
459 |
+
|
460 |
+
return None
|
461 |
+
|
462 |
+
|
463 |
+
def auto_rotate_frame(self, original_face, frame:Frame):
|
464 |
+
target_face = original_face
|
465 |
+
original_frame = frame
|
466 |
+
|
467 |
+
rotation_action = self.rotation_action(original_face, frame)
|
468 |
+
|
469 |
+
if rotation_action == "rotate_anticlockwise":
|
470 |
+
#face is horizontal, rotating frame anti-clockwise and getting face bounding box from rotated frame
|
471 |
+
frame = rotate_anticlockwise(frame)
|
472 |
+
elif rotation_action == "rotate_clockwise":
|
473 |
+
#face is horizontal, rotating frame clockwise and getting face bounding box from rotated frame
|
474 |
+
frame = rotate_clockwise(frame)
|
475 |
+
|
476 |
+
return target_face, frame, rotation_action
|
477 |
+
|
478 |
+
|
479 |
+
def auto_unrotate_frame(self, frame:Frame, rotation_action):
|
480 |
+
if rotation_action == "rotate_anticlockwise":
|
481 |
+
return rotate_clockwise(frame)
|
482 |
+
elif rotation_action == "rotate_clockwise":
|
483 |
+
return rotate_anticlockwise(frame)
|
484 |
+
|
485 |
+
return frame
|
486 |
+
|
487 |
+
|
488 |
+
|
489 |
+
def process_face(self,face_index, target_face:Face, frame:Frame):
|
490 |
+
from roop.face_util import align_crop
|
491 |
+
|
492 |
+
enhanced_frame = None
|
493 |
+
if(len(self.input_face_datas) > 0):
|
494 |
+
inputface = self.input_face_datas[face_index].faces[0]
|
495 |
+
else:
|
496 |
+
inputface = None
|
497 |
+
|
498 |
+
rotation_action = None
|
499 |
+
if roop.globals.autorotate_faces:
|
500 |
+
# check for sideways rotation of face
|
501 |
+
rotation_action = self.rotation_action(target_face, frame)
|
502 |
+
if rotation_action is not None:
|
503 |
+
(startX, startY, endX, endY) = target_face["bbox"].astype("int")
|
504 |
+
width = endX - startX
|
505 |
+
height = endY - startY
|
506 |
+
offs = int(max(width,height) * 0.25)
|
507 |
+
rotcutframe,startX, startY, endX, endY = self.cutout(frame, startX - offs, startY - offs, endX + offs, endY + offs)
|
508 |
+
if rotation_action == "rotate_anticlockwise":
|
509 |
+
rotcutframe = rotate_anticlockwise(rotcutframe)
|
510 |
+
elif rotation_action == "rotate_clockwise":
|
511 |
+
rotcutframe = rotate_clockwise(rotcutframe)
|
512 |
+
# rotate image and re-detect face to correct wonky landmarks
|
513 |
+
rotface = get_first_face(rotcutframe)
|
514 |
+
if rotface is None:
|
515 |
+
rotation_action = None
|
516 |
+
else:
|
517 |
+
saved_frame = frame.copy()
|
518 |
+
frame = rotcutframe
|
519 |
+
target_face = rotface
|
520 |
+
|
521 |
+
|
522 |
+
|
523 |
+
# if roop.globals.vr_mode:
|
524 |
+
# bbox = target_face.bbox
|
525 |
+
# [orig_width, orig_height, _] = frame.shape
|
526 |
+
|
527 |
+
# # Convert bounding box to ints
|
528 |
+
# x1, y1, x2, y2 = map(int, bbox)
|
529 |
+
|
530 |
+
# # Determine the center of the bounding box
|
531 |
+
# x_center = (x1 + x2) / 2
|
532 |
+
# y_center = (y1 + y2) / 2
|
533 |
+
|
534 |
+
# # Normalize coordinates to range [-1, 1]
|
535 |
+
# x_center_normalized = x_center / (orig_width / 2) - 1
|
536 |
+
# y_center_normalized = y_center / (orig_width / 2) - 1
|
537 |
+
|
538 |
+
# # Convert normalized coordinates to spherical (theta, phi)
|
539 |
+
# theta = x_center_normalized * 180 # Theta ranges from -180 to 180 degrees
|
540 |
+
# phi = -y_center_normalized * 90 # Phi ranges from -90 to 90 degrees
|
541 |
+
|
542 |
+
# img = vr.GetPerspective(frame, 90, theta, phi, 1280, 1280) # Generate perspective image
|
543 |
+
|
544 |
+
fake_frame = None
|
545 |
+
aligned_img, M = align_crop(frame, target_face.kps, 128)
|
546 |
+
fake_frame = aligned_img
|
547 |
+
swap_frame = aligned_img
|
548 |
+
target_face.matrix = M
|
549 |
+
for p in self.processors:
|
550 |
+
if p.type == 'swap':
|
551 |
+
if inputface is not None:
|
552 |
+
for _ in range(0,self.options.num_swap_steps):
|
553 |
+
swap_frame = p.Run(inputface, target_face, swap_frame)
|
554 |
+
fake_frame = swap_frame
|
555 |
+
scale_factor = 0.0
|
556 |
+
elif p.type == 'mask':
|
557 |
+
fake_frame = self.process_mask(p, aligned_img, fake_frame)
|
558 |
+
else:
|
559 |
+
enhanced_frame, scale_factor = p.Run(self.input_face_datas[face_index], target_face, fake_frame)
|
560 |
+
|
561 |
+
upscale = 512
|
562 |
+
orig_width = fake_frame.shape[1]
|
563 |
+
|
564 |
+
fake_frame = cv2.resize(fake_frame, (upscale, upscale), cv2.INTER_CUBIC)
|
565 |
+
mask_offsets = (0,0,0,0,1,20) if inputface is None else inputface.mask_offsets
|
566 |
+
|
567 |
+
|
568 |
+
if enhanced_frame is None:
|
569 |
+
scale_factor = int(upscale / orig_width)
|
570 |
+
result = self.paste_upscale(fake_frame, fake_frame, target_face.matrix, frame, scale_factor, mask_offsets)
|
571 |
+
else:
|
572 |
+
result = self.paste_upscale(fake_frame, enhanced_frame, target_face.matrix, frame, scale_factor, mask_offsets)
|
573 |
+
|
574 |
+
if rotation_action is not None:
|
575 |
+
fake_frame = self.auto_unrotate_frame(result, rotation_action)
|
576 |
+
return self.paste_simple(fake_frame, saved_frame, startX, startY)
|
577 |
+
|
578 |
+
return result
|
579 |
+
|
580 |
+
|
581 |
+
|
582 |
+
|
583 |
+
def cutout(self, frame:Frame, start_x, start_y, end_x, end_y):
|
584 |
+
if start_x < 0:
|
585 |
+
start_x = 0
|
586 |
+
if start_y < 0:
|
587 |
+
start_y = 0
|
588 |
+
if end_x > frame.shape[1]:
|
589 |
+
end_x = frame.shape[1]
|
590 |
+
if end_y > frame.shape[0]:
|
591 |
+
end_y = frame.shape[0]
|
592 |
+
return frame[start_y:end_y, start_x:end_x], start_x, start_y, end_x, end_y
|
593 |
+
|
594 |
+
def paste_simple(self, src:Frame, dest:Frame, start_x, start_y):
|
595 |
+
end_x = start_x + src.shape[1]
|
596 |
+
end_y = start_y + src.shape[0]
|
597 |
+
|
598 |
+
start_x, end_x, start_y, end_y = clamp_cut_values(start_x, end_x, start_y, end_y, dest)
|
599 |
+
dest[start_y:end_y, start_x:end_x] = src
|
600 |
+
return dest
|
601 |
+
|
602 |
+
def simple_blend_with_mask(self, image1, image2, mask):
|
603 |
+
# Blend the images
|
604 |
+
blended_image = image1.astype(np.float32) * (1.0 - mask) + image2.astype(np.float32) * mask
|
605 |
+
return blended_image.astype(np.uint8)
|
606 |
+
|
607 |
+
|
608 |
+
def paste_upscale(self, fake_face, upsk_face, M, target_img, scale_factor, mask_offsets):
|
609 |
+
M_scale = M * scale_factor
|
610 |
+
IM = cv2.invertAffineTransform(M_scale)
|
611 |
+
|
612 |
+
face_matte = np.full((target_img.shape[0],target_img.shape[1]), 255, dtype=np.uint8)
|
613 |
+
# Generate white square sized as a upsk_face
|
614 |
+
img_matte = np.zeros((upsk_face.shape[0],upsk_face.shape[1]), dtype=np.uint8)
|
615 |
+
|
616 |
+
w = img_matte.shape[1]
|
617 |
+
h = img_matte.shape[0]
|
618 |
+
|
619 |
+
top = int(mask_offsets[0] * h)
|
620 |
+
bottom = int(h - (mask_offsets[1] * h))
|
621 |
+
left = int(mask_offsets[2] * w)
|
622 |
+
right = int(w - (mask_offsets[3] * w))
|
623 |
+
img_matte[top:bottom,left:right] = 255
|
624 |
+
|
625 |
+
# Transform white square back to target_img
|
626 |
+
img_matte = cv2.warpAffine(img_matte, IM, (target_img.shape[1], target_img.shape[0]), flags=cv2.INTER_NEAREST, borderValue=0.0)
|
627 |
+
##Blacken the edges of face_matte by 1 pixels (so the mask in not expanded on the image edges)
|
628 |
+
img_matte[:1,:] = img_matte[-1:,:] = img_matte[:,:1] = img_matte[:,-1:] = 0
|
629 |
+
|
630 |
+
img_matte = self.blur_area(img_matte, mask_offsets[4], mask_offsets[5])
|
631 |
+
#Normalize images to float values and reshape
|
632 |
+
img_matte = img_matte.astype(np.float32)/255
|
633 |
+
face_matte = face_matte.astype(np.float32)/255
|
634 |
+
img_matte = np.minimum(face_matte, img_matte)
|
635 |
+
if self.options.show_face_area_overlay:
|
636 |
+
# Additional steps for green overlay
|
637 |
+
green_overlay = np.zeros_like(target_img)
|
638 |
+
green_color = [0, 255, 0] # RGB for green
|
639 |
+
for i in range(3): # Apply green color where img_matte is not zero
|
640 |
+
green_overlay[:, :, i] = np.where(img_matte > 0, green_color[i], 0) ##Transform upcaled face back to target_img
|
641 |
+
img_matte = np.reshape(img_matte, [img_matte.shape[0],img_matte.shape[1],1])
|
642 |
+
paste_face = cv2.warpAffine(upsk_face, IM, (target_img.shape[1], target_img.shape[0]), borderMode=cv2.BORDER_REPLICATE)
|
643 |
+
if upsk_face is not fake_face:
|
644 |
+
fake_face = cv2.warpAffine(fake_face, IM, (target_img.shape[1], target_img.shape[0]), borderMode=cv2.BORDER_REPLICATE)
|
645 |
+
paste_face = cv2.addWeighted(paste_face, self.options.blend_ratio, fake_face, 1.0 - self.options.blend_ratio, 0)
|
646 |
+
|
647 |
+
# Re-assemble image
|
648 |
+
paste_face = img_matte * paste_face
|
649 |
+
paste_face = paste_face + (1-img_matte) * target_img.astype(np.float32)
|
650 |
+
if self.options.show_face_area_overlay:
|
651 |
+
# Overlay the green overlay on the final image
|
652 |
+
paste_face = cv2.addWeighted(paste_face.astype(np.uint8), 1 - 0.5, green_overlay, 0.5, 0)
|
653 |
+
return paste_face.astype(np.uint8)
|
654 |
+
|
655 |
+
|
656 |
+
def blur_area(self, img_matte, num_erosion_iterations, blur_amount):
|
657 |
+
# Detect the affine transformed white area
|
658 |
+
mask_h_inds, mask_w_inds = np.where(img_matte==255)
|
659 |
+
# Calculate the size (and diagonal size) of transformed white area width and height boundaries
|
660 |
+
mask_h = np.max(mask_h_inds) - np.min(mask_h_inds)
|
661 |
+
mask_w = np.max(mask_w_inds) - np.min(mask_w_inds)
|
662 |
+
mask_size = int(np.sqrt(mask_h*mask_w))
|
663 |
+
# Calculate the kernel size for eroding img_matte by kernel (insightface empirical guess for best size was max(mask_size//10,10))
|
664 |
+
# k = max(mask_size//12, 8)
|
665 |
+
k = max(mask_size//(blur_amount // 2) , blur_amount // 2)
|
666 |
+
kernel = np.ones((k,k),np.uint8)
|
667 |
+
img_matte = cv2.erode(img_matte,kernel,iterations = num_erosion_iterations)
|
668 |
+
#Calculate the kernel size for blurring img_matte by blur_size (insightface empirical guess for best size was max(mask_size//20, 5))
|
669 |
+
# k = max(mask_size//24, 4)
|
670 |
+
k = max(mask_size//blur_amount, blur_amount//5)
|
671 |
+
kernel_size = (k, k)
|
672 |
+
blur_size = tuple(2*i+1 for i in kernel_size)
|
673 |
+
return cv2.GaussianBlur(img_matte, blur_size, 0)
|
674 |
+
|
675 |
+
|
676 |
+
def process_mask(self, processor, frame:Frame, target:Frame):
|
677 |
+
img_mask = processor.Run(frame, self.options.masking_text)
|
678 |
+
img_mask = cv2.resize(img_mask, (target.shape[1], target.shape[0]))
|
679 |
+
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
|
680 |
+
|
681 |
+
if self.options.show_face_masking:
|
682 |
+
result = (1 - img_mask) * frame.astype(np.float32)
|
683 |
+
return np.uint8(result)
|
684 |
+
|
685 |
+
|
686 |
+
target = target.astype(np.float32)
|
687 |
+
result = (1-img_mask) * target
|
688 |
+
result += img_mask * frame.astype(np.float32)
|
689 |
+
return np.uint8(result)
|
690 |
+
|
691 |
+
|
692 |
+
|
693 |
+
|
694 |
+
def unload_models():
|
695 |
+
pass
|
696 |
+
|
697 |
+
|
698 |
+
def release_resources(self):
|
699 |
+
for p in self.processors:
|
700 |
+
p.Release()
|
701 |
+
self.processors.clear()
|
702 |
+
|
roop/ProcessOptions.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class ProcessOptions:
|
2 |
+
|
3 |
+
def __init__(self, processordefines:dict, face_distance, blend_ratio, swap_mode, selected_index, masking_text, imagemask, num_steps, show_face_area, show_mask=False):
|
4 |
+
self.processors = processordefines
|
5 |
+
self.face_distance_threshold = face_distance
|
6 |
+
self.blend_ratio = blend_ratio
|
7 |
+
self.swap_mode = swap_mode
|
8 |
+
self.selected_index = selected_index
|
9 |
+
self.masking_text = masking_text
|
10 |
+
self.imagemask = imagemask
|
11 |
+
self.num_swap_steps = num_steps
|
12 |
+
self.show_face_area_overlay = show_face_area
|
13 |
+
self.show_face_masking = show_mask
|
roop/__init__.py
ADDED
File without changes
|
roop/capturer.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from roop.typing import Frame
|
6 |
+
|
7 |
+
def get_image_frame(filename: str):
|
8 |
+
try:
|
9 |
+
return cv2.imdecode(np.fromfile(filename, dtype=np.uint8), cv2.IMREAD_COLOR)
|
10 |
+
except:
|
11 |
+
print(f"Exception reading {filename}")
|
12 |
+
return None
|
13 |
+
|
14 |
+
|
15 |
+
def get_video_frame(video_path: str, frame_number: int = 0) -> Optional[Frame]:
|
16 |
+
capture = cv2.VideoCapture(video_path)
|
17 |
+
frame_total = capture.get(cv2.CAP_PROP_FRAME_COUNT)
|
18 |
+
capture.set(cv2.CAP_PROP_POS_FRAMES, min(frame_total, frame_number - 1))
|
19 |
+
has_frame, frame = capture.read()
|
20 |
+
capture.release()
|
21 |
+
if has_frame:
|
22 |
+
return frame
|
23 |
+
return None
|
24 |
+
|
25 |
+
|
26 |
+
def get_video_frame_total(video_path: str) -> int:
|
27 |
+
capture = cv2.VideoCapture(video_path)
|
28 |
+
video_frame_total = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
29 |
+
capture.release()
|
30 |
+
return video_frame_total
|
roop/core.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import shutil
|
6 |
+
# single thread doubles cuda performance - needs to be set before torch import
|
7 |
+
if any(arg.startswith('--execution-provider') for arg in sys.argv):
|
8 |
+
os.environ['OMP_NUM_THREADS'] = '1'
|
9 |
+
|
10 |
+
import warnings
|
11 |
+
from typing import List
|
12 |
+
import platform
|
13 |
+
import signal
|
14 |
+
import torch
|
15 |
+
import onnxruntime
|
16 |
+
import pathlib
|
17 |
+
|
18 |
+
from time import time
|
19 |
+
|
20 |
+
import roop.globals
|
21 |
+
import roop.metadata
|
22 |
+
import roop.utilities as util
|
23 |
+
import roop.util_ffmpeg as ffmpeg
|
24 |
+
import ui.main as main
|
25 |
+
from settings import Settings
|
26 |
+
from roop.face_util import extract_face_images
|
27 |
+
from roop.ProcessEntry import ProcessEntry
|
28 |
+
from roop.ProcessMgr import ProcessMgr
|
29 |
+
from roop.ProcessOptions import ProcessOptions
|
30 |
+
from roop.capturer import get_video_frame_total
|
31 |
+
|
32 |
+
|
33 |
+
clip_text = None
|
34 |
+
|
35 |
+
call_display_ui = None
|
36 |
+
|
37 |
+
process_mgr = None
|
38 |
+
|
39 |
+
|
40 |
+
if 'ROCMExecutionProvider' in roop.globals.execution_providers:
|
41 |
+
del torch
|
42 |
+
|
43 |
+
warnings.filterwarnings('ignore', category=FutureWarning, module='insightface')
|
44 |
+
warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
|
45 |
+
|
46 |
+
|
47 |
+
def parse_args() -> None:
|
48 |
+
signal.signal(signal.SIGINT, lambda signal_number, frame: destroy())
|
49 |
+
roop.globals.headless = False
|
50 |
+
# Always enable all processors when using GUI
|
51 |
+
if len(sys.argv) > 1:
|
52 |
+
print('No CLI args supported - use Settings Tab instead')
|
53 |
+
roop.globals.frame_processors = ['face_swapper', 'face_enhancer']
|
54 |
+
|
55 |
+
|
56 |
+
def encode_execution_providers(execution_providers: List[str]) -> List[str]:
|
57 |
+
return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers]
|
58 |
+
|
59 |
+
|
60 |
+
def decode_execution_providers(execution_providers: List[str]) -> List[str]:
|
61 |
+
return [provider for provider, encoded_execution_provider in zip(onnxruntime.get_available_providers(), encode_execution_providers(onnxruntime.get_available_providers()))
|
62 |
+
if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]
|
63 |
+
|
64 |
+
|
65 |
+
def suggest_max_memory() -> int:
|
66 |
+
if platform.system().lower() == 'darwin':
|
67 |
+
return 4
|
68 |
+
return 16
|
69 |
+
|
70 |
+
|
71 |
+
def suggest_execution_providers() -> List[str]:
|
72 |
+
return encode_execution_providers(onnxruntime.get_available_providers())
|
73 |
+
|
74 |
+
|
75 |
+
def suggest_execution_threads() -> int:
|
76 |
+
if 'DmlExecutionProvider' in roop.globals.execution_providers:
|
77 |
+
return 1
|
78 |
+
if 'ROCMExecutionProvider' in roop.globals.execution_providers:
|
79 |
+
return 1
|
80 |
+
return 8
|
81 |
+
|
82 |
+
|
83 |
+
def limit_resources() -> None:
|
84 |
+
# limit memory usage
|
85 |
+
if roop.globals.max_memory:
|
86 |
+
memory = roop.globals.max_memory * 1024 ** 3
|
87 |
+
if platform.system().lower() == 'darwin':
|
88 |
+
memory = roop.globals.max_memory * 1024 ** 6
|
89 |
+
if platform.system().lower() == 'windows':
|
90 |
+
import ctypes
|
91 |
+
kernel32 = ctypes.windll.kernel32 # type: ignore[attr-defined]
|
92 |
+
kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory))
|
93 |
+
else:
|
94 |
+
import resource
|
95 |
+
resource.setrlimit(resource.RLIMIT_DATA, (memory, memory))
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
def release_resources() -> None:
|
100 |
+
import gc
|
101 |
+
global process_mgr
|
102 |
+
|
103 |
+
if process_mgr is not None:
|
104 |
+
process_mgr.release_resources()
|
105 |
+
process_mgr = None
|
106 |
+
|
107 |
+
gc.collect()
|
108 |
+
# if 'CUDAExecutionProvider' in roop.globals.execution_providers and torch.cuda.is_available():
|
109 |
+
# with torch.cuda.device('cuda'):
|
110 |
+
# torch.cuda.empty_cache()
|
111 |
+
# torch.cuda.ipc_collect()
|
112 |
+
|
113 |
+
|
114 |
+
def pre_check() -> bool:
|
115 |
+
if sys.version_info < (3, 9):
|
116 |
+
update_status('Python version is not supported - please upgrade to 3.9 or higher.')
|
117 |
+
return False
|
118 |
+
|
119 |
+
download_directory_path = util.resolve_relative_path('../models')
|
120 |
+
util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/inswapper_128.onnx'])
|
121 |
+
util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/GFPGANv1.4.onnx'])
|
122 |
+
util.conditional_download(download_directory_path, ['https://github.com/csxmli2016/DMDNet/releases/download/v1/DMDNet.pth'])
|
123 |
+
util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/GPEN-BFR-512.onnx'])
|
124 |
+
util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/restoreformer_plus_plus.onnx'])
|
125 |
+
util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/xseg.onnx'])
|
126 |
+
download_directory_path = util.resolve_relative_path('../models/CLIP')
|
127 |
+
util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/rd64-uni-refined.pth'])
|
128 |
+
download_directory_path = util.resolve_relative_path('../models/CodeFormer')
|
129 |
+
util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/CodeFormerv0.1.onnx'])
|
130 |
+
download_directory_path = util.resolve_relative_path('../models/Frame')
|
131 |
+
util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/deoldify_artistic.onnx'])
|
132 |
+
util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/deoldify_stable.onnx'])
|
133 |
+
util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/isnet-general-use.onnx'])
|
134 |
+
util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/real_esrgan_x4.onnx'])
|
135 |
+
util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/real_esrgan_x2.onnx'])
|
136 |
+
util.conditional_download(download_directory_path, ['https://huggingface.co/countfloyd/deepfake/resolve/main/lsdir_x4.onnx'])
|
137 |
+
|
138 |
+
if not shutil.which('ffmpeg'):
|
139 |
+
update_status('ffmpeg is not installed.')
|
140 |
+
return True
|
141 |
+
|
142 |
+
def set_display_ui(function):
|
143 |
+
global call_display_ui
|
144 |
+
|
145 |
+
call_display_ui = function
|
146 |
+
|
147 |
+
|
148 |
+
def update_status(message: str) -> None:
|
149 |
+
global call_display_ui
|
150 |
+
|
151 |
+
print(message)
|
152 |
+
if call_display_ui is not None:
|
153 |
+
call_display_ui(message)
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
def start() -> None:
|
159 |
+
if roop.globals.headless:
|
160 |
+
print('Headless mode currently unsupported - starting UI!')
|
161 |
+
# faces = extract_face_images(roop.globals.source_path, (False, 0))
|
162 |
+
# roop.globals.INPUT_FACES.append(faces[roop.globals.source_face_index])
|
163 |
+
# faces = extract_face_images(roop.globals.target_path, (False, util.has_image_extension(roop.globals.target_path)))
|
164 |
+
# roop.globals.TARGET_FACES.append(faces[roop.globals.target_face_index])
|
165 |
+
# if 'face_enhancer' in roop.globals.frame_processors:
|
166 |
+
# roop.globals.selected_enhancer = 'GFPGAN'
|
167 |
+
|
168 |
+
batch_process_regular(None, False, None)
|
169 |
+
|
170 |
+
|
171 |
+
def get_processing_plugins(masking_engine):
|
172 |
+
processors = { "faceswap": {}}
|
173 |
+
if masking_engine is not None:
|
174 |
+
processors.update({masking_engine: {}})
|
175 |
+
|
176 |
+
if roop.globals.selected_enhancer == 'GFPGAN':
|
177 |
+
processors.update({"gfpgan": {}})
|
178 |
+
elif roop.globals.selected_enhancer == 'Codeformer':
|
179 |
+
processors.update({"codeformer": {}})
|
180 |
+
elif roop.globals.selected_enhancer == 'DMDNet':
|
181 |
+
processors.update({"dmdnet": {}})
|
182 |
+
elif roop.globals.selected_enhancer == 'GPEN':
|
183 |
+
processors.update({"gpen": {}})
|
184 |
+
elif roop.globals.selected_enhancer == 'Restoreformer++':
|
185 |
+
processors.update({"restoreformer++": {}})
|
186 |
+
return processors
|
187 |
+
|
188 |
+
|
189 |
+
def live_swap(frame, options):
|
190 |
+
global process_mgr
|
191 |
+
|
192 |
+
if frame is None:
|
193 |
+
return frame
|
194 |
+
|
195 |
+
if process_mgr is None:
|
196 |
+
process_mgr = ProcessMgr(None)
|
197 |
+
|
198 |
+
# if len(roop.globals.INPUT_FACESETS) <= selected_index:
|
199 |
+
# selected_index = 0
|
200 |
+
process_mgr.initialize(roop.globals.INPUT_FACESETS, roop.globals.TARGET_FACES, options)
|
201 |
+
newframe = process_mgr.process_frame(frame)
|
202 |
+
if newframe is None:
|
203 |
+
return frame
|
204 |
+
return newframe
|
205 |
+
|
206 |
+
|
207 |
+
def batch_process_regular(files:list[ProcessEntry], masking_engine:str, new_clip_text:str, use_new_method, imagemask, num_swap_steps, progress, selected_index = 0) -> None:
|
208 |
+
global clip_text, process_mgr
|
209 |
+
|
210 |
+
release_resources()
|
211 |
+
limit_resources()
|
212 |
+
if process_mgr is None:
|
213 |
+
process_mgr = ProcessMgr(progress)
|
214 |
+
mask = imagemask["layers"][0] if imagemask is not None else None
|
215 |
+
if len(roop.globals.INPUT_FACESETS) <= selected_index:
|
216 |
+
selected_index = 0
|
217 |
+
options = ProcessOptions(get_processing_plugins(masking_engine), roop.globals.distance_threshold, roop.globals.blend_ratio, roop.globals.face_swap_mode, selected_index, new_clip_text, mask, num_swap_steps, False)
|
218 |
+
process_mgr.initialize(roop.globals.INPUT_FACESETS, roop.globals.TARGET_FACES, options)
|
219 |
+
batch_process(files, use_new_method)
|
220 |
+
return
|
221 |
+
|
222 |
+
def batch_process_with_options(files:list[ProcessEntry], options, progress):
|
223 |
+
global clip_text, process_mgr
|
224 |
+
|
225 |
+
release_resources()
|
226 |
+
limit_resources()
|
227 |
+
if process_mgr is None:
|
228 |
+
process_mgr = ProcessMgr(progress)
|
229 |
+
process_mgr.initialize(roop.globals.INPUT_FACESETS, roop.globals.TARGET_FACES, options)
|
230 |
+
roop.globals.keep_frames = False
|
231 |
+
roop.globals.wait_after_extraction = False
|
232 |
+
roop.globals.skip_audio = False
|
233 |
+
batch_process(files, True)
|
234 |
+
|
235 |
+
|
236 |
+
|
237 |
+
def batch_process(files:list[ProcessEntry], use_new_method) -> None:
|
238 |
+
global clip_text, process_mgr
|
239 |
+
|
240 |
+
roop.globals.processing = True
|
241 |
+
|
242 |
+
# limit threads for some providers
|
243 |
+
max_threads = suggest_execution_threads()
|
244 |
+
if max_threads == 1:
|
245 |
+
roop.globals.execution_threads = 1
|
246 |
+
|
247 |
+
imagefiles:list[ProcessEntry] = []
|
248 |
+
videofiles:list[ProcessEntry] = []
|
249 |
+
|
250 |
+
update_status('Sorting videos/images')
|
251 |
+
|
252 |
+
|
253 |
+
for index, f in enumerate(files):
|
254 |
+
fullname = f.filename
|
255 |
+
if util.has_image_extension(fullname):
|
256 |
+
destination = util.get_destfilename_from_path(fullname, roop.globals.output_path, f'.{roop.globals.CFG.output_image_format}')
|
257 |
+
destination = util.replace_template(destination, index=index)
|
258 |
+
pathlib.Path(os.path.dirname(destination)).mkdir(parents=True, exist_ok=True)
|
259 |
+
f.finalname = destination
|
260 |
+
imagefiles.append(f)
|
261 |
+
|
262 |
+
elif util.is_video(fullname) or util.has_extension(fullname, ['gif']):
|
263 |
+
destination = util.get_destfilename_from_path(fullname, roop.globals.output_path, f'__temp.{roop.globals.CFG.output_video_format}')
|
264 |
+
f.finalname = destination
|
265 |
+
videofiles.append(f)
|
266 |
+
|
267 |
+
|
268 |
+
|
269 |
+
if(len(imagefiles) > 0):
|
270 |
+
update_status('Processing image(s)')
|
271 |
+
origimages = []
|
272 |
+
fakeimages = []
|
273 |
+
for f in imagefiles:
|
274 |
+
origimages.append(f.filename)
|
275 |
+
fakeimages.append(f.finalname)
|
276 |
+
|
277 |
+
process_mgr.run_batch(origimages, fakeimages, roop.globals.execution_threads)
|
278 |
+
origimages.clear()
|
279 |
+
fakeimages.clear()
|
280 |
+
|
281 |
+
if(len(videofiles) > 0):
|
282 |
+
for index,v in enumerate(videofiles):
|
283 |
+
if not roop.globals.processing:
|
284 |
+
end_processing('Processing stopped!')
|
285 |
+
return
|
286 |
+
fps = v.fps if v.fps > 0 else util.detect_fps(v.filename)
|
287 |
+
if v.endframe == 0:
|
288 |
+
v.endframe = get_video_frame_total(v.filename)
|
289 |
+
|
290 |
+
update_status(f'Creating {os.path.basename(v.finalname)} with {fps} FPS...')
|
291 |
+
start_processing = time()
|
292 |
+
if roop.globals.keep_frames or not use_new_method:
|
293 |
+
util.create_temp(v.filename)
|
294 |
+
update_status('Extracting frames...')
|
295 |
+
ffmpeg.extract_frames(v.filename,v.startframe,v.endframe, fps)
|
296 |
+
if not roop.globals.processing:
|
297 |
+
end_processing('Processing stopped!')
|
298 |
+
return
|
299 |
+
|
300 |
+
temp_frame_paths = util.get_temp_frame_paths(v.filename)
|
301 |
+
process_mgr.run_batch(temp_frame_paths, temp_frame_paths, roop.globals.execution_threads)
|
302 |
+
if not roop.globals.processing:
|
303 |
+
end_processing('Processing stopped!')
|
304 |
+
return
|
305 |
+
if roop.globals.wait_after_extraction:
|
306 |
+
extract_path = os.path.dirname(temp_frame_paths[0])
|
307 |
+
util.open_folder(extract_path)
|
308 |
+
input("Press any key to continue...")
|
309 |
+
print("Resorting frames to create video")
|
310 |
+
util.sort_rename_frames(extract_path)
|
311 |
+
|
312 |
+
ffmpeg.create_video(v.filename, v.finalname, fps)
|
313 |
+
if not roop.globals.keep_frames:
|
314 |
+
util.delete_temp_frames(temp_frame_paths[0])
|
315 |
+
else:
|
316 |
+
if util.has_extension(v.filename, ['gif']):
|
317 |
+
skip_audio = True
|
318 |
+
else:
|
319 |
+
skip_audio = roop.globals.skip_audio
|
320 |
+
process_mgr.run_batch_inmem(v.filename, v.finalname, v.startframe, v.endframe, fps,roop.globals.execution_threads, skip_audio)
|
321 |
+
|
322 |
+
if not roop.globals.processing:
|
323 |
+
end_processing('Processing stopped!')
|
324 |
+
return
|
325 |
+
|
326 |
+
video_file_name = v.finalname
|
327 |
+
if os.path.isfile(video_file_name):
|
328 |
+
destination = ''
|
329 |
+
if util.has_extension(v.filename, ['gif']):
|
330 |
+
gifname = util.get_destfilename_from_path(v.filename, roop.globals.output_path, '.gif')
|
331 |
+
destination = util.replace_template(gifname, index=index)
|
332 |
+
pathlib.Path(os.path.dirname(destination)).mkdir(parents=True, exist_ok=True)
|
333 |
+
|
334 |
+
update_status('Creating final GIF')
|
335 |
+
ffmpeg.create_gif_from_video(video_file_name, destination)
|
336 |
+
if os.path.isfile(destination):
|
337 |
+
os.remove(video_file_name)
|
338 |
+
else:
|
339 |
+
skip_audio = roop.globals.skip_audio
|
340 |
+
destination = util.replace_template(video_file_name, index=index)
|
341 |
+
pathlib.Path(os.path.dirname(destination)).mkdir(parents=True, exist_ok=True)
|
342 |
+
|
343 |
+
if not skip_audio:
|
344 |
+
ffmpeg.restore_audio(video_file_name, v.filename, v.startframe, v.endframe, destination)
|
345 |
+
if os.path.isfile(destination):
|
346 |
+
os.remove(video_file_name)
|
347 |
+
else:
|
348 |
+
shutil.move(video_file_name, destination)
|
349 |
+
update_status(f'\nProcessing {os.path.basename(destination)} took {time() - start_processing} secs')
|
350 |
+
|
351 |
+
else:
|
352 |
+
update_status(f'Failed processing {os.path.basename(v.finalname)}!')
|
353 |
+
end_processing('Finished')
|
354 |
+
|
355 |
+
|
356 |
+
def end_processing(msg:str):
|
357 |
+
update_status(msg)
|
358 |
+
roop.globals.target_folder_path = None
|
359 |
+
release_resources()
|
360 |
+
|
361 |
+
|
362 |
+
def destroy() -> None:
|
363 |
+
if roop.globals.target_path:
|
364 |
+
util.clean_temp(roop.globals.target_path)
|
365 |
+
release_resources()
|
366 |
+
sys.exit()
|
367 |
+
|
368 |
+
|
369 |
+
def run() -> None:
|
370 |
+
parse_args()
|
371 |
+
if not pre_check():
|
372 |
+
return
|
373 |
+
roop.globals.CFG = Settings('config.yaml')
|
374 |
+
roop.globals.execution_threads = roop.globals.CFG.max_threads
|
375 |
+
roop.globals.video_encoder = roop.globals.CFG.output_video_codec
|
376 |
+
roop.globals.video_quality = roop.globals.CFG.video_quality
|
377 |
+
roop.globals.max_memory = roop.globals.CFG.memory_limit if roop.globals.CFG.memory_limit > 0 else None
|
378 |
+
main.run()
|
roop/face_util.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
from typing import Any
|
3 |
+
import insightface
|
4 |
+
|
5 |
+
import roop.globals
|
6 |
+
from roop.typing import Frame, Face
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
from skimage import transform as trans
|
11 |
+
from roop.capturer import get_video_frame
|
12 |
+
from roop.utilities import resolve_relative_path, conditional_download
|
13 |
+
|
14 |
+
FACE_ANALYSER = None
|
15 |
+
THREAD_LOCK_ANALYSER = threading.Lock()
|
16 |
+
THREAD_LOCK_SWAPPER = threading.Lock()
|
17 |
+
FACE_SWAPPER = None
|
18 |
+
|
19 |
+
|
20 |
+
def get_face_analyser() -> Any:
|
21 |
+
global FACE_ANALYSER
|
22 |
+
|
23 |
+
with THREAD_LOCK_ANALYSER:
|
24 |
+
if FACE_ANALYSER is None or roop.globals.g_current_face_analysis != roop.globals.g_desired_face_analysis:
|
25 |
+
model_path = resolve_relative_path('..')
|
26 |
+
# removed genderage
|
27 |
+
allowed_modules = roop.globals.g_desired_face_analysis
|
28 |
+
roop.globals.g_current_face_analysis = roop.globals.g_desired_face_analysis
|
29 |
+
if roop.globals.CFG.force_cpu:
|
30 |
+
print("Forcing CPU for Face Analysis")
|
31 |
+
FACE_ANALYSER = insightface.app.FaceAnalysis(
|
32 |
+
name="buffalo_l",
|
33 |
+
root=model_path, providers=["CPUExecutionProvider"],allowed_modules=allowed_modules
|
34 |
+
)
|
35 |
+
else:
|
36 |
+
FACE_ANALYSER = insightface.app.FaceAnalysis(
|
37 |
+
name="buffalo_l", root=model_path, providers=roop.globals.execution_providers,allowed_modules=allowed_modules
|
38 |
+
)
|
39 |
+
FACE_ANALYSER.prepare(
|
40 |
+
ctx_id=0,
|
41 |
+
det_size=(640, 640) if roop.globals.default_det_size else (320, 320),
|
42 |
+
)
|
43 |
+
return FACE_ANALYSER
|
44 |
+
|
45 |
+
|
46 |
+
def get_first_face(frame: Frame) -> Any:
|
47 |
+
try:
|
48 |
+
faces = get_face_analyser().get(frame)
|
49 |
+
return min(faces, key=lambda x: x.bbox[0])
|
50 |
+
# return sorted(faces, reverse=True, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[0]
|
51 |
+
except:
|
52 |
+
return None
|
53 |
+
|
54 |
+
|
55 |
+
def get_all_faces(frame: Frame) -> Any:
|
56 |
+
try:
|
57 |
+
faces = get_face_analyser().get(frame)
|
58 |
+
return sorted(faces, key=lambda x: x.bbox[0])
|
59 |
+
except:
|
60 |
+
return None
|
61 |
+
|
62 |
+
|
63 |
+
def extract_face_images(source_filename, video_info, extra_padding=-1.0):
|
64 |
+
face_data = []
|
65 |
+
source_image = None
|
66 |
+
|
67 |
+
if video_info[0]:
|
68 |
+
frame = get_video_frame(source_filename, video_info[1])
|
69 |
+
if frame is not None:
|
70 |
+
source_image = frame
|
71 |
+
else:
|
72 |
+
return face_data
|
73 |
+
else:
|
74 |
+
source_image = cv2.imdecode(np.fromfile(source_filename, dtype=np.uint8), cv2.IMREAD_COLOR)
|
75 |
+
|
76 |
+
faces = get_all_faces(source_image)
|
77 |
+
if faces is None:
|
78 |
+
return face_data
|
79 |
+
|
80 |
+
i = 0
|
81 |
+
for face in faces:
|
82 |
+
(startX, startY, endX, endY) = face["bbox"].astype("int")
|
83 |
+
startX, endX, startY, endY = clamp_cut_values(startX, endX, startY, endY, source_image)
|
84 |
+
if extra_padding > 0.0:
|
85 |
+
if source_image.shape[:2] == (512, 512):
|
86 |
+
i += 1
|
87 |
+
face_data.append([face, source_image])
|
88 |
+
continue
|
89 |
+
|
90 |
+
found = False
|
91 |
+
for i in range(1, 3):
|
92 |
+
(startX, startY, endX, endY) = face["bbox"].astype("int")
|
93 |
+
startX, endX, startY, endY = clamp_cut_values(startX, endX, startY, endY, source_image)
|
94 |
+
cutout_padding = extra_padding
|
95 |
+
# top needs extra room for detection
|
96 |
+
padding = int((endY - startY) * cutout_padding)
|
97 |
+
oldY = startY
|
98 |
+
startY -= padding
|
99 |
+
|
100 |
+
factor = 0.25 if i == 1 else 0.5
|
101 |
+
cutout_padding = factor
|
102 |
+
padding = int((endY - oldY) * cutout_padding)
|
103 |
+
endY += padding
|
104 |
+
padding = int((endX - startX) * cutout_padding)
|
105 |
+
startX -= padding
|
106 |
+
endX += padding
|
107 |
+
startX, endX, startY, endY = clamp_cut_values(
|
108 |
+
startX, endX, startY, endY, source_image
|
109 |
+
)
|
110 |
+
face_temp = source_image[startY:endY, startX:endX]
|
111 |
+
face_temp = resize_image_keep_content(face_temp)
|
112 |
+
testfaces = get_all_faces(face_temp)
|
113 |
+
if testfaces is not None and len(testfaces) > 0:
|
114 |
+
i += 1
|
115 |
+
face_data.append([testfaces[0], face_temp])
|
116 |
+
found = True
|
117 |
+
break
|
118 |
+
|
119 |
+
if not found:
|
120 |
+
print("No face found after resizing, this shouldn't happen!")
|
121 |
+
continue
|
122 |
+
|
123 |
+
face_temp = source_image[startY:endY, startX:endX]
|
124 |
+
if face_temp.size < 1:
|
125 |
+
continue
|
126 |
+
|
127 |
+
i += 1
|
128 |
+
face_data.append([face, face_temp])
|
129 |
+
return face_data
|
130 |
+
|
131 |
+
|
132 |
+
def clamp_cut_values(startX, endX, startY, endY, image):
|
133 |
+
if startX < 0:
|
134 |
+
startX = 0
|
135 |
+
if endX > image.shape[1]:
|
136 |
+
endX = image.shape[1]
|
137 |
+
if startY < 0:
|
138 |
+
startY = 0
|
139 |
+
if endY > image.shape[0]:
|
140 |
+
endY = image.shape[0]
|
141 |
+
return startX, endX, startY, endY
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
def face_offset_top(face: Face, offset):
|
146 |
+
face["bbox"][1] += offset
|
147 |
+
face["bbox"][3] += offset
|
148 |
+
lm106 = face.landmark_2d_106
|
149 |
+
add = np.full_like(lm106, [0, offset])
|
150 |
+
face["landmark_2d_106"] = lm106 + add
|
151 |
+
return face
|
152 |
+
|
153 |
+
|
154 |
+
def resize_image_keep_content(image, new_width=512, new_height=512):
|
155 |
+
dim = None
|
156 |
+
(h, w) = image.shape[:2]
|
157 |
+
if h > w:
|
158 |
+
r = new_height / float(h)
|
159 |
+
dim = (int(w * r), new_height)
|
160 |
+
else:
|
161 |
+
# Calculate the ratio of the width and construct the dimensions
|
162 |
+
r = new_width / float(w)
|
163 |
+
dim = (new_width, int(h * r))
|
164 |
+
image = cv2.resize(image, dim, interpolation=cv2.INTER_AREA)
|
165 |
+
(h, w) = image.shape[:2]
|
166 |
+
if h == new_height and w == new_width:
|
167 |
+
return image
|
168 |
+
resize_img = np.zeros(shape=(new_height, new_width, 3), dtype=image.dtype)
|
169 |
+
offs = (new_width - w) if h == new_height else (new_height - h)
|
170 |
+
startoffs = int(offs // 2) if offs % 2 == 0 else int(offs // 2) + 1
|
171 |
+
offs = int(offs // 2)
|
172 |
+
|
173 |
+
if h == new_height:
|
174 |
+
resize_img[0:new_height, startoffs : new_width - offs] = image
|
175 |
+
else:
|
176 |
+
resize_img[startoffs : new_height - offs, 0:new_width] = image
|
177 |
+
return resize_img
|
178 |
+
|
179 |
+
|
180 |
+
def rotate_image_90(image, rotate=True):
|
181 |
+
if rotate:
|
182 |
+
return np.rot90(image)
|
183 |
+
else:
|
184 |
+
return np.rot90(image, 1, (1, 0))
|
185 |
+
|
186 |
+
|
187 |
+
def rotate_anticlockwise(frame):
|
188 |
+
return rotate_image_90(frame)
|
189 |
+
|
190 |
+
|
191 |
+
def rotate_clockwise(frame):
|
192 |
+
return rotate_image_90(frame, False)
|
193 |
+
|
194 |
+
|
195 |
+
def rotate_image_180(image):
|
196 |
+
return np.flip(image, 0)
|
197 |
+
|
198 |
+
|
199 |
+
# alignment code from insightface https://github.com/deepinsight/insightface/blob/master/python-package/insightface/utils/face_align.py
|
200 |
+
|
201 |
+
arcface_dst = np.array(
|
202 |
+
[
|
203 |
+
[38.2946, 51.6963],
|
204 |
+
[73.5318, 51.5014],
|
205 |
+
[56.0252, 71.7366],
|
206 |
+
[41.5493, 92.3655],
|
207 |
+
[70.7299, 92.2041],
|
208 |
+
],
|
209 |
+
dtype=np.float32,
|
210 |
+
)
|
211 |
+
|
212 |
+
|
213 |
+
def estimate_norm(lmk, image_size=112, mode="arcface"):
|
214 |
+
assert lmk.shape == (5, 2)
|
215 |
+
assert image_size % 112 == 0 or image_size % 128 == 0
|
216 |
+
if image_size % 112 == 0:
|
217 |
+
ratio = float(image_size) / 112.0
|
218 |
+
diff_x = 0
|
219 |
+
else:
|
220 |
+
ratio = float(image_size) / 128.0
|
221 |
+
diff_x = 8.0 * ratio
|
222 |
+
dst = arcface_dst * ratio
|
223 |
+
dst[:, 0] += diff_x
|
224 |
+
tform = trans.SimilarityTransform()
|
225 |
+
tform.estimate(lmk, dst)
|
226 |
+
M = tform.params[0:2, :]
|
227 |
+
return M
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
# aligned, M = norm_crop2(f[1], face.kps, 512)
|
232 |
+
def align_crop(img, landmark, image_size=112, mode="arcface"):
|
233 |
+
M = estimate_norm(landmark, image_size, mode)
|
234 |
+
warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
|
235 |
+
return warped, M
|
236 |
+
|
237 |
+
|
238 |
+
def square_crop(im, S):
|
239 |
+
if im.shape[0] > im.shape[1]:
|
240 |
+
height = S
|
241 |
+
width = int(float(im.shape[1]) / im.shape[0] * S)
|
242 |
+
scale = float(S) / im.shape[0]
|
243 |
+
else:
|
244 |
+
width = S
|
245 |
+
height = int(float(im.shape[0]) / im.shape[1] * S)
|
246 |
+
scale = float(S) / im.shape[1]
|
247 |
+
resized_im = cv2.resize(im, (width, height))
|
248 |
+
det_im = np.zeros((S, S, 3), dtype=np.uint8)
|
249 |
+
det_im[: resized_im.shape[0], : resized_im.shape[1], :] = resized_im
|
250 |
+
return det_im, scale
|
251 |
+
|
252 |
+
|
253 |
+
def transform(data, center, output_size, scale, rotation):
|
254 |
+
scale_ratio = scale
|
255 |
+
rot = float(rotation) * np.pi / 180.0
|
256 |
+
# translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
|
257 |
+
t1 = trans.SimilarityTransform(scale=scale_ratio)
|
258 |
+
cx = center[0] * scale_ratio
|
259 |
+
cy = center[1] * scale_ratio
|
260 |
+
t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
|
261 |
+
t3 = trans.SimilarityTransform(rotation=rot)
|
262 |
+
t4 = trans.SimilarityTransform(translation=(output_size / 2, output_size / 2))
|
263 |
+
t = t1 + t2 + t3 + t4
|
264 |
+
M = t.params[0:2]
|
265 |
+
cropped = cv2.warpAffine(data, M, (output_size, output_size), borderValue=0.0)
|
266 |
+
return cropped, M
|
267 |
+
|
268 |
+
|
269 |
+
def trans_points2d(pts, M):
|
270 |
+
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
|
271 |
+
for i in range(pts.shape[0]):
|
272 |
+
pt = pts[i]
|
273 |
+
new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32)
|
274 |
+
new_pt = np.dot(M, new_pt)
|
275 |
+
# print('new_pt', new_pt.shape, new_pt)
|
276 |
+
new_pts[i] = new_pt[0:2]
|
277 |
+
|
278 |
+
return new_pts
|
279 |
+
|
280 |
+
|
281 |
+
def trans_points3d(pts, M):
|
282 |
+
scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])
|
283 |
+
# print(scale)
|
284 |
+
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
|
285 |
+
for i in range(pts.shape[0]):
|
286 |
+
pt = pts[i]
|
287 |
+
new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32)
|
288 |
+
new_pt = np.dot(M, new_pt)
|
289 |
+
# print('new_pt', new_pt.shape, new_pt)
|
290 |
+
new_pts[i][0:2] = new_pt[0:2]
|
291 |
+
new_pts[i][2] = pts[i][2] * scale
|
292 |
+
|
293 |
+
return new_pts
|
294 |
+
|
295 |
+
|
296 |
+
def trans_points(pts, M):
|
297 |
+
if pts.shape[1] == 2:
|
298 |
+
return trans_points2d(pts, M)
|
299 |
+
else:
|
300 |
+
return trans_points3d(pts, M)
|
301 |
+
|
302 |
+
def create_blank_image(width, height):
|
303 |
+
img = np.zeros((height, width, 4), dtype=np.uint8)
|
304 |
+
img[:] = [0,0,0,0]
|
305 |
+
return img
|
306 |
+
|
roop/ffmpeg_writer.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
FFMPEG_Writer - write set of frames to video file
|
3 |
+
|
4 |
+
original from
|
5 |
+
https://github.com/Zulko/moviepy/blob/master/moviepy/video/io/ffmpeg_writer.py
|
6 |
+
|
7 |
+
removed unnecessary dependencies
|
8 |
+
|
9 |
+
The MIT License (MIT)
|
10 |
+
|
11 |
+
Copyright (c) 2015 Zulko
|
12 |
+
Copyright (c) 2023 Janvarev Vladislav
|
13 |
+
"""
|
14 |
+
|
15 |
+
import os
|
16 |
+
import subprocess as sp
|
17 |
+
|
18 |
+
PIPE = -1
|
19 |
+
STDOUT = -2
|
20 |
+
DEVNULL = -3
|
21 |
+
|
22 |
+
FFMPEG_BINARY = "ffmpeg"
|
23 |
+
|
24 |
+
class FFMPEG_VideoWriter:
|
25 |
+
""" A class for FFMPEG-based video writing.
|
26 |
+
|
27 |
+
A class to write videos using ffmpeg. ffmpeg will write in a large
|
28 |
+
choice of formats.
|
29 |
+
|
30 |
+
Parameters
|
31 |
+
-----------
|
32 |
+
|
33 |
+
filename
|
34 |
+
Any filename like 'video.mp4' etc. but if you want to avoid
|
35 |
+
complications it is recommended to use the generic extension
|
36 |
+
'.avi' for all your videos.
|
37 |
+
|
38 |
+
size
|
39 |
+
Size (width,height) of the output video in pixels.
|
40 |
+
|
41 |
+
fps
|
42 |
+
Frames per second in the output video file.
|
43 |
+
|
44 |
+
codec
|
45 |
+
FFMPEG codec. It seems that in terms of quality the hierarchy is
|
46 |
+
'rawvideo' = 'png' > 'mpeg4' > 'libx264'
|
47 |
+
'png' manages the same lossless quality as 'rawvideo' but yields
|
48 |
+
smaller files. Type ``ffmpeg -codecs`` in a terminal to get a list
|
49 |
+
of accepted codecs.
|
50 |
+
|
51 |
+
Note for default 'libx264': by default the pixel format yuv420p
|
52 |
+
is used. If the video dimensions are not both even (e.g. 720x405)
|
53 |
+
another pixel format is used, and this can cause problem in some
|
54 |
+
video readers.
|
55 |
+
|
56 |
+
audiofile
|
57 |
+
Optional: The name of an audio file that will be incorporated
|
58 |
+
to the video.
|
59 |
+
|
60 |
+
preset
|
61 |
+
Sets the time that FFMPEG will take to compress the video. The slower,
|
62 |
+
the better the compression rate. Possibilities are: ultrafast,superfast,
|
63 |
+
veryfast, faster, fast, medium (default), slow, slower, veryslow,
|
64 |
+
placebo.
|
65 |
+
|
66 |
+
bitrate
|
67 |
+
Only relevant for codecs which accept a bitrate. "5000k" offers
|
68 |
+
nice results in general.
|
69 |
+
|
70 |
+
"""
|
71 |
+
|
72 |
+
def __init__(self, filename, size, fps, codec="libx265", crf=14, audiofile=None,
|
73 |
+
preset="medium", bitrate=None,
|
74 |
+
logfile=None, threads=None, ffmpeg_params=None):
|
75 |
+
|
76 |
+
if logfile is None:
|
77 |
+
logfile = sp.PIPE
|
78 |
+
|
79 |
+
self.filename = filename
|
80 |
+
self.codec = codec
|
81 |
+
self.ext = self.filename.split(".")[-1]
|
82 |
+
w = size[0] - 1 if size[0] % 2 != 0 else size[0]
|
83 |
+
h = size[1] - 1 if size[1] % 2 != 0 else size[1]
|
84 |
+
|
85 |
+
|
86 |
+
# order is important
|
87 |
+
cmd = [
|
88 |
+
FFMPEG_BINARY,
|
89 |
+
'-hide_banner',
|
90 |
+
'-hwaccel', 'auto',
|
91 |
+
'-y',
|
92 |
+
'-loglevel', 'error' if logfile == sp.PIPE else 'info',
|
93 |
+
'-f', 'rawvideo',
|
94 |
+
'-vcodec', 'rawvideo',
|
95 |
+
'-s', '%dx%d' % (size[0], size[1]),
|
96 |
+
#'-pix_fmt', 'rgba' if withmask else 'rgb24',
|
97 |
+
'-pix_fmt', 'bgr24',
|
98 |
+
'-r', str(fps),
|
99 |
+
'-an', '-i', '-'
|
100 |
+
]
|
101 |
+
|
102 |
+
if audiofile is not None:
|
103 |
+
cmd.extend([
|
104 |
+
'-i', audiofile,
|
105 |
+
'-acodec', 'copy'
|
106 |
+
])
|
107 |
+
|
108 |
+
cmd.extend([
|
109 |
+
'-vcodec', codec,
|
110 |
+
'-crf', str(crf)
|
111 |
+
#'-preset', preset,
|
112 |
+
])
|
113 |
+
if ffmpeg_params is not None:
|
114 |
+
cmd.extend(ffmpeg_params)
|
115 |
+
if bitrate is not None:
|
116 |
+
cmd.extend([
|
117 |
+
'-b', bitrate
|
118 |
+
])
|
119 |
+
|
120 |
+
# scale to a resolution divisible by 2 if not even
|
121 |
+
cmd.extend(['-vf', f'scale={w}:{h}' if w != size[0] or h != size[1] else 'colorspace=bt709:iall=bt601-6-625:fast=1'])
|
122 |
+
|
123 |
+
if threads is not None:
|
124 |
+
cmd.extend(["-threads", str(threads)])
|
125 |
+
|
126 |
+
cmd.extend([
|
127 |
+
'-pix_fmt', 'yuv420p',
|
128 |
+
|
129 |
+
])
|
130 |
+
cmd.extend([
|
131 |
+
filename
|
132 |
+
])
|
133 |
+
|
134 |
+
test = str(cmd)
|
135 |
+
print(test)
|
136 |
+
|
137 |
+
popen_params = {"stdout": DEVNULL,
|
138 |
+
"stderr": logfile,
|
139 |
+
"stdin": sp.PIPE}
|
140 |
+
|
141 |
+
# This was added so that no extra unwanted window opens on windows
|
142 |
+
# when the child process is created
|
143 |
+
if os.name == "nt":
|
144 |
+
popen_params["creationflags"] = 0x08000000 # CREATE_NO_WINDOW
|
145 |
+
|
146 |
+
self.proc = sp.Popen(cmd, **popen_params)
|
147 |
+
|
148 |
+
|
149 |
+
def write_frame(self, img_array):
|
150 |
+
""" Writes one frame in the file."""
|
151 |
+
try:
|
152 |
+
#if PY3:
|
153 |
+
self.proc.stdin.write(img_array.tobytes())
|
154 |
+
# else:
|
155 |
+
# self.proc.stdin.write(img_array.tostring())
|
156 |
+
except IOError as err:
|
157 |
+
_, ffmpeg_error = self.proc.communicate()
|
158 |
+
error = (str(err) + ("\n\nroop unleashed error: FFMPEG encountered "
|
159 |
+
"the following error while writing file %s:"
|
160 |
+
"\n\n %s" % (self.filename, str(ffmpeg_error))))
|
161 |
+
|
162 |
+
if b"Unknown encoder" in ffmpeg_error:
|
163 |
+
|
164 |
+
error = error+("\n\nThe video export "
|
165 |
+
"failed because FFMPEG didn't find the specified "
|
166 |
+
"codec for video encoding (%s). Please install "
|
167 |
+
"this codec or change the codec when calling "
|
168 |
+
"write_videofile. For instance:\n"
|
169 |
+
" >>> clip.write_videofile('myvid.webm', codec='libvpx')")%(self.codec)
|
170 |
+
|
171 |
+
elif b"incorrect codec parameters ?" in ffmpeg_error:
|
172 |
+
|
173 |
+
error = error+("\n\nThe video export "
|
174 |
+
"failed, possibly because the codec specified for "
|
175 |
+
"the video (%s) is not compatible with the given "
|
176 |
+
"extension (%s). Please specify a valid 'codec' "
|
177 |
+
"argument in write_videofile. This would be 'libx264' "
|
178 |
+
"or 'mpeg4' for mp4, 'libtheora' for ogv, 'libvpx for webm. "
|
179 |
+
"Another possible reason is that the audio codec was not "
|
180 |
+
"compatible with the video codec. For instance the video "
|
181 |
+
"extensions 'ogv' and 'webm' only allow 'libvorbis' (default) as a"
|
182 |
+
"video codec."
|
183 |
+
)%(self.codec, self.ext)
|
184 |
+
|
185 |
+
elif b"encoder setup failed" in ffmpeg_error:
|
186 |
+
|
187 |
+
error = error+("\n\nThe video export "
|
188 |
+
"failed, possibly because the bitrate you specified "
|
189 |
+
"was too high or too low for the video codec.")
|
190 |
+
|
191 |
+
elif b"Invalid encoder type" in ffmpeg_error:
|
192 |
+
|
193 |
+
error = error + ("\n\nThe video export failed because the codec "
|
194 |
+
"or file extension you provided is not a video")
|
195 |
+
|
196 |
+
|
197 |
+
raise IOError(error)
|
198 |
+
|
199 |
+
def close(self):
|
200 |
+
if self.proc:
|
201 |
+
self.proc.stdin.close()
|
202 |
+
if self.proc.stderr is not None:
|
203 |
+
self.proc.stderr.close()
|
204 |
+
self.proc.wait()
|
205 |
+
|
206 |
+
self.proc = None
|
207 |
+
|
208 |
+
# Support the Context Manager protocol, to ensure that resources are cleaned up.
|
209 |
+
|
210 |
+
def __enter__(self):
|
211 |
+
return self
|
212 |
+
|
213 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
214 |
+
self.close()
|
215 |
+
|
216 |
+
|
217 |
+
|
218 |
+
|
roop/globals.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from settings import Settings
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
source_path = None
|
5 |
+
target_path = None
|
6 |
+
output_path = None
|
7 |
+
target_folder_path = None
|
8 |
+
|
9 |
+
frame_processors: List[str] = []
|
10 |
+
keep_fps = None
|
11 |
+
keep_frames = None
|
12 |
+
autorotate_faces = None
|
13 |
+
vr_mode = None
|
14 |
+
skip_audio = None
|
15 |
+
wait_after_extraction = None
|
16 |
+
many_faces = None
|
17 |
+
use_batch = None
|
18 |
+
source_face_index = 0
|
19 |
+
target_face_index = 0
|
20 |
+
face_position = None
|
21 |
+
video_encoder = None
|
22 |
+
video_quality = None
|
23 |
+
max_memory = None
|
24 |
+
execution_providers: List[str] = []
|
25 |
+
execution_threads = None
|
26 |
+
headless = None
|
27 |
+
log_level = 'error'
|
28 |
+
selected_enhancer = None
|
29 |
+
face_swap_mode = None
|
30 |
+
blend_ratio = 0.5
|
31 |
+
distance_threshold = 0.65
|
32 |
+
default_det_size = True
|
33 |
+
|
34 |
+
no_face_action = 0
|
35 |
+
|
36 |
+
processing = False
|
37 |
+
|
38 |
+
g_current_face_analysis = None
|
39 |
+
g_desired_face_analysis = None
|
40 |
+
|
41 |
+
FACE_ENHANCER = None
|
42 |
+
|
43 |
+
INPUT_FACESETS = []
|
44 |
+
TARGET_FACES = []
|
45 |
+
|
46 |
+
|
47 |
+
IMAGE_CHAIN_PROCESSOR = None
|
48 |
+
VIDEO_CHAIN_PROCESSOR = None
|
49 |
+
BATCH_IMAGE_CHAIN_PROCESSOR = None
|
50 |
+
|
51 |
+
CFG: Settings = None
|
52 |
+
|
53 |
+
|
roop/metadata.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
name = 'roop unleashed'
|
2 |
+
version = '4.0.0'
|
roop/processors/Enhance_CodeFormer.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Callable
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import onnxruntime
|
5 |
+
import roop.globals
|
6 |
+
|
7 |
+
from roop.typing import Face, Frame, FaceSet
|
8 |
+
from roop.utilities import resolve_relative_path
|
9 |
+
|
10 |
+
|
11 |
+
# THREAD_LOCK = threading.Lock()
|
12 |
+
|
13 |
+
|
14 |
+
class Enhance_CodeFormer():
|
15 |
+
model_codeformer = None
|
16 |
+
|
17 |
+
plugin_options:dict = None
|
18 |
+
|
19 |
+
processorname = 'codeformer'
|
20 |
+
type = 'enhance'
|
21 |
+
|
22 |
+
|
23 |
+
def Initialize(self, plugin_options:dict):
|
24 |
+
if self.plugin_options is not None:
|
25 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
26 |
+
self.Release()
|
27 |
+
|
28 |
+
self.plugin_options = plugin_options
|
29 |
+
if self.model_codeformer is None:
|
30 |
+
# replace Mac mps with cpu for the moment
|
31 |
+
self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu')
|
32 |
+
model_path = resolve_relative_path('../models/CodeFormer/CodeFormerv0.1.onnx')
|
33 |
+
self.model_codeformer = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
|
34 |
+
self.model_inputs = self.model_codeformer.get_inputs()
|
35 |
+
model_outputs = self.model_codeformer.get_outputs()
|
36 |
+
self.io_binding = self.model_codeformer.io_binding()
|
37 |
+
self.io_binding.bind_cpu_input(self.model_inputs[1].name, np.array([0.5]))
|
38 |
+
self.io_binding.bind_output(model_outputs[0].name, self.devicename)
|
39 |
+
|
40 |
+
|
41 |
+
def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame:
|
42 |
+
input_size = temp_frame.shape[1]
|
43 |
+
# preprocess
|
44 |
+
temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC)
|
45 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)
|
46 |
+
temp_frame = temp_frame.astype('float32') / 255.0
|
47 |
+
temp_frame = (temp_frame - 0.5) / 0.5
|
48 |
+
temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2)
|
49 |
+
|
50 |
+
self.io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame.astype(np.float32))
|
51 |
+
self.model_codeformer.run_with_iobinding(self.io_binding)
|
52 |
+
ort_outs = self.io_binding.copy_outputs_to_cpu()
|
53 |
+
result = ort_outs[0][0]
|
54 |
+
del ort_outs
|
55 |
+
|
56 |
+
# post-process
|
57 |
+
result = result.transpose((1, 2, 0))
|
58 |
+
|
59 |
+
un_min = -1.0
|
60 |
+
un_max = 1.0
|
61 |
+
result = np.clip(result, un_min, un_max)
|
62 |
+
result = (result - un_min) / (un_max - un_min)
|
63 |
+
|
64 |
+
result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
|
65 |
+
result = (result * 255.0).round()
|
66 |
+
scale_factor = int(result.shape[1] / input_size)
|
67 |
+
return result.astype(np.uint8), scale_factor
|
68 |
+
|
69 |
+
|
70 |
+
def Release(self):
|
71 |
+
del self.model_codeformer
|
72 |
+
self.model_codeformer = None
|
73 |
+
del self.io_binding
|
74 |
+
self.io_binding = None
|
75 |
+
|
roop/processors/Enhance_DMDNet.py
ADDED
@@ -0,0 +1,898 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Callable
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.nn.utils.spectral_norm as SpectralNorm
|
8 |
+
import threading
|
9 |
+
from torchvision.ops import roi_align
|
10 |
+
|
11 |
+
from math import sqrt
|
12 |
+
|
13 |
+
from torchvision.transforms.functional import normalize
|
14 |
+
|
15 |
+
from roop.typing import Face, Frame, FaceSet
|
16 |
+
|
17 |
+
|
18 |
+
THREAD_LOCK_DMDNET = threading.Lock()
|
19 |
+
|
20 |
+
|
21 |
+
class Enhance_DMDNet():
|
22 |
+
plugin_options:dict = None
|
23 |
+
model_dmdnet = None
|
24 |
+
torchdevice = None
|
25 |
+
|
26 |
+
processorname = 'dmdnet'
|
27 |
+
type = 'enhance'
|
28 |
+
|
29 |
+
|
30 |
+
def Initialize(self, plugin_options:dict):
|
31 |
+
if self.plugin_options is not None:
|
32 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
33 |
+
self.Release()
|
34 |
+
|
35 |
+
self.plugin_options = plugin_options
|
36 |
+
if self.model_dmdnet is None:
|
37 |
+
self.model_dmdnet = self.create(self.plugin_options["devicename"])
|
38 |
+
|
39 |
+
|
40 |
+
# temp_frame already cropped+aligned, bbox not
|
41 |
+
def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame:
|
42 |
+
input_size = temp_frame.shape[1]
|
43 |
+
|
44 |
+
result = self.enhance_face(source_faceset, temp_frame, target_face)
|
45 |
+
scale_factor = int(result.shape[1] / input_size)
|
46 |
+
return result.astype(np.uint8), scale_factor
|
47 |
+
|
48 |
+
|
49 |
+
def Release(self):
|
50 |
+
self.model_gfpgan = None
|
51 |
+
|
52 |
+
|
53 |
+
# https://stackoverflow.com/a/67174339
|
54 |
+
def landmarks106_to_68(self, pt106):
|
55 |
+
map106to68=[1,10,12,14,16,3,5,7,0,23,21,19,32,30,28,26,17,
|
56 |
+
43,48,49,51,50,
|
57 |
+
102,103,104,105,101,
|
58 |
+
72,73,74,86,78,79,80,85,84,
|
59 |
+
35,41,42,39,37,36,
|
60 |
+
89,95,96,93,91,90,
|
61 |
+
52,64,63,71,67,68,61,58,59,53,56,55,65,66,62,70,69,57,60,54
|
62 |
+
]
|
63 |
+
|
64 |
+
pt68 = []
|
65 |
+
for i in range(68):
|
66 |
+
index = map106to68[i]
|
67 |
+
pt68.append(pt106[index])
|
68 |
+
return pt68
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
def check_bbox(self, imgs, boxes):
|
74 |
+
boxes = boxes.view(-1, 4, 4)
|
75 |
+
colors = [(0, 255, 0), (0, 255, 0), (255, 255, 0), (255, 0, 0)]
|
76 |
+
i = 0
|
77 |
+
for img, box in zip(imgs, boxes):
|
78 |
+
img = (img + 1)/2 * 255
|
79 |
+
img2 = img.permute(1, 2, 0).float().cpu().flip(2).numpy().copy()
|
80 |
+
for idx, point in enumerate(box):
|
81 |
+
cv2.rectangle(img2, (int(point[0]), int(point[1])), (int(point[2]), int(point[3])), color=colors[idx], thickness=2)
|
82 |
+
cv2.imwrite('dmdnet_{:02d}.png'.format(i), img2)
|
83 |
+
i += 1
|
84 |
+
|
85 |
+
|
86 |
+
def trans_points2d(self, pts, M):
|
87 |
+
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
|
88 |
+
for i in range(pts.shape[0]):
|
89 |
+
pt = pts[i]
|
90 |
+
new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32)
|
91 |
+
new_pt = np.dot(M, new_pt)
|
92 |
+
new_pts[i] = new_pt[0:2]
|
93 |
+
|
94 |
+
return new_pts
|
95 |
+
|
96 |
+
|
97 |
+
def enhance_face(self, ref_faceset: FaceSet, temp_frame, face: Face):
|
98 |
+
# preprocess
|
99 |
+
start_x, start_y, end_x, end_y = map(int, face['bbox'])
|
100 |
+
lm106 = face.landmark_2d_106
|
101 |
+
lq_landmarks = np.asarray(self.landmarks106_to_68(lm106))
|
102 |
+
|
103 |
+
if temp_frame.shape[0] != 512 or temp_frame.shape[1] != 512:
|
104 |
+
# scale to 512x512
|
105 |
+
scale_factor = 512 / temp_frame.shape[1]
|
106 |
+
|
107 |
+
M = face.matrix * scale_factor
|
108 |
+
|
109 |
+
lq_landmarks = self.trans_points2d(lq_landmarks, M)
|
110 |
+
temp_frame = cv2.resize(temp_frame, (512,512), interpolation = cv2.INTER_AREA)
|
111 |
+
|
112 |
+
if temp_frame.ndim == 2:
|
113 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) # GGG
|
114 |
+
# else:
|
115 |
+
# temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) # RGB
|
116 |
+
|
117 |
+
lq = read_img_tensor(temp_frame)
|
118 |
+
|
119 |
+
LQLocs = get_component_location(lq_landmarks)
|
120 |
+
# self.check_bbox(lq, LQLocs.unsqueeze(0))
|
121 |
+
|
122 |
+
# specific, change 1000 to 1 to activate
|
123 |
+
if len(ref_faceset.faces) > 1:
|
124 |
+
SpecificImgs = []
|
125 |
+
SpecificLocs = []
|
126 |
+
for i,face in enumerate(ref_faceset.faces):
|
127 |
+
lm106 = face.landmark_2d_106
|
128 |
+
lq_landmarks = np.asarray(self.landmarks106_to_68(lm106))
|
129 |
+
ref_image = ref_faceset.ref_images[i]
|
130 |
+
if ref_image.shape[0] != 512 or ref_image.shape[1] != 512:
|
131 |
+
# scale to 512x512
|
132 |
+
scale_factor = 512 / ref_image.shape[1]
|
133 |
+
|
134 |
+
M = face.matrix * scale_factor
|
135 |
+
|
136 |
+
lq_landmarks = self.trans_points2d(lq_landmarks, M)
|
137 |
+
ref_image = cv2.resize(ref_image, (512,512), interpolation = cv2.INTER_AREA)
|
138 |
+
|
139 |
+
if ref_image.ndim == 2:
|
140 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) # GGG
|
141 |
+
# else:
|
142 |
+
# temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) # RGB
|
143 |
+
|
144 |
+
ref_tensor = read_img_tensor(ref_image)
|
145 |
+
ref_locs = get_component_location(lq_landmarks)
|
146 |
+
# self.check_bbox(ref_tensor, ref_locs.unsqueeze(0))
|
147 |
+
|
148 |
+
SpecificImgs.append(ref_tensor)
|
149 |
+
SpecificLocs.append(ref_locs.unsqueeze(0))
|
150 |
+
|
151 |
+
SpecificImgs = torch.cat(SpecificImgs, dim=0)
|
152 |
+
SpecificLocs = torch.cat(SpecificLocs, dim=0)
|
153 |
+
# check_bbox(SpecificImgs, SpecificLocs)
|
154 |
+
SpMem256, SpMem128, SpMem64 = self.model_dmdnet.generate_specific_dictionary(sp_imgs = SpecificImgs.to(self.torchdevice), sp_locs = SpecificLocs)
|
155 |
+
SpMem256Para = {}
|
156 |
+
SpMem128Para = {}
|
157 |
+
SpMem64Para = {}
|
158 |
+
for k, v in SpMem256.items():
|
159 |
+
SpMem256Para[k] = v
|
160 |
+
for k, v in SpMem128.items():
|
161 |
+
SpMem128Para[k] = v
|
162 |
+
for k, v in SpMem64.items():
|
163 |
+
SpMem64Para[k] = v
|
164 |
+
else:
|
165 |
+
# generic
|
166 |
+
SpMem256Para, SpMem128Para, SpMem64Para = None, None, None
|
167 |
+
|
168 |
+
with torch.no_grad():
|
169 |
+
with THREAD_LOCK_DMDNET:
|
170 |
+
try:
|
171 |
+
GenericResult, SpecificResult = self.model_dmdnet(lq = lq.to(self.torchdevice), loc = LQLocs.unsqueeze(0), sp_256 = SpMem256Para, sp_128 = SpMem128Para, sp_64 = SpMem64Para)
|
172 |
+
except Exception as e:
|
173 |
+
print(f'Error {e} there may be something wrong with the detected component locations.')
|
174 |
+
return temp_frame
|
175 |
+
|
176 |
+
if SpecificResult is not None:
|
177 |
+
save_specific = SpecificResult * 0.5 + 0.5
|
178 |
+
save_specific = save_specific.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
|
179 |
+
save_specific = np.clip(save_specific.float().cpu().numpy(), 0, 1) * 255.0
|
180 |
+
temp_frame = save_specific.astype("uint8")
|
181 |
+
if False:
|
182 |
+
save_generic = GenericResult * 0.5 + 0.5
|
183 |
+
save_generic = save_generic.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
|
184 |
+
save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0
|
185 |
+
check_lq = lq * 0.5 + 0.5
|
186 |
+
check_lq = check_lq.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
|
187 |
+
check_lq = np.clip(check_lq.float().cpu().numpy(), 0, 1) * 255.0
|
188 |
+
cv2.imwrite('dmdnet_comparison.png', cv2.cvtColor(np.hstack((check_lq, save_generic, save_specific)),cv2.COLOR_RGB2BGR))
|
189 |
+
else:
|
190 |
+
save_generic = GenericResult * 0.5 + 0.5
|
191 |
+
save_generic = save_generic.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
|
192 |
+
save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0
|
193 |
+
temp_frame = save_generic.astype("uint8")
|
194 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_RGB2BGR) # RGB
|
195 |
+
return temp_frame
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
def create(self, devicename):
|
200 |
+
self.torchdevice = torch.device(devicename)
|
201 |
+
model_dmdnet = DMDNet().to(self.torchdevice)
|
202 |
+
weights = torch.load('./models/DMDNet.pth')
|
203 |
+
model_dmdnet.load_state_dict(weights, strict=True)
|
204 |
+
|
205 |
+
model_dmdnet.eval()
|
206 |
+
num_params = 0
|
207 |
+
for param in model_dmdnet.parameters():
|
208 |
+
num_params += param.numel()
|
209 |
+
return model_dmdnet
|
210 |
+
|
211 |
+
# print('{:>8s} : {}'.format('Using device', device))
|
212 |
+
# print('{:>8s} : {:.2f}M'.format('Model params', num_params/1e6))
|
213 |
+
|
214 |
+
|
215 |
+
|
216 |
+
def read_img_tensor(Img=None): #rgb -1~1
|
217 |
+
Img = Img.transpose((2, 0, 1))/255.0
|
218 |
+
Img = torch.from_numpy(Img).float()
|
219 |
+
normalize(Img, [0.5,0.5,0.5], [0.5,0.5,0.5], inplace=True)
|
220 |
+
ImgTensor = Img.unsqueeze(0)
|
221 |
+
return ImgTensor
|
222 |
+
|
223 |
+
|
224 |
+
def get_component_location(Landmarks, re_read=False):
|
225 |
+
if re_read:
|
226 |
+
ReadLandmark = []
|
227 |
+
with open(Landmarks,'r') as f:
|
228 |
+
for line in f:
|
229 |
+
tmp = [float(i) for i in line.split(' ') if i != '\n']
|
230 |
+
ReadLandmark.append(tmp)
|
231 |
+
ReadLandmark = np.array(ReadLandmark) #
|
232 |
+
Landmarks = np.reshape(ReadLandmark, [-1, 2]) # 68*2
|
233 |
+
Map_LE_B = list(np.hstack((range(17,22), range(36,42))))
|
234 |
+
Map_RE_B = list(np.hstack((range(22,27), range(42,48))))
|
235 |
+
Map_LE = list(range(36,42))
|
236 |
+
Map_RE = list(range(42,48))
|
237 |
+
Map_NO = list(range(29,36))
|
238 |
+
Map_MO = list(range(48,68))
|
239 |
+
|
240 |
+
Landmarks[Landmarks>504]=504
|
241 |
+
Landmarks[Landmarks<8]=8
|
242 |
+
|
243 |
+
#left eye
|
244 |
+
Mean_LE = np.mean(Landmarks[Map_LE],0)
|
245 |
+
L_LE1 = Mean_LE[1] - np.min(Landmarks[Map_LE_B,1])
|
246 |
+
L_LE1 = L_LE1 * 1.3
|
247 |
+
L_LE2 = L_LE1 / 1.9
|
248 |
+
L_LE_xy = L_LE1 + L_LE2
|
249 |
+
L_LE_lt = [L_LE_xy/2, L_LE1]
|
250 |
+
L_LE_rb = [L_LE_xy/2, L_LE2]
|
251 |
+
Location_LE = np.hstack((Mean_LE - L_LE_lt + 1, Mean_LE + L_LE_rb)).astype(int)
|
252 |
+
|
253 |
+
#right eye
|
254 |
+
Mean_RE = np.mean(Landmarks[Map_RE],0)
|
255 |
+
L_RE1 = Mean_RE[1] - np.min(Landmarks[Map_RE_B,1])
|
256 |
+
L_RE1 = L_RE1 * 1.3
|
257 |
+
L_RE2 = L_RE1 / 1.9
|
258 |
+
L_RE_xy = L_RE1 + L_RE2
|
259 |
+
L_RE_lt = [L_RE_xy/2, L_RE1]
|
260 |
+
L_RE_rb = [L_RE_xy/2, L_RE2]
|
261 |
+
Location_RE = np.hstack((Mean_RE - L_RE_lt + 1, Mean_RE + L_RE_rb)).astype(int)
|
262 |
+
|
263 |
+
#nose
|
264 |
+
Mean_NO = np.mean(Landmarks[Map_NO],0)
|
265 |
+
L_NO1 =( np.max([Mean_NO[0] - Landmarks[31][0], Landmarks[35][0] - Mean_NO[0]])) * 1.25
|
266 |
+
L_NO2 = (Landmarks[33][1] - Mean_NO[1]) * 1.1
|
267 |
+
L_NO_xy = L_NO1 * 2
|
268 |
+
L_NO_lt = [L_NO_xy/2, L_NO_xy - L_NO2]
|
269 |
+
L_NO_rb = [L_NO_xy/2, L_NO2]
|
270 |
+
Location_NO = np.hstack((Mean_NO - L_NO_lt + 1, Mean_NO + L_NO_rb)).astype(int)
|
271 |
+
|
272 |
+
#mouth
|
273 |
+
Mean_MO = np.mean(Landmarks[Map_MO],0)
|
274 |
+
L_MO = np.max((np.max(np.max(Landmarks[Map_MO],0) - np.min(Landmarks[Map_MO],0))/2,16)) * 1.1
|
275 |
+
MO_O = Mean_MO - L_MO + 1
|
276 |
+
MO_T = Mean_MO + L_MO
|
277 |
+
MO_T[MO_T>510]=510
|
278 |
+
Location_MO = np.hstack((MO_O, MO_T)).astype(int)
|
279 |
+
return torch.cat([torch.FloatTensor(Location_LE).unsqueeze(0), torch.FloatTensor(Location_RE).unsqueeze(0), torch.FloatTensor(Location_NO).unsqueeze(0), torch.FloatTensor(Location_MO).unsqueeze(0)], dim=0)
|
280 |
+
|
281 |
+
|
282 |
+
|
283 |
+
|
284 |
+
def calc_mean_std_4D(feat, eps=1e-5):
|
285 |
+
# eps is a small value added to the variance to avoid divide-by-zero.
|
286 |
+
size = feat.size()
|
287 |
+
assert (len(size) == 4)
|
288 |
+
N, C = size[:2]
|
289 |
+
feat_var = feat.view(N, C, -1).var(dim=2) + eps
|
290 |
+
feat_std = feat_var.sqrt().view(N, C, 1, 1)
|
291 |
+
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
292 |
+
return feat_mean, feat_std
|
293 |
+
|
294 |
+
def adaptive_instance_normalization_4D(content_feat, style_feat): # content_feat is ref feature, style is degradate feature
|
295 |
+
size = content_feat.size()
|
296 |
+
style_mean, style_std = calc_mean_std_4D(style_feat)
|
297 |
+
content_mean, content_std = calc_mean_std_4D(content_feat)
|
298 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
299 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
300 |
+
|
301 |
+
|
302 |
+
def convU(in_channels, out_channels,conv_layer, norm_layer, kernel_size=3, stride=1,dilation=1, bias=True):
|
303 |
+
return nn.Sequential(
|
304 |
+
SpectralNorm(conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
|
305 |
+
nn.LeakyReLU(0.2),
|
306 |
+
SpectralNorm(conv_layer(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
|
307 |
+
)
|
308 |
+
|
309 |
+
|
310 |
+
class MSDilateBlock(nn.Module):
|
311 |
+
def __init__(self, in_channels,conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, kernel_size=3, dilation=[1,1,1,1], bias=True):
|
312 |
+
super(MSDilateBlock, self).__init__()
|
313 |
+
self.conv1 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[0], bias=bias)
|
314 |
+
self.conv2 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[1], bias=bias)
|
315 |
+
self.conv3 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[2], bias=bias)
|
316 |
+
self.conv4 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[3], bias=bias)
|
317 |
+
self.convi = SpectralNorm(conv_layer(in_channels*4, in_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2, bias=bias))
|
318 |
+
def forward(self, x):
|
319 |
+
conv1 = self.conv1(x)
|
320 |
+
conv2 = self.conv2(x)
|
321 |
+
conv3 = self.conv3(x)
|
322 |
+
conv4 = self.conv4(x)
|
323 |
+
cat = torch.cat([conv1, conv2, conv3, conv4], 1)
|
324 |
+
out = self.convi(cat) + x
|
325 |
+
return out
|
326 |
+
|
327 |
+
|
328 |
+
class AdaptiveInstanceNorm(nn.Module):
|
329 |
+
def __init__(self, in_channel):
|
330 |
+
super().__init__()
|
331 |
+
self.norm = nn.InstanceNorm2d(in_channel)
|
332 |
+
|
333 |
+
def forward(self, input, style):
|
334 |
+
style_mean, style_std = calc_mean_std_4D(style)
|
335 |
+
out = self.norm(input)
|
336 |
+
size = input.size()
|
337 |
+
out = style_std.expand(size) * out + style_mean.expand(size)
|
338 |
+
return out
|
339 |
+
|
340 |
+
class NoiseInjection(nn.Module):
|
341 |
+
def __init__(self, channel):
|
342 |
+
super().__init__()
|
343 |
+
self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
344 |
+
def forward(self, image, noise):
|
345 |
+
if noise is None:
|
346 |
+
b, c, h, w = image.shape
|
347 |
+
noise = image.new_empty(b, 1, h, w).normal_()
|
348 |
+
return image + self.weight * noise
|
349 |
+
|
350 |
+
class StyledUpBlock(nn.Module):
|
351 |
+
def __init__(self, in_channel, out_channel, kernel_size=3, padding=1,upsample=False, noise_inject=False):
|
352 |
+
super().__init__()
|
353 |
+
|
354 |
+
self.noise_inject = noise_inject
|
355 |
+
if upsample:
|
356 |
+
self.conv1 = nn.Sequential(
|
357 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
358 |
+
SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
|
359 |
+
nn.LeakyReLU(0.2),
|
360 |
+
)
|
361 |
+
else:
|
362 |
+
self.conv1 = nn.Sequential(
|
363 |
+
SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
|
364 |
+
nn.LeakyReLU(0.2),
|
365 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
|
366 |
+
)
|
367 |
+
self.convup = nn.Sequential(
|
368 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
369 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
|
370 |
+
nn.LeakyReLU(0.2),
|
371 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
|
372 |
+
)
|
373 |
+
if self.noise_inject:
|
374 |
+
self.noise1 = NoiseInjection(out_channel)
|
375 |
+
|
376 |
+
self.lrelu1 = nn.LeakyReLU(0.2)
|
377 |
+
|
378 |
+
self.ScaleModel1 = nn.Sequential(
|
379 |
+
SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
|
380 |
+
nn.LeakyReLU(0.2),
|
381 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1))
|
382 |
+
)
|
383 |
+
self.ShiftModel1 = nn.Sequential(
|
384 |
+
SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
|
385 |
+
nn.LeakyReLU(0.2),
|
386 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)),
|
387 |
+
)
|
388 |
+
|
389 |
+
def forward(self, input, style):
|
390 |
+
out = self.conv1(input)
|
391 |
+
out = self.lrelu1(out)
|
392 |
+
Shift1 = self.ShiftModel1(style)
|
393 |
+
Scale1 = self.ScaleModel1(style)
|
394 |
+
out = out * Scale1 + Shift1
|
395 |
+
if self.noise_inject:
|
396 |
+
out = self.noise1(out, noise=None)
|
397 |
+
outup = self.convup(out)
|
398 |
+
return outup
|
399 |
+
|
400 |
+
|
401 |
+
####################################################################
|
402 |
+
###############Face Dictionary Generator
|
403 |
+
####################################################################
|
404 |
+
def AttentionBlock(in_channel):
|
405 |
+
return nn.Sequential(
|
406 |
+
SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
|
407 |
+
nn.LeakyReLU(0.2),
|
408 |
+
SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
|
409 |
+
)
|
410 |
+
|
411 |
+
class DilateResBlock(nn.Module):
|
412 |
+
def __init__(self, dim, dilation=[5,3] ):
|
413 |
+
super(DilateResBlock, self).__init__()
|
414 |
+
self.Res = nn.Sequential(
|
415 |
+
SpectralNorm(nn.Conv2d(dim, dim, 3, 1, ((3-1)//2)*dilation[0], dilation[0])),
|
416 |
+
nn.LeakyReLU(0.2),
|
417 |
+
SpectralNorm(nn.Conv2d(dim, dim, 3, 1, ((3-1)//2)*dilation[1], dilation[1])),
|
418 |
+
)
|
419 |
+
def forward(self, x):
|
420 |
+
out = x + self.Res(x)
|
421 |
+
return out
|
422 |
+
|
423 |
+
|
424 |
+
class KeyValue(nn.Module):
|
425 |
+
def __init__(self, indim, keydim, valdim):
|
426 |
+
super(KeyValue, self).__init__()
|
427 |
+
self.Key = nn.Sequential(
|
428 |
+
SpectralNorm(nn.Conv2d(indim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
429 |
+
nn.LeakyReLU(0.2),
|
430 |
+
SpectralNorm(nn.Conv2d(keydim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
431 |
+
)
|
432 |
+
self.Value = nn.Sequential(
|
433 |
+
SpectralNorm(nn.Conv2d(indim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
434 |
+
nn.LeakyReLU(0.2),
|
435 |
+
SpectralNorm(nn.Conv2d(valdim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
436 |
+
)
|
437 |
+
def forward(self, x):
|
438 |
+
return self.Key(x), self.Value(x)
|
439 |
+
|
440 |
+
class MaskAttention(nn.Module):
|
441 |
+
def __init__(self, indim):
|
442 |
+
super(MaskAttention, self).__init__()
|
443 |
+
self.conv1 = nn.Sequential(
|
444 |
+
SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
445 |
+
nn.LeakyReLU(0.2),
|
446 |
+
SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
447 |
+
)
|
448 |
+
self.conv2 = nn.Sequential(
|
449 |
+
SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
450 |
+
nn.LeakyReLU(0.2),
|
451 |
+
SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
452 |
+
)
|
453 |
+
self.conv3 = nn.Sequential(
|
454 |
+
SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
455 |
+
nn.LeakyReLU(0.2),
|
456 |
+
SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
457 |
+
)
|
458 |
+
self.convCat = nn.Sequential(
|
459 |
+
SpectralNorm(nn.Conv2d(indim//3 * 3, indim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
460 |
+
nn.LeakyReLU(0.2),
|
461 |
+
SpectralNorm(nn.Conv2d(indim, indim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
462 |
+
)
|
463 |
+
def forward(self, x, y, z):
|
464 |
+
c1 = self.conv1(x)
|
465 |
+
c2 = self.conv2(y)
|
466 |
+
c3 = self.conv3(z)
|
467 |
+
return self.convCat(torch.cat([c1,c2,c3], dim=1))
|
468 |
+
|
469 |
+
class Query(nn.Module):
|
470 |
+
def __init__(self, indim, quedim):
|
471 |
+
super(Query, self).__init__()
|
472 |
+
self.Query = nn.Sequential(
|
473 |
+
SpectralNorm(nn.Conv2d(indim, quedim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
474 |
+
nn.LeakyReLU(0.2),
|
475 |
+
SpectralNorm(nn.Conv2d(quedim, quedim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
476 |
+
)
|
477 |
+
def forward(self, x):
|
478 |
+
return self.Query(x)
|
479 |
+
|
480 |
+
def roi_align_self(input, location, target_size):
|
481 |
+
test = (target_size.item(),target_size.item())
|
482 |
+
return torch.cat([F.interpolate(input[i:i+1,:,location[i,1]:location[i,3],location[i,0]:location[i,2]],test,mode='bilinear',align_corners=False) for i in range(input.size(0))],0)
|
483 |
+
|
484 |
+
class FeatureExtractor(nn.Module):
|
485 |
+
def __init__(self, ngf = 64, key_scale = 4):#
|
486 |
+
super().__init__()
|
487 |
+
|
488 |
+
self.key_scale = 4
|
489 |
+
self.part_sizes = np.array([80,80,50,110]) #
|
490 |
+
self.feature_sizes = np.array([256,128,64]) #
|
491 |
+
|
492 |
+
self.conv1 = nn.Sequential(
|
493 |
+
SpectralNorm(nn.Conv2d(3, ngf, 3, 2, 1)),
|
494 |
+
nn.LeakyReLU(0.2),
|
495 |
+
SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
|
496 |
+
)
|
497 |
+
self.conv2 = nn.Sequential(
|
498 |
+
SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
|
499 |
+
nn.LeakyReLU(0.2),
|
500 |
+
SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1))
|
501 |
+
)
|
502 |
+
self.res1 = DilateResBlock(ngf, [5,3])
|
503 |
+
self.res2 = DilateResBlock(ngf, [5,3])
|
504 |
+
|
505 |
+
|
506 |
+
self.conv3 = nn.Sequential(
|
507 |
+
SpectralNorm(nn.Conv2d(ngf, ngf*2, 3, 2, 1)),
|
508 |
+
nn.LeakyReLU(0.2),
|
509 |
+
SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)),
|
510 |
+
)
|
511 |
+
self.conv4 = nn.Sequential(
|
512 |
+
SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)),
|
513 |
+
nn.LeakyReLU(0.2),
|
514 |
+
SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1))
|
515 |
+
)
|
516 |
+
self.res3 = DilateResBlock(ngf*2, [3,1])
|
517 |
+
self.res4 = DilateResBlock(ngf*2, [3,1])
|
518 |
+
|
519 |
+
self.conv5 = nn.Sequential(
|
520 |
+
SpectralNorm(nn.Conv2d(ngf*2, ngf*4, 3, 2, 1)),
|
521 |
+
nn.LeakyReLU(0.2),
|
522 |
+
SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)),
|
523 |
+
)
|
524 |
+
self.conv6 = nn.Sequential(
|
525 |
+
SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)),
|
526 |
+
nn.LeakyReLU(0.2),
|
527 |
+
SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1))
|
528 |
+
)
|
529 |
+
self.res5 = DilateResBlock(ngf*4, [1,1])
|
530 |
+
self.res6 = DilateResBlock(ngf*4, [1,1])
|
531 |
+
|
532 |
+
self.LE_256_Q = Query(ngf, ngf // self.key_scale)
|
533 |
+
self.RE_256_Q = Query(ngf, ngf // self.key_scale)
|
534 |
+
self.MO_256_Q = Query(ngf, ngf // self.key_scale)
|
535 |
+
self.LE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
|
536 |
+
self.RE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
|
537 |
+
self.MO_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
|
538 |
+
self.LE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
|
539 |
+
self.RE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
|
540 |
+
self.MO_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
|
541 |
+
|
542 |
+
|
543 |
+
def forward(self, img, locs):
|
544 |
+
le_location = locs[:,0,:].int().cpu().numpy()
|
545 |
+
re_location = locs[:,1,:].int().cpu().numpy()
|
546 |
+
no_location = locs[:,2,:].int().cpu().numpy()
|
547 |
+
mo_location = locs[:,3,:].int().cpu().numpy()
|
548 |
+
|
549 |
+
|
550 |
+
f1_0 = self.conv1(img)
|
551 |
+
f1_1 = self.res1(f1_0)
|
552 |
+
f2_0 = self.conv2(f1_1)
|
553 |
+
f2_1 = self.res2(f2_0)
|
554 |
+
|
555 |
+
f3_0 = self.conv3(f2_1)
|
556 |
+
f3_1 = self.res3(f3_0)
|
557 |
+
f4_0 = self.conv4(f3_1)
|
558 |
+
f4_1 = self.res4(f4_0)
|
559 |
+
|
560 |
+
f5_0 = self.conv5(f4_1)
|
561 |
+
f5_1 = self.res5(f5_0)
|
562 |
+
f6_0 = self.conv6(f5_1)
|
563 |
+
f6_1 = self.res6(f6_0)
|
564 |
+
|
565 |
+
|
566 |
+
####ROI Align
|
567 |
+
le_part_256 = roi_align_self(f2_1.clone(), le_location//2, self.part_sizes[0]//2)
|
568 |
+
re_part_256 = roi_align_self(f2_1.clone(), re_location//2, self.part_sizes[1]//2)
|
569 |
+
mo_part_256 = roi_align_self(f2_1.clone(), mo_location//2, self.part_sizes[3]//2)
|
570 |
+
|
571 |
+
le_part_128 = roi_align_self(f4_1.clone(), le_location//4, self.part_sizes[0]//4)
|
572 |
+
re_part_128 = roi_align_self(f4_1.clone(), re_location//4, self.part_sizes[1]//4)
|
573 |
+
mo_part_128 = roi_align_self(f4_1.clone(), mo_location//4, self.part_sizes[3]//4)
|
574 |
+
|
575 |
+
le_part_64 = roi_align_self(f6_1.clone(), le_location//8, self.part_sizes[0]//8)
|
576 |
+
re_part_64 = roi_align_self(f6_1.clone(), re_location//8, self.part_sizes[1]//8)
|
577 |
+
mo_part_64 = roi_align_self(f6_1.clone(), mo_location//8, self.part_sizes[3]//8)
|
578 |
+
|
579 |
+
|
580 |
+
le_256_q = self.LE_256_Q(le_part_256)
|
581 |
+
re_256_q = self.RE_256_Q(re_part_256)
|
582 |
+
mo_256_q = self.MO_256_Q(mo_part_256)
|
583 |
+
|
584 |
+
le_128_q = self.LE_128_Q(le_part_128)
|
585 |
+
re_128_q = self.RE_128_Q(re_part_128)
|
586 |
+
mo_128_q = self.MO_128_Q(mo_part_128)
|
587 |
+
|
588 |
+
le_64_q = self.LE_64_Q(le_part_64)
|
589 |
+
re_64_q = self.RE_64_Q(re_part_64)
|
590 |
+
mo_64_q = self.MO_64_Q(mo_part_64)
|
591 |
+
|
592 |
+
return {'f256': f2_1, 'f128': f4_1, 'f64': f6_1,\
|
593 |
+
'le256': le_part_256, 're256': re_part_256, 'mo256': mo_part_256, \
|
594 |
+
'le128': le_part_128, 're128': re_part_128, 'mo128': mo_part_128, \
|
595 |
+
'le64': le_part_64, 're64': re_part_64, 'mo64': mo_part_64, \
|
596 |
+
'le_256_q': le_256_q, 're_256_q': re_256_q, 'mo_256_q': mo_256_q,\
|
597 |
+
'le_128_q': le_128_q, 're_128_q': re_128_q, 'mo_128_q': mo_128_q,\
|
598 |
+
'le_64_q': le_64_q, 're_64_q': re_64_q, 'mo_64_q': mo_64_q}
|
599 |
+
|
600 |
+
|
601 |
+
class DMDNet(nn.Module):
|
602 |
+
def __init__(self, ngf = 64, banks_num = 128):
|
603 |
+
super().__init__()
|
604 |
+
self.part_sizes = np.array([80,80,50,110]) # size for 512
|
605 |
+
self.feature_sizes = np.array([256,128,64]) # size for 512
|
606 |
+
|
607 |
+
self.banks_num = banks_num
|
608 |
+
self.key_scale = 4
|
609 |
+
|
610 |
+
self.E_lq = FeatureExtractor(key_scale = self.key_scale)
|
611 |
+
self.E_hq = FeatureExtractor(key_scale = self.key_scale)
|
612 |
+
|
613 |
+
self.LE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
|
614 |
+
self.RE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
|
615 |
+
self.MO_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
|
616 |
+
|
617 |
+
self.LE_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2)
|
618 |
+
self.RE_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2)
|
619 |
+
self.MO_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2)
|
620 |
+
|
621 |
+
self.LE_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4)
|
622 |
+
self.RE_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4)
|
623 |
+
self.MO_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4)
|
624 |
+
|
625 |
+
|
626 |
+
self.LE_256_Attention = AttentionBlock(64)
|
627 |
+
self.RE_256_Attention = AttentionBlock(64)
|
628 |
+
self.MO_256_Attention = AttentionBlock(64)
|
629 |
+
|
630 |
+
self.LE_128_Attention = AttentionBlock(128)
|
631 |
+
self.RE_128_Attention = AttentionBlock(128)
|
632 |
+
self.MO_128_Attention = AttentionBlock(128)
|
633 |
+
|
634 |
+
self.LE_64_Attention = AttentionBlock(256)
|
635 |
+
self.RE_64_Attention = AttentionBlock(256)
|
636 |
+
self.MO_64_Attention = AttentionBlock(256)
|
637 |
+
|
638 |
+
self.LE_256_Mask = MaskAttention(64)
|
639 |
+
self.RE_256_Mask = MaskAttention(64)
|
640 |
+
self.MO_256_Mask = MaskAttention(64)
|
641 |
+
|
642 |
+
self.LE_128_Mask = MaskAttention(128)
|
643 |
+
self.RE_128_Mask = MaskAttention(128)
|
644 |
+
self.MO_128_Mask = MaskAttention(128)
|
645 |
+
|
646 |
+
self.LE_64_Mask = MaskAttention(256)
|
647 |
+
self.RE_64_Mask = MaskAttention(256)
|
648 |
+
self.MO_64_Mask = MaskAttention(256)
|
649 |
+
|
650 |
+
self.MSDilate = MSDilateBlock(ngf*4, dilation = [4,3,2,1])
|
651 |
+
|
652 |
+
self.up1 = StyledUpBlock(ngf*4, ngf*2, noise_inject=False) #
|
653 |
+
self.up2 = StyledUpBlock(ngf*2, ngf, noise_inject=False) #
|
654 |
+
self.up3 = StyledUpBlock(ngf, ngf, noise_inject=False) #
|
655 |
+
self.up4 = nn.Sequential(
|
656 |
+
SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
|
657 |
+
nn.LeakyReLU(0.2),
|
658 |
+
UpResBlock(ngf),
|
659 |
+
UpResBlock(ngf),
|
660 |
+
SpectralNorm(nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)),
|
661 |
+
nn.Tanh()
|
662 |
+
)
|
663 |
+
|
664 |
+
# define generic memory, revise register_buffer to register_parameter for backward update
|
665 |
+
self.register_buffer('le_256_mem_key', torch.randn(128,16,40,40))
|
666 |
+
self.register_buffer('re_256_mem_key', torch.randn(128,16,40,40))
|
667 |
+
self.register_buffer('mo_256_mem_key', torch.randn(128,16,55,55))
|
668 |
+
self.register_buffer('le_256_mem_value', torch.randn(128,64,40,40))
|
669 |
+
self.register_buffer('re_256_mem_value', torch.randn(128,64,40,40))
|
670 |
+
self.register_buffer('mo_256_mem_value', torch.randn(128,64,55,55))
|
671 |
+
|
672 |
+
|
673 |
+
self.register_buffer('le_128_mem_key', torch.randn(128,32,20,20))
|
674 |
+
self.register_buffer('re_128_mem_key', torch.randn(128,32,20,20))
|
675 |
+
self.register_buffer('mo_128_mem_key', torch.randn(128,32,27,27))
|
676 |
+
self.register_buffer('le_128_mem_value', torch.randn(128,128,20,20))
|
677 |
+
self.register_buffer('re_128_mem_value', torch.randn(128,128,20,20))
|
678 |
+
self.register_buffer('mo_128_mem_value', torch.randn(128,128,27,27))
|
679 |
+
|
680 |
+
self.register_buffer('le_64_mem_key', torch.randn(128,64,10,10))
|
681 |
+
self.register_buffer('re_64_mem_key', torch.randn(128,64,10,10))
|
682 |
+
self.register_buffer('mo_64_mem_key', torch.randn(128,64,13,13))
|
683 |
+
self.register_buffer('le_64_mem_value', torch.randn(128,256,10,10))
|
684 |
+
self.register_buffer('re_64_mem_value', torch.randn(128,256,10,10))
|
685 |
+
self.register_buffer('mo_64_mem_value', torch.randn(128,256,13,13))
|
686 |
+
|
687 |
+
|
688 |
+
def readMem(self, k, v, q):
|
689 |
+
sim = F.conv2d(q, k)
|
690 |
+
score = F.softmax(sim/sqrt(sim.size(1)), dim=1) #B * S * 1 * 1 6*128
|
691 |
+
sb,sn,sw,sh = score.size()
|
692 |
+
s_m = score.view(sb, -1).unsqueeze(1)#2*1*M
|
693 |
+
vb,vn,vw,vh = v.size()
|
694 |
+
v_in = v.view(vb, -1).repeat(sb,1,1)#2*M*(c*w*h)
|
695 |
+
mem_out = torch.bmm(s_m, v_in).squeeze(1).view(sb, vn, vw,vh)
|
696 |
+
max_inds = torch.argmax(score, dim=1).squeeze()
|
697 |
+
return mem_out, max_inds
|
698 |
+
|
699 |
+
|
700 |
+
def memorize(self, img, locs):
|
701 |
+
fs = self.E_hq(img, locs)
|
702 |
+
LE256_key, LE256_value = self.LE_256_KV(fs['le256'])
|
703 |
+
RE256_key, RE256_value = self.RE_256_KV(fs['re256'])
|
704 |
+
MO256_key, MO256_value = self.MO_256_KV(fs['mo256'])
|
705 |
+
|
706 |
+
LE128_key, LE128_value = self.LE_128_KV(fs['le128'])
|
707 |
+
RE128_key, RE128_value = self.RE_128_KV(fs['re128'])
|
708 |
+
MO128_key, MO128_value = self.MO_128_KV(fs['mo128'])
|
709 |
+
|
710 |
+
LE64_key, LE64_value = self.LE_64_KV(fs['le64'])
|
711 |
+
RE64_key, RE64_value = self.RE_64_KV(fs['re64'])
|
712 |
+
MO64_key, MO64_value = self.MO_64_KV(fs['mo64'])
|
713 |
+
|
714 |
+
Mem256 = {'LE256Key': LE256_key, 'LE256Value': LE256_value, 'RE256Key': RE256_key, 'RE256Value': RE256_value,'MO256Key': MO256_key, 'MO256Value': MO256_value}
|
715 |
+
Mem128 = {'LE128Key': LE128_key, 'LE128Value': LE128_value, 'RE128Key': RE128_key, 'RE128Value': RE128_value,'MO128Key': MO128_key, 'MO128Value': MO128_value}
|
716 |
+
Mem64 = {'LE64Key': LE64_key, 'LE64Value': LE64_value, 'RE64Key': RE64_key, 'RE64Value': RE64_value,'MO64Key': MO64_key, 'MO64Value': MO64_value}
|
717 |
+
|
718 |
+
FS256 = {'LE256F':fs['le256'], 'RE256F':fs['re256'], 'MO256F':fs['mo256']}
|
719 |
+
FS128 = {'LE128F':fs['le128'], 'RE128F':fs['re128'], 'MO128F':fs['mo128']}
|
720 |
+
FS64 = {'LE64F':fs['le64'], 'RE64F':fs['re64'], 'MO64F':fs['mo64']}
|
721 |
+
|
722 |
+
return Mem256, Mem128, Mem64
|
723 |
+
|
724 |
+
def enhancer(self, fs_in, sp_256=None, sp_128=None, sp_64=None):
|
725 |
+
le_256_q = fs_in['le_256_q']
|
726 |
+
re_256_q = fs_in['re_256_q']
|
727 |
+
mo_256_q = fs_in['mo_256_q']
|
728 |
+
|
729 |
+
le_128_q = fs_in['le_128_q']
|
730 |
+
re_128_q = fs_in['re_128_q']
|
731 |
+
mo_128_q = fs_in['mo_128_q']
|
732 |
+
|
733 |
+
le_64_q = fs_in['le_64_q']
|
734 |
+
re_64_q = fs_in['re_64_q']
|
735 |
+
mo_64_q = fs_in['mo_64_q']
|
736 |
+
|
737 |
+
|
738 |
+
####for 256
|
739 |
+
le_256_mem_g, le_256_inds = self.readMem(self.le_256_mem_key, self.le_256_mem_value, le_256_q)
|
740 |
+
re_256_mem_g, re_256_inds = self.readMem(self.re_256_mem_key, self.re_256_mem_value, re_256_q)
|
741 |
+
mo_256_mem_g, mo_256_inds = self.readMem(self.mo_256_mem_key, self.mo_256_mem_value, mo_256_q)
|
742 |
+
|
743 |
+
le_128_mem_g, le_128_inds = self.readMem(self.le_128_mem_key, self.le_128_mem_value, le_128_q)
|
744 |
+
re_128_mem_g, re_128_inds = self.readMem(self.re_128_mem_key, self.re_128_mem_value, re_128_q)
|
745 |
+
mo_128_mem_g, mo_128_inds = self.readMem(self.mo_128_mem_key, self.mo_128_mem_value, mo_128_q)
|
746 |
+
|
747 |
+
le_64_mem_g, le_64_inds = self.readMem(self.le_64_mem_key, self.le_64_mem_value, le_64_q)
|
748 |
+
re_64_mem_g, re_64_inds = self.readMem(self.re_64_mem_key, self.re_64_mem_value, re_64_q)
|
749 |
+
mo_64_mem_g, mo_64_inds = self.readMem(self.mo_64_mem_key, self.mo_64_mem_value, mo_64_q)
|
750 |
+
|
751 |
+
if sp_256 is not None and sp_128 is not None and sp_64 is not None:
|
752 |
+
le_256_mem_s, _ = self.readMem(sp_256['LE256Key'], sp_256['LE256Value'], le_256_q)
|
753 |
+
re_256_mem_s, _ = self.readMem(sp_256['RE256Key'], sp_256['RE256Value'], re_256_q)
|
754 |
+
mo_256_mem_s, _ = self.readMem(sp_256['MO256Key'], sp_256['MO256Value'], mo_256_q)
|
755 |
+
le_256_mask = self.LE_256_Mask(fs_in['le256'],le_256_mem_s,le_256_mem_g)
|
756 |
+
le_256_mem = le_256_mask*le_256_mem_s + (1-le_256_mask)*le_256_mem_g
|
757 |
+
re_256_mask = self.RE_256_Mask(fs_in['re256'],re_256_mem_s,re_256_mem_g)
|
758 |
+
re_256_mem = re_256_mask*re_256_mem_s + (1-re_256_mask)*re_256_mem_g
|
759 |
+
mo_256_mask = self.MO_256_Mask(fs_in['mo256'],mo_256_mem_s,mo_256_mem_g)
|
760 |
+
mo_256_mem = mo_256_mask*mo_256_mem_s + (1-mo_256_mask)*mo_256_mem_g
|
761 |
+
|
762 |
+
le_128_mem_s, _ = self.readMem(sp_128['LE128Key'], sp_128['LE128Value'], le_128_q)
|
763 |
+
re_128_mem_s, _ = self.readMem(sp_128['RE128Key'], sp_128['RE128Value'], re_128_q)
|
764 |
+
mo_128_mem_s, _ = self.readMem(sp_128['MO128Key'], sp_128['MO128Value'], mo_128_q)
|
765 |
+
le_128_mask = self.LE_128_Mask(fs_in['le128'],le_128_mem_s,le_128_mem_g)
|
766 |
+
le_128_mem = le_128_mask*le_128_mem_s + (1-le_128_mask)*le_128_mem_g
|
767 |
+
re_128_mask = self.RE_128_Mask(fs_in['re128'],re_128_mem_s,re_128_mem_g)
|
768 |
+
re_128_mem = re_128_mask*re_128_mem_s + (1-re_128_mask)*re_128_mem_g
|
769 |
+
mo_128_mask = self.MO_128_Mask(fs_in['mo128'],mo_128_mem_s,mo_128_mem_g)
|
770 |
+
mo_128_mem = mo_128_mask*mo_128_mem_s + (1-mo_128_mask)*mo_128_mem_g
|
771 |
+
|
772 |
+
le_64_mem_s, _ = self.readMem(sp_64['LE64Key'], sp_64['LE64Value'], le_64_q)
|
773 |
+
re_64_mem_s, _ = self.readMem(sp_64['RE64Key'], sp_64['RE64Value'], re_64_q)
|
774 |
+
mo_64_mem_s, _ = self.readMem(sp_64['MO64Key'], sp_64['MO64Value'], mo_64_q)
|
775 |
+
le_64_mask = self.LE_64_Mask(fs_in['le64'],le_64_mem_s,le_64_mem_g)
|
776 |
+
le_64_mem = le_64_mask*le_64_mem_s + (1-le_64_mask)*le_64_mem_g
|
777 |
+
re_64_mask = self.RE_64_Mask(fs_in['re64'],re_64_mem_s,re_64_mem_g)
|
778 |
+
re_64_mem = re_64_mask*re_64_mem_s + (1-re_64_mask)*re_64_mem_g
|
779 |
+
mo_64_mask = self.MO_64_Mask(fs_in['mo64'],mo_64_mem_s,mo_64_mem_g)
|
780 |
+
mo_64_mem = mo_64_mask*mo_64_mem_s + (1-mo_64_mask)*mo_64_mem_g
|
781 |
+
else:
|
782 |
+
le_256_mem = le_256_mem_g
|
783 |
+
re_256_mem = re_256_mem_g
|
784 |
+
mo_256_mem = mo_256_mem_g
|
785 |
+
le_128_mem = le_128_mem_g
|
786 |
+
re_128_mem = re_128_mem_g
|
787 |
+
mo_128_mem = mo_128_mem_g
|
788 |
+
le_64_mem = le_64_mem_g
|
789 |
+
re_64_mem = re_64_mem_g
|
790 |
+
mo_64_mem = mo_64_mem_g
|
791 |
+
|
792 |
+
le_256_mem_norm = adaptive_instance_normalization_4D(le_256_mem, fs_in['le256'])
|
793 |
+
re_256_mem_norm = adaptive_instance_normalization_4D(re_256_mem, fs_in['re256'])
|
794 |
+
mo_256_mem_norm = adaptive_instance_normalization_4D(mo_256_mem, fs_in['mo256'])
|
795 |
+
|
796 |
+
####for 128
|
797 |
+
le_128_mem_norm = adaptive_instance_normalization_4D(le_128_mem, fs_in['le128'])
|
798 |
+
re_128_mem_norm = adaptive_instance_normalization_4D(re_128_mem, fs_in['re128'])
|
799 |
+
mo_128_mem_norm = adaptive_instance_normalization_4D(mo_128_mem, fs_in['mo128'])
|
800 |
+
|
801 |
+
####for 64
|
802 |
+
le_64_mem_norm = adaptive_instance_normalization_4D(le_64_mem, fs_in['le64'])
|
803 |
+
re_64_mem_norm = adaptive_instance_normalization_4D(re_64_mem, fs_in['re64'])
|
804 |
+
mo_64_mem_norm = adaptive_instance_normalization_4D(mo_64_mem, fs_in['mo64'])
|
805 |
+
|
806 |
+
|
807 |
+
EnMem256 = {'LE256Norm': le_256_mem_norm, 'RE256Norm': re_256_mem_norm, 'MO256Norm': mo_256_mem_norm}
|
808 |
+
EnMem128 = {'LE128Norm': le_128_mem_norm, 'RE128Norm': re_128_mem_norm, 'MO128Norm': mo_128_mem_norm}
|
809 |
+
EnMem64 = {'LE64Norm': le_64_mem_norm, 'RE64Norm': re_64_mem_norm, 'MO64Norm': mo_64_mem_norm}
|
810 |
+
Ind256 = {'LE': le_256_inds, 'RE': re_256_inds, 'MO': mo_256_inds}
|
811 |
+
Ind128 = {'LE': le_128_inds, 'RE': re_128_inds, 'MO': mo_128_inds}
|
812 |
+
Ind64 = {'LE': le_64_inds, 'RE': re_64_inds, 'MO': mo_64_inds}
|
813 |
+
return EnMem256, EnMem128, EnMem64, Ind256, Ind128, Ind64
|
814 |
+
|
815 |
+
def reconstruct(self, fs_in, locs, memstar):
|
816 |
+
le_256_mem_norm, re_256_mem_norm, mo_256_mem_norm = memstar[0]['LE256Norm'], memstar[0]['RE256Norm'], memstar[0]['MO256Norm']
|
817 |
+
le_128_mem_norm, re_128_mem_norm, mo_128_mem_norm = memstar[1]['LE128Norm'], memstar[1]['RE128Norm'], memstar[1]['MO128Norm']
|
818 |
+
le_64_mem_norm, re_64_mem_norm, mo_64_mem_norm = memstar[2]['LE64Norm'], memstar[2]['RE64Norm'], memstar[2]['MO64Norm']
|
819 |
+
|
820 |
+
le_256_final = self.LE_256_Attention(le_256_mem_norm - fs_in['le256']) * le_256_mem_norm + fs_in['le256']
|
821 |
+
re_256_final = self.RE_256_Attention(re_256_mem_norm - fs_in['re256']) * re_256_mem_norm + fs_in['re256']
|
822 |
+
mo_256_final = self.MO_256_Attention(mo_256_mem_norm - fs_in['mo256']) * mo_256_mem_norm + fs_in['mo256']
|
823 |
+
|
824 |
+
le_128_final = self.LE_128_Attention(le_128_mem_norm - fs_in['le128']) * le_128_mem_norm + fs_in['le128']
|
825 |
+
re_128_final = self.RE_128_Attention(re_128_mem_norm - fs_in['re128']) * re_128_mem_norm + fs_in['re128']
|
826 |
+
mo_128_final = self.MO_128_Attention(mo_128_mem_norm - fs_in['mo128']) * mo_128_mem_norm + fs_in['mo128']
|
827 |
+
|
828 |
+
le_64_final = self.LE_64_Attention(le_64_mem_norm - fs_in['le64']) * le_64_mem_norm + fs_in['le64']
|
829 |
+
re_64_final = self.RE_64_Attention(re_64_mem_norm - fs_in['re64']) * re_64_mem_norm + fs_in['re64']
|
830 |
+
mo_64_final = self.MO_64_Attention(mo_64_mem_norm - fs_in['mo64']) * mo_64_mem_norm + fs_in['mo64']
|
831 |
+
|
832 |
+
|
833 |
+
le_location = locs[:,0,:]
|
834 |
+
re_location = locs[:,1,:]
|
835 |
+
mo_location = locs[:,3,:]
|
836 |
+
|
837 |
+
# Somehow with latest Torch it doesn't like numpy wrappers anymore
|
838 |
+
|
839 |
+
# le_location = le_location.cpu().int().numpy()
|
840 |
+
# re_location = re_location.cpu().int().numpy()
|
841 |
+
# mo_location = mo_location.cpu().int().numpy()
|
842 |
+
le_location = le_location.cpu().int()
|
843 |
+
re_location = re_location.cpu().int()
|
844 |
+
mo_location = mo_location.cpu().int()
|
845 |
+
|
846 |
+
up_in_256 = fs_in['f256'].clone()# * 0
|
847 |
+
up_in_128 = fs_in['f128'].clone()# * 0
|
848 |
+
up_in_64 = fs_in['f64'].clone()# * 0
|
849 |
+
|
850 |
+
for i in range(fs_in['f256'].size(0)):
|
851 |
+
up_in_256[i:i+1,:,le_location[i,1]//2:le_location[i,3]//2,le_location[i,0]//2:le_location[i,2]//2] = F.interpolate(le_256_final[i:i+1,:,:,:].clone(), (le_location[i,3]//2-le_location[i,1]//2,le_location[i,2]//2-le_location[i,0]//2),mode='bilinear',align_corners=False)
|
852 |
+
up_in_256[i:i+1,:,re_location[i,1]//2:re_location[i,3]//2,re_location[i,0]//2:re_location[i,2]//2] = F.interpolate(re_256_final[i:i+1,:,:,:].clone(), (re_location[i,3]//2-re_location[i,1]//2,re_location[i,2]//2-re_location[i,0]//2),mode='bilinear',align_corners=False)
|
853 |
+
up_in_256[i:i+1,:,mo_location[i,1]//2:mo_location[i,3]//2,mo_location[i,0]//2:mo_location[i,2]//2] = F.interpolate(mo_256_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//2-mo_location[i,1]//2,mo_location[i,2]//2-mo_location[i,0]//2),mode='bilinear',align_corners=False)
|
854 |
+
|
855 |
+
up_in_128[i:i+1,:,le_location[i,1]//4:le_location[i,3]//4,le_location[i,0]//4:le_location[i,2]//4] = F.interpolate(le_128_final[i:i+1,:,:,:].clone(), (le_location[i,3]//4-le_location[i,1]//4,le_location[i,2]//4-le_location[i,0]//4),mode='bilinear',align_corners=False)
|
856 |
+
up_in_128[i:i+1,:,re_location[i,1]//4:re_location[i,3]//4,re_location[i,0]//4:re_location[i,2]//4] = F.interpolate(re_128_final[i:i+1,:,:,:].clone(), (re_location[i,3]//4-re_location[i,1]//4,re_location[i,2]//4-re_location[i,0]//4),mode='bilinear',align_corners=False)
|
857 |
+
up_in_128[i:i+1,:,mo_location[i,1]//4:mo_location[i,3]//4,mo_location[i,0]//4:mo_location[i,2]//4] = F.interpolate(mo_128_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//4-mo_location[i,1]//4,mo_location[i,2]//4-mo_location[i,0]//4),mode='bilinear',align_corners=False)
|
858 |
+
|
859 |
+
up_in_64[i:i+1,:,le_location[i,1]//8:le_location[i,3]//8,le_location[i,0]//8:le_location[i,2]//8] = F.interpolate(le_64_final[i:i+1,:,:,:].clone(), (le_location[i,3]//8-le_location[i,1]//8,le_location[i,2]//8-le_location[i,0]//8),mode='bilinear',align_corners=False)
|
860 |
+
up_in_64[i:i+1,:,re_location[i,1]//8:re_location[i,3]//8,re_location[i,0]//8:re_location[i,2]//8] = F.interpolate(re_64_final[i:i+1,:,:,:].clone(), (re_location[i,3]//8-re_location[i,1]//8,re_location[i,2]//8-re_location[i,0]//8),mode='bilinear',align_corners=False)
|
861 |
+
up_in_64[i:i+1,:,mo_location[i,1]//8:mo_location[i,3]//8,mo_location[i,0]//8:mo_location[i,2]//8] = F.interpolate(mo_64_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//8-mo_location[i,1]//8,mo_location[i,2]//8-mo_location[i,0]//8),mode='bilinear',align_corners=False)
|
862 |
+
|
863 |
+
ms_in_64 = self.MSDilate(fs_in['f64'].clone())
|
864 |
+
fea_up1 = self.up1(ms_in_64, up_in_64)
|
865 |
+
fea_up2 = self.up2(fea_up1, up_in_128) #
|
866 |
+
fea_up3 = self.up3(fea_up2, up_in_256) #
|
867 |
+
output = self.up4(fea_up3) #
|
868 |
+
return output
|
869 |
+
|
870 |
+
def generate_specific_dictionary(self, sp_imgs=None, sp_locs=None):
|
871 |
+
return self.memorize(sp_imgs, sp_locs)
|
872 |
+
|
873 |
+
def forward(self, lq=None, loc=None, sp_256 = None, sp_128 = None, sp_64 = None):
|
874 |
+
try:
|
875 |
+
fs_in = self.E_lq(lq, loc) # low quality images
|
876 |
+
except Exception as e:
|
877 |
+
print(e)
|
878 |
+
|
879 |
+
GeMemNorm256, GeMemNorm128, GeMemNorm64, Ind256, Ind128, Ind64 = self.enhancer(fs_in)
|
880 |
+
GeOut = self.reconstruct(fs_in, loc, memstar = [GeMemNorm256, GeMemNorm128, GeMemNorm64])
|
881 |
+
if sp_256 is not None and sp_128 is not None and sp_64 is not None:
|
882 |
+
GSMemNorm256, GSMemNorm128, GSMemNorm64, _, _, _ = self.enhancer(fs_in, sp_256, sp_128, sp_64)
|
883 |
+
GSOut = self.reconstruct(fs_in, loc, memstar = [GSMemNorm256, GSMemNorm128, GSMemNorm64])
|
884 |
+
else:
|
885 |
+
GSOut = None
|
886 |
+
return GeOut, GSOut
|
887 |
+
|
888 |
+
class UpResBlock(nn.Module):
|
889 |
+
def __init__(self, dim, conv_layer = nn.Conv2d, norm_layer = nn.BatchNorm2d):
|
890 |
+
super(UpResBlock, self).__init__()
|
891 |
+
self.Model = nn.Sequential(
|
892 |
+
SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
|
893 |
+
nn.LeakyReLU(0.2),
|
894 |
+
SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
|
895 |
+
)
|
896 |
+
def forward(self, x):
|
897 |
+
out = x + self.Model(x)
|
898 |
+
return out
|
roop/processors/Enhance_GFPGAN.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Callable
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import onnxruntime
|
5 |
+
import roop.globals
|
6 |
+
|
7 |
+
from roop.typing import Face, Frame, FaceSet
|
8 |
+
from roop.utilities import resolve_relative_path
|
9 |
+
|
10 |
+
|
11 |
+
# THREAD_LOCK = threading.Lock()
|
12 |
+
|
13 |
+
|
14 |
+
class Enhance_GFPGAN():
|
15 |
+
plugin_options:dict = None
|
16 |
+
|
17 |
+
model_gfpgan = None
|
18 |
+
name = None
|
19 |
+
devicename = None
|
20 |
+
|
21 |
+
processorname = 'gfpgan'
|
22 |
+
type = 'enhance'
|
23 |
+
|
24 |
+
|
25 |
+
def Initialize(self, plugin_options:dict):
|
26 |
+
if self.plugin_options is not None:
|
27 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
28 |
+
self.Release()
|
29 |
+
|
30 |
+
self.plugin_options = plugin_options
|
31 |
+
if self.model_gfpgan is None:
|
32 |
+
model_path = resolve_relative_path('../models/GFPGANv1.4.onnx')
|
33 |
+
self.model_gfpgan = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
|
34 |
+
# replace Mac mps with cpu for the moment
|
35 |
+
self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu')
|
36 |
+
|
37 |
+
self.name = self.model_gfpgan.get_inputs()[0].name
|
38 |
+
|
39 |
+
def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame:
|
40 |
+
# preprocess
|
41 |
+
input_size = temp_frame.shape[1]
|
42 |
+
temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC)
|
43 |
+
|
44 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)
|
45 |
+
temp_frame = temp_frame.astype('float32') / 255.0
|
46 |
+
temp_frame = (temp_frame - 0.5) / 0.5
|
47 |
+
temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2)
|
48 |
+
|
49 |
+
io_binding = self.model_gfpgan.io_binding()
|
50 |
+
io_binding.bind_cpu_input("input", temp_frame)
|
51 |
+
io_binding.bind_output("1288", self.devicename)
|
52 |
+
self.model_gfpgan.run_with_iobinding(io_binding)
|
53 |
+
ort_outs = io_binding.copy_outputs_to_cpu()
|
54 |
+
result = ort_outs[0][0]
|
55 |
+
|
56 |
+
# post-process
|
57 |
+
result = np.clip(result, -1, 1)
|
58 |
+
result = (result + 1) / 2
|
59 |
+
result = result.transpose(1, 2, 0) * 255.0
|
60 |
+
result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
|
61 |
+
scale_factor = int(result.shape[1] / input_size)
|
62 |
+
return result.astype(np.uint8), scale_factor
|
63 |
+
|
64 |
+
|
65 |
+
def Release(self):
|
66 |
+
self.model_gfpgan = None
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
|
roop/processors/Enhance_GPEN.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Callable
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import onnxruntime
|
5 |
+
import roop.globals
|
6 |
+
|
7 |
+
from roop.typing import Face, Frame, FaceSet
|
8 |
+
from roop.utilities import resolve_relative_path
|
9 |
+
|
10 |
+
|
11 |
+
class Enhance_GPEN():
|
12 |
+
plugin_options:dict = None
|
13 |
+
|
14 |
+
model_gpen = None
|
15 |
+
name = None
|
16 |
+
devicename = None
|
17 |
+
|
18 |
+
processorname = 'gpen'
|
19 |
+
type = 'enhance'
|
20 |
+
|
21 |
+
|
22 |
+
def Initialize(self, plugin_options:dict):
|
23 |
+
if self.plugin_options is not None:
|
24 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
25 |
+
self.Release()
|
26 |
+
|
27 |
+
self.plugin_options = plugin_options
|
28 |
+
if self.model_gpen is None:
|
29 |
+
model_path = resolve_relative_path('../models/GPEN-BFR-512.onnx')
|
30 |
+
self.model_gpen = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
|
31 |
+
# replace Mac mps with cpu for the moment
|
32 |
+
self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu')
|
33 |
+
|
34 |
+
self.name = self.model_gpen.get_inputs()[0].name
|
35 |
+
|
36 |
+
def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame:
|
37 |
+
# preprocess
|
38 |
+
input_size = temp_frame.shape[1]
|
39 |
+
temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC)
|
40 |
+
|
41 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)
|
42 |
+
temp_frame = temp_frame.astype('float32') / 255.0
|
43 |
+
temp_frame = (temp_frame - 0.5) / 0.5
|
44 |
+
temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2)
|
45 |
+
|
46 |
+
io_binding = self.model_gpen.io_binding()
|
47 |
+
io_binding.bind_cpu_input("input", temp_frame)
|
48 |
+
io_binding.bind_output("output", self.devicename)
|
49 |
+
self.model_gpen.run_with_iobinding(io_binding)
|
50 |
+
ort_outs = io_binding.copy_outputs_to_cpu()
|
51 |
+
result = ort_outs[0][0]
|
52 |
+
|
53 |
+
# post-process
|
54 |
+
result = np.clip(result, -1, 1)
|
55 |
+
result = (result + 1) / 2
|
56 |
+
result = result.transpose(1, 2, 0) * 255.0
|
57 |
+
result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
|
58 |
+
scale_factor = int(result.shape[1] / input_size)
|
59 |
+
return result.astype(np.uint8), scale_factor
|
60 |
+
|
61 |
+
|
62 |
+
def Release(self):
|
63 |
+
self.model_gpen = None
|
roop/processors/Enhance_RestoreFormerPPlus.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Callable
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import onnxruntime
|
5 |
+
import roop.globals
|
6 |
+
|
7 |
+
from roop.typing import Face, Frame, FaceSet
|
8 |
+
from roop.utilities import resolve_relative_path
|
9 |
+
|
10 |
+
class Enhance_RestoreFormerPPlus():
|
11 |
+
plugin_options:dict = None
|
12 |
+
model_restoreformerpplus = None
|
13 |
+
devicename = None
|
14 |
+
name = None
|
15 |
+
|
16 |
+
processorname = 'restoreformer++'
|
17 |
+
type = 'enhance'
|
18 |
+
|
19 |
+
|
20 |
+
def Initialize(self, plugin_options:dict):
|
21 |
+
if self.plugin_options is not None:
|
22 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
23 |
+
self.Release()
|
24 |
+
|
25 |
+
self.plugin_options = plugin_options
|
26 |
+
if self.model_restoreformerpplus is None:
|
27 |
+
# replace Mac mps with cpu for the moment
|
28 |
+
self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu')
|
29 |
+
model_path = resolve_relative_path('../models/restoreformer_plus_plus.onnx')
|
30 |
+
self.model_restoreformerpplus = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
|
31 |
+
self.model_inputs = self.model_restoreformerpplus.get_inputs()
|
32 |
+
model_outputs = self.model_restoreformerpplus.get_outputs()
|
33 |
+
self.io_binding = self.model_restoreformerpplus.io_binding()
|
34 |
+
self.io_binding.bind_output(model_outputs[0].name, self.devicename)
|
35 |
+
|
36 |
+
def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame:
|
37 |
+
# preprocess
|
38 |
+
input_size = temp_frame.shape[1]
|
39 |
+
temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC)
|
40 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)
|
41 |
+
temp_frame = temp_frame.astype('float32') / 255.0
|
42 |
+
temp_frame = (temp_frame - 0.5) / 0.5
|
43 |
+
temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2)
|
44 |
+
|
45 |
+
self.io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame) # .astype(np.float32)
|
46 |
+
self.model_restoreformerpplus.run_with_iobinding(self.io_binding)
|
47 |
+
ort_outs = self.io_binding.copy_outputs_to_cpu()
|
48 |
+
result = ort_outs[0][0]
|
49 |
+
del ort_outs
|
50 |
+
|
51 |
+
result = np.clip(result, -1, 1)
|
52 |
+
result = (result + 1) / 2
|
53 |
+
result = result.transpose(1, 2, 0) * 255.0
|
54 |
+
result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
|
55 |
+
scale_factor = int(result.shape[1] / input_size)
|
56 |
+
return result.astype(np.uint8), scale_factor
|
57 |
+
|
58 |
+
|
59 |
+
def Release(self):
|
60 |
+
del self.model_restoreformerpplus
|
61 |
+
self.model_restoreformerpplus = None
|
62 |
+
del self.io_binding
|
63 |
+
self.io_binding = None
|
64 |
+
|
roop/processors/FaceSwapInsightFace.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import roop.globals
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import onnx
|
5 |
+
import onnxruntime
|
6 |
+
|
7 |
+
from roop.typing import Face, Frame
|
8 |
+
from roop.utilities import resolve_relative_path
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
class FaceSwapInsightFace():
|
13 |
+
plugin_options:dict = None
|
14 |
+
model_swap_insightface = None
|
15 |
+
|
16 |
+
processorname = 'faceswap'
|
17 |
+
type = 'swap'
|
18 |
+
|
19 |
+
|
20 |
+
def Initialize(self, plugin_options:dict):
|
21 |
+
if self.plugin_options is not None:
|
22 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
23 |
+
self.Release()
|
24 |
+
|
25 |
+
self.plugin_options = plugin_options
|
26 |
+
if self.model_swap_insightface is None:
|
27 |
+
model_path = resolve_relative_path('../models/inswapper_128.onnx')
|
28 |
+
graph = onnx.load(model_path).graph
|
29 |
+
self.emap = onnx.numpy_helper.to_array(graph.initializer[-1])
|
30 |
+
self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu')
|
31 |
+
self.input_mean = 0.0
|
32 |
+
self.input_std = 255.0
|
33 |
+
#cuda_options = {"arena_extend_strategy": "kSameAsRequested", 'cudnn_conv_algo_search': 'DEFAULT'}
|
34 |
+
sess_options = onnxruntime.SessionOptions()
|
35 |
+
sess_options.enable_cpu_mem_arena = False
|
36 |
+
self.model_swap_insightface = onnxruntime.InferenceSession(model_path, sess_options, providers=roop.globals.execution_providers)
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
def Run(self, source_face: Face, target_face: Face, temp_frame: Frame) -> Frame:
|
41 |
+
blob = cv2.dnn.blobFromImage(temp_frame, 1.0 / self.input_std, (128, 128),
|
42 |
+
(self.input_mean, self.input_mean, self.input_mean), swapRB=True)
|
43 |
+
latent = source_face.normed_embedding.reshape((1,-1))
|
44 |
+
latent = np.dot(latent, self.emap)
|
45 |
+
latent /= np.linalg.norm(latent)
|
46 |
+
io_binding = self.model_swap_insightface.io_binding()
|
47 |
+
io_binding.bind_cpu_input("target", blob)
|
48 |
+
io_binding.bind_cpu_input("source", latent)
|
49 |
+
io_binding.bind_output("output", self.devicename)
|
50 |
+
self.model_swap_insightface.run_with_iobinding(io_binding)
|
51 |
+
ort_outs = io_binding.copy_outputs_to_cpu()[0]
|
52 |
+
img_fake = ort_outs.transpose((0,2,3,1))[0]
|
53 |
+
return np.clip(255 * img_fake, 0, 255).astype(np.uint8)[:,:,::-1]
|
54 |
+
|
55 |
+
|
56 |
+
img_fake, M = self.model_swap_insightface.get(temp_frame, target_face, source_face, paste_back=False)
|
57 |
+
# target_face.matrix = M
|
58 |
+
# return img_fake
|
59 |
+
|
60 |
+
|
61 |
+
def Release(self):
|
62 |
+
del self.model_swap_insightface
|
63 |
+
self.model_swap_insightface = None
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
roop/processors/Frame_Colorizer.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import onnxruntime
|
4 |
+
import roop.globals
|
5 |
+
|
6 |
+
from roop.utilities import resolve_relative_path
|
7 |
+
from roop.typing import Frame
|
8 |
+
|
9 |
+
class Frame_Colorizer():
|
10 |
+
plugin_options:dict = None
|
11 |
+
model_colorizer = None
|
12 |
+
devicename = None
|
13 |
+
prev_type = None
|
14 |
+
|
15 |
+
processorname = 'deoldify'
|
16 |
+
type = 'frame_colorizer'
|
17 |
+
|
18 |
+
|
19 |
+
def Initialize(self, plugin_options:dict):
|
20 |
+
if self.plugin_options is not None:
|
21 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
22 |
+
self.Release()
|
23 |
+
|
24 |
+
self.plugin_options = plugin_options
|
25 |
+
if self.prev_type is not None and self.prev_type != self.plugin_options["subtype"]:
|
26 |
+
self.Release()
|
27 |
+
self.prev_type = self.plugin_options["subtype"]
|
28 |
+
if self.model_colorizer is None:
|
29 |
+
# replace Mac mps with cpu for the moment
|
30 |
+
self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu')
|
31 |
+
if self.prev_type == "deoldify_artistic":
|
32 |
+
model_path = resolve_relative_path('../models/Frame/deoldify_artistic.onnx')
|
33 |
+
elif self.prev_type == "deoldify_stable":
|
34 |
+
model_path = resolve_relative_path('../models/Frame/deoldify_stable.onnx')
|
35 |
+
|
36 |
+
onnxruntime.set_default_logger_severity(3)
|
37 |
+
self.model_colorizer = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
|
38 |
+
self.model_inputs = self.model_colorizer.get_inputs()
|
39 |
+
model_outputs = self.model_colorizer.get_outputs()
|
40 |
+
self.io_binding = self.model_colorizer.io_binding()
|
41 |
+
self.io_binding.bind_output(model_outputs[0].name, self.devicename)
|
42 |
+
|
43 |
+
def Run(self, input_frame: Frame) -> Frame:
|
44 |
+
temp_frame = cv2.cvtColor(input_frame, cv2.COLOR_BGR2GRAY)
|
45 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB)
|
46 |
+
temp_frame = cv2.resize(temp_frame, (256, 256))
|
47 |
+
temp_frame = temp_frame.transpose((2, 0, 1))
|
48 |
+
temp_frame = np.expand_dims(temp_frame, axis=0).astype(np.float32)
|
49 |
+
self.io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame)
|
50 |
+
self.model_colorizer.run_with_iobinding(self.io_binding)
|
51 |
+
ort_outs = self.io_binding.copy_outputs_to_cpu()
|
52 |
+
result = ort_outs[0][0]
|
53 |
+
del ort_outs
|
54 |
+
colorized_frame = result.transpose(1, 2, 0)
|
55 |
+
colorized_frame = cv2.resize(colorized_frame, (input_frame.shape[1], input_frame.shape[0]))
|
56 |
+
temp_blue_channel, _, _ = cv2.split(input_frame)
|
57 |
+
colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2RGB).astype(np.uint8)
|
58 |
+
colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2LAB)
|
59 |
+
_, color_green_channel, color_red_channel = cv2.split(colorized_frame)
|
60 |
+
colorized_frame = cv2.merge((temp_blue_channel, color_green_channel, color_red_channel))
|
61 |
+
colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_LAB2BGR)
|
62 |
+
return colorized_frame.astype(np.uint8)
|
63 |
+
|
64 |
+
|
65 |
+
def Release(self):
|
66 |
+
del self.model_colorizer
|
67 |
+
self.model_colorizer = None
|
68 |
+
del self.io_binding
|
69 |
+
self.io_binding = None
|
70 |
+
|
roop/processors/Frame_Filter.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from roop.typing import Frame
|
5 |
+
|
6 |
+
class Frame_Filter():
|
7 |
+
processorname = 'generic_filter'
|
8 |
+
type = 'frame_processor'
|
9 |
+
|
10 |
+
plugin_options:dict = None
|
11 |
+
|
12 |
+
c64_palette = np.array([
|
13 |
+
[0, 0, 0],
|
14 |
+
[255, 255, 255],
|
15 |
+
[0x81, 0x33, 0x38],
|
16 |
+
[0x75, 0xce, 0xc8],
|
17 |
+
[0x8e, 0x3c, 0x97],
|
18 |
+
[0x56, 0xac, 0x4d],
|
19 |
+
[0x2e, 0x2c, 0x9b],
|
20 |
+
[0xed, 0xf1, 0x71],
|
21 |
+
[0x8e, 0x50, 0x29],
|
22 |
+
[0x55, 0x38, 0x00],
|
23 |
+
[0xc4, 0x6c, 0x71],
|
24 |
+
[0x4a, 0x4a, 0x4a],
|
25 |
+
[0x7b, 0x7b, 0x7b],
|
26 |
+
[0xa9, 0xff, 0x9f],
|
27 |
+
[0x70, 0x6d, 0xeb],
|
28 |
+
[0xb2, 0xb2, 0xb2]
|
29 |
+
])
|
30 |
+
|
31 |
+
|
32 |
+
def RenderC64Screen(self, image):
|
33 |
+
# Simply round the color values to the nearest color in the palette
|
34 |
+
image = cv2.resize(image,(320,200))
|
35 |
+
palette = self.c64_palette / 255.0 # Normalize palette
|
36 |
+
img_normalized = image / 255.0 # Normalize image
|
37 |
+
|
38 |
+
# Calculate the index in the palette that is closest to each pixel in the image
|
39 |
+
indices = np.sqrt(((img_normalized[:, :, None, :] - palette[None, None, :, :]) ** 2).sum(axis=3)).argmin(axis=2)
|
40 |
+
# Map the image to the palette colors
|
41 |
+
mapped_image = palette[indices]
|
42 |
+
return (mapped_image * 255).astype(np.uint8) # Denormalize and return the image
|
43 |
+
|
44 |
+
|
45 |
+
def RenderDetailEnhance(self, image):
|
46 |
+
return cv2.detailEnhance(image)
|
47 |
+
|
48 |
+
def RenderStylize(self, image):
|
49 |
+
return cv2.stylization(image)
|
50 |
+
|
51 |
+
def RenderPencilSketch(self, image):
|
52 |
+
imgray, imout = cv2.pencilSketch(image, sigma_s=60, sigma_r=0.07, shade_factor=0.05)
|
53 |
+
return imout
|
54 |
+
|
55 |
+
def RenderCartoon(self, image):
|
56 |
+
numDownSamples = 2 # number of downscaling steps
|
57 |
+
numBilateralFilters = 7 # number of bilateral filtering steps
|
58 |
+
|
59 |
+
img_color = image
|
60 |
+
for _ in range(numDownSamples):
|
61 |
+
img_color = cv2.pyrDown(img_color)
|
62 |
+
for _ in range(numBilateralFilters):
|
63 |
+
img_color = cv2.bilateralFilter(img_color, 9, 9, 7)
|
64 |
+
for _ in range(numDownSamples):
|
65 |
+
img_color = cv2.pyrUp(img_color)
|
66 |
+
img_gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
67 |
+
img_blur = cv2.medianBlur(img_gray, 7)
|
68 |
+
img_edge = cv2.adaptiveThreshold(img_blur, 255,
|
69 |
+
cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 9, 2)
|
70 |
+
img_edge = cv2.cvtColor(img_edge, cv2.COLOR_GRAY2RGB)
|
71 |
+
if img_color.shape != image.shape:
|
72 |
+
img_color = cv2.resize(img_color, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR)
|
73 |
+
if img_color.shape != img_edge.shape:
|
74 |
+
img_edge = cv2.resize(img_edge, (img_color.shape[1], img_color.shape[0]), interpolation=cv2.INTER_LINEAR)
|
75 |
+
return cv2.bitwise_and(img_color, img_edge)
|
76 |
+
|
77 |
+
|
78 |
+
def Initialize(self, plugin_options:dict):
|
79 |
+
if self.plugin_options is not None:
|
80 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
81 |
+
self.Release()
|
82 |
+
self.plugin_options = plugin_options
|
83 |
+
|
84 |
+
def Run(self, temp_frame: Frame) -> Frame:
|
85 |
+
subtype = self.plugin_options["subtype"]
|
86 |
+
if subtype == "stylize":
|
87 |
+
return self.RenderStylize(temp_frame).astype(np.uint8)
|
88 |
+
if subtype == "detailenhance":
|
89 |
+
return self.RenderDetailEnhance(temp_frame).astype(np.uint8)
|
90 |
+
if subtype == "pencil":
|
91 |
+
return self.RenderPencilSketch(temp_frame).astype(np.uint8)
|
92 |
+
if subtype == "cartoon":
|
93 |
+
return self.RenderCartoon(temp_frame).astype(np.uint8)
|
94 |
+
if subtype == "C64":
|
95 |
+
return self.RenderC64Screen(temp_frame).astype(np.uint8)
|
96 |
+
|
97 |
+
|
98 |
+
def Release(self):
|
99 |
+
pass
|
100 |
+
|
101 |
+
def getProcessedResolution(self, width, height):
|
102 |
+
if self.plugin_options["subtype"] == "C64":
|
103 |
+
return (320,200)
|
104 |
+
return None
|
105 |
+
|
roop/processors/Frame_Masking.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import onnxruntime
|
4 |
+
import roop.globals
|
5 |
+
|
6 |
+
from roop.utilities import resolve_relative_path
|
7 |
+
from roop.typing import Frame
|
8 |
+
|
9 |
+
class Frame_Masking():
|
10 |
+
plugin_options:dict = None
|
11 |
+
model_masking = None
|
12 |
+
devicename = None
|
13 |
+
name = None
|
14 |
+
|
15 |
+
processorname = 'removebg'
|
16 |
+
type = 'frame_masking'
|
17 |
+
|
18 |
+
|
19 |
+
def Initialize(self, plugin_options:dict):
|
20 |
+
if self.plugin_options is not None:
|
21 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
22 |
+
self.Release()
|
23 |
+
|
24 |
+
self.plugin_options = plugin_options
|
25 |
+
if self.model_masking is None:
|
26 |
+
# replace Mac mps with cpu for the moment
|
27 |
+
self.devicename = self.plugin_options["devicename"]
|
28 |
+
self.devicename = self.devicename.replace('mps', 'cpu')
|
29 |
+
model_path = resolve_relative_path('../models/Frame/isnet-general-use.onnx')
|
30 |
+
self.model_masking = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
|
31 |
+
self.model_inputs = self.model_masking.get_inputs()
|
32 |
+
model_outputs = self.model_masking.get_outputs()
|
33 |
+
self.io_binding = self.model_masking.io_binding()
|
34 |
+
self.io_binding.bind_output(model_outputs[0].name, self.devicename)
|
35 |
+
|
36 |
+
def Run(self, temp_frame: Frame) -> Frame:
|
37 |
+
# Pre process:Resize, BGR->RGB, float32 cast
|
38 |
+
input_image = cv2.resize(temp_frame, (1024, 1024))
|
39 |
+
input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
|
40 |
+
mean = [0.5, 0.5, 0.5]
|
41 |
+
std = [1.0, 1.0, 1.0]
|
42 |
+
input_image = (input_image / 255.0 - mean) / std
|
43 |
+
input_image = input_image.transpose(2, 0, 1)
|
44 |
+
input_image = np.expand_dims(input_image, axis=0)
|
45 |
+
input_image = input_image.astype('float32')
|
46 |
+
|
47 |
+
self.io_binding.bind_cpu_input(self.model_inputs[0].name, input_image)
|
48 |
+
self.model_masking.run_with_iobinding(self.io_binding)
|
49 |
+
ort_outs = self.io_binding.copy_outputs_to_cpu()
|
50 |
+
result = ort_outs[0][0]
|
51 |
+
del ort_outs
|
52 |
+
# Post process:squeeze, Sigmoid, Normarize, uint8 cast
|
53 |
+
mask = np.squeeze(result[0])
|
54 |
+
min_value = np.min(mask)
|
55 |
+
max_value = np.max(mask)
|
56 |
+
mask = (mask - min_value) / (max_value - min_value)
|
57 |
+
#mask = np.where(mask < score_th, 0, 1)
|
58 |
+
#mask *= 255
|
59 |
+
mask = cv2.resize(mask, (temp_frame.shape[1], temp_frame.shape[0]), interpolation=cv2.INTER_LINEAR)
|
60 |
+
mask = np.reshape(mask, [mask.shape[0],mask.shape[1],1])
|
61 |
+
result = mask * temp_frame.astype(np.float32)
|
62 |
+
return result.astype(np.uint8)
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
def Release(self):
|
67 |
+
del self.model_masking
|
68 |
+
self.model_masking = None
|
69 |
+
del self.io_binding
|
70 |
+
self.io_binding = None
|
71 |
+
|
roop/processors/Frame_Upscale.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import onnxruntime
|
4 |
+
import roop.globals
|
5 |
+
import threading
|
6 |
+
|
7 |
+
from roop.utilities import resolve_relative_path
|
8 |
+
from roop.typing import Frame
|
9 |
+
|
10 |
+
class Frame_Upscale():
|
11 |
+
plugin_options:dict = None
|
12 |
+
model_upscale = None
|
13 |
+
devicename = None
|
14 |
+
prev_type = None
|
15 |
+
|
16 |
+
processorname = 'upscale'
|
17 |
+
type = 'frame_enhancer'
|
18 |
+
|
19 |
+
THREAD_LOCK_UPSCALE = threading.Lock()
|
20 |
+
|
21 |
+
|
22 |
+
def Initialize(self, plugin_options:dict):
|
23 |
+
if self.plugin_options is not None:
|
24 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
25 |
+
self.Release()
|
26 |
+
|
27 |
+
self.plugin_options = plugin_options
|
28 |
+
if self.prev_type is not None and self.prev_type != self.plugin_options["subtype"]:
|
29 |
+
self.Release()
|
30 |
+
self.prev_type = self.plugin_options["subtype"]
|
31 |
+
if self.model_upscale is None:
|
32 |
+
# replace Mac mps with cpu for the moment
|
33 |
+
self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu')
|
34 |
+
if self.prev_type == "esrganx4":
|
35 |
+
model_path = resolve_relative_path('../models/Frame/real_esrgan_x4.onnx')
|
36 |
+
self.scale = 4
|
37 |
+
elif self.prev_type == "esrganx2":
|
38 |
+
model_path = resolve_relative_path('../models/Frame/real_esrgan_x2.onnx')
|
39 |
+
self.scale = 2
|
40 |
+
elif self.prev_type == "lsdirx4":
|
41 |
+
model_path = resolve_relative_path('../models/Frame/lsdir_x4.onnx')
|
42 |
+
self.scale = 4
|
43 |
+
|
44 |
+
self.model_upscale = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
|
45 |
+
self.model_inputs = self.model_upscale.get_inputs()
|
46 |
+
model_outputs = self.model_upscale.get_outputs()
|
47 |
+
self.io_binding = self.model_upscale.io_binding()
|
48 |
+
self.io_binding.bind_output(model_outputs[0].name, self.devicename)
|
49 |
+
|
50 |
+
def getProcessedResolution(self, width, height):
|
51 |
+
return (width * self.scale, height * self.scale)
|
52 |
+
|
53 |
+
# borrowed from facefusion -> https://github.com/facefusion/facefusion
|
54 |
+
def prepare_tile_frame(self, tile_frame : Frame) -> Frame:
|
55 |
+
tile_frame = np.expand_dims(tile_frame[:, :, ::-1], axis = 0)
|
56 |
+
tile_frame = tile_frame.transpose(0, 3, 1, 2)
|
57 |
+
tile_frame = tile_frame.astype(np.float32) / 255
|
58 |
+
return tile_frame
|
59 |
+
|
60 |
+
|
61 |
+
def normalize_tile_frame(self, tile_frame : Frame) -> Frame:
|
62 |
+
tile_frame = tile_frame.transpose(0, 2, 3, 1).squeeze(0) * 255
|
63 |
+
tile_frame = tile_frame.clip(0, 255).astype(np.uint8)[:, :, ::-1]
|
64 |
+
return tile_frame
|
65 |
+
|
66 |
+
def create_tile_frames(self, input_frame : Frame, size):
|
67 |
+
input_frame = np.pad(input_frame, ((size[1], size[1]), (size[1], size[1]), (0, 0)))
|
68 |
+
tile_width = size[0] - 2 * size[2]
|
69 |
+
pad_size_bottom = size[2] + tile_width - input_frame.shape[0] % tile_width
|
70 |
+
pad_size_right = size[2] + tile_width - input_frame.shape[1] % tile_width
|
71 |
+
pad_vision_frame = np.pad(input_frame, ((size[2], pad_size_bottom), (size[2], pad_size_right), (0, 0)))
|
72 |
+
pad_height, pad_width = pad_vision_frame.shape[:2]
|
73 |
+
row_range = range(size[2], pad_height - size[2], tile_width)
|
74 |
+
col_range = range(size[2], pad_width - size[2], tile_width)
|
75 |
+
tile_frames = []
|
76 |
+
|
77 |
+
for row_frame in row_range:
|
78 |
+
top = row_frame - size[2]
|
79 |
+
bottom = row_frame + size[2] + tile_width
|
80 |
+
for column_vision_frame in col_range:
|
81 |
+
left = column_vision_frame - size[2]
|
82 |
+
right = column_vision_frame + size[2] + tile_width
|
83 |
+
tile_frames.append(pad_vision_frame[top:bottom, left:right, :])
|
84 |
+
return tile_frames, pad_width, pad_height
|
85 |
+
|
86 |
+
|
87 |
+
def merge_tile_frames(self, tile_frames, temp_width : int, temp_height : int, pad_width : int, pad_height : int, size) -> Frame:
|
88 |
+
merge_frame = np.zeros((pad_height, pad_width, 3)).astype(np.uint8)
|
89 |
+
tile_width = tile_frames[0].shape[1] - 2 * size[2]
|
90 |
+
tiles_per_row = min(pad_width // tile_width, len(tile_frames))
|
91 |
+
|
92 |
+
for index, tile_frame in enumerate(tile_frames):
|
93 |
+
tile_frame = tile_frame[size[2]:-size[2], size[2]:-size[2]]
|
94 |
+
row_index = index // tiles_per_row
|
95 |
+
col_index = index % tiles_per_row
|
96 |
+
top = row_index * tile_frame.shape[0]
|
97 |
+
bottom = top + tile_frame.shape[0]
|
98 |
+
left = col_index * tile_frame.shape[1]
|
99 |
+
right = left + tile_frame.shape[1]
|
100 |
+
merge_frame[top:bottom, left:right, :] = tile_frame
|
101 |
+
merge_frame = merge_frame[size[1] : size[1] + temp_height, size[1]: size[1] + temp_width, :]
|
102 |
+
return merge_frame
|
103 |
+
|
104 |
+
|
105 |
+
def Run(self, temp_frame: Frame) -> Frame:
|
106 |
+
size = (128, 8, 2)
|
107 |
+
temp_height, temp_width = temp_frame.shape[:2]
|
108 |
+
upscale_tile_frames, pad_width, pad_height = self.create_tile_frames(temp_frame, size)
|
109 |
+
|
110 |
+
for index, tile_frame in enumerate(upscale_tile_frames):
|
111 |
+
tile_frame = self.prepare_tile_frame(tile_frame)
|
112 |
+
with self.THREAD_LOCK_UPSCALE:
|
113 |
+
self.io_binding.bind_cpu_input(self.model_inputs[0].name, tile_frame)
|
114 |
+
self.model_upscale.run_with_iobinding(self.io_binding)
|
115 |
+
ort_outs = self.io_binding.copy_outputs_to_cpu()
|
116 |
+
result = ort_outs[0]
|
117 |
+
upscale_tile_frames[index] = self.normalize_tile_frame(result)
|
118 |
+
final_frame = self.merge_tile_frames(upscale_tile_frames, temp_width * self.scale
|
119 |
+
, temp_height * self.scale
|
120 |
+
, pad_width * self.scale, pad_height * self.scale
|
121 |
+
, (size[0] * self.scale, size[1] * self.scale, size[2] * self.scale))
|
122 |
+
return final_frame.astype(np.uint8)
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
def Release(self):
|
127 |
+
del self.model_upscale
|
128 |
+
self.model_upscale = None
|
129 |
+
del self.io_binding
|
130 |
+
self.io_binding = None
|
131 |
+
|
roop/processors/Mask_Clip2Seg.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import threading
|
5 |
+
from torchvision import transforms
|
6 |
+
from clip.clipseg import CLIPDensePredT
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from roop.typing import Frame
|
10 |
+
|
11 |
+
THREAD_LOCK_CLIP = threading.Lock()
|
12 |
+
|
13 |
+
|
14 |
+
class Mask_Clip2Seg():
|
15 |
+
plugin_options:dict = None
|
16 |
+
model_clip = None
|
17 |
+
|
18 |
+
processorname = 'clip2seg'
|
19 |
+
type = 'mask'
|
20 |
+
|
21 |
+
|
22 |
+
def Initialize(self, plugin_options:dict):
|
23 |
+
if self.plugin_options is not None:
|
24 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
25 |
+
self.Release()
|
26 |
+
|
27 |
+
self.plugin_options = plugin_options
|
28 |
+
if self.model_clip is None:
|
29 |
+
self.model_clip = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
|
30 |
+
self.model_clip.eval();
|
31 |
+
self.model_clip.load_state_dict(torch.load('models/CLIP/rd64-uni-refined.pth', map_location=torch.device('cpu')), strict=False)
|
32 |
+
|
33 |
+
device = torch.device(self.plugin_options["devicename"])
|
34 |
+
self.model_clip.to(device)
|
35 |
+
|
36 |
+
|
37 |
+
def Run(self, img1, keywords:str) -> Frame:
|
38 |
+
if keywords is None or len(keywords) < 1 or img1 is None:
|
39 |
+
return img1
|
40 |
+
|
41 |
+
source_image_small = cv2.resize(img1, (256,256))
|
42 |
+
|
43 |
+
img_mask = np.full((source_image_small.shape[0],source_image_small.shape[1]), 0, dtype=np.float32)
|
44 |
+
mask_border = 1
|
45 |
+
l = 0
|
46 |
+
t = 0
|
47 |
+
r = 1
|
48 |
+
b = 1
|
49 |
+
|
50 |
+
mask_blur = 5
|
51 |
+
clip_blur = 5
|
52 |
+
|
53 |
+
img_mask = cv2.rectangle(img_mask, (mask_border+int(l), mask_border+int(t)),
|
54 |
+
(256 - mask_border-int(r), 256-mask_border-int(b)), (255, 255, 255), -1)
|
55 |
+
img_mask = cv2.GaussianBlur(img_mask, (mask_blur*2+1,mask_blur*2+1), 0)
|
56 |
+
img_mask /= 255
|
57 |
+
|
58 |
+
|
59 |
+
input_image = source_image_small
|
60 |
+
|
61 |
+
transform = transforms.Compose([
|
62 |
+
transforms.ToTensor(),
|
63 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
64 |
+
transforms.Resize((256, 256)),
|
65 |
+
])
|
66 |
+
img = transform(input_image).unsqueeze(0)
|
67 |
+
|
68 |
+
thresh = 0.5
|
69 |
+
prompts = keywords.split(',')
|
70 |
+
with THREAD_LOCK_CLIP:
|
71 |
+
with torch.no_grad():
|
72 |
+
preds = self.model_clip(img.repeat(len(prompts),1,1,1), prompts)[0]
|
73 |
+
clip_mask = torch.sigmoid(preds[0][0])
|
74 |
+
for i in range(len(prompts)-1):
|
75 |
+
clip_mask += torch.sigmoid(preds[i+1][0])
|
76 |
+
|
77 |
+
clip_mask = clip_mask.data.cpu().numpy()
|
78 |
+
np.clip(clip_mask, 0, 1)
|
79 |
+
|
80 |
+
clip_mask[clip_mask>thresh] = 1.0
|
81 |
+
clip_mask[clip_mask<=thresh] = 0.0
|
82 |
+
kernel = np.ones((5, 5), np.float32)
|
83 |
+
clip_mask = cv2.dilate(clip_mask, kernel, iterations=1)
|
84 |
+
clip_mask = cv2.GaussianBlur(clip_mask, (clip_blur*2+1,clip_blur*2+1), 0)
|
85 |
+
|
86 |
+
img_mask *= clip_mask
|
87 |
+
img_mask[img_mask<0.0] = 0.0
|
88 |
+
return img_mask
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
def Release(self):
|
93 |
+
self.model_clip = None
|
94 |
+
|
roop/processors/Mask_XSeg.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import onnxruntime
|
4 |
+
import threading
|
5 |
+
import roop.globals
|
6 |
+
|
7 |
+
from roop.typing import Frame
|
8 |
+
from roop.utilities import resolve_relative_path
|
9 |
+
|
10 |
+
THREAD_LOCK_CLIP = threading.Lock()
|
11 |
+
|
12 |
+
|
13 |
+
class Mask_XSeg():
|
14 |
+
plugin_options:dict = None
|
15 |
+
|
16 |
+
model_xseg = None
|
17 |
+
|
18 |
+
processorname = 'mask_xseg'
|
19 |
+
type = 'mask'
|
20 |
+
|
21 |
+
|
22 |
+
def Initialize(self, plugin_options:dict):
|
23 |
+
if self.plugin_options is not None:
|
24 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
25 |
+
self.Release()
|
26 |
+
|
27 |
+
self.plugin_options = plugin_options
|
28 |
+
if self.model_xseg is None:
|
29 |
+
model_path = resolve_relative_path('../models/xseg.onnx')
|
30 |
+
onnxruntime.set_default_logger_severity(3)
|
31 |
+
self.model_xseg = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
|
32 |
+
self.model_inputs = self.model_xseg.get_inputs()
|
33 |
+
self.model_outputs = self.model_xseg.get_outputs()
|
34 |
+
|
35 |
+
# replace Mac mps with cpu for the moment
|
36 |
+
self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu')
|
37 |
+
|
38 |
+
|
39 |
+
def Run(self, img1, keywords:str) -> Frame:
|
40 |
+
temp_frame = cv2.resize(img1, (256, 256), cv2.INTER_CUBIC)
|
41 |
+
temp_frame = temp_frame.astype('float32') / 255.0
|
42 |
+
temp_frame = temp_frame[None, ...]
|
43 |
+
io_binding = self.model_xseg.io_binding()
|
44 |
+
io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame)
|
45 |
+
io_binding.bind_output(self.model_outputs[0].name, self.devicename)
|
46 |
+
self.model_xseg.run_with_iobinding(io_binding)
|
47 |
+
ort_outs = io_binding.copy_outputs_to_cpu()
|
48 |
+
result = ort_outs[0][0]
|
49 |
+
result = np.clip(result, 0, 1.0)
|
50 |
+
result[result < 0.1] = 0
|
51 |
+
# invert values to mask areas to keep
|
52 |
+
result = 1.0 - result
|
53 |
+
return result
|
54 |
+
|
55 |
+
|
56 |
+
def Release(self):
|
57 |
+
del self.model_xseg
|
58 |
+
self.model_xseg = None
|
59 |
+
|
60 |
+
|
roop/processors/__init__.py
ADDED
File without changes
|
roop/template_parser.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from datetime import datetime
|
3 |
+
|
4 |
+
template_functions = {
|
5 |
+
"timestamp": lambda data: str(int(datetime.now().timestamp())),
|
6 |
+
"i": lambda data: data.get("index", False),
|
7 |
+
"file": lambda data: data.get("file", False),
|
8 |
+
"date": lambda data: datetime.now().strftime("%Y-%m-%d"),
|
9 |
+
"time": lambda data: datetime.now().strftime("%H-%M-%S"),
|
10 |
+
}
|
11 |
+
|
12 |
+
|
13 |
+
def parse(text: str, data: dict):
|
14 |
+
pattern = r"\{([^}]+)\}"
|
15 |
+
|
16 |
+
matches = re.findall(pattern, text)
|
17 |
+
|
18 |
+
for match in matches:
|
19 |
+
replacement = template_functions[match](data)
|
20 |
+
if replacement is not False:
|
21 |
+
text = text.replace(f"{{{match}}}", replacement)
|
22 |
+
|
23 |
+
return text
|
roop/typing.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
from insightface.app.common import Face
|
4 |
+
from roop.FaceSet import FaceSet
|
5 |
+
import numpy
|
6 |
+
|
7 |
+
Face = Face
|
8 |
+
FaceSet = FaceSet
|
9 |
+
Frame = numpy.ndarray[Any, Any]
|