Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .devcontainer/Dockerfile +11 -0
- .devcontainer/devcontainer.json +27 -0
- .devcontainer/noop.txt +3 -0
- .gitattributes +7 -0
- .github/dependabot.yml +12 -0
- LICENSE +35 -0
- README.md +189 -8
- __pycache__/sampler.cpython-310.pyc +0 -0
- __pycache__/sampler.cpython-38.pyc +0 -0
- app.py +200 -0
- assets/0015.png +3 -0
- assets/0030.png +3 -0
- assets/Lincon.png +3 -0
- assets/cat.png +3 -0
- assets/dog2.png +3 -0
- assets/framework.png +0 -0
- assets/frog.png +3 -0
- assets/oldphoto6.png +3 -0
- basicsr/__init__.py +4 -0
- basicsr/__pycache__/__init__.cpython-310.pyc +0 -0
- basicsr/__pycache__/__init__.cpython-38.pyc +0 -0
- basicsr/data/__init__.py +101 -0
- basicsr/data/__pycache__/__init__.cpython-310.pyc +0 -0
- basicsr/data/__pycache__/__init__.cpython-38.pyc +0 -0
- basicsr/data/__pycache__/data_util.cpython-310.pyc +0 -0
- basicsr/data/__pycache__/data_util.cpython-38.pyc +0 -0
- basicsr/data/__pycache__/degradations.cpython-310.pyc +0 -0
- basicsr/data/__pycache__/degradations.cpython-38.pyc +0 -0
- basicsr/data/__pycache__/ffhq_dataset.cpython-310.pyc +0 -0
- basicsr/data/__pycache__/ffhq_dataset.cpython-38.pyc +0 -0
- basicsr/data/__pycache__/paired_image_dataset.cpython-310.pyc +0 -0
- basicsr/data/__pycache__/paired_image_dataset.cpython-38.pyc +0 -0
- basicsr/data/__pycache__/prefetch_dataloader.cpython-310.pyc +0 -0
- basicsr/data/__pycache__/prefetch_dataloader.cpython-38.pyc +0 -0
- basicsr/data/__pycache__/realesrgan_dataset.cpython-310.pyc +0 -0
- basicsr/data/__pycache__/realesrgan_dataset.cpython-38.pyc +0 -0
- basicsr/data/__pycache__/realesrgan_paired_dataset.cpython-310.pyc +0 -0
- basicsr/data/__pycache__/realesrgan_paired_dataset.cpython-38.pyc +0 -0
- basicsr/data/__pycache__/reds_dataset.cpython-310.pyc +0 -0
- basicsr/data/__pycache__/reds_dataset.cpython-38.pyc +0 -0
- basicsr/data/__pycache__/single_image_dataset.cpython-310.pyc +0 -0
- basicsr/data/__pycache__/single_image_dataset.cpython-38.pyc +0 -0
- basicsr/data/__pycache__/transforms.cpython-310.pyc +0 -0
- basicsr/data/__pycache__/transforms.cpython-38.pyc +0 -0
- basicsr/data/__pycache__/video_test_dataset.cpython-310.pyc +0 -0
- basicsr/data/__pycache__/video_test_dataset.cpython-38.pyc +0 -0
- basicsr/data/__pycache__/vimeo90k_dataset.cpython-310.pyc +0 -0
- basicsr/data/__pycache__/vimeo90k_dataset.cpython-38.pyc +0 -0
- basicsr/data/data_sampler.py +48 -0
- basicsr/data/data_util.py +315 -0
.devcontainer/Dockerfile
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM mcr.microsoft.com/devcontainers/anaconda:0-3
|
2 |
+
|
3 |
+
# Copy environment.yml (if found) to a temp location so we update the environment. Also
|
4 |
+
# copy "noop.txt" so the COPY instruction does not fail if no environment.yml exists.
|
5 |
+
COPY environment.yml* .devcontainer/noop.txt /tmp/conda-tmp/
|
6 |
+
RUN if [ -f "/tmp/conda-tmp/environment.yml" ]; then umask 0002 && /opt/conda/bin/conda env update -n base -f /tmp/conda-tmp/environment.yml; fi \
|
7 |
+
&& rm -rf /tmp/conda-tmp
|
8 |
+
|
9 |
+
# [Optional] Uncomment this section to install additional OS packages.
|
10 |
+
# RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
11 |
+
# && apt-get -y install --no-install-recommends <your-package-list-here>
|
.devcontainer/devcontainer.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
|
2 |
+
// README at: https://github.com/devcontainers/templates/tree/main/src/anaconda
|
3 |
+
{
|
4 |
+
"name": "Anaconda (Python 3)",
|
5 |
+
"build": {
|
6 |
+
"context": "..",
|
7 |
+
"dockerfile": "Dockerfile"
|
8 |
+
},
|
9 |
+
"features": {
|
10 |
+
"ghcr.io/flexwie/devcontainer-features/op:1": {}
|
11 |
+
}
|
12 |
+
|
13 |
+
// Features to add to the dev container. More info: https://containers.dev/features.
|
14 |
+
// "features": {},
|
15 |
+
|
16 |
+
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
17 |
+
// "forwardPorts": [],
|
18 |
+
|
19 |
+
// Use 'postCreateCommand' to run commands after the container is created.
|
20 |
+
// "postCreateCommand": "python --version",
|
21 |
+
|
22 |
+
// Configure tool-specific properties.
|
23 |
+
// "customizations": {},
|
24 |
+
|
25 |
+
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
|
26 |
+
// "remoteUser": "root"
|
27 |
+
}
|
.devcontainer/noop.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
This file copied into the container along with environment.yml* from the parent
|
2 |
+
folder. This file is included to prevents the Dockerfile COPY instruction from
|
3 |
+
failing if no environment.yml is found.
|
.gitattributes
CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/0015.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/0030.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/Lincon.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/cat.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/dog2.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/frog.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
assets/oldphoto6.png filter=lfs diff=lfs merge=lfs -text
|
.github/dependabot.yml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# To get started with Dependabot version updates, you'll need to specify which
|
2 |
+
# package ecosystems to update and where the package manifests are located.
|
3 |
+
# Please see the documentation for more information:
|
4 |
+
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
|
5 |
+
# https://containers.dev/guide/dependabot
|
6 |
+
|
7 |
+
version: 2
|
8 |
+
updates:
|
9 |
+
- package-ecosystem: "devcontainers"
|
10 |
+
directory: "/"
|
11 |
+
schedule:
|
12 |
+
interval: weekly
|
LICENSE
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
S-Lab License 1.0
|
2 |
+
|
3 |
+
Copyright 2022 S-Lab
|
4 |
+
|
5 |
+
Redistribution and use for non-commercial purpose in source and
|
6 |
+
binary forms, with or without modification, are permitted provided
|
7 |
+
that the following conditions are met:
|
8 |
+
|
9 |
+
1. Redistributions of source code must retain the above copyright
|
10 |
+
notice, this list of conditions and the following disclaimer.
|
11 |
+
|
12 |
+
2. Redistributions in binary form must reproduce the above copyright
|
13 |
+
notice, this list of conditions and the following disclaimer in
|
14 |
+
the documentation and/or other materials provided with the
|
15 |
+
distribution.
|
16 |
+
|
17 |
+
3. Neither the name of the copyright holder nor the names of its
|
18 |
+
contributors may be used to endorse or promote products derived
|
19 |
+
from this software without specific prior written permission.
|
20 |
+
|
21 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
22 |
+
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
23 |
+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
24 |
+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
25 |
+
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
26 |
+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
27 |
+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
28 |
+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
29 |
+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
30 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
31 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
32 |
+
|
33 |
+
In the event that redistribution and/or use for commercial purpose in
|
34 |
+
source or binary forms, with or without modification is required,
|
35 |
+
please contact the contributor(s) of the work.
|
README.md
CHANGED
@@ -1,12 +1,193 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji: 📉
|
4 |
-
colorFrom: yellow
|
5 |
-
colorTo: green
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 4.36.1
|
8 |
app_file: app.py
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
1 |
---
|
2 |
+
title: resshift
|
|
|
|
|
|
|
|
|
|
|
3 |
app_file: app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 4.29.0
|
6 |
+
---
|
7 |
+
# ResShift: Efficient Diffusion Model for Image Super-resolution by Residual Shifting (NeurIPS 2023, Spotlight)
|
8 |
+
|
9 |
+
[Zongsheng Yue](https://zsyoaoa.github.io/), [Jianyi Wang](https://iceclear.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/)
|
10 |
+
|
11 |
+
[Conference Paper](https://arxiv.org/abs/2307.12348) | [Journal Paper](http://arxiv.org/abs/2403.07319) | [Project Page](https://zsyoaoa.github.io/projects/resshift/) | [Video](https://www.youtube.com/watch?v=8DB-6Xvvl5o)
|
12 |
+
|
13 |
+
<a href="https://colab.research.google.com/drive/1CL8aJO7a_RA4MetanrCLqQO5H7KWO8KI?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a> [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/cjwbw/resshift) [![OpenXLab](https://img.shields.io/badge/Demo-%F0%9F%90%BC%20OpenXLab-blue)](https://openxlab.org.cn/apps/detail/Zongsheng/ResShift) ![visitors](https://visitor-badge.laobi.icu/badge?page_id=zsyOAOA/ResShift)
|
14 |
+
|
15 |
+
|
16 |
+
:star: If ResShift is helpful to your images or projects, please help star this repo. Thanks! :hugs:
|
17 |
+
|
18 |
---
|
19 |
+
>Diffusion-based image super-resolution (SR) methods are mainly limited by the low inference speed due to the requirements of hundreds or even thousands of sampling steps. Existing acceleration sampling techniques inevitably sacrifice performance to some extent, leading to over-blurry SR results. To address this issue, we propose a novel and efficient diffusion model for SR that significantly reduces the number of diffusion steps, thereby eliminating the need for post-acceleration during inference and its associated performance deterioration. Our method constructs a Markov chain that transfers between the high-resolution image and the low-resolution image by shifting the residual between them, substantially improving the transition efficiency. Additionally, an elaborate noise schedule is developed to flexibly control the shifting speed and the noise strength during the diffusion process. Extensive experiments demonstrate that the proposed method obtains superior or at least comparable performance to current state-of-the-art methods on both synthetic and real-world datasets, *even only with 15 sampling steps*.
|
20 |
+
><img src="./assets/framework.png" align="middle" width="800">
|
21 |
+
|
22 |
+
---
|
23 |
+
## Update
|
24 |
+
- **2024.03.11**: Update the code for the Journal paper
|
25 |
+
- **2023.12.02**: Add configurations for the x2 super-resolution task.
|
26 |
+
- **2023.08.15**: Add [![OpenXLab](https://img.shields.io/badge/Demo-%F0%9F%90%BC%20OpenXLab-blue)](https://openxlab.org.cn/apps/detail/Zongsheng/ResShift).
|
27 |
+
- **2023.08.15**: Add Gradio Demo.
|
28 |
+
- **2023.08.14**: Add bicubic (matlab resize) model.
|
29 |
+
- **2023.08.14**: Add [Project Page](https://zsyoaoa.github.io/projects/resshift/).
|
30 |
+
- **2023.08.02**: Add [Replicate](https://replicate.com/) demo [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/cjwbw/resshift).
|
31 |
+
- **2023.07.31**: Add Colab demo <a href="https://colab.research.google.com/drive/1CL8aJO7a_RA4MetanrCLqQO5H7KWO8KI?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>.
|
32 |
+
- **2023.07.24**: Create this repo.
|
33 |
+
|
34 |
+
## Requirements
|
35 |
+
* Python 3.10, Pytorch 2.1.2, [xformers](https://github.com/facebookresearch/xformers) 0.0.23
|
36 |
+
* More detail (See [environment.yml](environment.yml))
|
37 |
+
A suitable [conda](https://conda.io/) environment named `resshift` can be created and activated with:
|
38 |
+
|
39 |
+
```
|
40 |
+
conda create -n resshift python=3.10
|
41 |
+
conda activate resshift
|
42 |
+
pip install -r requirements.txt
|
43 |
+
```
|
44 |
+
or
|
45 |
+
```
|
46 |
+
conda env create -f environment.yml
|
47 |
+
conda activate resshift
|
48 |
+
```
|
49 |
+
|
50 |
+
## Applications
|
51 |
+
### :point_right: Real-world image super-resolution
|
52 |
+
[<img src="assets/0015.png" height="324px"/>](https://imgsli.com/MTkzNzgz) [<img src="assets/0030.png" height="324px"/>](https://imgsli.com/MTkzNzgx)
|
53 |
+
|
54 |
+
[<img src="assets/frog.png" height="324px"/>](https://imgsli.com/MTkzNzg0) [<img src="assets/dog2.png" height="324px">](https://imgsli.com/MTkzNzg3)
|
55 |
+
|
56 |
+
[<img src="assets/cat.png" height="252px"/>](https://imgsli.com/MTkzNzkx) [<img src="assets/Lincon.png" height="252px"/>](https://imgsli.com/MTkzNzk5) [<img src="assets/oldphoto6.png" height="252px"/>](https://imgsli.com/MTkzNzk2)
|
57 |
+
|
58 |
+
### :point_right: Image inpainting
|
59 |
+
<img src="testdata/inpainting/imagenet/lq_mark/ILSVRC2012_val_00001639_mark.png" height="126px"/> <img src="testdata/inpainting/imagenet/results/ILSVRC2012_val_00001639.png" height="126px"/> <img src="testdata/inpainting/imagenet/lq_mark/ILSVRC2012_val_00001810_mark.png" height="126px"/> <img src="testdata/inpainting/imagenet/results/ILSVRC2012_val_00001810.png" height="126px"/> <img src="testdata/inpainting/imagenet/lq_mark/ILSVRC2012_val_00001204_mark.png" height="126px"/> <img src="testdata/inpainting/imagenet/results/ILSVRC2012_val_00001204.png" height="126px"/>
|
60 |
+
<img src="testdata/inpainting/imagenet/lq_mark/ILSVRC2012_val_00002438_mark.png" height="126px"/> <img src="testdata/inpainting/imagenet/results/ILSVRC2012_val_00002438.png" height="126px"/> <img src="testdata/inpainting/imagenet/lq_mark/ILSVRC2012_val_00005693_mark.png" height="126px"/> <img src="testdata/inpainting/imagenet/results/ILSVRC2012_val_00005693.png" height="126px"/> <img src="testdata/inpainting/imagenet/lq_mark/ILSVRC2012_val_00005814_mark.png" height="126px"/> <img src="testdata/inpainting/imagenet/results/ILSVRC2012_val_00005814.png" height="126px"/>
|
61 |
+
<img src="testdata/inpainting/face/lq_mark/94_mark.png" height="126px"/> <img src="testdata/inpainting/face/results/94.png" height="126px"/> <img src="testdata/inpainting/face/lq_mark/410_mark.png" height="126px"/> <img src="testdata/inpainting/face/results/410.png" height="126px"/> <img src="testdata/inpainting/face/lq_mark/269_mark.png" height="126px"/> <img src="testdata/inpainting/face/results/269.png" height="126px"/>
|
62 |
+
<img src="testdata/inpainting/face/lq_mark/321_mark.png" height="126px"/> <img src="testdata/inpainting/face/results/321.png" height="126px"/> <img src="testdata/inpainting/face/lq_mark/5915_mark.png" height="126px"/> <img src="testdata/inpainting/face/results/5915.png" height="126px"/> <img src="testdata/inpainting/face/lq_mark/5489_mark.png" height="126px"/> <img src="testdata/inpainting/face/results/5489.png" height="126px"/>
|
63 |
+
|
64 |
+
### :point_right: Blind Face Restoration
|
65 |
+
<img src="testdata/faceir/cropped_faces/lq/0729.png" height="126px"/> <img src="testdata/faceir/cropped_faces/results/0729.png" height="126px"/> <img src="testdata/faceir/cropped_faces/lq/0444.png" height="126px"/> <img src="testdata/faceir/cropped_faces/results/0444.png" height="126px"/> <img src="testdata/faceir/cropped_faces/lq/0885.png" height="126px"/> <img src="testdata/faceir/cropped_faces/results/0885.png" height="126px"/>
|
66 |
+
<img src="testdata/faceir/cropped_faces/lq/0500.png" height="126px"/> <img src="testdata/faceir/cropped_faces/results/0500.png" height="126px"/> <img src="testdata/faceir/cropped_faces/lq/Solvay_conference_1927_2_16.png" height="126px"/> <img src="testdata/faceir/cropped_faces/results/Solvay_conference_1927_0018.png" height="126px"/> <img src="testdata/faceir/cropped_faces/lq/Solvay_conference_1927_2_16.png" height="126px"/> <img src="testdata/faceir/cropped_faces/results/Solvay_conference_1927_2_16.png" height="126px"/>
|
67 |
+
|
68 |
+
## Online Demo
|
69 |
+
You can try our method through an online demo:
|
70 |
+
```
|
71 |
+
python app.py
|
72 |
+
```
|
73 |
+
|
74 |
+
## Fast Testing
|
75 |
+
#### :tiger: Real-world image super-resolution
|
76 |
+
|
77 |
+
```
|
78 |
+
python inference_resshift.py -i [image folder/image path] -o [result folder] --task realsr --scale 4 --version v3
|
79 |
+
```
|
80 |
+
#### :lion: Bicubic (resize by Matlab) image super-resolution
|
81 |
+
```
|
82 |
+
python inference_resshift.py -i [image folder/image path] -o [result folder] --task bicsr --scale 4
|
83 |
+
```
|
84 |
+
#### :snake: Natural image inpainting
|
85 |
+
```
|
86 |
+
python inference_resshift.py -i [image folder/image path] -o [result folder] --mask_path [mask path] --task inpaint_imagenet --scale 1
|
87 |
+
```
|
88 |
+
#### :crocodile: Face image inpainting
|
89 |
+
```
|
90 |
+
python inference_resshift.py -i [image folder/image path] -o [result folder] --mask_path [mask path] --task inpaint_face --scale 1
|
91 |
+
```
|
92 |
+
#### :octopus: Blind Face Restoration
|
93 |
+
```
|
94 |
+
python inference_resshift.py -i [image folder/image path] -o [result folder] --task faceir --scale 1
|
95 |
+
```
|
96 |
+
|
97 |
+
### Training
|
98 |
+
#### :turtle: Preparing stage
|
99 |
+
1. Download the pre-trained VQGAN model from this [link](https://github.com/zsyOAOA/ResShift/releases) and put it in the folder of 'weights'
|
100 |
+
2. Adjust the data path in the [config](configs) file.
|
101 |
+
3. Adjust batchsize according your GPUS.
|
102 |
+
+ configs.train.batch: [training batchsize, validation batchsize]
|
103 |
+
+ configs.train.microbatch: total batchsize = microbatch * #GPUS * num_grad_accumulation
|
104 |
+
|
105 |
+
#### :dolphin: Real-world Image Super-resolution for NeurIPS
|
106 |
+
```
|
107 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --standalone --nproc_per_node=8 --nnodes=1 main.py --cfg_path configs/realsr_swinunet_realesrgan256.yaml --save_dir [Logging Folder]
|
108 |
+
```
|
109 |
+
#### :whale: Real-world Image Super-resolution for Journal
|
110 |
+
```
|
111 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --standalone --nproc_per_node=8 --nnodes=1 main.py --cfg_path configs/realsr_swinunet_realesrgan256_journal.yaml --save_dir [Logging Folder]
|
112 |
+
```
|
113 |
+
#### :ox: Image inpainting (Natural) for Journal
|
114 |
+
```
|
115 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --standalone --nproc_per_node=8 --nnodes=1 main.py --cfg_path configs/inpaint_lama256_imagenet.yaml --save_dir [Logging Folder]
|
116 |
+
```
|
117 |
+
#### :honeybee: Image inpainting (Face) for Journal
|
118 |
+
```
|
119 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --standalone --nproc_per_node=8 --nnodes=1 main.py --cfg_path configs/inpaint_lama256_face.yaml --save_dir [Logging Folder]
|
120 |
+
```
|
121 |
+
#### :frog: Blind face restoration for Journal
|
122 |
+
```
|
123 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --standalone --nproc_per_node=8 --nnodes=1 main.py --cfg_path configs/faceir_gfpgan512_lpips.yaml --save_dir [Logging Folder]
|
124 |
+
```
|
125 |
+
|
126 |
+
### Reproducing the results in our paper
|
127 |
+
#### :red_car: Prepare data
|
128 |
+
+ Synthetic data for image super-resolution: [Link](https://drive.google.com/file/d/1NhmpON2dB2LjManfX6uIj8Pj_Jx6N-6l/view?usp=sharing)
|
129 |
+
|
130 |
+
+ Real data for image super-resolution: [RealSet65](testdata/RealSet65) | [RealSet80](testdata/RealSet80)
|
131 |
+
|
132 |
+
+ Synthetic data for natural image inpainting: [Link](https://drive.google.com/file/d/11_1xntiGnZzRX87Ve6thbeWDPVksTHNC/view?usp=sharing)
|
133 |
+
|
134 |
+
+ Synthetic data for face image inpainting: [Link](https://drive.google.com/file/d/1nyfry2XjgA_qV8fS2_Y5TNQwtG5Irwux/view?usp=sharing)
|
135 |
+
|
136 |
+
+ Synthetic data for blind face restoration: [Link](https://drive.google.com/file/d/15Ij-UaI8BQ7fBDF0i4M1wDOk-bnn_C4X/view?usp=drive_link)
|
137 |
+
|
138 |
+
#### :rocket: Image super-resolution
|
139 |
+
Reproduce the results in Table 3 of our NeurIPS paper:
|
140 |
+
```
|
141 |
+
python inference_resshift.py -i [image folder/image path] -o [result folder] --task realsr --scale 4 --version v1 --chop_size 64 --chop_stride 64 --bs 64
|
142 |
+
```
|
143 |
+
Reproduce the results in Table 4 of our NeurIPS paper:
|
144 |
+
```
|
145 |
+
python inference_resshift.py -i [image folder/image path] -o [result folder] --task realsr --scale 4 --version v1 --chop_size 512 --chop_stride 448
|
146 |
+
```
|
147 |
+
Reproduce the results in Table 2 of our Journal paper:
|
148 |
+
```
|
149 |
+
python inference_resshift.py -i [image folder/image path] -o [result folder] --task realsr --scale 4 --version v3 --chop_size 64 --chop_stride 64 --bs 64
|
150 |
+
```
|
151 |
+
Reproduce the results in Table 3 of our Journal paper:
|
152 |
+
```
|
153 |
+
python inference_resshift.py -i [image folder/image path] -o [result folder] --task realsr --scale 4 --version v3 --chop_size 512 --chop_stride 448
|
154 |
+
```
|
155 |
+
##### Model card:
|
156 |
+
+ version-1: Conference paper, 15 diffusion steps, trained with 300k iterations.
|
157 |
+
+ version-2: Conference paper, 15 diffusion steps, trained with 500k iterations.
|
158 |
+
+ version-3: Journal paper, 4 diffusion steps.
|
159 |
+
|
160 |
+
#### :airplane: Image inpainting
|
161 |
+
Reproduce the results in Table 4 of our Journal paper:
|
162 |
+
```
|
163 |
+
python inference_resshift.py -i [image folder/image path] -o [result folder] --mask_path [mask path] --task inpaint_imagenet --scale 1 --chop_size 256 --chop_stride 256 --bs 32
|
164 |
+
```
|
165 |
+
Reproduce the results in Table 5 of our Journal paper:
|
166 |
+
```
|
167 |
+
python inference_resshift.py -i [image folder/image path] -o [result folder] --mask_path [mask path] --task inpaint_face --scale 1 --chop_size 256 --chop_stride 256 --bs 32
|
168 |
+
```
|
169 |
+
#### :boat: Blind Face Restoration
|
170 |
+
Reproduce the results in Table 6 of our Journal paper (arXiv):
|
171 |
+
```
|
172 |
+
python inference_resshift.py -i [image folder/image path] -o [result folder] --task faceir --scale 1 --chop_size 256 --chop_stride 256 --bs 16
|
173 |
+
```
|
174 |
+
|
175 |
+
<!--## Note on General Restoration Task-->
|
176 |
+
<!--For general restoration task, please adjust the settings in the config file:-->
|
177 |
+
<!--```-->
|
178 |
+
<!--model.params.lq_size: resolution of the low-quality image. # should be divided by 64-->
|
179 |
+
<!--diffusion.params.sf: scale factor for super-resolution, 1 for restoration task.-->
|
180 |
+
<!--degradation.sf: scale factor for super-resolution, 1 for restoration task. # only required for the pipeline of Real-Esrgan -->
|
181 |
+
<!--```-->
|
182 |
+
<!--In some cases, you need to rewrite the data loading process. -->
|
183 |
+
|
184 |
+
## License
|
185 |
+
|
186 |
+
This project is licensed under <a rel="license" href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">NTU S-Lab License 1.0</a>. Redistribution and use should follow this license.
|
187 |
+
|
188 |
+
## Acknowledgement
|
189 |
+
|
190 |
+
This project is based on [Improved Diffusion Model](https://github.com/openai/improved-diffusion), [LDM](https://github.com/CompVis/latent-diffusion), and [BasicSR](https://github.com/XPixelGroup/BasicSR). We also adopt [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to synthesize the training data for real-world super-resolution. Thanks for their awesome works.
|
191 |
|
192 |
+
### Contact
|
193 |
+
If you have any questions, please feel free to contact me via `[email protected]`.
|
__pycache__/sampler.cpython-310.pyc
ADDED
Binary file (8.29 kB). View file
|
|
__pycache__/sampler.cpython-38.pyc
ADDED
Binary file (8.1 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Power by Zongsheng Yue 2023-08-15 09:39:58
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import gradio as gr
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
from sampler import ResShiftSampler
|
11 |
+
|
12 |
+
from utils import util_image
|
13 |
+
from basicsr.utils.download_util import load_file_from_url
|
14 |
+
|
15 |
+
_STEP = {
|
16 |
+
'v1': 15,
|
17 |
+
'v2': 15,
|
18 |
+
'v3': 4,
|
19 |
+
'bicsr': 4,
|
20 |
+
'inpaint_imagenet': 4,
|
21 |
+
'inpaint_face': 4,
|
22 |
+
'faceir': 4,
|
23 |
+
}
|
24 |
+
_LINK = {
|
25 |
+
'vqgan': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/autoencoder_vq_f4.pth',
|
26 |
+
'vqgan_face256': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/celeba256_vq_f4_dim3_face.pth',
|
27 |
+
'vqgan_face512': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/ffhq512_vq_f8_dim8_face.pth',
|
28 |
+
'v1': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_realsrx4_s15_v1.pth',
|
29 |
+
'v2': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_realsrx4_s15_v2.pth',
|
30 |
+
'v3': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_realsrx4_s4_v3.pth',
|
31 |
+
'bicsr': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_bicsrx4_s4.pth',
|
32 |
+
'inpaint_imagenet': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_inpainting_imagenet_s4.pth',
|
33 |
+
'inpaint_face': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_inpainting_face_s4.pth',
|
34 |
+
'faceir': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_faceir_s4.pth',
|
35 |
+
}
|
36 |
+
|
37 |
+
def get_configs(task='realsr', version='v3', scale=4):
|
38 |
+
ckpt_dir = Path('./weights')
|
39 |
+
if not ckpt_dir.exists():
|
40 |
+
ckpt_dir.mkdir()
|
41 |
+
|
42 |
+
if task == 'realsr':
|
43 |
+
if version in ['v1', 'v2']:
|
44 |
+
configs = OmegaConf.load('./configs/realsr_swinunet_realesrgan256.yaml')
|
45 |
+
elif version == 'v3':
|
46 |
+
configs = OmegaConf.load('./configs/realsr_swinunet_realesrgan256_journal.yaml')
|
47 |
+
else:
|
48 |
+
raise ValueError(f"Unexpected version type: {version}")
|
49 |
+
assert scale == 4, 'We only support the 4x super-resolution now!'
|
50 |
+
ckpt_url = _LINK[version]
|
51 |
+
ckpt_path = ckpt_dir / f'resshift_{task}x{scale}_s{_STEP[version]}_{version}.pth'
|
52 |
+
vqgan_url = _LINK['vqgan']
|
53 |
+
vqgan_path = ckpt_dir / f'autoencoder_vq_f4.pth'
|
54 |
+
elif task == 'bicsr':
|
55 |
+
configs = OmegaConf.load('./configs/bicx4_swinunet_lpips.yaml')
|
56 |
+
assert scale == 4, 'We only support the 4x super-resolution now!'
|
57 |
+
ckpt_url = _LINK[task]
|
58 |
+
ckpt_path = ckpt_dir / f'resshift_{task}x{scale}_s{_STEP[task]}.pth'
|
59 |
+
vqgan_url = _LINK['vqgan']
|
60 |
+
vqgan_path = ckpt_dir / f'autoencoder_vq_f4.pth'
|
61 |
+
# elif task == 'inpaint_imagenet':
|
62 |
+
# configs = OmegaConf.load('./configs/inpaint_lama256_imagenet.yaml')
|
63 |
+
# assert scale == 1, 'Please set scale equals 1 for image inpainting!'
|
64 |
+
# ckpt_url = _LINK[task]
|
65 |
+
# ckpt_path = ckpt_dir / f'resshift_{task}_s{_STEP[task]}.pth'
|
66 |
+
# vqgan_url = _LINK['vqgan']
|
67 |
+
# vqgan_path = ckpt_dir / f'autoencoder_vq_f4.pth'
|
68 |
+
# elif task == 'inpaint_face':
|
69 |
+
# configs = OmegaConf.load('./configs/inpaint_lama256_face.yaml')
|
70 |
+
# assert scale == 1, 'Please set scale equals 1 for image inpainting!'
|
71 |
+
# ckpt_url = _LINK[task]
|
72 |
+
# ckpt_path = ckpt_dir / f'resshift_{task}_s{_STEP[task]}.pth'
|
73 |
+
# vqgan_url = _LINK['vqgan_face256']
|
74 |
+
# vqgan_path = ckpt_dir / f'celeba256_vq_f4_dim3_face.pth'
|
75 |
+
# elif task == 'faceir':
|
76 |
+
# configs = OmegaConf.load('./configs/faceir_gfpgan512_lpips.yaml')
|
77 |
+
# assert scale == 1, 'Please set scale equals 1 for face restoration!'
|
78 |
+
# ckpt_url = _LINK[task]
|
79 |
+
# ckpt_path = ckpt_dir / f'resshift_{task}_s{_STEP[task]}.pth'
|
80 |
+
# vqgan_url = _LINK['vqgan_face512']
|
81 |
+
# vqgan_path = ckpt_dir / f'ffhq512_vq_f8_dim8_face.pth'
|
82 |
+
else:
|
83 |
+
raise TypeError(f"Unexpected task type: {task}!")
|
84 |
+
|
85 |
+
# prepare the checkpoint
|
86 |
+
if not ckpt_path.exists():
|
87 |
+
load_file_from_url(
|
88 |
+
url=ckpt_url,
|
89 |
+
model_dir=ckpt_dir,
|
90 |
+
progress=True,
|
91 |
+
file_name=ckpt_path.name,
|
92 |
+
)
|
93 |
+
if not vqgan_path.exists():
|
94 |
+
load_file_from_url(
|
95 |
+
url=vqgan_url,
|
96 |
+
model_dir=ckpt_dir,
|
97 |
+
progress=True,
|
98 |
+
file_name=vqgan_path.name,
|
99 |
+
)
|
100 |
+
|
101 |
+
configs.model.ckpt_path = str(ckpt_path)
|
102 |
+
configs.diffusion.params.sf = scale
|
103 |
+
configs.autoencoder.ckpt_path = str(vqgan_path)
|
104 |
+
|
105 |
+
return configs
|
106 |
+
|
107 |
+
def predict(in_path, task='realsrx4', seed=12345, scale=4, version='v3'):
|
108 |
+
configs = get_configs(task, version, scale)
|
109 |
+
resshift_sampler = ResShiftSampler(
|
110 |
+
configs,
|
111 |
+
sf=scale,
|
112 |
+
chop_size=256,
|
113 |
+
chop_stride=224,
|
114 |
+
chop_bs=1,
|
115 |
+
use_amp=True,
|
116 |
+
seed=seed,
|
117 |
+
padding_offset=configs.model.params.get('lq_size', 64),
|
118 |
+
)
|
119 |
+
|
120 |
+
out_dir = Path('restored_output')
|
121 |
+
if not out_dir.exists():
|
122 |
+
out_dir.mkdir()
|
123 |
+
|
124 |
+
resshift_sampler.inference(
|
125 |
+
in_path,
|
126 |
+
out_dir,
|
127 |
+
mask_path=None,
|
128 |
+
bs=1,
|
129 |
+
noise_repeat=False
|
130 |
+
)
|
131 |
+
|
132 |
+
out_path = out_dir / f"{Path(in_path).stem}.png"
|
133 |
+
assert out_path.exists(), 'Super-resolution failed!'
|
134 |
+
im_sr = util_image.imread(out_path, chn="rgb", dtype="uint8")
|
135 |
+
|
136 |
+
return im_sr, str(out_path)
|
137 |
+
|
138 |
+
title = "ResShift: Efficient Diffusion Model for Image Super-resolution by Residual Shifting"
|
139 |
+
description = r"""
|
140 |
+
<b>Official Gradio demo</b> for <a href='https://github.com/zsyOAOA/ResShift' target='_blank'><b>ResShift: Efficient Diffusion Model for Image Super-resolution by Residual Shifting</b></a>.<br>
|
141 |
+
🔥 ResShift is an efficient diffusion model designed for image super-resolution or restoration.<br>
|
142 |
+
"""
|
143 |
+
article = r"""
|
144 |
+
If ResShift is helpful for your work, please help to ⭐ the <a href='https://github.com/zsyOAOA/ResShift' target='_blank'>Github Repo</a>. Thanks!
|
145 |
+
[![GitHub Stars](https://img.shields.io/github/stars/zsyOAOA/ResShift?affiliations=OWNER&color=green&style=social)](https://github.com/zsyOAOA/ResShift)
|
146 |
+
|
147 |
+
---
|
148 |
+
If our work is useful for your research, please consider citing:
|
149 |
+
```bibtex
|
150 |
+
@inproceedings{yue2023resshift,
|
151 |
+
title={ResShift: Efficient Diffusion Model for Image Super-resolution by Residual Shifting},
|
152 |
+
author={Yue, Zongsheng and Wang, Jianyi and Loy, Chen Change},
|
153 |
+
booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
|
154 |
+
year={2023},
|
155 |
+
volume = {36},
|
156 |
+
pages = {13294--13307},
|
157 |
+
}
|
158 |
+
```
|
159 |
+
|
160 |
+
📋 **License**
|
161 |
+
|
162 |
+
This project is licensed under <a rel="license" href="https://github.com/zsyOAOA/ResShift/blob/master/LICENSE">S-Lab License 1.0</a>.
|
163 |
+
Redistribution and use for non-commercial purposes should follow this license.
|
164 |
+
|
165 |
+
📧 **Contact**
|
166 |
+
|
167 |
+
If you have any questions, please feel free to contact me via <b>[email protected]</b>.
|
168 |
+
![visitors](https://visitor-badge.laobi.icu/badge?page_id=zsyOAOA/ResShift)
|
169 |
+
"""
|
170 |
+
demo = gr.Interface(
|
171 |
+
fn=predict,
|
172 |
+
inputs=[
|
173 |
+
gr.Image(type="filepath", label="Input: Low Quality Image"),
|
174 |
+
gr.Dropdown(
|
175 |
+
choices=["realsr", "bicsr"],
|
176 |
+
value="realsr",
|
177 |
+
label="Task",
|
178 |
+
),
|
179 |
+
gr.Number(value=12345, precision=0, label="Ranom seed")
|
180 |
+
],
|
181 |
+
outputs=[
|
182 |
+
gr.Image(type="numpy", label="Output: High Quality Image"),
|
183 |
+
gr.outputs.File(label="Download the output")
|
184 |
+
],
|
185 |
+
title=title,
|
186 |
+
description=description,
|
187 |
+
article=article,
|
188 |
+
examples=[
|
189 |
+
['./testdata/RealSet65/0030.jpg', "realsr", 12345],
|
190 |
+
['./testdata/RealSet65/dog2.png', "realsr", 12345],
|
191 |
+
['./testdata/RealSet65/bears.jpg', "realsr", 12345],
|
192 |
+
['./testdata/RealSet65/oldphoto6.png', "realsr", 12345],
|
193 |
+
['./testdata/Bicubicx4/lq_matlab/ILSVRC2012_val_00000067.png', "bicsr", 12345],
|
194 |
+
['./testdata/Bicubicx4/lq_matlab/ILSVRC2012_val_00016898.png', "bicsr", 12345],
|
195 |
+
]
|
196 |
+
)
|
197 |
+
|
198 |
+
demo.queue(concurrency_count=4)
|
199 |
+
demo.launch(share=True)
|
200 |
+
|
assets/0015.png
ADDED
Git LFS Details
|
assets/0030.png
ADDED
Git LFS Details
|
assets/Lincon.png
ADDED
Git LFS Details
|
assets/cat.png
ADDED
Git LFS Details
|
assets/dog2.png
ADDED
Git LFS Details
|
assets/framework.png
ADDED
assets/frog.png
ADDED
Git LFS Details
|
assets/oldphoto6.png
ADDED
Git LFS Details
|
basicsr/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/xinntao/BasicSR
|
2 |
+
# flake8: noqa
|
3 |
+
from .data import *
|
4 |
+
from .utils import *
|
basicsr/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (160 Bytes). View file
|
|
basicsr/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (193 Bytes). View file
|
|
basicsr/data/__init__.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import torch.utils.data
|
6 |
+
from copy import deepcopy
|
7 |
+
from functools import partial
|
8 |
+
from os import path as osp
|
9 |
+
|
10 |
+
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
|
11 |
+
from basicsr.utils import get_root_logger, scandir
|
12 |
+
from basicsr.utils.dist_util import get_dist_info
|
13 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
14 |
+
|
15 |
+
__all__ = ['build_dataset', 'build_dataloader']
|
16 |
+
|
17 |
+
# automatically scan and import dataset modules for registry
|
18 |
+
# scan all the files under the data folder with '_dataset' in file names
|
19 |
+
data_folder = osp.dirname(osp.abspath(__file__))
|
20 |
+
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
21 |
+
# import all the dataset modules
|
22 |
+
_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
|
23 |
+
|
24 |
+
|
25 |
+
def build_dataset(dataset_opt):
|
26 |
+
"""Build dataset from options.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
dataset_opt (dict): Configuration for dataset. It must contain:
|
30 |
+
name (str): Dataset name.
|
31 |
+
type (str): Dataset type.
|
32 |
+
"""
|
33 |
+
dataset_opt = deepcopy(dataset_opt)
|
34 |
+
dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
|
35 |
+
logger = get_root_logger()
|
36 |
+
logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
|
37 |
+
return dataset
|
38 |
+
|
39 |
+
|
40 |
+
def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
|
41 |
+
"""Build dataloader.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
45 |
+
dataset_opt (dict): Dataset options. It contains the following keys:
|
46 |
+
phase (str): 'train' or 'val'.
|
47 |
+
num_worker_per_gpu (int): Number of workers for each GPU.
|
48 |
+
batch_size_per_gpu (int): Training batch size for each GPU.
|
49 |
+
num_gpu (int): Number of GPUs. Used only in the train phase.
|
50 |
+
Default: 1.
|
51 |
+
dist (bool): Whether in distributed training. Used only in the train
|
52 |
+
phase. Default: False.
|
53 |
+
sampler (torch.utils.data.sampler): Data sampler. Default: None.
|
54 |
+
seed (int | None): Seed. Default: None
|
55 |
+
"""
|
56 |
+
phase = dataset_opt['phase']
|
57 |
+
rank, _ = get_dist_info()
|
58 |
+
if phase == 'train':
|
59 |
+
if dist: # distributed training
|
60 |
+
batch_size = dataset_opt['batch_size_per_gpu']
|
61 |
+
num_workers = dataset_opt['num_worker_per_gpu']
|
62 |
+
else: # non-distributed training
|
63 |
+
multiplier = 1 if num_gpu == 0 else num_gpu
|
64 |
+
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
|
65 |
+
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
|
66 |
+
dataloader_args = dict(
|
67 |
+
dataset=dataset,
|
68 |
+
batch_size=batch_size,
|
69 |
+
shuffle=False,
|
70 |
+
num_workers=num_workers,
|
71 |
+
sampler=sampler,
|
72 |
+
drop_last=True)
|
73 |
+
if sampler is None:
|
74 |
+
dataloader_args['shuffle'] = True
|
75 |
+
dataloader_args['worker_init_fn'] = partial(
|
76 |
+
worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
|
77 |
+
elif phase in ['val', 'test']: # validation
|
78 |
+
dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
79 |
+
else:
|
80 |
+
raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
|
81 |
+
|
82 |
+
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
|
83 |
+
dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
|
84 |
+
|
85 |
+
prefetch_mode = dataset_opt.get('prefetch_mode')
|
86 |
+
if prefetch_mode == 'cpu': # CPUPrefetcher
|
87 |
+
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
|
88 |
+
logger = get_root_logger()
|
89 |
+
logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
|
90 |
+
return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
|
91 |
+
else:
|
92 |
+
# prefetch_mode=None: Normal dataloader
|
93 |
+
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
|
94 |
+
return torch.utils.data.DataLoader(**dataloader_args)
|
95 |
+
|
96 |
+
|
97 |
+
def worker_init_fn(worker_id, num_workers, rank, seed):
|
98 |
+
# Set the worker seed to num_workers * rank + worker_id + seed
|
99 |
+
worker_seed = num_workers * rank + worker_id + seed
|
100 |
+
np.random.seed(worker_seed)
|
101 |
+
random.seed(worker_seed)
|
basicsr/data/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (3.55 kB). View file
|
|
basicsr/data/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (3.59 kB). View file
|
|
basicsr/data/__pycache__/data_util.cpython-310.pyc
ADDED
Binary file (11.2 kB). View file
|
|
basicsr/data/__pycache__/data_util.cpython-38.pyc
ADDED
Binary file (11.2 kB). View file
|
|
basicsr/data/__pycache__/degradations.cpython-310.pyc
ADDED
Binary file (20.3 kB). View file
|
|
basicsr/data/__pycache__/degradations.cpython-38.pyc
ADDED
Binary file (21.7 kB). View file
|
|
basicsr/data/__pycache__/ffhq_dataset.cpython-310.pyc
ADDED
Binary file (3.05 kB). View file
|
|
basicsr/data/__pycache__/ffhq_dataset.cpython-38.pyc
ADDED
Binary file (3.02 kB). View file
|
|
basicsr/data/__pycache__/paired_image_dataset.cpython-310.pyc
ADDED
Binary file (3.84 kB). View file
|
|
basicsr/data/__pycache__/paired_image_dataset.cpython-38.pyc
ADDED
Binary file (3.88 kB). View file
|
|
basicsr/data/__pycache__/prefetch_dataloader.cpython-310.pyc
ADDED
Binary file (4.32 kB). View file
|
|
basicsr/data/__pycache__/prefetch_dataloader.cpython-38.pyc
ADDED
Binary file (4.37 kB). View file
|
|
basicsr/data/__pycache__/realesrgan_dataset.cpython-310.pyc
ADDED
Binary file (8.58 kB). View file
|
|
basicsr/data/__pycache__/realesrgan_dataset.cpython-38.pyc
ADDED
Binary file (5.1 kB). View file
|
|
basicsr/data/__pycache__/realesrgan_paired_dataset.cpython-310.pyc
ADDED
Binary file (4.03 kB). View file
|
|
basicsr/data/__pycache__/realesrgan_paired_dataset.cpython-38.pyc
ADDED
Binary file (4.03 kB). View file
|
|
basicsr/data/__pycache__/reds_dataset.cpython-310.pyc
ADDED
Binary file (10.5 kB). View file
|
|
basicsr/data/__pycache__/reds_dataset.cpython-38.pyc
ADDED
Binary file (10.7 kB). View file
|
|
basicsr/data/__pycache__/single_image_dataset.cpython-310.pyc
ADDED
Binary file (2.82 kB). View file
|
|
basicsr/data/__pycache__/single_image_dataset.cpython-38.pyc
ADDED
Binary file (2.82 kB). View file
|
|
basicsr/data/__pycache__/transforms.cpython-310.pyc
ADDED
Binary file (5.97 kB). View file
|
|
basicsr/data/__pycache__/transforms.cpython-38.pyc
ADDED
Binary file (6.04 kB). View file
|
|
basicsr/data/__pycache__/video_test_dataset.cpython-310.pyc
ADDED
Binary file (10.1 kB). View file
|
|
basicsr/data/__pycache__/video_test_dataset.cpython-38.pyc
ADDED
Binary file (10.3 kB). View file
|
|
basicsr/data/__pycache__/vimeo90k_dataset.cpython-310.pyc
ADDED
Binary file (5.77 kB). View file
|
|
basicsr/data/__pycache__/vimeo90k_dataset.cpython-38.pyc
ADDED
Binary file (5.79 kB). View file
|
|
basicsr/data/data_sampler.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch.utils.data.sampler import Sampler
|
4 |
+
|
5 |
+
|
6 |
+
class EnlargedSampler(Sampler):
|
7 |
+
"""Sampler that restricts data loading to a subset of the dataset.
|
8 |
+
|
9 |
+
Modified from torch.utils.data.distributed.DistributedSampler
|
10 |
+
Support enlarging the dataset for iteration-based training, for saving
|
11 |
+
time when restart the dataloader after each epoch
|
12 |
+
|
13 |
+
Args:
|
14 |
+
dataset (torch.utils.data.Dataset): Dataset used for sampling.
|
15 |
+
num_replicas (int | None): Number of processes participating in
|
16 |
+
the training. It is usually the world_size.
|
17 |
+
rank (int | None): Rank of the current process within num_replicas.
|
18 |
+
ratio (int): Enlarging ratio. Default: 1.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, dataset, num_replicas, rank, ratio=1):
|
22 |
+
self.dataset = dataset
|
23 |
+
self.num_replicas = num_replicas
|
24 |
+
self.rank = rank
|
25 |
+
self.epoch = 0
|
26 |
+
self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
|
27 |
+
self.total_size = self.num_samples * self.num_replicas
|
28 |
+
|
29 |
+
def __iter__(self):
|
30 |
+
# deterministically shuffle based on epoch
|
31 |
+
g = torch.Generator()
|
32 |
+
g.manual_seed(self.epoch)
|
33 |
+
indices = torch.randperm(self.total_size, generator=g).tolist()
|
34 |
+
|
35 |
+
dataset_size = len(self.dataset)
|
36 |
+
indices = [v % dataset_size for v in indices]
|
37 |
+
|
38 |
+
# subsample
|
39 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
40 |
+
assert len(indices) == self.num_samples
|
41 |
+
|
42 |
+
return iter(indices)
|
43 |
+
|
44 |
+
def __len__(self):
|
45 |
+
return self.num_samples
|
46 |
+
|
47 |
+
def set_epoch(self, epoch):
|
48 |
+
self.epoch = epoch
|
basicsr/data/data_util.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from os import path as osp
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from basicsr.data.transforms import mod_crop
|
8 |
+
from basicsr.utils import img2tensor, scandir
|
9 |
+
|
10 |
+
|
11 |
+
def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
|
12 |
+
"""Read a sequence of images from a given folder path.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
path (list[str] | str): List of image paths or image folder path.
|
16 |
+
require_mod_crop (bool): Require mod crop for each image.
|
17 |
+
Default: False.
|
18 |
+
scale (int): Scale factor for mod_crop. Default: 1.
|
19 |
+
return_imgname(bool): Whether return image names. Default False.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
Tensor: size (t, c, h, w), RGB, [0, 1].
|
23 |
+
list[str]: Returned image name list.
|
24 |
+
"""
|
25 |
+
if isinstance(path, list):
|
26 |
+
img_paths = path
|
27 |
+
else:
|
28 |
+
img_paths = sorted(list(scandir(path, full_path=True)))
|
29 |
+
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
|
30 |
+
|
31 |
+
if require_mod_crop:
|
32 |
+
imgs = [mod_crop(img, scale) for img in imgs]
|
33 |
+
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
|
34 |
+
imgs = torch.stack(imgs, dim=0)
|
35 |
+
|
36 |
+
if return_imgname:
|
37 |
+
imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
|
38 |
+
return imgs, imgnames
|
39 |
+
else:
|
40 |
+
return imgs
|
41 |
+
|
42 |
+
|
43 |
+
def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
|
44 |
+
"""Generate an index list for reading `num_frames` frames from a sequence
|
45 |
+
of images.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
crt_idx (int): Current center index.
|
49 |
+
max_frame_num (int): Max number of the sequence of images (from 1).
|
50 |
+
num_frames (int): Reading num_frames frames.
|
51 |
+
padding (str): Padding mode, one of
|
52 |
+
'replicate' | 'reflection' | 'reflection_circle' | 'circle'
|
53 |
+
Examples: current_idx = 0, num_frames = 5
|
54 |
+
The generated frame indices under different padding mode:
|
55 |
+
replicate: [0, 0, 0, 1, 2]
|
56 |
+
reflection: [2, 1, 0, 1, 2]
|
57 |
+
reflection_circle: [4, 3, 0, 1, 2]
|
58 |
+
circle: [3, 4, 0, 1, 2]
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
list[int]: A list of indices.
|
62 |
+
"""
|
63 |
+
assert num_frames % 2 == 1, 'num_frames should be an odd number.'
|
64 |
+
assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
|
65 |
+
|
66 |
+
max_frame_num = max_frame_num - 1 # start from 0
|
67 |
+
num_pad = num_frames // 2
|
68 |
+
|
69 |
+
indices = []
|
70 |
+
for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
|
71 |
+
if i < 0:
|
72 |
+
if padding == 'replicate':
|
73 |
+
pad_idx = 0
|
74 |
+
elif padding == 'reflection':
|
75 |
+
pad_idx = -i
|
76 |
+
elif padding == 'reflection_circle':
|
77 |
+
pad_idx = crt_idx + num_pad - i
|
78 |
+
else:
|
79 |
+
pad_idx = num_frames + i
|
80 |
+
elif i > max_frame_num:
|
81 |
+
if padding == 'replicate':
|
82 |
+
pad_idx = max_frame_num
|
83 |
+
elif padding == 'reflection':
|
84 |
+
pad_idx = max_frame_num * 2 - i
|
85 |
+
elif padding == 'reflection_circle':
|
86 |
+
pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
|
87 |
+
else:
|
88 |
+
pad_idx = i - num_frames
|
89 |
+
else:
|
90 |
+
pad_idx = i
|
91 |
+
indices.append(pad_idx)
|
92 |
+
return indices
|
93 |
+
|
94 |
+
|
95 |
+
def paired_paths_from_lmdb(folders, keys):
|
96 |
+
"""Generate paired paths from lmdb files.
|
97 |
+
|
98 |
+
Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
|
99 |
+
|
100 |
+
::
|
101 |
+
|
102 |
+
lq.lmdb
|
103 |
+
├── data.mdb
|
104 |
+
├── lock.mdb
|
105 |
+
├── meta_info.txt
|
106 |
+
|
107 |
+
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
108 |
+
https://lmdb.readthedocs.io/en/release/ for more details.
|
109 |
+
|
110 |
+
The meta_info.txt is a specified txt file to record the meta information
|
111 |
+
of our datasets. It will be automatically created when preparing
|
112 |
+
datasets by our provided dataset tools.
|
113 |
+
Each line in the txt file records
|
114 |
+
1)image name (with extension),
|
115 |
+
2)image shape,
|
116 |
+
3)compression level, separated by a white space.
|
117 |
+
Example: `baboon.png (120,125,3) 1`
|
118 |
+
|
119 |
+
We use the image name without extension as the lmdb key.
|
120 |
+
Note that we use the same key for the corresponding lq and gt images.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
folders (list[str]): A list of folder path. The order of list should
|
124 |
+
be [input_folder, gt_folder].
|
125 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
126 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
127 |
+
Note that this key is different from lmdb keys.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
list[str]: Returned path list.
|
131 |
+
"""
|
132 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
133 |
+
f'But got {len(folders)}')
|
134 |
+
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
135 |
+
input_folder, gt_folder = folders
|
136 |
+
input_key, gt_key = keys
|
137 |
+
|
138 |
+
if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
|
139 |
+
raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
|
140 |
+
f'formats. But received {input_key}: {input_folder}; '
|
141 |
+
f'{gt_key}: {gt_folder}')
|
142 |
+
# ensure that the two meta_info files are the same
|
143 |
+
with open(osp.join(input_folder, 'meta_info.txt')) as fin:
|
144 |
+
input_lmdb_keys = [line.split('.')[0] for line in fin]
|
145 |
+
with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
|
146 |
+
gt_lmdb_keys = [line.split('.')[0] for line in fin]
|
147 |
+
if set(input_lmdb_keys) != set(gt_lmdb_keys):
|
148 |
+
raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
|
149 |
+
else:
|
150 |
+
paths = []
|
151 |
+
for lmdb_key in sorted(input_lmdb_keys):
|
152 |
+
paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
|
153 |
+
return paths
|
154 |
+
|
155 |
+
|
156 |
+
def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
|
157 |
+
"""Generate paired paths from an meta information file.
|
158 |
+
|
159 |
+
Each line in the meta information file contains the image names and
|
160 |
+
image shape (usually for gt), separated by a white space.
|
161 |
+
|
162 |
+
Example of an meta information file:
|
163 |
+
```
|
164 |
+
0001_s001.png (480,480,3)
|
165 |
+
0001_s002.png (480,480,3)
|
166 |
+
```
|
167 |
+
|
168 |
+
Args:
|
169 |
+
folders (list[str]): A list of folder path. The order of list should
|
170 |
+
be [input_folder, gt_folder].
|
171 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
172 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
173 |
+
meta_info_file (str): Path to the meta information file.
|
174 |
+
filename_tmpl (str): Template for each filename. Note that the
|
175 |
+
template excludes the file extension. Usually the filename_tmpl is
|
176 |
+
for files in the input folder.
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
list[str]: Returned path list.
|
180 |
+
"""
|
181 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
182 |
+
f'But got {len(folders)}')
|
183 |
+
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
184 |
+
input_folder, gt_folder = folders
|
185 |
+
input_key, gt_key = keys
|
186 |
+
|
187 |
+
with open(meta_info_file, 'r') as fin:
|
188 |
+
gt_names = [line.strip().split(' ')[0] for line in fin]
|
189 |
+
|
190 |
+
paths = []
|
191 |
+
for gt_name in gt_names:
|
192 |
+
basename, ext = osp.splitext(osp.basename(gt_name))
|
193 |
+
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
194 |
+
input_path = osp.join(input_folder, input_name)
|
195 |
+
gt_path = osp.join(gt_folder, gt_name)
|
196 |
+
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
197 |
+
return paths
|
198 |
+
|
199 |
+
|
200 |
+
def paired_paths_from_folder(folders, keys, filename_tmpl):
|
201 |
+
"""Generate paired paths from folders.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
folders (list[str]): A list of folder path. The order of list should
|
205 |
+
be [input_folder, gt_folder].
|
206 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
207 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
208 |
+
filename_tmpl (str): Template for each filename. Note that the
|
209 |
+
template excludes the file extension. Usually the filename_tmpl is
|
210 |
+
for files in the input folder.
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
list[str]: Returned path list.
|
214 |
+
"""
|
215 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
216 |
+
f'But got {len(folders)}')
|
217 |
+
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
218 |
+
input_folder, gt_folder = folders
|
219 |
+
input_key, gt_key = keys
|
220 |
+
|
221 |
+
input_paths = list(scandir(input_folder))
|
222 |
+
gt_paths = list(scandir(gt_folder))
|
223 |
+
assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
|
224 |
+
f'{len(input_paths)}, {len(gt_paths)}.')
|
225 |
+
paths = []
|
226 |
+
for gt_path in gt_paths:
|
227 |
+
basename, ext = osp.splitext(osp.basename(gt_path))
|
228 |
+
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
229 |
+
input_path = osp.join(input_folder, input_name)
|
230 |
+
assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.'
|
231 |
+
gt_path = osp.join(gt_folder, gt_path)
|
232 |
+
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
233 |
+
return paths
|
234 |
+
|
235 |
+
|
236 |
+
def paths_from_folder(folder):
|
237 |
+
"""Generate paths from folder.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
folder (str): Folder path.
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
list[str]: Returned path list.
|
244 |
+
"""
|
245 |
+
|
246 |
+
paths = list(scandir(folder))
|
247 |
+
paths = [osp.join(folder, path) for path in paths]
|
248 |
+
return paths
|
249 |
+
|
250 |
+
|
251 |
+
def paths_from_lmdb(folder):
|
252 |
+
"""Generate paths from lmdb.
|
253 |
+
|
254 |
+
Args:
|
255 |
+
folder (str): Folder path.
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
list[str]: Returned path list.
|
259 |
+
"""
|
260 |
+
if not folder.endswith('.lmdb'):
|
261 |
+
raise ValueError(f'Folder {folder}folder should in lmdb format.')
|
262 |
+
with open(osp.join(folder, 'meta_info.txt')) as fin:
|
263 |
+
paths = [line.split('.')[0] for line in fin]
|
264 |
+
return paths
|
265 |
+
|
266 |
+
|
267 |
+
def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
|
268 |
+
"""Generate Gaussian kernel used in `duf_downsample`.
|
269 |
+
|
270 |
+
Args:
|
271 |
+
kernel_size (int): Kernel size. Default: 13.
|
272 |
+
sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
np.array: The Gaussian kernel.
|
276 |
+
"""
|
277 |
+
from scipy.ndimage import filters as filters
|
278 |
+
kernel = np.zeros((kernel_size, kernel_size))
|
279 |
+
# set element at the middle to one, a dirac delta
|
280 |
+
kernel[kernel_size // 2, kernel_size // 2] = 1
|
281 |
+
# gaussian-smooth the dirac, resulting in a gaussian filter
|
282 |
+
return filters.gaussian_filter(kernel, sigma)
|
283 |
+
|
284 |
+
|
285 |
+
def duf_downsample(x, kernel_size=13, scale=4):
|
286 |
+
"""Downsamping with Gaussian kernel used in the DUF official code.
|
287 |
+
|
288 |
+
Args:
|
289 |
+
x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
|
290 |
+
kernel_size (int): Kernel size. Default: 13.
|
291 |
+
scale (int): Downsampling factor. Supported scale: (2, 3, 4).
|
292 |
+
Default: 4.
|
293 |
+
|
294 |
+
Returns:
|
295 |
+
Tensor: DUF downsampled frames.
|
296 |
+
"""
|
297 |
+
assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
|
298 |
+
|
299 |
+
squeeze_flag = False
|
300 |
+
if x.ndim == 4:
|
301 |
+
squeeze_flag = True
|
302 |
+
x = x.unsqueeze(0)
|
303 |
+
b, t, c, h, w = x.size()
|
304 |
+
x = x.view(-1, 1, h, w)
|
305 |
+
pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
|
306 |
+
x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
|
307 |
+
|
308 |
+
gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
|
309 |
+
gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
|
310 |
+
x = F.conv2d(x, gaussian_filter, stride=scale)
|
311 |
+
x = x[:, :, 2:-2, 2:-2]
|
312 |
+
x = x.view(b, t, c, x.size(2), x.size(3))
|
313 |
+
if squeeze_flag:
|
314 |
+
x = x.squeeze(0)
|
315 |
+
return x
|