yuhj95 commited on
Commit
4730cdc
1 Parent(s): 22efa73

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .devcontainer/Dockerfile +11 -0
  2. .devcontainer/devcontainer.json +27 -0
  3. .devcontainer/noop.txt +3 -0
  4. .gitattributes +7 -0
  5. .github/dependabot.yml +12 -0
  6. LICENSE +35 -0
  7. README.md +189 -8
  8. __pycache__/sampler.cpython-310.pyc +0 -0
  9. __pycache__/sampler.cpython-38.pyc +0 -0
  10. app.py +200 -0
  11. assets/0015.png +3 -0
  12. assets/0030.png +3 -0
  13. assets/Lincon.png +3 -0
  14. assets/cat.png +3 -0
  15. assets/dog2.png +3 -0
  16. assets/framework.png +0 -0
  17. assets/frog.png +3 -0
  18. assets/oldphoto6.png +3 -0
  19. basicsr/__init__.py +4 -0
  20. basicsr/__pycache__/__init__.cpython-310.pyc +0 -0
  21. basicsr/__pycache__/__init__.cpython-38.pyc +0 -0
  22. basicsr/data/__init__.py +101 -0
  23. basicsr/data/__pycache__/__init__.cpython-310.pyc +0 -0
  24. basicsr/data/__pycache__/__init__.cpython-38.pyc +0 -0
  25. basicsr/data/__pycache__/data_util.cpython-310.pyc +0 -0
  26. basicsr/data/__pycache__/data_util.cpython-38.pyc +0 -0
  27. basicsr/data/__pycache__/degradations.cpython-310.pyc +0 -0
  28. basicsr/data/__pycache__/degradations.cpython-38.pyc +0 -0
  29. basicsr/data/__pycache__/ffhq_dataset.cpython-310.pyc +0 -0
  30. basicsr/data/__pycache__/ffhq_dataset.cpython-38.pyc +0 -0
  31. basicsr/data/__pycache__/paired_image_dataset.cpython-310.pyc +0 -0
  32. basicsr/data/__pycache__/paired_image_dataset.cpython-38.pyc +0 -0
  33. basicsr/data/__pycache__/prefetch_dataloader.cpython-310.pyc +0 -0
  34. basicsr/data/__pycache__/prefetch_dataloader.cpython-38.pyc +0 -0
  35. basicsr/data/__pycache__/realesrgan_dataset.cpython-310.pyc +0 -0
  36. basicsr/data/__pycache__/realesrgan_dataset.cpython-38.pyc +0 -0
  37. basicsr/data/__pycache__/realesrgan_paired_dataset.cpython-310.pyc +0 -0
  38. basicsr/data/__pycache__/realesrgan_paired_dataset.cpython-38.pyc +0 -0
  39. basicsr/data/__pycache__/reds_dataset.cpython-310.pyc +0 -0
  40. basicsr/data/__pycache__/reds_dataset.cpython-38.pyc +0 -0
  41. basicsr/data/__pycache__/single_image_dataset.cpython-310.pyc +0 -0
  42. basicsr/data/__pycache__/single_image_dataset.cpython-38.pyc +0 -0
  43. basicsr/data/__pycache__/transforms.cpython-310.pyc +0 -0
  44. basicsr/data/__pycache__/transforms.cpython-38.pyc +0 -0
  45. basicsr/data/__pycache__/video_test_dataset.cpython-310.pyc +0 -0
  46. basicsr/data/__pycache__/video_test_dataset.cpython-38.pyc +0 -0
  47. basicsr/data/__pycache__/vimeo90k_dataset.cpython-310.pyc +0 -0
  48. basicsr/data/__pycache__/vimeo90k_dataset.cpython-38.pyc +0 -0
  49. basicsr/data/data_sampler.py +48 -0
  50. basicsr/data/data_util.py +315 -0
.devcontainer/Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM mcr.microsoft.com/devcontainers/anaconda:0-3
2
+
3
+ # Copy environment.yml (if found) to a temp location so we update the environment. Also
4
+ # copy "noop.txt" so the COPY instruction does not fail if no environment.yml exists.
5
+ COPY environment.yml* .devcontainer/noop.txt /tmp/conda-tmp/
6
+ RUN if [ -f "/tmp/conda-tmp/environment.yml" ]; then umask 0002 && /opt/conda/bin/conda env update -n base -f /tmp/conda-tmp/environment.yml; fi \
7
+ && rm -rf /tmp/conda-tmp
8
+
9
+ # [Optional] Uncomment this section to install additional OS packages.
10
+ # RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
11
+ # && apt-get -y install --no-install-recommends <your-package-list-here>
.devcontainer/devcontainer.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // For format details, see https://aka.ms/devcontainer.json. For config options, see the
2
+ // README at: https://github.com/devcontainers/templates/tree/main/src/anaconda
3
+ {
4
+ "name": "Anaconda (Python 3)",
5
+ "build": {
6
+ "context": "..",
7
+ "dockerfile": "Dockerfile"
8
+ },
9
+ "features": {
10
+ "ghcr.io/flexwie/devcontainer-features/op:1": {}
11
+ }
12
+
13
+ // Features to add to the dev container. More info: https://containers.dev/features.
14
+ // "features": {},
15
+
16
+ // Use 'forwardPorts' to make a list of ports inside the container available locally.
17
+ // "forwardPorts": [],
18
+
19
+ // Use 'postCreateCommand' to run commands after the container is created.
20
+ // "postCreateCommand": "python --version",
21
+
22
+ // Configure tool-specific properties.
23
+ // "customizations": {},
24
+
25
+ // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
26
+ // "remoteUser": "root"
27
+ }
.devcontainer/noop.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ This file copied into the container along with environment.yml* from the parent
2
+ folder. This file is included to prevents the Dockerfile COPY instruction from
3
+ failing if no environment.yml is found.
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/0015.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/0030.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/Lincon.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/cat.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/dog2.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/frog.png filter=lfs diff=lfs merge=lfs -text
42
+ assets/oldphoto6.png filter=lfs diff=lfs merge=lfs -text
.github/dependabot.yml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # To get started with Dependabot version updates, you'll need to specify which
2
+ # package ecosystems to update and where the package manifests are located.
3
+ # Please see the documentation for more information:
4
+ # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
5
+ # https://containers.dev/guide/dependabot
6
+
7
+ version: 2
8
+ updates:
9
+ - package-ecosystem: "devcontainers"
10
+ directory: "/"
11
+ schedule:
12
+ interval: weekly
LICENSE ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ S-Lab License 1.0
2
+
3
+ Copyright 2022 S-Lab
4
+
5
+ Redistribution and use for non-commercial purpose in source and
6
+ binary forms, with or without modification, are permitted provided
7
+ that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright
10
+ notice, this list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright
13
+ notice, this list of conditions and the following disclaimer in
14
+ the documentation and/or other materials provided with the
15
+ distribution.
16
+
17
+ 3. Neither the name of the copyright holder nor the names of its
18
+ contributors may be used to endorse or promote products derived
19
+ from this software without specific prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25
+ HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
+
33
+ In the event that redistribution and/or use for commercial purpose in
34
+ source or binary forms, with or without modification is required,
35
+ please contact the contributor(s) of the work.
README.md CHANGED
@@ -1,12 +1,193 @@
1
  ---
2
- title: Resshift
3
- emoji: 📉
4
- colorFrom: yellow
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 4.36.1
8
  app_file: app.py
9
- pinned: false
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
  ---
2
+ title: resshift
 
 
 
 
 
3
  app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 4.29.0
6
+ ---
7
+ # ResShift: Efficient Diffusion Model for Image Super-resolution by Residual Shifting (NeurIPS 2023, Spotlight)
8
+
9
+ [Zongsheng Yue](https://zsyoaoa.github.io/), [Jianyi Wang](https://iceclear.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/)
10
+
11
+ [Conference Paper](https://arxiv.org/abs/2307.12348) | [Journal Paper](http://arxiv.org/abs/2403.07319) | [Project Page](https://zsyoaoa.github.io/projects/resshift/) | [Video](https://www.youtube.com/watch?v=8DB-6Xvvl5o)
12
+
13
+ <a href="https://colab.research.google.com/drive/1CL8aJO7a_RA4MetanrCLqQO5H7KWO8KI?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a> [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/cjwbw/resshift) [![OpenXLab](https://img.shields.io/badge/Demo-%F0%9F%90%BC%20OpenXLab-blue)](https://openxlab.org.cn/apps/detail/Zongsheng/ResShift) ![visitors](https://visitor-badge.laobi.icu/badge?page_id=zsyOAOA/ResShift)
14
+
15
+
16
+ :star: If ResShift is helpful to your images or projects, please help star this repo. Thanks! :hugs:
17
+
18
  ---
19
+ >Diffusion-based image super-resolution (SR) methods are mainly limited by the low inference speed due to the requirements of hundreds or even thousands of sampling steps. Existing acceleration sampling techniques inevitably sacrifice performance to some extent, leading to over-blurry SR results. To address this issue, we propose a novel and efficient diffusion model for SR that significantly reduces the number of diffusion steps, thereby eliminating the need for post-acceleration during inference and its associated performance deterioration. Our method constructs a Markov chain that transfers between the high-resolution image and the low-resolution image by shifting the residual between them, substantially improving the transition efficiency. Additionally, an elaborate noise schedule is developed to flexibly control the shifting speed and the noise strength during the diffusion process. Extensive experiments demonstrate that the proposed method obtains superior or at least comparable performance to current state-of-the-art methods on both synthetic and real-world datasets, *even only with 15 sampling steps*.
20
+ ><img src="./assets/framework.png" align="middle" width="800">
21
+
22
+ ---
23
+ ## Update
24
+ - **2024.03.11**: Update the code for the Journal paper
25
+ - **2023.12.02**: Add configurations for the x2 super-resolution task.
26
+ - **2023.08.15**: Add [![OpenXLab](https://img.shields.io/badge/Demo-%F0%9F%90%BC%20OpenXLab-blue)](https://openxlab.org.cn/apps/detail/Zongsheng/ResShift).
27
+ - **2023.08.15**: Add Gradio Demo.
28
+ - **2023.08.14**: Add bicubic (matlab resize) model.
29
+ - **2023.08.14**: Add [Project Page](https://zsyoaoa.github.io/projects/resshift/).
30
+ - **2023.08.02**: Add [Replicate](https://replicate.com/) demo [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/cjwbw/resshift).
31
+ - **2023.07.31**: Add Colab demo <a href="https://colab.research.google.com/drive/1CL8aJO7a_RA4MetanrCLqQO5H7KWO8KI?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>.
32
+ - **2023.07.24**: Create this repo.
33
+
34
+ ## Requirements
35
+ * Python 3.10, Pytorch 2.1.2, [xformers](https://github.com/facebookresearch/xformers) 0.0.23
36
+ * More detail (See [environment.yml](environment.yml))
37
+ A suitable [conda](https://conda.io/) environment named `resshift` can be created and activated with:
38
+
39
+ ```
40
+ conda create -n resshift python=3.10
41
+ conda activate resshift
42
+ pip install -r requirements.txt
43
+ ```
44
+ or
45
+ ```
46
+ conda env create -f environment.yml
47
+ conda activate resshift
48
+ ```
49
+
50
+ ## Applications
51
+ ### :point_right: Real-world image super-resolution
52
+ [<img src="assets/0015.png" height="324px"/>](https://imgsli.com/MTkzNzgz) [<img src="assets/0030.png" height="324px"/>](https://imgsli.com/MTkzNzgx)
53
+
54
+ [<img src="assets/frog.png" height="324px"/>](https://imgsli.com/MTkzNzg0) [<img src="assets/dog2.png" height="324px">](https://imgsli.com/MTkzNzg3)
55
+
56
+ [<img src="assets/cat.png" height="252px"/>](https://imgsli.com/MTkzNzkx) [<img src="assets/Lincon.png" height="252px"/>](https://imgsli.com/MTkzNzk5) [<img src="assets/oldphoto6.png" height="252px"/>](https://imgsli.com/MTkzNzk2)
57
+
58
+ ### :point_right: Image inpainting
59
+ <img src="testdata/inpainting/imagenet/lq_mark/ILSVRC2012_val_00001639_mark.png" height="126px"/> <img src="testdata/inpainting/imagenet/results/ILSVRC2012_val_00001639.png" height="126px"/> <img src="testdata/inpainting/imagenet/lq_mark/ILSVRC2012_val_00001810_mark.png" height="126px"/> <img src="testdata/inpainting/imagenet/results/ILSVRC2012_val_00001810.png" height="126px"/> <img src="testdata/inpainting/imagenet/lq_mark/ILSVRC2012_val_00001204_mark.png" height="126px"/> <img src="testdata/inpainting/imagenet/results/ILSVRC2012_val_00001204.png" height="126px"/>
60
+ <img src="testdata/inpainting/imagenet/lq_mark/ILSVRC2012_val_00002438_mark.png" height="126px"/> <img src="testdata/inpainting/imagenet/results/ILSVRC2012_val_00002438.png" height="126px"/> <img src="testdata/inpainting/imagenet/lq_mark/ILSVRC2012_val_00005693_mark.png" height="126px"/> <img src="testdata/inpainting/imagenet/results/ILSVRC2012_val_00005693.png" height="126px"/> <img src="testdata/inpainting/imagenet/lq_mark/ILSVRC2012_val_00005814_mark.png" height="126px"/> <img src="testdata/inpainting/imagenet/results/ILSVRC2012_val_00005814.png" height="126px"/>
61
+ <img src="testdata/inpainting/face/lq_mark/94_mark.png" height="126px"/> <img src="testdata/inpainting/face/results/94.png" height="126px"/> <img src="testdata/inpainting/face/lq_mark/410_mark.png" height="126px"/> <img src="testdata/inpainting/face/results/410.png" height="126px"/> <img src="testdata/inpainting/face/lq_mark/269_mark.png" height="126px"/> <img src="testdata/inpainting/face/results/269.png" height="126px"/>
62
+ <img src="testdata/inpainting/face/lq_mark/321_mark.png" height="126px"/> <img src="testdata/inpainting/face/results/321.png" height="126px"/> <img src="testdata/inpainting/face/lq_mark/5915_mark.png" height="126px"/> <img src="testdata/inpainting/face/results/5915.png" height="126px"/> <img src="testdata/inpainting/face/lq_mark/5489_mark.png" height="126px"/> <img src="testdata/inpainting/face/results/5489.png" height="126px"/>
63
+
64
+ ### :point_right: Blind Face Restoration
65
+ <img src="testdata/faceir/cropped_faces/lq/0729.png" height="126px"/> <img src="testdata/faceir/cropped_faces/results/0729.png" height="126px"/> <img src="testdata/faceir/cropped_faces/lq/0444.png" height="126px"/> <img src="testdata/faceir/cropped_faces/results/0444.png" height="126px"/> <img src="testdata/faceir/cropped_faces/lq/0885.png" height="126px"/> <img src="testdata/faceir/cropped_faces/results/0885.png" height="126px"/>
66
+ <img src="testdata/faceir/cropped_faces/lq/0500.png" height="126px"/> <img src="testdata/faceir/cropped_faces/results/0500.png" height="126px"/> <img src="testdata/faceir/cropped_faces/lq/Solvay_conference_1927_2_16.png" height="126px"/> <img src="testdata/faceir/cropped_faces/results/Solvay_conference_1927_0018.png" height="126px"/> <img src="testdata/faceir/cropped_faces/lq/Solvay_conference_1927_2_16.png" height="126px"/> <img src="testdata/faceir/cropped_faces/results/Solvay_conference_1927_2_16.png" height="126px"/>
67
+
68
+ ## Online Demo
69
+ You can try our method through an online demo:
70
+ ```
71
+ python app.py
72
+ ```
73
+
74
+ ## Fast Testing
75
+ #### :tiger: Real-world image super-resolution
76
+
77
+ ```
78
+ python inference_resshift.py -i [image folder/image path] -o [result folder] --task realsr --scale 4 --version v3
79
+ ```
80
+ #### :lion: Bicubic (resize by Matlab) image super-resolution
81
+ ```
82
+ python inference_resshift.py -i [image folder/image path] -o [result folder] --task bicsr --scale 4
83
+ ```
84
+ #### :snake: Natural image inpainting
85
+ ```
86
+ python inference_resshift.py -i [image folder/image path] -o [result folder] --mask_path [mask path] --task inpaint_imagenet --scale 1
87
+ ```
88
+ #### :crocodile: Face image inpainting
89
+ ```
90
+ python inference_resshift.py -i [image folder/image path] -o [result folder] --mask_path [mask path] --task inpaint_face --scale 1
91
+ ```
92
+ #### :octopus: Blind Face Restoration
93
+ ```
94
+ python inference_resshift.py -i [image folder/image path] -o [result folder] --task faceir --scale 1
95
+ ```
96
+
97
+ ### Training
98
+ #### :turtle: Preparing stage
99
+ 1. Download the pre-trained VQGAN model from this [link](https://github.com/zsyOAOA/ResShift/releases) and put it in the folder of 'weights'
100
+ 2. Adjust the data path in the [config](configs) file.
101
+ 3. Adjust batchsize according your GPUS.
102
+ + configs.train.batch: [training batchsize, validation batchsize]
103
+ + configs.train.microbatch: total batchsize = microbatch * #GPUS * num_grad_accumulation
104
+
105
+ #### :dolphin: Real-world Image Super-resolution for NeurIPS
106
+ ```
107
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --standalone --nproc_per_node=8 --nnodes=1 main.py --cfg_path configs/realsr_swinunet_realesrgan256.yaml --save_dir [Logging Folder]
108
+ ```
109
+ #### :whale: Real-world Image Super-resolution for Journal
110
+ ```
111
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --standalone --nproc_per_node=8 --nnodes=1 main.py --cfg_path configs/realsr_swinunet_realesrgan256_journal.yaml --save_dir [Logging Folder]
112
+ ```
113
+ #### :ox: Image inpainting (Natural) for Journal
114
+ ```
115
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --standalone --nproc_per_node=8 --nnodes=1 main.py --cfg_path configs/inpaint_lama256_imagenet.yaml --save_dir [Logging Folder]
116
+ ```
117
+ #### :honeybee: Image inpainting (Face) for Journal
118
+ ```
119
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --standalone --nproc_per_node=8 --nnodes=1 main.py --cfg_path configs/inpaint_lama256_face.yaml --save_dir [Logging Folder]
120
+ ```
121
+ #### :frog: Blind face restoration for Journal
122
+ ```
123
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --standalone --nproc_per_node=8 --nnodes=1 main.py --cfg_path configs/faceir_gfpgan512_lpips.yaml --save_dir [Logging Folder]
124
+ ```
125
+
126
+ ### Reproducing the results in our paper
127
+ #### :red_car: Prepare data
128
+ + Synthetic data for image super-resolution: [Link](https://drive.google.com/file/d/1NhmpON2dB2LjManfX6uIj8Pj_Jx6N-6l/view?usp=sharing)
129
+
130
+ + Real data for image super-resolution: [RealSet65](testdata/RealSet65) | [RealSet80](testdata/RealSet80)
131
+
132
+ + Synthetic data for natural image inpainting: [Link](https://drive.google.com/file/d/11_1xntiGnZzRX87Ve6thbeWDPVksTHNC/view?usp=sharing)
133
+
134
+ + Synthetic data for face image inpainting: [Link](https://drive.google.com/file/d/1nyfry2XjgA_qV8fS2_Y5TNQwtG5Irwux/view?usp=sharing)
135
+
136
+ + Synthetic data for blind face restoration: [Link](https://drive.google.com/file/d/15Ij-UaI8BQ7fBDF0i4M1wDOk-bnn_C4X/view?usp=drive_link)
137
+
138
+ #### :rocket: Image super-resolution
139
+ Reproduce the results in Table 3 of our NeurIPS paper:
140
+ ```
141
+ python inference_resshift.py -i [image folder/image path] -o [result folder] --task realsr --scale 4 --version v1 --chop_size 64 --chop_stride 64 --bs 64
142
+ ```
143
+ Reproduce the results in Table 4 of our NeurIPS paper:
144
+ ```
145
+ python inference_resshift.py -i [image folder/image path] -o [result folder] --task realsr --scale 4 --version v1 --chop_size 512 --chop_stride 448
146
+ ```
147
+ Reproduce the results in Table 2 of our Journal paper:
148
+ ```
149
+ python inference_resshift.py -i [image folder/image path] -o [result folder] --task realsr --scale 4 --version v3 --chop_size 64 --chop_stride 64 --bs 64
150
+ ```
151
+ Reproduce the results in Table 3 of our Journal paper:
152
+ ```
153
+ python inference_resshift.py -i [image folder/image path] -o [result folder] --task realsr --scale 4 --version v3 --chop_size 512 --chop_stride 448
154
+ ```
155
+ ##### Model card:
156
+ + version-1: Conference paper, 15 diffusion steps, trained with 300k iterations.
157
+ + version-2: Conference paper, 15 diffusion steps, trained with 500k iterations.
158
+ + version-3: Journal paper, 4 diffusion steps.
159
+
160
+ #### :airplane: Image inpainting
161
+ Reproduce the results in Table 4 of our Journal paper:
162
+ ```
163
+ python inference_resshift.py -i [image folder/image path] -o [result folder] --mask_path [mask path] --task inpaint_imagenet --scale 1 --chop_size 256 --chop_stride 256 --bs 32
164
+ ```
165
+ Reproduce the results in Table 5 of our Journal paper:
166
+ ```
167
+ python inference_resshift.py -i [image folder/image path] -o [result folder] --mask_path [mask path] --task inpaint_face --scale 1 --chop_size 256 --chop_stride 256 --bs 32
168
+ ```
169
+ #### :boat: Blind Face Restoration
170
+ Reproduce the results in Table 6 of our Journal paper (arXiv):
171
+ ```
172
+ python inference_resshift.py -i [image folder/image path] -o [result folder] --task faceir --scale 1 --chop_size 256 --chop_stride 256 --bs 16
173
+ ```
174
+
175
+ <!--## Note on General Restoration Task-->
176
+ <!--For general restoration task, please adjust the settings in the config file:-->
177
+ <!--```-->
178
+ <!--model.params.lq_size: resolution of the low-quality image. # should be divided by 64-->
179
+ <!--diffusion.params.sf: scale factor for super-resolution, 1 for restoration task.-->
180
+ <!--degradation.sf: scale factor for super-resolution, 1 for restoration task. # only required for the pipeline of Real-Esrgan -->
181
+ <!--```-->
182
+ <!--In some cases, you need to rewrite the data loading process. -->
183
+
184
+ ## License
185
+
186
+ This project is licensed under <a rel="license" href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">NTU S-Lab License 1.0</a>. Redistribution and use should follow this license.
187
+
188
+ ## Acknowledgement
189
+
190
+ This project is based on [Improved Diffusion Model](https://github.com/openai/improved-diffusion), [LDM](https://github.com/CompVis/latent-diffusion), and [BasicSR](https://github.com/XPixelGroup/BasicSR). We also adopt [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to synthesize the training data for real-world super-resolution. Thanks for their awesome works.
191
 
192
+ ### Contact
193
+ If you have any questions, please feel free to contact me via `[email protected]`.
__pycache__/sampler.cpython-310.pyc ADDED
Binary file (8.29 kB). View file
 
__pycache__/sampler.cpython-38.pyc ADDED
Binary file (8.1 kB). View file
 
app.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding:utf-8 -*-
3
+ # Power by Zongsheng Yue 2023-08-15 09:39:58
4
+
5
+ import argparse
6
+ import gradio as gr
7
+ from pathlib import Path
8
+
9
+ from omegaconf import OmegaConf
10
+ from sampler import ResShiftSampler
11
+
12
+ from utils import util_image
13
+ from basicsr.utils.download_util import load_file_from_url
14
+
15
+ _STEP = {
16
+ 'v1': 15,
17
+ 'v2': 15,
18
+ 'v3': 4,
19
+ 'bicsr': 4,
20
+ 'inpaint_imagenet': 4,
21
+ 'inpaint_face': 4,
22
+ 'faceir': 4,
23
+ }
24
+ _LINK = {
25
+ 'vqgan': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/autoencoder_vq_f4.pth',
26
+ 'vqgan_face256': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/celeba256_vq_f4_dim3_face.pth',
27
+ 'vqgan_face512': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/ffhq512_vq_f8_dim8_face.pth',
28
+ 'v1': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_realsrx4_s15_v1.pth',
29
+ 'v2': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_realsrx4_s15_v2.pth',
30
+ 'v3': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_realsrx4_s4_v3.pth',
31
+ 'bicsr': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_bicsrx4_s4.pth',
32
+ 'inpaint_imagenet': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_inpainting_imagenet_s4.pth',
33
+ 'inpaint_face': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_inpainting_face_s4.pth',
34
+ 'faceir': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_faceir_s4.pth',
35
+ }
36
+
37
+ def get_configs(task='realsr', version='v3', scale=4):
38
+ ckpt_dir = Path('./weights')
39
+ if not ckpt_dir.exists():
40
+ ckpt_dir.mkdir()
41
+
42
+ if task == 'realsr':
43
+ if version in ['v1', 'v2']:
44
+ configs = OmegaConf.load('./configs/realsr_swinunet_realesrgan256.yaml')
45
+ elif version == 'v3':
46
+ configs = OmegaConf.load('./configs/realsr_swinunet_realesrgan256_journal.yaml')
47
+ else:
48
+ raise ValueError(f"Unexpected version type: {version}")
49
+ assert scale == 4, 'We only support the 4x super-resolution now!'
50
+ ckpt_url = _LINK[version]
51
+ ckpt_path = ckpt_dir / f'resshift_{task}x{scale}_s{_STEP[version]}_{version}.pth'
52
+ vqgan_url = _LINK['vqgan']
53
+ vqgan_path = ckpt_dir / f'autoencoder_vq_f4.pth'
54
+ elif task == 'bicsr':
55
+ configs = OmegaConf.load('./configs/bicx4_swinunet_lpips.yaml')
56
+ assert scale == 4, 'We only support the 4x super-resolution now!'
57
+ ckpt_url = _LINK[task]
58
+ ckpt_path = ckpt_dir / f'resshift_{task}x{scale}_s{_STEP[task]}.pth'
59
+ vqgan_url = _LINK['vqgan']
60
+ vqgan_path = ckpt_dir / f'autoencoder_vq_f4.pth'
61
+ # elif task == 'inpaint_imagenet':
62
+ # configs = OmegaConf.load('./configs/inpaint_lama256_imagenet.yaml')
63
+ # assert scale == 1, 'Please set scale equals 1 for image inpainting!'
64
+ # ckpt_url = _LINK[task]
65
+ # ckpt_path = ckpt_dir / f'resshift_{task}_s{_STEP[task]}.pth'
66
+ # vqgan_url = _LINK['vqgan']
67
+ # vqgan_path = ckpt_dir / f'autoencoder_vq_f4.pth'
68
+ # elif task == 'inpaint_face':
69
+ # configs = OmegaConf.load('./configs/inpaint_lama256_face.yaml')
70
+ # assert scale == 1, 'Please set scale equals 1 for image inpainting!'
71
+ # ckpt_url = _LINK[task]
72
+ # ckpt_path = ckpt_dir / f'resshift_{task}_s{_STEP[task]}.pth'
73
+ # vqgan_url = _LINK['vqgan_face256']
74
+ # vqgan_path = ckpt_dir / f'celeba256_vq_f4_dim3_face.pth'
75
+ # elif task == 'faceir':
76
+ # configs = OmegaConf.load('./configs/faceir_gfpgan512_lpips.yaml')
77
+ # assert scale == 1, 'Please set scale equals 1 for face restoration!'
78
+ # ckpt_url = _LINK[task]
79
+ # ckpt_path = ckpt_dir / f'resshift_{task}_s{_STEP[task]}.pth'
80
+ # vqgan_url = _LINK['vqgan_face512']
81
+ # vqgan_path = ckpt_dir / f'ffhq512_vq_f8_dim8_face.pth'
82
+ else:
83
+ raise TypeError(f"Unexpected task type: {task}!")
84
+
85
+ # prepare the checkpoint
86
+ if not ckpt_path.exists():
87
+ load_file_from_url(
88
+ url=ckpt_url,
89
+ model_dir=ckpt_dir,
90
+ progress=True,
91
+ file_name=ckpt_path.name,
92
+ )
93
+ if not vqgan_path.exists():
94
+ load_file_from_url(
95
+ url=vqgan_url,
96
+ model_dir=ckpt_dir,
97
+ progress=True,
98
+ file_name=vqgan_path.name,
99
+ )
100
+
101
+ configs.model.ckpt_path = str(ckpt_path)
102
+ configs.diffusion.params.sf = scale
103
+ configs.autoencoder.ckpt_path = str(vqgan_path)
104
+
105
+ return configs
106
+
107
+ def predict(in_path, task='realsrx4', seed=12345, scale=4, version='v3'):
108
+ configs = get_configs(task, version, scale)
109
+ resshift_sampler = ResShiftSampler(
110
+ configs,
111
+ sf=scale,
112
+ chop_size=256,
113
+ chop_stride=224,
114
+ chop_bs=1,
115
+ use_amp=True,
116
+ seed=seed,
117
+ padding_offset=configs.model.params.get('lq_size', 64),
118
+ )
119
+
120
+ out_dir = Path('restored_output')
121
+ if not out_dir.exists():
122
+ out_dir.mkdir()
123
+
124
+ resshift_sampler.inference(
125
+ in_path,
126
+ out_dir,
127
+ mask_path=None,
128
+ bs=1,
129
+ noise_repeat=False
130
+ )
131
+
132
+ out_path = out_dir / f"{Path(in_path).stem}.png"
133
+ assert out_path.exists(), 'Super-resolution failed!'
134
+ im_sr = util_image.imread(out_path, chn="rgb", dtype="uint8")
135
+
136
+ return im_sr, str(out_path)
137
+
138
+ title = "ResShift: Efficient Diffusion Model for Image Super-resolution by Residual Shifting"
139
+ description = r"""
140
+ <b>Official Gradio demo</b> for <a href='https://github.com/zsyOAOA/ResShift' target='_blank'><b>ResShift: Efficient Diffusion Model for Image Super-resolution by Residual Shifting</b></a>.<br>
141
+ 🔥 ResShift is an efficient diffusion model designed for image super-resolution or restoration.<br>
142
+ """
143
+ article = r"""
144
+ If ResShift is helpful for your work, please help to ⭐ the <a href='https://github.com/zsyOAOA/ResShift' target='_blank'>Github Repo</a>. Thanks!
145
+ [![GitHub Stars](https://img.shields.io/github/stars/zsyOAOA/ResShift?affiliations=OWNER&color=green&style=social)](https://github.com/zsyOAOA/ResShift)
146
+
147
+ ---
148
+ If our work is useful for your research, please consider citing:
149
+ ```bibtex
150
+ @inproceedings{yue2023resshift,
151
+ title={ResShift: Efficient Diffusion Model for Image Super-resolution by Residual Shifting},
152
+ author={Yue, Zongsheng and Wang, Jianyi and Loy, Chen Change},
153
+ booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
154
+ year={2023},
155
+ volume = {36},
156
+ pages = {13294--13307},
157
+ }
158
+ ```
159
+
160
+ 📋 **License**
161
+
162
+ This project is licensed under <a rel="license" href="https://github.com/zsyOAOA/ResShift/blob/master/LICENSE">S-Lab License 1.0</a>.
163
+ Redistribution and use for non-commercial purposes should follow this license.
164
+
165
+ 📧 **Contact**
166
+
167
+ If you have any questions, please feel free to contact me via <b>[email protected]</b>.
168
+ ![visitors](https://visitor-badge.laobi.icu/badge?page_id=zsyOAOA/ResShift)
169
+ """
170
+ demo = gr.Interface(
171
+ fn=predict,
172
+ inputs=[
173
+ gr.Image(type="filepath", label="Input: Low Quality Image"),
174
+ gr.Dropdown(
175
+ choices=["realsr", "bicsr"],
176
+ value="realsr",
177
+ label="Task",
178
+ ),
179
+ gr.Number(value=12345, precision=0, label="Ranom seed")
180
+ ],
181
+ outputs=[
182
+ gr.Image(type="numpy", label="Output: High Quality Image"),
183
+ gr.outputs.File(label="Download the output")
184
+ ],
185
+ title=title,
186
+ description=description,
187
+ article=article,
188
+ examples=[
189
+ ['./testdata/RealSet65/0030.jpg', "realsr", 12345],
190
+ ['./testdata/RealSet65/dog2.png', "realsr", 12345],
191
+ ['./testdata/RealSet65/bears.jpg', "realsr", 12345],
192
+ ['./testdata/RealSet65/oldphoto6.png', "realsr", 12345],
193
+ ['./testdata/Bicubicx4/lq_matlab/ILSVRC2012_val_00000067.png', "bicsr", 12345],
194
+ ['./testdata/Bicubicx4/lq_matlab/ILSVRC2012_val_00016898.png', "bicsr", 12345],
195
+ ]
196
+ )
197
+
198
+ demo.queue(concurrency_count=4)
199
+ demo.launch(share=True)
200
+
assets/0015.png ADDED

Git LFS Details

  • SHA256: 2045be367a165930c7e038fb05518ec87f092e17e981b69fb2ba0820d858aefc
  • Pointer size: 132 Bytes
  • Size of remote file: 3.34 MB
assets/0030.png ADDED

Git LFS Details

  • SHA256: f30d34c2129a79b190cec64256987a5c8e616d09e13d1a6499c06f89bbeb3ab5
  • Pointer size: 132 Bytes
  • Size of remote file: 2.15 MB
assets/Lincon.png ADDED

Git LFS Details

  • SHA256: 1b42f58763f809f42bfcba8c3db2bb51f586b1a220b656942a9788a621e12351
  • Pointer size: 132 Bytes
  • Size of remote file: 1.19 MB
assets/cat.png ADDED

Git LFS Details

  • SHA256: 0de46bf98f7eef30a4a853f3550fbae26b5ec5ca8c155926798ccad135bdac01
  • Pointer size: 132 Bytes
  • Size of remote file: 3.36 MB
assets/dog2.png ADDED

Git LFS Details

  • SHA256: 52b2c1ae4c38bac4ad642033e81b7b2421b39512711a76c952fe3095a88716b7
  • Pointer size: 132 Bytes
  • Size of remote file: 3.71 MB
assets/framework.png ADDED
assets/frog.png ADDED

Git LFS Details

  • SHA256: c94a08be4678b354785de7e80f2d3c3ce59990f56519a8ee93bd730e25f11283
  • Pointer size: 132 Bytes
  • Size of remote file: 3.06 MB
assets/oldphoto6.png ADDED

Git LFS Details

  • SHA256: 3c4c25216a6c2a43b0aff62ee03b37a7b68b627a0a70cbdc932ba11ca7856de0
  • Pointer size: 132 Bytes
  • Size of remote file: 2.52 MB
basicsr/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # https://github.com/xinntao/BasicSR
2
+ # flake8: noqa
3
+ from .data import *
4
+ from .utils import *
basicsr/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (160 Bytes). View file
 
basicsr/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (193 Bytes). View file
 
basicsr/data/__init__.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import torch.utils.data
6
+ from copy import deepcopy
7
+ from functools import partial
8
+ from os import path as osp
9
+
10
+ from basicsr.data.prefetch_dataloader import PrefetchDataLoader
11
+ from basicsr.utils import get_root_logger, scandir
12
+ from basicsr.utils.dist_util import get_dist_info
13
+ from basicsr.utils.registry import DATASET_REGISTRY
14
+
15
+ __all__ = ['build_dataset', 'build_dataloader']
16
+
17
+ # automatically scan and import dataset modules for registry
18
+ # scan all the files under the data folder with '_dataset' in file names
19
+ data_folder = osp.dirname(osp.abspath(__file__))
20
+ dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
21
+ # import all the dataset modules
22
+ _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
23
+
24
+
25
+ def build_dataset(dataset_opt):
26
+ """Build dataset from options.
27
+
28
+ Args:
29
+ dataset_opt (dict): Configuration for dataset. It must contain:
30
+ name (str): Dataset name.
31
+ type (str): Dataset type.
32
+ """
33
+ dataset_opt = deepcopy(dataset_opt)
34
+ dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
35
+ logger = get_root_logger()
36
+ logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
37
+ return dataset
38
+
39
+
40
+ def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
41
+ """Build dataloader.
42
+
43
+ Args:
44
+ dataset (torch.utils.data.Dataset): Dataset.
45
+ dataset_opt (dict): Dataset options. It contains the following keys:
46
+ phase (str): 'train' or 'val'.
47
+ num_worker_per_gpu (int): Number of workers for each GPU.
48
+ batch_size_per_gpu (int): Training batch size for each GPU.
49
+ num_gpu (int): Number of GPUs. Used only in the train phase.
50
+ Default: 1.
51
+ dist (bool): Whether in distributed training. Used only in the train
52
+ phase. Default: False.
53
+ sampler (torch.utils.data.sampler): Data sampler. Default: None.
54
+ seed (int | None): Seed. Default: None
55
+ """
56
+ phase = dataset_opt['phase']
57
+ rank, _ = get_dist_info()
58
+ if phase == 'train':
59
+ if dist: # distributed training
60
+ batch_size = dataset_opt['batch_size_per_gpu']
61
+ num_workers = dataset_opt['num_worker_per_gpu']
62
+ else: # non-distributed training
63
+ multiplier = 1 if num_gpu == 0 else num_gpu
64
+ batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
65
+ num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
66
+ dataloader_args = dict(
67
+ dataset=dataset,
68
+ batch_size=batch_size,
69
+ shuffle=False,
70
+ num_workers=num_workers,
71
+ sampler=sampler,
72
+ drop_last=True)
73
+ if sampler is None:
74
+ dataloader_args['shuffle'] = True
75
+ dataloader_args['worker_init_fn'] = partial(
76
+ worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
77
+ elif phase in ['val', 'test']: # validation
78
+ dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
79
+ else:
80
+ raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
81
+
82
+ dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
83
+ dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
84
+
85
+ prefetch_mode = dataset_opt.get('prefetch_mode')
86
+ if prefetch_mode == 'cpu': # CPUPrefetcher
87
+ num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
88
+ logger = get_root_logger()
89
+ logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
90
+ return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
91
+ else:
92
+ # prefetch_mode=None: Normal dataloader
93
+ # prefetch_mode='cuda': dataloader for CUDAPrefetcher
94
+ return torch.utils.data.DataLoader(**dataloader_args)
95
+
96
+
97
+ def worker_init_fn(worker_id, num_workers, rank, seed):
98
+ # Set the worker seed to num_workers * rank + worker_id + seed
99
+ worker_seed = num_workers * rank + worker_id + seed
100
+ np.random.seed(worker_seed)
101
+ random.seed(worker_seed)
basicsr/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (3.55 kB). View file
 
basicsr/data/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (3.59 kB). View file
 
basicsr/data/__pycache__/data_util.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
basicsr/data/__pycache__/data_util.cpython-38.pyc ADDED
Binary file (11.2 kB). View file
 
basicsr/data/__pycache__/degradations.cpython-310.pyc ADDED
Binary file (20.3 kB). View file
 
basicsr/data/__pycache__/degradations.cpython-38.pyc ADDED
Binary file (21.7 kB). View file
 
basicsr/data/__pycache__/ffhq_dataset.cpython-310.pyc ADDED
Binary file (3.05 kB). View file
 
basicsr/data/__pycache__/ffhq_dataset.cpython-38.pyc ADDED
Binary file (3.02 kB). View file
 
basicsr/data/__pycache__/paired_image_dataset.cpython-310.pyc ADDED
Binary file (3.84 kB). View file
 
basicsr/data/__pycache__/paired_image_dataset.cpython-38.pyc ADDED
Binary file (3.88 kB). View file
 
basicsr/data/__pycache__/prefetch_dataloader.cpython-310.pyc ADDED
Binary file (4.32 kB). View file
 
basicsr/data/__pycache__/prefetch_dataloader.cpython-38.pyc ADDED
Binary file (4.37 kB). View file
 
basicsr/data/__pycache__/realesrgan_dataset.cpython-310.pyc ADDED
Binary file (8.58 kB). View file
 
basicsr/data/__pycache__/realesrgan_dataset.cpython-38.pyc ADDED
Binary file (5.1 kB). View file
 
basicsr/data/__pycache__/realesrgan_paired_dataset.cpython-310.pyc ADDED
Binary file (4.03 kB). View file
 
basicsr/data/__pycache__/realesrgan_paired_dataset.cpython-38.pyc ADDED
Binary file (4.03 kB). View file
 
basicsr/data/__pycache__/reds_dataset.cpython-310.pyc ADDED
Binary file (10.5 kB). View file
 
basicsr/data/__pycache__/reds_dataset.cpython-38.pyc ADDED
Binary file (10.7 kB). View file
 
basicsr/data/__pycache__/single_image_dataset.cpython-310.pyc ADDED
Binary file (2.82 kB). View file
 
basicsr/data/__pycache__/single_image_dataset.cpython-38.pyc ADDED
Binary file (2.82 kB). View file
 
basicsr/data/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (5.97 kB). View file
 
basicsr/data/__pycache__/transforms.cpython-38.pyc ADDED
Binary file (6.04 kB). View file
 
basicsr/data/__pycache__/video_test_dataset.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
basicsr/data/__pycache__/video_test_dataset.cpython-38.pyc ADDED
Binary file (10.3 kB). View file
 
basicsr/data/__pycache__/vimeo90k_dataset.cpython-310.pyc ADDED
Binary file (5.77 kB). View file
 
basicsr/data/__pycache__/vimeo90k_dataset.cpython-38.pyc ADDED
Binary file (5.79 kB). View file
 
basicsr/data/data_sampler.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.utils.data.sampler import Sampler
4
+
5
+
6
+ class EnlargedSampler(Sampler):
7
+ """Sampler that restricts data loading to a subset of the dataset.
8
+
9
+ Modified from torch.utils.data.distributed.DistributedSampler
10
+ Support enlarging the dataset for iteration-based training, for saving
11
+ time when restart the dataloader after each epoch
12
+
13
+ Args:
14
+ dataset (torch.utils.data.Dataset): Dataset used for sampling.
15
+ num_replicas (int | None): Number of processes participating in
16
+ the training. It is usually the world_size.
17
+ rank (int | None): Rank of the current process within num_replicas.
18
+ ratio (int): Enlarging ratio. Default: 1.
19
+ """
20
+
21
+ def __init__(self, dataset, num_replicas, rank, ratio=1):
22
+ self.dataset = dataset
23
+ self.num_replicas = num_replicas
24
+ self.rank = rank
25
+ self.epoch = 0
26
+ self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
27
+ self.total_size = self.num_samples * self.num_replicas
28
+
29
+ def __iter__(self):
30
+ # deterministically shuffle based on epoch
31
+ g = torch.Generator()
32
+ g.manual_seed(self.epoch)
33
+ indices = torch.randperm(self.total_size, generator=g).tolist()
34
+
35
+ dataset_size = len(self.dataset)
36
+ indices = [v % dataset_size for v in indices]
37
+
38
+ # subsample
39
+ indices = indices[self.rank:self.total_size:self.num_replicas]
40
+ assert len(indices) == self.num_samples
41
+
42
+ return iter(indices)
43
+
44
+ def __len__(self):
45
+ return self.num_samples
46
+
47
+ def set_epoch(self, epoch):
48
+ self.epoch = epoch
basicsr/data/data_util.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from os import path as osp
5
+ from torch.nn import functional as F
6
+
7
+ from basicsr.data.transforms import mod_crop
8
+ from basicsr.utils import img2tensor, scandir
9
+
10
+
11
+ def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
12
+ """Read a sequence of images from a given folder path.
13
+
14
+ Args:
15
+ path (list[str] | str): List of image paths or image folder path.
16
+ require_mod_crop (bool): Require mod crop for each image.
17
+ Default: False.
18
+ scale (int): Scale factor for mod_crop. Default: 1.
19
+ return_imgname(bool): Whether return image names. Default False.
20
+
21
+ Returns:
22
+ Tensor: size (t, c, h, w), RGB, [0, 1].
23
+ list[str]: Returned image name list.
24
+ """
25
+ if isinstance(path, list):
26
+ img_paths = path
27
+ else:
28
+ img_paths = sorted(list(scandir(path, full_path=True)))
29
+ imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
30
+
31
+ if require_mod_crop:
32
+ imgs = [mod_crop(img, scale) for img in imgs]
33
+ imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
34
+ imgs = torch.stack(imgs, dim=0)
35
+
36
+ if return_imgname:
37
+ imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
38
+ return imgs, imgnames
39
+ else:
40
+ return imgs
41
+
42
+
43
+ def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
44
+ """Generate an index list for reading `num_frames` frames from a sequence
45
+ of images.
46
+
47
+ Args:
48
+ crt_idx (int): Current center index.
49
+ max_frame_num (int): Max number of the sequence of images (from 1).
50
+ num_frames (int): Reading num_frames frames.
51
+ padding (str): Padding mode, one of
52
+ 'replicate' | 'reflection' | 'reflection_circle' | 'circle'
53
+ Examples: current_idx = 0, num_frames = 5
54
+ The generated frame indices under different padding mode:
55
+ replicate: [0, 0, 0, 1, 2]
56
+ reflection: [2, 1, 0, 1, 2]
57
+ reflection_circle: [4, 3, 0, 1, 2]
58
+ circle: [3, 4, 0, 1, 2]
59
+
60
+ Returns:
61
+ list[int]: A list of indices.
62
+ """
63
+ assert num_frames % 2 == 1, 'num_frames should be an odd number.'
64
+ assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
65
+
66
+ max_frame_num = max_frame_num - 1 # start from 0
67
+ num_pad = num_frames // 2
68
+
69
+ indices = []
70
+ for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
71
+ if i < 0:
72
+ if padding == 'replicate':
73
+ pad_idx = 0
74
+ elif padding == 'reflection':
75
+ pad_idx = -i
76
+ elif padding == 'reflection_circle':
77
+ pad_idx = crt_idx + num_pad - i
78
+ else:
79
+ pad_idx = num_frames + i
80
+ elif i > max_frame_num:
81
+ if padding == 'replicate':
82
+ pad_idx = max_frame_num
83
+ elif padding == 'reflection':
84
+ pad_idx = max_frame_num * 2 - i
85
+ elif padding == 'reflection_circle':
86
+ pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
87
+ else:
88
+ pad_idx = i - num_frames
89
+ else:
90
+ pad_idx = i
91
+ indices.append(pad_idx)
92
+ return indices
93
+
94
+
95
+ def paired_paths_from_lmdb(folders, keys):
96
+ """Generate paired paths from lmdb files.
97
+
98
+ Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
99
+
100
+ ::
101
+
102
+ lq.lmdb
103
+ ├── data.mdb
104
+ ├── lock.mdb
105
+ ├── meta_info.txt
106
+
107
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
108
+ https://lmdb.readthedocs.io/en/release/ for more details.
109
+
110
+ The meta_info.txt is a specified txt file to record the meta information
111
+ of our datasets. It will be automatically created when preparing
112
+ datasets by our provided dataset tools.
113
+ Each line in the txt file records
114
+ 1)image name (with extension),
115
+ 2)image shape,
116
+ 3)compression level, separated by a white space.
117
+ Example: `baboon.png (120,125,3) 1`
118
+
119
+ We use the image name without extension as the lmdb key.
120
+ Note that we use the same key for the corresponding lq and gt images.
121
+
122
+ Args:
123
+ folders (list[str]): A list of folder path. The order of list should
124
+ be [input_folder, gt_folder].
125
+ keys (list[str]): A list of keys identifying folders. The order should
126
+ be in consistent with folders, e.g., ['lq', 'gt'].
127
+ Note that this key is different from lmdb keys.
128
+
129
+ Returns:
130
+ list[str]: Returned path list.
131
+ """
132
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
133
+ f'But got {len(folders)}')
134
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
135
+ input_folder, gt_folder = folders
136
+ input_key, gt_key = keys
137
+
138
+ if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
139
+ raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
140
+ f'formats. But received {input_key}: {input_folder}; '
141
+ f'{gt_key}: {gt_folder}')
142
+ # ensure that the two meta_info files are the same
143
+ with open(osp.join(input_folder, 'meta_info.txt')) as fin:
144
+ input_lmdb_keys = [line.split('.')[0] for line in fin]
145
+ with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
146
+ gt_lmdb_keys = [line.split('.')[0] for line in fin]
147
+ if set(input_lmdb_keys) != set(gt_lmdb_keys):
148
+ raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
149
+ else:
150
+ paths = []
151
+ for lmdb_key in sorted(input_lmdb_keys):
152
+ paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
153
+ return paths
154
+
155
+
156
+ def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
157
+ """Generate paired paths from an meta information file.
158
+
159
+ Each line in the meta information file contains the image names and
160
+ image shape (usually for gt), separated by a white space.
161
+
162
+ Example of an meta information file:
163
+ ```
164
+ 0001_s001.png (480,480,3)
165
+ 0001_s002.png (480,480,3)
166
+ ```
167
+
168
+ Args:
169
+ folders (list[str]): A list of folder path. The order of list should
170
+ be [input_folder, gt_folder].
171
+ keys (list[str]): A list of keys identifying folders. The order should
172
+ be in consistent with folders, e.g., ['lq', 'gt'].
173
+ meta_info_file (str): Path to the meta information file.
174
+ filename_tmpl (str): Template for each filename. Note that the
175
+ template excludes the file extension. Usually the filename_tmpl is
176
+ for files in the input folder.
177
+
178
+ Returns:
179
+ list[str]: Returned path list.
180
+ """
181
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
182
+ f'But got {len(folders)}')
183
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
184
+ input_folder, gt_folder = folders
185
+ input_key, gt_key = keys
186
+
187
+ with open(meta_info_file, 'r') as fin:
188
+ gt_names = [line.strip().split(' ')[0] for line in fin]
189
+
190
+ paths = []
191
+ for gt_name in gt_names:
192
+ basename, ext = osp.splitext(osp.basename(gt_name))
193
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
194
+ input_path = osp.join(input_folder, input_name)
195
+ gt_path = osp.join(gt_folder, gt_name)
196
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
197
+ return paths
198
+
199
+
200
+ def paired_paths_from_folder(folders, keys, filename_tmpl):
201
+ """Generate paired paths from folders.
202
+
203
+ Args:
204
+ folders (list[str]): A list of folder path. The order of list should
205
+ be [input_folder, gt_folder].
206
+ keys (list[str]): A list of keys identifying folders. The order should
207
+ be in consistent with folders, e.g., ['lq', 'gt'].
208
+ filename_tmpl (str): Template for each filename. Note that the
209
+ template excludes the file extension. Usually the filename_tmpl is
210
+ for files in the input folder.
211
+
212
+ Returns:
213
+ list[str]: Returned path list.
214
+ """
215
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
216
+ f'But got {len(folders)}')
217
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
218
+ input_folder, gt_folder = folders
219
+ input_key, gt_key = keys
220
+
221
+ input_paths = list(scandir(input_folder))
222
+ gt_paths = list(scandir(gt_folder))
223
+ assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
224
+ f'{len(input_paths)}, {len(gt_paths)}.')
225
+ paths = []
226
+ for gt_path in gt_paths:
227
+ basename, ext = osp.splitext(osp.basename(gt_path))
228
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
229
+ input_path = osp.join(input_folder, input_name)
230
+ assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.'
231
+ gt_path = osp.join(gt_folder, gt_path)
232
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
233
+ return paths
234
+
235
+
236
+ def paths_from_folder(folder):
237
+ """Generate paths from folder.
238
+
239
+ Args:
240
+ folder (str): Folder path.
241
+
242
+ Returns:
243
+ list[str]: Returned path list.
244
+ """
245
+
246
+ paths = list(scandir(folder))
247
+ paths = [osp.join(folder, path) for path in paths]
248
+ return paths
249
+
250
+
251
+ def paths_from_lmdb(folder):
252
+ """Generate paths from lmdb.
253
+
254
+ Args:
255
+ folder (str): Folder path.
256
+
257
+ Returns:
258
+ list[str]: Returned path list.
259
+ """
260
+ if not folder.endswith('.lmdb'):
261
+ raise ValueError(f'Folder {folder}folder should in lmdb format.')
262
+ with open(osp.join(folder, 'meta_info.txt')) as fin:
263
+ paths = [line.split('.')[0] for line in fin]
264
+ return paths
265
+
266
+
267
+ def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
268
+ """Generate Gaussian kernel used in `duf_downsample`.
269
+
270
+ Args:
271
+ kernel_size (int): Kernel size. Default: 13.
272
+ sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
273
+
274
+ Returns:
275
+ np.array: The Gaussian kernel.
276
+ """
277
+ from scipy.ndimage import filters as filters
278
+ kernel = np.zeros((kernel_size, kernel_size))
279
+ # set element at the middle to one, a dirac delta
280
+ kernel[kernel_size // 2, kernel_size // 2] = 1
281
+ # gaussian-smooth the dirac, resulting in a gaussian filter
282
+ return filters.gaussian_filter(kernel, sigma)
283
+
284
+
285
+ def duf_downsample(x, kernel_size=13, scale=4):
286
+ """Downsamping with Gaussian kernel used in the DUF official code.
287
+
288
+ Args:
289
+ x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
290
+ kernel_size (int): Kernel size. Default: 13.
291
+ scale (int): Downsampling factor. Supported scale: (2, 3, 4).
292
+ Default: 4.
293
+
294
+ Returns:
295
+ Tensor: DUF downsampled frames.
296
+ """
297
+ assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
298
+
299
+ squeeze_flag = False
300
+ if x.ndim == 4:
301
+ squeeze_flag = True
302
+ x = x.unsqueeze(0)
303
+ b, t, c, h, w = x.size()
304
+ x = x.view(-1, 1, h, w)
305
+ pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
306
+ x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
307
+
308
+ gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
309
+ gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
310
+ x = F.conv2d(x, gaussian_filter, stride=scale)
311
+ x = x[:, :, 2:-2, 2:-2]
312
+ x = x.view(b, t, c, x.size(2), x.size(3))
313
+ if squeeze_flag:
314
+ x = x.squeeze(0)
315
+ return x