gy65896 commited on
Commit
73ba284
·
verified ·
1 Parent(s): 8690f76

Upload 51 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +8 -0
  2. README.md +298 -0
  3. app.py +89 -0
  4. ckpts/ckpts_file.txt +0 -0
  5. data/data_file.txt +0 -0
  6. image/low_haze_rain_00469_01_lq.png +3 -0
  7. image/low_haze_snow_00337_01_lq.png +3 -0
  8. img_file/OneRestore_poster.png +3 -0
  9. img_file/abstract.jpg +3 -0
  10. img_file/cal_psnr_ssim.py +96 -0
  11. img_file/clear_img.jpg +0 -0
  12. img_file/control1.jpg +0 -0
  13. img_file/control2.jpg +0 -0
  14. img_file/depth_map.jpg +0 -0
  15. img_file/l+h+r.jpg +0 -0
  16. img_file/l+h+s.jpg +0 -0
  17. img_file/light_map.jpg +0 -0
  18. img_file/logo_onerestore.png +0 -0
  19. img_file/metric.png +0 -0
  20. img_file/metrics_CDD-11_psnr_ssim.xlsx +0 -0
  21. img_file/pipeline.jpg +3 -0
  22. img_file/rain_mask.jpg +0 -0
  23. img_file/real.jpg +3 -0
  24. img_file/snow_mask.png +0 -0
  25. img_file/syn.jpg +0 -0
  26. makedataset.py +157 -0
  27. model/Embedder.py +238 -0
  28. model/OneRestore.py +314 -0
  29. model/loss.py +222 -0
  30. output/low_haze_rain_00469_01_lq.png +3 -0
  31. output/low_haze_snow_00337_01_lq.png +3 -0
  32. remove_optim.py +32 -0
  33. requirements.txt +10 -0
  34. syn_data/data/clear/1.jpg +0 -0
  35. syn_data/data/depth_map/1.jpg +0 -0
  36. syn_data/data/light_map/1.jpg +0 -0
  37. syn_data/data/rain_mask/00001.jpg +0 -0
  38. syn_data/data/rain_mask/00002.jpg +0 -0
  39. syn_data/data/rain_mask/00003.jpg +0 -0
  40. syn_data/data/snow_mask/beautiful_smile_00001.jpg +0 -0
  41. syn_data/data/snow_mask/beautiful_smile_00006.jpg +0 -0
  42. syn_data/data/snow_mask/beautiful_smile_00008.jpg +0 -0
  43. syn_data/out/1.jpg +0 -0
  44. syn_data/syn_data.py +86 -0
  45. test.py +82 -0
  46. train_Embedder.py +104 -0
  47. train_OneRestore_multi-gpu.py +153 -0
  48. train_OneRestore_single-gpu.py +140 -0
  49. utils/glove.6B.300d.txt +5 -0
  50. utils/utils.py +232 -0
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ image/low_haze_rain_00469_01_lq.png filter=lfs diff=lfs merge=lfs -text
37
+ image/low_haze_snow_00337_01_lq.png filter=lfs diff=lfs merge=lfs -text
38
+ img_file/abstract.jpg filter=lfs diff=lfs merge=lfs -text
39
+ img_file/OneRestore_poster.png filter=lfs diff=lfs merge=lfs -text
40
+ img_file/pipeline.jpg filter=lfs diff=lfs merge=lfs -text
41
+ img_file/real.jpg filter=lfs diff=lfs merge=lfs -text
42
+ output/low_haze_rain_00469_01_lq.png filter=lfs diff=lfs merge=lfs -text
43
+ output/low_haze_snow_00337_01_lq.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ </div>
2
+ <div align=center>
3
+ <img src="https://github.com/gy65896/OneRestore/blob/main/img_file/logo_onerestore.png" width="200">
4
+ </div>
5
+
6
+ # <p align=center> [ECCV 2024] OneRestore: A Universal Restoration Framework for Composite Degradation</p>
7
+
8
+
9
+ <div align="center">
10
+
11
+ [![ArXiv](https://img.shields.io/badge/OneRestore-ArXiv-red.svg)](https://arxiv.org/abs/2407.04621)
12
+ [![Paper](https://img.shields.io/badge/OneRestore-Paper-purple.svg)](https://arxiv.org/abs/2407.04621)
13
+ [![Web](https://img.shields.io/badge/OneRestore-Web-blue.svg)](https://gy65896.github.io/projects/ECCV2024_OneRestore/index.html)
14
+ [![Poster](https://img.shields.io/badge/OneRestore-Poster-green.svg)](https://github.com/gy65896/OneRestore/blob/main/img_file/OneRestore_poster.png)
15
+ [![Video](https://img.shields.io/badge/OneRestore-Video-orange.svg)](https://www.youtube.com/watch?v=AFr5tZdPlZ4)
16
+
17
+ [![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2Fgy65896%2FOneRestore&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false)](https://hits.seeyoufarm.com)
18
+ [![Hugging Face Demo](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-blue)](https://huggingface.co/spaces/gy65896/OneRestore)
19
+ [![Closed Issues](https://img.shields.io/github/issues-closed/gy65896/OneRestore)](https://github.com/gy65896/OneRestore/issues?q=is%3Aissue+is%3Aclosed)
20
+ [![Open Issues](https://img.shields.io/github/issues/gy65896/OneRestore)](https://github.com/gy65896/OneRestore/issues)
21
+
22
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/onerestore-a-universal-restoration-framework/low-light-image-enhancement-on-lol)](https://paperswithcode.com/sota/low-light-image-enhancement-on-lol?p=onerestore-a-universal-restoration-framework)
23
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/onerestore-a-universal-restoration-framework/image-dehazing-on-sots-outdoor)](https://paperswithcode.com/sota/image-dehazing-on-sots-outdoor?p=onerestore-a-universal-restoration-framework)
24
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/onerestore-a-universal-restoration-framework/rain-removal-on-did-mdn)](https://paperswithcode.com/sota/rain-removal-on-did-mdn?p=onerestore-a-universal-restoration-framework)
25
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/onerestore-a-universal-restoration-framework/snow-removal-on-snow100k)](https://paperswithcode.com/sota/snow-removal-on-snow100k?p=onerestore-a-universal-restoration-framework)
26
+
27
+ </div>
28
+ <div align=center>
29
+ <img src="https://github.com/gy65896/OneRestore/blob/main/img_file/abstract.jpg" width="720">
30
+ </div>
31
+
32
+ ---
33
+ >**OneRestore: A Universal Restoration Framework for Composite Degradation**<br> [Yu Guo](https://scholar.google.com/citations?user=klYz-acAAAAJ&hl=zh-CN)<sup>† </sup>, [Yuan Gao](https://scholar.google.com.hk/citations?user=4JpRnU4AAAAJ&hl=zh-CN)<sup>† </sup>, [Yuxu Lu](https://scholar.google.com.hk/citations?user=XXge2_0AAAAJ&hl=zh-CN), [Huilin Zhu](https://scholar.google.com.hk/citations?hl=zh-CN&user=fluPrxcAAAAJ), [Ryan Wen Liu](http://mipc.whut.edu.cn/index.html)<sup>* </sup>, [Shengfeng He](http://www.shengfenghe.com/)<sup>* </sup> <br>
34
+ († Co-first Author, * Corresponding Author)<br>
35
+ >European Conference on Computer Vision
36
+
37
+ > **Abstract:** *In real-world scenarios, image impairments often manifest as composite degradations, presenting a complex interplay of elements such as low light, haze, rain, and snow. Despite this reality, existing restoration methods typically target isolated degradation types, thereby falling short in environments where multiple degrading factors coexist. To bridge this gap, our study proposes a versatile imaging model that consolidates four physical corruption paradigms to accurately represent complex, composite degradation scenarios. In this context, we propose OneRestore, a novel transformer-based framework designed for adaptive, controllable scene restoration. The proposed framework leverages a unique cross-attention mechanism, merging degraded scene descriptors with image features, allowing for nuanced restoration. Our model allows versatile input scene descriptors, ranging from manual text embeddings to automatic extractions based on visual attributes. Our methodology is further enhanced through a composite degradation restoration loss, using extra degraded images as negative samples to fortify model constraints. Comparative results on synthetic and real-world datasets demonstrate OneRestore as a superior solution, significantly advancing the state-of-the-art in addressing complex, composite degradations.*
38
+ ---
39
+
40
+ ## News 🚀
41
+ * **2024.09.07**: [Hugging Face Demo](https://huggingface.co/spaces/gy65896/OneRestore) is released.
42
+ * **2024.09.05**: Video and poster are released.
43
+ * **2024.09.04**: Code for data synthesis is released.
44
+ * **2024.07.27**: Code for multiple GPUs training is released.
45
+ * **2024.07.20**: [New Website](https://gy65896.github.io/projects/ECCV2024_OneRestore) has been created.
46
+ * **2024.07.10**: [Paper](https://arxiv.org/abs/2407.04621) is released on ArXiv.
47
+ * **2024.07.07**: Code and Dataset are released.
48
+ * **2024.07.02**: OneRestore is accepted by [ECCV2024](https://eccv.ecva.net/).
49
+
50
+ ## Network Architecture
51
+
52
+ </div>
53
+ <div align=center>
54
+ <img src="https://github.com/gy65896/OneRestore/blob/main/img_file/pipeline.jpg" width="1080">
55
+ </div>
56
+
57
+ ## Quick Start
58
+
59
+ ### Install
60
+
61
+ - python 3.7
62
+ - cuda 11.7
63
+
64
+ ```
65
+ # git clone this repository
66
+ git clone https://github.com/gy65896/OneRestore.git
67
+ cd OneRestore
68
+
69
+ # create new anaconda env
70
+ conda create -n onerestore python=3.7
71
+ conda activate onerestore
72
+
73
+ # download ckpts
74
+ put embedder_model.tar and onerestore_cdd-11.tar in ckpts folder
75
+
76
+ # install pytorch (Take cuda 11.7 as an example to install torch 1.13)
77
+ pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117
78
+
79
+ # install other packages
80
+ pip install -r requirements.txt
81
+ pip install gensim
82
+ ```
83
+
84
+ ### Pretrained Models
85
+
86
+ Please download our pre-trained models and put them in `./ckpts`.
87
+
88
+ | Model | Description
89
+ | :--- | :----------
90
+ |[embedder_model.tar](https://1drv.ms/u/s!As3rCDROnrbLgqpnhSQFIoD9msXWOA?e=aUpHOT) | Text/Visual Embedder trained on our CDD-11.
91
+ |[onerestore_cdd-11.tar](https://1drv.ms/u/s!As3rCDROnrbLgqpmWkGBku6oj33efg?e=7yUGfN) | OneRestore trained on our CDD-11.
92
+ |[onerestore_real.tar](https://1drv.ms/u/s!As3rCDROnrbLgqpi-iJOyN6OSYqiaA?e=QFfMeL) | OneRestore trained on our CDD-11 for Real Scenes.
93
+ |[onerestore_lol.tar](https://1drv.ms/u/s!As3rCDROnrbLgqpkSoVB1j-wYHFpHg?e=0gR9pn) | OneRestore trained on LOL (low light enhancement benchmark).
94
+ |[onerestore_reside_ots.tar](https://1drv.ms/u/s!As3rCDROnrbLgqpjGh8KjfM_QIJzEw?e=zabGTw) | OneRestore trained on RESIDE-OTS (image dehazing benchmark).
95
+ |[onerestore_rain1200.tar](https://1drv.ms/u/s!As3rCDROnrbLgqplAFHv6B348jarGA?e=GuduMT) | OneRestore trained on Rain1200 (image deraining benchmark).
96
+ |[onerestore_snow100k.tar](https://1drv.ms/u/s!As3rCDROnrbLgqphsWWxLZN_7JFJDQ?e=pqezzo) | OneRestore trained on Snow100k-L (image desnowing benchmark).
97
+
98
+ ### Inference
99
+
100
+ We provide two samples in `./image` for the quick inference:
101
+
102
+ ```
103
+ python test.py --embedder-model-path ./ckpts/embedder_model.tar --restore-model-path ./ckpts/onerestore_cdd-11.tar --input ./image/ --output ./output/ --concat
104
+ ```
105
+
106
+ You can also input the prompt to perform controllable restoration. For example:
107
+
108
+ ```
109
+ python test.py --embedder-model-path ./ckpts/embedder_model.tar --restore-model-path ./ckpts/onerestore_cdd-11.tar --prompt low_haze --input ./image/ --output ./output/ --concat
110
+ ```
111
+
112
+ ## Training
113
+
114
+ ### Prepare Dataset
115
+
116
+ We provide the download link of our Composite Degradation Dataset with 11 types of degradation ([CDD-11](https://1drv.ms/f/s!As3rCDROnrbLgqpezG4sao-u9ddDhw?e=A0REHx)).
117
+
118
+ Preparing the train and test datasets as follows:
119
+
120
+ ```
121
+ ./data/
122
+ |--train
123
+ | |--clear
124
+ | | |--000001.png
125
+ | | |--000002.png
126
+ | |--low
127
+ | |--haze
128
+ | |--rain
129
+ | |--snow
130
+ | |--low_haze
131
+ | |--low_rain
132
+ | |--low_snow
133
+ | |--haze_rain
134
+ | |--haze_snow
135
+ | |--low_haze_rain
136
+ | |--low_haze_snow
137
+ |--test
138
+ ```
139
+ ### Train Model
140
+
141
+ **1. Train Text/Visual Embedder by**
142
+
143
+ ```
144
+ python train_Embedder.py --train-dir ./data/CDD-11_train --test-dir ./data/CDD-11_test --check-dir ./ckpts --batch 256 --num-workers 0 --epoch 200 --lr 1e-4 --lr-decay 50
145
+ ```
146
+
147
+ **2. Remove the optimizer weights in the Embedder model file by**
148
+
149
+ ```
150
+ python remove_optim.py --type Embedder --input-file ./ckpts/embedder_model.tar --output-file ./ckpts/embedder_model.tar
151
+ ```
152
+
153
+ **3. Generate the `dataset.h5` file for training OneRestore by**
154
+
155
+ ```
156
+ python makedataset.py --train-path ./data/CDD-11_train --data-name dataset.h5 --patch-size 256 --stride 200
157
+ ```
158
+
159
+ **4. Train OneRestore model by**
160
+
161
+ - **Single GPU**
162
+
163
+ ```
164
+ python train_OneRestore_single-gpu.py --embedder-model-path ./ckpts/embedder_model.tar --save-model-path ./ckpts --train-input ./dataset.h5 --test-input ./data/CDD-11_test --output ./result/ --epoch 120 --bs 4 --lr 1e-4 --adjust-lr 30 --num-works 4
165
+ ```
166
+
167
+ - **Multiple GPUs**
168
+
169
+ Assuming you train the OneRestore model using 4 GPUs (e.g., 0, 1, 2, and 3), you can use the following command. Note that the number of nproc_per_node should equal the number of GPUs.
170
+
171
+ ```
172
+ CUDA_VISIBLE_DEVICES=0, 1, 2, 3 torchrun --nproc_per_node=4 train_OneRestore_multi-gpu.py --embedder-model-path ./ckpts/embedder_model.tar --save-model-path ./ckpts --train-input ./dataset.h5 --test-input ./data/CDD-11_test --output ./result/ --epoch 120 --bs 4 --lr 1e-4 --adjust-lr 30 --num-works 4
173
+ ```
174
+
175
+ **5. Remove the optimizer weights in the OneRestore model file by**
176
+
177
+ ```
178
+ python remove_optim.py --type OneRestore --input-file ./ckpts/onerestore_model.tar --output-file ./ckpts/onerestore_model.tar
179
+ ```
180
+
181
+ ### Customize your own composite degradation dataset
182
+
183
+ **1. Prepare raw data**
184
+
185
+ - Collect your own clear images.
186
+ - Generate the depth map based on [MegaDepth](https://github.com/zhengqili/MegaDepth).
187
+ - Generate the light map based on [LIME](https://github.com/estija/LIME).
188
+ - Generate the rain mask database based on [RainStreakGen](https://github.com/liruoteng/RainStreakGen?tab=readme-ov-file).
189
+ - Download the snow mask database from [Snow100k](https://sites.google.com/view/yunfuliu/desnownet).
190
+
191
+ A generated example is as follows:
192
+
193
+ | Clear Image | Depth Map | Light Map | Rain Mask | Snow Mask
194
+ | :--- | :---| :---| :--- | :---
195
+ | <img src="https://github.com/gy65896/OneRestore/blob/main/img_file/clear_img.jpg" width="200"> | <img src="https://github.com/gy65896/OneRestore/blob/main/img_file/depth_map.jpg" width="200"> | <img src="https://github.com/gy65896/OneRestore/blob/main/img_file/light_map.jpg" width="200"> | <img src="https://github.com/gy65896/OneRestore/blob/main/img_file/rain_mask.jpg" width="200"> | <img src="https://github.com/gy65896/OneRestore/blob/main/img_file/snow_mask.png" width="200">
196
+
197
+ (Note: The rain and snow masks do not require strict alignment with the image.)
198
+
199
+ - Prepare the dataset as follows:
200
+
201
+ ```
202
+ ./syn_data/
203
+ |--data
204
+ | |--clear
205
+ | | |--000001.png
206
+ | | |--000002.png
207
+ | |--depth_map
208
+ | | |--000001.png
209
+ | | |--000002.png
210
+ | |--light_map
211
+ | | |--000001.png
212
+ | | |--000002.png
213
+ | |--rain_mask
214
+ | | |--aaaaaa.png
215
+ | | |--bbbbbb.png
216
+ | |--snow_mask
217
+ | | |--cccccc.png
218
+ | | |--dddddd.png
219
+ |--out
220
+ ```
221
+
222
+ **2. Generate composite degradation images**
223
+
224
+ - low+haze+rain
225
+
226
+ ```
227
+ python syn_data.py --hq-file ./data/clear/ --light-file ./data/light_map/ --depth-file ./data/depth_map/ --rain-file ./data/rain_mask/ --snow-file ./data/snow_mask/ --out-file ./out/ --low --haze --rain
228
+ ```
229
+
230
+ - low+haze+snow
231
+
232
+ ```
233
+ python syn_data.py --hq-file ./data/clear/ --light-file ./data/light_map/ --depth-file ./data/depth_map/ --rain-file ./data/rain_mask/ --snow-file ./data/snow_mask/ --out-file ./out/ --low --haze --snow
234
+ ```
235
+ (Note: The degradation types can be customized according to specific needs.)
236
+
237
+ | Clear Image | low+haze+rain | low+haze+snow
238
+ | :--- | :--- | :---
239
+ | <img src="https://github.com/gy65896/OneRestore/blob/main/img_file/clear_img.jpg" width="200"> | <img src="https://github.com/gy65896/OneRestore/blob/main/img_file/l+h+r.jpg" width="200"> | <img src="https://github.com/gy65896/OneRestore/blob/main/img_file/l+h+s.jpg" width="200">
240
+
241
+ ## Performance
242
+
243
+ ### CDD-11
244
+
245
+ | Types | Methods | Venue & Year | PSNR ↑ | SSIM ↑ | #Params |
246
+ |-------------------|-----------------------------------------------|--------------|----------|----------|------------|
247
+ | Input | [Input](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuNlQAAAAABf9KaFodlfC8H-K_MNiriFw?e=SiOrWU) | | 16.00 | 0.6008 | - |
248
+ | One-to-One | [MIRNet](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuMlQAAAAABBzDLjLu69noXflImQ2V9ng?e=4wohVK) | ECCV2020 | 25.97 | 0.8474 | 31.79M |
249
+ | One-to-One | [MPRNet](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuLlQAAAAAB_iz3hjLHZDMi-RyxHKgDDg?e=SwSQML) | CVPR2021 | 25.47 | 0.8555 | 15.74M |
250
+ | One-to-One | [MIRNetv2](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuQlQAAAAAB2miyepdTE3qdy4z2-LM4pg?e=moXVAR) | TPAMI2022 | 25.37 | 0.8335 | 5.86M |
251
+ | One-to-One | [Restormer](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuPlQAAAAABE86t03kpAVm_TZDIBPKolw?e=vHAR7A) | CVPR2022 | 26.99 | 0.8646 | 26.13M |
252
+ | One-to-One | [DGUNet](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuOlQAAAAABZkHj8tMamqaGhQ0w4VwFrg?e=lfDUlx) | CVPR2022 | 26.92 | 0.8559 | 17.33M |
253
+ | One-to-One | [NAFNet](https://1drv.ms/u/c/cbb69e4e3408ebcd/EWm9jiJiZLlLgq1trYO67EsB42LrjGpepvpS4oLqKnj8xg?e=5Efa4W) | ECCV2022 | 24.13 | 0.7964 | 17.11M |
254
+ | One-to-One | [SRUDC](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuWlQAAAAABf9RNAUZH_xL6wF4aODWKqA?e=h4EqVN) | ICCV2023 | 27.64 | 0.8600 | 6.80M |
255
+ | One-to-One | [Fourmer](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuXlQAAAAABQKrbA47G8kMD2cf7Chq5EQ?e=vOiWV0) | ICML2023 | 23.44 | 0.7885 | 0.55M |
256
+ | One-to-One | [OKNet](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuVlQAAAAABSMzfS1xEOxLeuvw8HsGyMw?e=jRmf9t) | AAAI2024 | 26.33 | 0.8605 | 4.72M |
257
+ | One-to-Many | [AirNet](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMualQAAAAABYJ96PX0fipkP93zRXN_NVw?e=sXFOl8) | CVPR2022 | 23.75 | 0.8140 | 8.93M |
258
+ | One-to-Many | [TransWeather](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuZlQAAAAABoBiLjwJ8L2kl6rGQO5PeJA?e=msprhI) | CVPR2022 | 23.13 | 0.7810 | 21.90M |
259
+ | One-to-Many | [WeatherDiff](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuYlQAAAAABxdWbznZA1CQ0Bh1JH_ze-A?e=LEkcZw) | TPAMI2023 | 22.49 | 0.7985 | 82.96M |
260
+ | One-to-Many | [PromptIR](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMublQAAAAAB9aGo3QK-WlKkL5ItITW9Hg?e=wXrJf1) | NIPS2023 | 25.90 | 0.8499 | 38.45M |
261
+ | One-to-Many | [WGWSNet](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMudlQAAAAABi3HUMldxdoLHgDcUNoWMPw?e=z0qjAH) | CVPR2023 | 26.96 | 0.8626 | 25.76M |
262
+ | One-to-Composite | [OneRestore](https://1drv.ms/u/c/cbb69e4e3408ebcd/Ec3rCDROnrYggMuclQAAAAABSmNvDBKR1u5rDtqQnZ8X7A?e=OcnrjY) | ECCV2024 | 28.47 | 0.8784 | 5.98M |
263
+ | One-to-Composite | [OneRestore<sup>† </sup>](https://1drv.ms/u/c/cbb69e4e3408ebcd/EVM43y_W_WxAjrZqZdK9sfoBk1vpSzKilG0m7T-3i3la-A?e=dbNsD3) | ECCV2024 | 28.72 | 0.8821 | 5.98M |
264
+
265
+ [Indicator calculation code](https://github.com/gy65896/OneRestore/blob/main/img_file/cal_psnr_ssim.py) and [numerical results](https://github.com/gy65896/OneRestore/blob/main/img_file/metrics_CDD-11_psnr_ssim.xlsx) can be download here.
266
+
267
+ </div>
268
+ <div align=center>
269
+ <img src="https://github.com/gy65896/OneRestore/blob/main/img_file/syn.jpg" width="1080">
270
+ </div>
271
+
272
+ ### Real Scene
273
+
274
+ </div>
275
+ <div align=center>
276
+ <img src="https://github.com/gy65896/OneRestore/blob/main/img_file/real.jpg" width="1080">
277
+ </div>
278
+
279
+ ### Controllability
280
+
281
+ </div>
282
+ <div align=center>
283
+ <img src="https://github.com/gy65896/OneRestore/blob/main/img_file/control1.jpg" width="410"><img src="https://github.com/gy65896/OneRestore/blob/main/img_file/control2.jpg" width="410">
284
+ </div>
285
+
286
+
287
+ ## Citation
288
+
289
+ ```
290
+ @inproceedings{guo2024onerestore,
291
+ title={OneRestore: A Universal Restoration Framework for Composite Degradation},
292
+ author={Guo, Yu and Gao, Yuan and Lu, Yuxu and Liu, Ryan Wen and He, Shengfeng},
293
+ booktitle={European Conference on Computer Vision},
294
+ year={2024}
295
+ }
296
+ ```
297
+
298
+ #### If you have any questions, please get in touch with me ([email protected]).
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import gradio as gr
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import numpy as np
7
+ from utils.utils import load_restore_ckpt, load_embedder_ckpt
8
+ import os
9
+ from gradio_imageslider import ImageSlider
10
+
11
+ # Enforce CPU usage
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ embedder_model_path = "ckpts/embedder_model.tar" # Update with actual path to embedder checkpoint
15
+ restorer_model_path = "ckpts/onerestore_cdd-11.tar" # Update with actual path to restorer checkpoint
16
+
17
+ # Load models on CPU only
18
+ embedder = load_embedder_ckpt(device, freeze_model=True, ckpt_name=embedder_model_path)
19
+ restorer = load_restore_ckpt(device, freeze_model=True, ckpt_name=restorer_model_path)
20
+
21
+ # Define image preprocessing and postprocessing
22
+ transform_resize = transforms.Compose([
23
+ transforms.Resize([224,224]),
24
+ transforms.ToTensor()
25
+ ])
26
+
27
+
28
+ def postprocess_image(tensor):
29
+ image = tensor.squeeze(0).cpu().detach().numpy()
30
+ image = (image) * 255 # Assuming output in [-1, 1], rescale to [0, 255]
31
+ image = np.clip(image, 0, 255).astype("uint8") # Clip values to [0, 255]
32
+ return Image.fromarray(image.transpose(1, 2, 0)) # Reorder to (H, W, C)
33
+
34
+ # Define the enhancement function
35
+ def enhance_image(image, degradation_type=None):
36
+ # Preprocess the image
37
+ input_tensor = torch.Tensor((np.array(image)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
38
+ lq_em = transform_resize(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
39
+ lq_em = transform_resize(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
40
+
41
+ # Generate embedding
42
+ if degradation_type == "auto" or degradation_type is None:
43
+ text_embedding, _, [text] = embedder(lq_em, 'image_encoder')
44
+ else:
45
+ text_embedding, _, [text] = embedder([degradation_type], 'text_encoder')
46
+
47
+ # Model inference
48
+ with torch.no_grad():
49
+ enhanced_tensor = restorer(input_tensor, text_embedding)
50
+
51
+ # Postprocess the output
52
+ return (image, postprocess_image(enhanced_tensor)), text
53
+
54
+ # Define the Gradio interface
55
+ def inference(image, degradation_type=None):
56
+ return enhance_image(image, degradation_type)
57
+
58
+ #### Image,Prompts examples
59
+ examples = [
60
+ ['image/low_haze_rain_00469_01_lq.png'],
61
+ ['image/low_haze_snow_00337_01_lq.png'],
62
+ ]
63
+
64
+
65
+
66
+ # Create the Gradio app interface using updated API
67
+ interface = gr.Interface(
68
+ fn=inference,
69
+ inputs=[
70
+ gr.Image(type="pil", value="image/low_haze_rain_00469_01_lq.png"), # Image input
71
+ gr.Dropdown(['auto', 'low', 'haze', 'rain', 'snow',\
72
+ 'low_haze', 'low_rain', 'low_snow', 'haze_rain',\
73
+ 'haze_snow', 'low_haze_rain', 'low_haze_snow'], label="Degradation Type", value="auto") # Manual or auto degradation
74
+ ],
75
+ outputs=[
76
+ ImageSlider(label="Restored Image",
77
+ type="pil",
78
+ show_download_button=True,
79
+ ), # Enhanced image outputImageSlider(type="pil", show_download_button=True, ),
80
+ gr.Textbox(label="Degradation Type") # Display the estimated degradation type
81
+ ],
82
+ title="Image Restoration with OneRestore",
83
+ description="Upload an image and enhance it using OneRestore model. You can choose to let the model automatically estimate the degradation type or set it manually.",
84
+ examples=examples,
85
+ )
86
+
87
+ # Launch the app
88
+ if __name__ == "__main__":
89
+ interface.launch()
ckpts/ckpts_file.txt ADDED
File without changes
data/data_file.txt ADDED
File without changes
image/low_haze_rain_00469_01_lq.png ADDED

Git LFS Details

  • SHA256: ac5c71a539806d961d33b98e39c04c70be1a01b27a457d00493be4132b7facdf
  • Pointer size: 132 Bytes
  • Size of remote file: 1.64 MB
image/low_haze_snow_00337_01_lq.png ADDED

Git LFS Details

  • SHA256: b89f728f4b9498d7fcd15ab79d6a46ed76eb490a6e9971e7c2cab071b8f8cc20
  • Pointer size: 132 Bytes
  • Size of remote file: 1.5 MB
img_file/OneRestore_poster.png ADDED

Git LFS Details

  • SHA256: 86ee7b33d4e6b3024b12d60eb420a58b4f3b1cccb40f0569440a46e93daf816d
  • Pointer size: 133 Bytes
  • Size of remote file: 12 MB
img_file/abstract.jpg ADDED

Git LFS Details

  • SHA256: c3e79021de0ed441274a97a4c141c64e730293a95af65c3d963ee5cd3205ace1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.8 MB
img_file/cal_psnr_ssim.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import random
5
+ from skimage.metrics import peak_signal_noise_ratio as compare_psnr
6
+ from skimage.metrics import mean_squared_error as compare_mse
7
+ from skimage.metrics import structural_similarity as compare_ssim
8
+ # Modified function to add progress display using tqdm for better progress tracking
9
+ from tqdm import tqdm
10
+ import pandas as pd
11
+ # Updated function with progress display for PSNR and SSIM calculation
12
+ def calculate_psnr_ssim_with_progress(clear_folder, methods, degradation_types, win_size=7):
13
+ # Get list of all clear images
14
+ img_list = [img for img in os.listdir(clear_folder) if img.endswith('.png')]
15
+
16
+ # Initialize matrices to store mean PSNR and SSIM values
17
+ psnr_matrix = np.zeros((len(methods), len(degradation_types)))
18
+ ssim_matrix = np.zeros((len(methods), len(degradation_types)))
19
+
20
+ # Total number of tasks for progress tracking
21
+ total_tasks = len(methods) * len(degradation_types) * len(img_list)
22
+ print(len(methods), len(degradation_types), len(img_list))
23
+
24
+ # Create a progress bar
25
+ with tqdm(total=total_tasks, desc="Processing Images", unit="task") as pbar:
26
+ # Loop over methods
27
+ for k, method in enumerate(methods):
28
+ print(f"Processing method: {method}")
29
+
30
+ # Loop over degradation types
31
+ for j, degradation_type in enumerate(degradation_types):
32
+ psnr_values = []
33
+ ssim_values = []
34
+
35
+ # Loop over each image in the clear folder
36
+ for img_name in img_list:
37
+ clear_img_path = os.path.join(clear_folder, img_name)
38
+ degraded_img_path = f'./{method}/{degradation_type}/{img_name}'
39
+
40
+ # Read the clear and degraded images
41
+ clear_img = cv2.imread(clear_img_path) / 255.0
42
+ degraded_img = cv2.imread(degraded_img_path) / 255.0
43
+
44
+ # Ensure the images are read correctly
45
+ if clear_img is not None and degraded_img is not None:
46
+ # Compute PSNR and SSIM between clear and degraded image
47
+ psnr_value = compare_psnr(clear_img, degraded_img, data_range=1.0)
48
+
49
+ # Compute SSIM with specified window size and for multichannel images
50
+ ssim_value = compare_ssim(clear_img, degraded_img, multichannel=True,
51
+ win_size=min(win_size, clear_img.shape[0], clear_img.shape[1]),
52
+ channel_axis=-1, data_range=1.0)
53
+
54
+ # Store values
55
+ psnr_values.append(psnr_value)
56
+ ssim_values.append(ssim_value)
57
+
58
+ # Update progress bar after processing each image
59
+ pbar.update(1)
60
+
61
+ # Calculate mean PSNR and SSIM for the current method and degradation type
62
+ if psnr_values:
63
+ psnr_matrix[k, j] = np.mean(psnr_values)
64
+ if ssim_values:
65
+ ssim_matrix[k, j] = np.mean(ssim_values)
66
+
67
+ return psnr_matrix, ssim_matrix
68
+
69
+ def save_matrices_to_excel(psnr_matrix, ssim_matrix, methods, degradation_types, output_file='metrics.xlsx'):
70
+ # Create DataFrames for PSNR and SSIM matrices
71
+ psnr_df = pd.DataFrame(psnr_matrix, index=methods, columns=degradation_types)
72
+ ssim_df = pd.DataFrame(ssim_matrix, index=methods, columns=degradation_types)
73
+
74
+ # Create a writer to write both DataFrames to the same Excel file
75
+ with pd.ExcelWriter(output_file) as writer:
76
+ psnr_df.to_excel(writer, sheet_name='PSNR')
77
+ ssim_df.to_excel(writer, sheet_name='SSIM')
78
+
79
+ print(f'Matrices saved to {output_file}')
80
+
81
+ # Define the parameters
82
+ clear_folder = './00_gt'
83
+ methods = ['01_input', '02_MIRNet', '03_MPRNet', '04_MIRNetv2', '05_Restormer',
84
+ '06_DGUNet', '07_NAFNet', '08_SRUDC', '09_Fourmer', '10_OKNet', '11_AirNet',
85
+ '12_TransWeather', '13_WeatherDiff', '14_PromptIR', '15_WGWSNet', '16_OneRestore_visual', '17_OneRestore']
86
+ degradation_types = ['low', 'haze', 'rain', 'snow', 'low_haze', 'low_rain', 'low_snow', 'haze_rain', 'haze_snow', 'low_haze_rain', 'low_haze_snow']
87
+
88
+ # This is the function that will be used to calculate the PSNR and SSIM values across methods and degradation types
89
+ # To use the function, uncomment the line below and ensure the file paths are set correctly in your environment
90
+
91
+
92
+ psnr_matrix, ssim_matrix = calculate_psnr_ssim_with_progress(clear_folder, methods, degradation_types)
93
+ save_matrices_to_excel(psnr_matrix, ssim_matrix, methods, degradation_types)
94
+
95
+
96
+
img_file/clear_img.jpg ADDED
img_file/control1.jpg ADDED
img_file/control2.jpg ADDED
img_file/depth_map.jpg ADDED
img_file/l+h+r.jpg ADDED
img_file/l+h+s.jpg ADDED
img_file/light_map.jpg ADDED
img_file/logo_onerestore.png ADDED
img_file/metric.png ADDED
img_file/metrics_CDD-11_psnr_ssim.xlsx ADDED
Binary file (15.7 kB). View file
 
img_file/pipeline.jpg ADDED

Git LFS Details

  • SHA256: 80600ecac7ff326ef3d322ea4db08d27edf3595befdba088deb783ceb260afa3
  • Pointer size: 132 Bytes
  • Size of remote file: 2.72 MB
img_file/rain_mask.jpg ADDED
img_file/real.jpg ADDED

Git LFS Details

  • SHA256: bdd33e434a8370527d4c8d00603ce7b2b3d981eb5db58c38fe3ec32ed6a73c2d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.16 MB
img_file/snow_mask.png ADDED
img_file/syn.jpg ADDED
makedataset.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Wed Feb 12 20:00:46 2020
4
+
5
+ @author: Administrator
6
+ """
7
+
8
+ import os
9
+ import os.path
10
+ import random
11
+ import numpy as np
12
+ import cv2
13
+ import h5py
14
+ import torch
15
+ import torch.utils.data as udata
16
+ import argparse
17
+ from PIL import Image
18
+ class Dataset(udata.Dataset):
19
+ r"""Implements torch.utils.data.Dataset
20
+ """
21
+ def __init__(self, file, trainrgb=True,trainsyn = True, shuffle=False):
22
+ super(Dataset, self).__init__()
23
+ self.trainrgb = trainrgb
24
+ self.trainsyn = trainsyn
25
+ self.train_haze = file
26
+
27
+ h5f = h5py.File(self.train_haze, 'r')
28
+
29
+ self.keys = list(h5f.keys())
30
+ if shuffle:
31
+ random.shuffle(self.keys)
32
+ h5f.close()
33
+
34
+ def __len__(self):
35
+ return len(self.keys)
36
+
37
+ def __getitem__(self, index):
38
+
39
+ h5f = h5py.File(self.train_haze, 'r')
40
+
41
+ key = self.keys[index]
42
+ data = np.array(h5f[key])
43
+ h5f.close()
44
+ return torch.Tensor(data)
45
+
46
+ def data_augmentation(clear, mode):
47
+ r"""Performs dat augmentation of the input image
48
+
49
+ Args:
50
+ image: a cv2 (OpenCV) image
51
+ mode: int. Choice of transformation to apply to the image
52
+ 0 - no transformation
53
+ 1 - flip up and down
54
+ 2 - rotate counterwise 90 degree
55
+ 3 - rotate 90 degree and flip up and down
56
+ 4 - rotate 180 degree
57
+ 5 - rotate 180 degree and flip
58
+ 6 - rotate 270 degree
59
+ 7 - rotate 270 degree and flip
60
+ """
61
+ clear = np.transpose(clear, (2, 3, 0, 1))
62
+ if mode == 0:
63
+ # original
64
+ clear = clear
65
+ elif mode == 1:
66
+ # flip up and down
67
+ clear = np.flipud(clear)
68
+ elif mode == 2:
69
+ # rotate counterwise 90 degree
70
+ clear = np.rot90(clear)
71
+ elif mode == 3:
72
+ # rotate 90 degree and flip up and down
73
+ clear = np.rot90(clear)
74
+ clear = np.flipud(clear)
75
+ elif mode == 4:
76
+ # rotate 180 degree
77
+ clear = np.rot90(clear, k=2)
78
+ elif mode == 5:
79
+ # rotate 180 degree and flip
80
+ clear = np.rot90(clear, k=2)
81
+ clear = np.flipud(clear)
82
+ elif mode == 6:
83
+ # rotate 270 degree
84
+ clear = np.rot90(clear, k=3)
85
+ elif mode == 7:
86
+ # rotate 270 degree and flip
87
+ clear = np.rot90(clear, k=3)
88
+ clear = np.flipud(clear)
89
+ else:
90
+ raise Exception('Invalid choice of image transformation')
91
+ return np.transpose(clear, (2, 3, 0, 1))
92
+
93
+ def img_to_patches(img,win,stride,Syn=True):
94
+ typ, chl, raw, col = img.shape
95
+ chl = int(chl)
96
+ num_raw = np.ceil((raw-win)/stride+1).astype(np.uint8)
97
+ num_col = np.ceil((col-win)/stride+1).astype(np.uint8)
98
+ count = 0
99
+ total_process = int(num_col)*int(num_raw)
100
+ img_patches = np.zeros([typ, chl, win, win, total_process])
101
+ if Syn:
102
+ for i in range(num_raw):
103
+ for j in range(num_col):
104
+ if stride * i + win <= raw and stride * j + win <=col:
105
+ img_patches[:,:,:,:,count] = img[:, :, stride*i : stride*i + win, stride*j : stride*j + win]
106
+ elif stride * i + win > raw and stride * j + win<=col:
107
+ img_patches[:,:,:,:,count] = img[:, :,raw-win : raw,stride * j : stride * j + win]
108
+ elif stride * i + win <= raw and stride*j + win>col:
109
+ img_patches[:,:,:,:,count] = img[:, :,stride*i : stride*i + win, col-win : col]
110
+ else:
111
+ img_patches[:,:,:,:,count] = img[:, :,raw-win : raw,col-win : col]
112
+ img_patches[:,:,:,:,count] = data_augmentation(img_patches[:, :, :, :, count], np.random.randint(0, 7))
113
+ count +=1
114
+ return img_patches
115
+
116
+ def read_img(img):
117
+ return np.array(Image.open(img))/255.
118
+
119
+ def Train_data(args):
120
+ file_list = os.listdir(f'{args.train_path}/{args.gt_name}')
121
+
122
+ with h5py.File(args.data_name, 'w') as h5f:
123
+ count = 0
124
+ for i in range(len(file_list)):
125
+ print(file_list[i])
126
+ img_list = []
127
+
128
+ img_list.append(read_img(f'{args.train_path}/{args.gt_name}/{file_list[i]}'))
129
+ for j in args.degradation_name:
130
+ img_list.append(read_img(f'{args.train_path}/{j}/{file_list[i]}'))
131
+
132
+ img = np.stack(img_list,0)
133
+ img = img_to_patches(img.transpose(0, 3, 1, 2), args.patch_size, args.stride)
134
+
135
+ for nx in range(img.shape[4]):
136
+ data = img[:,:,:,:,nx]
137
+ print(count, data.shape)
138
+ h5f.create_dataset(str(count), data=data)
139
+ count += 1
140
+ h5f.close()
141
+
142
+ if __name__ == "__main__":
143
+
144
+ parser = argparse.ArgumentParser(description = "Building the training patch database")
145
+ parser.add_argument("--patch-size", type = int, default=256, help="Patch size")
146
+ parser.add_argument("--stride", type = int, default=200, help="Size of stride")
147
+
148
+ parser.add_argument("--train-path", type = str, default='./data/CDD-11_train', help="Train path")
149
+ parser.add_argument("--data-name", type = str, default='dataset.h5', help="Data name")
150
+
151
+ parser.add_argument("--gt-name", type = str, default='clear', help="HQ name")
152
+ parser.add_argument("--degradation-name", type = list, default=['low','haze','rain','snow',\
153
+ 'low_haze','low_rain','low_snow','haze_rain','haze_snow','low_haze_rain','low_haze_snow'], help="LQ name")
154
+
155
+ args = parser.parse_args()
156
+
157
+ Train_data(args)
model/Embedder.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch, torchvision
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms as transforms
6
+ from utils.utils_word_embedding import initialize_wordembedding_matrix
7
+
8
+ class Backbone(nn.Module):
9
+ def __init__(self, backbone='resnet18'):
10
+ super(Backbone, self).__init__()
11
+
12
+ if backbone == 'resnet18':
13
+ resnet = torchvision.models.resnet.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
14
+ elif backbone == 'resnet50':
15
+ resnet = torchvision.models.resnet.resnet50(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
16
+ elif backbone == 'resnet101':
17
+ resnet = torchvision.models.resnet.resnet101(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
18
+
19
+ self.block0 = nn.Sequential(
20
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool,
21
+ )
22
+ self.block1 = resnet.layer1
23
+ self.block2 = resnet.layer2
24
+ self.block3 = resnet.layer3
25
+ self.block4 = resnet.layer4
26
+
27
+ def forward(self, x, returned=[4]):
28
+ blocks = [self.block0(x)]
29
+
30
+ blocks.append(self.block1(blocks[-1]))
31
+ blocks.append(self.block2(blocks[-1]))
32
+ blocks.append(self.block3(blocks[-1]))
33
+ blocks.append(self.block4(blocks[-1]))
34
+
35
+ out = [blocks[i] for i in returned]
36
+ return out
37
+
38
+ class CosineClassifier(nn.Module):
39
+ def __init__(self, temp=0.05):
40
+ super(CosineClassifier, self).__init__()
41
+ self.temp = temp
42
+
43
+ def forward(self, img, concept, scale=True):
44
+ """
45
+ img: (bs, emb_dim)
46
+ concept: (n_class, emb_dim)
47
+ """
48
+ img_norm = F.normalize(img, dim=-1)
49
+ concept_norm = F.normalize(concept, dim=-1)
50
+ pred = torch.matmul(img_norm, concept_norm.transpose(0, 1))
51
+ if scale:
52
+ pred = pred / self.temp
53
+ return pred
54
+
55
+ class Embedder(nn.Module):
56
+ """
57
+ Text and Visual Embedding Model.
58
+ """
59
+ def __init__(self,
60
+ type_name,
61
+ feat_dim = 512,
62
+ mid_dim = 1024,
63
+ out_dim = 324,
64
+ drop_rate = 0.35,
65
+ cosine_cls_temp = 0.05,
66
+ wordembs = 'glove',
67
+ extractor_name = 'resnet18'):
68
+ super(Embedder, self).__init__()
69
+
70
+ mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
71
+ self.type_name = type_name
72
+ self.feat_dim = feat_dim
73
+ self.mid_dim = mid_dim
74
+ self.out_dim = out_dim
75
+ self.drop_rate = drop_rate
76
+ self.cosine_cls_temp = cosine_cls_temp
77
+ self.wordembs = wordembs
78
+ self.extractor_name = extractor_name
79
+ self.transform = transforms.Normalize(mean, std)
80
+
81
+ self._setup_word_embedding()
82
+ self._setup_image_embedding()
83
+
84
+ def _setup_image_embedding(self):
85
+ # image embedding
86
+ self.feat_extractor = Backbone(self.extractor_name)
87
+
88
+ img_emb_modules = [
89
+ nn.Conv2d(self.feat_dim, self.mid_dim, kernel_size=1, bias=False),
90
+ nn.BatchNorm2d(self.mid_dim),
91
+ nn.ReLU()
92
+ ]
93
+ if self.drop_rate > 0:
94
+ img_emb_modules += [nn.Dropout2d(self.drop_rate)]
95
+ self.img_embedder = nn.Sequential(*img_emb_modules)
96
+
97
+ self.img_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
98
+ self.img_final = nn.Linear(self.mid_dim, self.out_dim)
99
+
100
+ self.classifier = CosineClassifier(temp=self.cosine_cls_temp)
101
+
102
+ def _setup_word_embedding(self):
103
+
104
+ self.type2idx = {self.type_name[i]: i for i in range(len(self.type_name))}
105
+ self.num_type = len(self.type_name)
106
+ train_type = [self.type2idx[type_i] for type_i in self.type_name]
107
+ self.train_type = torch.LongTensor(train_type).to("cuda" if torch.cuda.is_available() else "cpu")
108
+
109
+ wordemb, self.word_dim = \
110
+ initialize_wordembedding_matrix(self.wordembs, self.type_name)
111
+
112
+ self.embedder = nn.Embedding(self.num_type, self.word_dim)
113
+ self.embedder.weight.data.copy_(wordemb)
114
+
115
+ self.mlp = nn.Sequential(
116
+ nn.Linear(self.word_dim, self.out_dim),
117
+ nn.ReLU(True)
118
+ )
119
+
120
+ def train_forward(self, batch):
121
+
122
+ scene, img = batch[0], self.transform(batch[1])
123
+ bs = img.shape[0]
124
+
125
+ # word embedding
126
+ scene_emb = self.embedder(self.train_type)
127
+ scene_weight = self.mlp(scene_emb)
128
+
129
+ #image embedding
130
+ img = self.feat_extractor(img)[0]
131
+ img = self.img_embedder(img)
132
+ img = self.img_avg_pool(img).squeeze(3).squeeze(2)
133
+ img = self.img_final(img)
134
+
135
+ pred = self.classifier(img, scene_weight)
136
+ label_loss = F.cross_entropy(pred, scene)
137
+ pred = torch.max(pred, dim=1)[1]
138
+ type_pred = self.train_type[pred]
139
+ correct_type = (type_pred == scene)
140
+ out = {
141
+ 'loss_total': label_loss,
142
+ 'acc_type': torch.div(correct_type.sum(),float(bs)),
143
+ }
144
+
145
+ return out
146
+
147
+ def image_encoder_forward(self, batch):
148
+ img = self.transform(batch)
149
+
150
+ # word embedding
151
+ scene_emb = self.embedder(self.train_type)
152
+ scene_weight = self.mlp(scene_emb)
153
+
154
+ #image embedding
155
+ img = self.feat_extractor(img)[0]
156
+ bs, _, h, w = img.shape
157
+ img = self.img_embedder(img)
158
+ img = self.img_avg_pool(img).squeeze(3).squeeze(2)
159
+ img = self.img_final(img)
160
+
161
+ pred = self.classifier(img, scene_weight)
162
+ pred = torch.max(pred, dim=1)[1]
163
+
164
+ out_embedding = torch.zeros((bs,self.out_dim)).to("cuda" if torch.cuda.is_available() else "cpu")
165
+ for i in range(bs):
166
+ out_embedding[i,:] = scene_weight[pred[i],:]
167
+ num_type = self.train_type[pred]
168
+ text_type = [self.type_name[num_type[i]] for i in range(bs)]
169
+
170
+ return out_embedding, num_type, text_type
171
+
172
+ def text_encoder_forward(self, text):
173
+
174
+ bs = len(text)
175
+
176
+ # word embedding
177
+ scene_emb = self.embedder(self.train_type)
178
+ scene_weight = self.mlp(scene_emb)
179
+
180
+ num_type = torch.zeros((bs)).to("cuda" if torch.cuda.is_available() else "cpu")
181
+ for i in range(bs):
182
+ num_type[i] = self.type2idx[text[i]]
183
+
184
+ out_embedding = torch.zeros((bs,self.out_dim)).to("cuda" if torch.cuda.is_available() else "cpu")
185
+ for i in range(bs):
186
+ out_embedding[i,:] = scene_weight[int(num_type[i]),:]
187
+ text_type = text
188
+
189
+ return out_embedding, num_type, text_type
190
+
191
+ def text_idx_encoder_forward(self, idx):
192
+
193
+ bs = idx.shape[0]
194
+
195
+ # word embedding
196
+ scene_emb = self.embedder(self.train_type)
197
+ scene_weight = self.mlp(scene_emb)
198
+
199
+ num_type = idx
200
+
201
+ out_embedding = torch.zeros((bs,self.out_dim)).to("cuda" if torch.cuda.is_available() else "cpu")
202
+ for i in range(bs):
203
+ out_embedding[i,:] = scene_weight[int(num_type[i]),:]
204
+
205
+ return out_embedding
206
+
207
+ def contrast_loss_forward(self, batch):
208
+
209
+ img = self.transform(batch)
210
+
211
+ #image embedding
212
+ img = self.feat_extractor(img)[0]
213
+ img = self.img_embedder(img)
214
+ img = self.img_avg_pool(img).squeeze(3).squeeze(2)
215
+ img = self.img_final(img)
216
+
217
+ return img
218
+
219
+ def forward(self, x, type = 'image_encoder'):
220
+
221
+ if type == 'train':
222
+ out = self.train_forward(x)
223
+
224
+ elif type == 'image_encoder':
225
+ with torch.no_grad():
226
+ out = self.image_encoder_forward(x)
227
+
228
+ elif type == 'text_encoder':
229
+ out = self.text_encoder_forward(x)
230
+
231
+ elif type == 'text_idx_encoder':
232
+ out = self.text_idx_encoder_forward(x)
233
+
234
+ elif type == 'visual_embed':
235
+ x = F.interpolate(x,size=(224,224),mode='bilinear')
236
+ out = self.contrast_loss_forward(x)
237
+
238
+ return out
model/OneRestore.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sun Jun 20 16:14:37 2021
4
+
5
+ @author: Administrator
6
+ """
7
+
8
+
9
+ from __future__ import absolute_import
10
+ from __future__ import division
11
+ from __future__ import print_function
12
+ from torchvision import transforms
13
+ import torch, math
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from einops import rearrange, repeat
17
+ import numbers
18
+
19
+ from thop import profile
20
+ import numpy as np
21
+ import time
22
+ from torchvision import transforms
23
+
24
+
25
+ class OneRestore(nn.Module):
26
+ def __init__(self, channel = 32):
27
+ super(OneRestore,self).__init__()
28
+ self.norm = lambda x: (x-0.5)/0.5
29
+ self.denorm = lambda x: (x+1)/2
30
+ self.in_conv = nn.Conv2d(3,channel,kernel_size=1,stride=1,padding=0,bias=False)
31
+ self.encoder = encoder(channel)
32
+ self.middle = backbone(channel)
33
+ self.decoder = decoder(channel)
34
+ self.out_conv = nn.Conv2d(channel,3,kernel_size=1,stride=1,padding=0,bias=False)
35
+
36
+ def forward(self,x,embedding):
37
+ x_in = self.in_conv(self.norm(x))
38
+ x_l, x_m, x_s, x_ss = self.encoder(x_in, embedding)
39
+ x_mid = self.middle(x_ss, embedding)
40
+ x_out = self.decoder(x_mid, x_ss, x_s, x_m, x_l, embedding)
41
+ out = self.out_conv(x_out) + x
42
+ return self.denorm(out)
43
+
44
+ class encoder(nn.Module):
45
+ def __init__(self,channel):
46
+ super(encoder,self).__init__()
47
+
48
+ self.el = ResidualBlock(channel)#16
49
+ self.em = ResidualBlock(channel*2)#32
50
+ self.es = ResidualBlock(channel*4)#64
51
+ self.ess = ResidualBlock(channel*8)#128
52
+
53
+ self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
54
+ self.conv_eltem = nn.Conv2d(channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False)#16 32
55
+ self.conv_emtes = nn.Conv2d(2*channel,4*channel,kernel_size=1,stride=1,padding=0,bias=False)#32 64
56
+ self.conv_estess = nn.Conv2d(4*channel,8*channel,kernel_size=1,stride=1,padding=0,bias=False)#64 128
57
+ self.conv_esstesss = nn.Conv2d(8*channel,16*channel,kernel_size=1,stride=1,padding=0,bias=False)#128 256
58
+
59
+ def forward(self,x,embedding):
60
+
61
+ elout = self.el(x, embedding)#16
62
+ x_emin = self.conv_eltem(self.maxpool(elout))#32
63
+ emout = self.em(x_emin, embedding)
64
+ x_esin = self.conv_emtes(self.maxpool(emout))
65
+ esout = self.es(x_esin, embedding)
66
+ x_esin = self.conv_estess(self.maxpool(esout))
67
+ essout = self.ess(x_esin, embedding)#128
68
+
69
+ return elout, emout, esout, essout#,esssout
70
+
71
+ class backbone(nn.Module):
72
+ def __init__(self,channel):
73
+ super(backbone,self).__init__()
74
+
75
+ self.s1 = ResidualBlock(channel*8)#128
76
+ self.s2 = ResidualBlock(channel*8)#128
77
+
78
+ def forward(self,x,embedding):
79
+
80
+ share1 = self.s1(x, embedding)
81
+ share2 = self.s2(share1, embedding)
82
+
83
+ return share2
84
+
85
+ class decoder(nn.Module):
86
+ def __init__(self,channel):
87
+ super(decoder,self).__init__()
88
+
89
+ self.dss = ResidualBlock(channel*8)#128
90
+ self.ds = ResidualBlock(channel*4)#64
91
+ self.dm = ResidualBlock(channel*2)#32
92
+ self.dl = ResidualBlock(channel)#16
93
+
94
+ #self.conv_dssstdss = nn.Conv2d(16*channel,8*channel,kernel_size=1,stride=1,padding=0,bias=False)#256 128
95
+ self.conv_dsstds = nn.Conv2d(8*channel,4*channel,kernel_size=1,stride=1,padding=0,bias=False)#128 64
96
+ self.conv_dstdm = nn.Conv2d(4*channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False)#64 32
97
+ self.conv_dmtdl = nn.Conv2d(2*channel,channel,kernel_size=1,stride=1,padding=0,bias=False)#32 16
98
+
99
+ def _upsample(self,x,y):
100
+ _,_,H0,W0 = y.size()
101
+ return F.interpolate(x,size=(H0,W0),mode='bilinear')
102
+
103
+ def forward(self, x, x_ss, x_s, x_m, x_l, embedding):
104
+
105
+ dssout = self.dss(x + x_ss, embedding)
106
+ x_dsin = self.conv_dsstds(self._upsample(dssout, x_s))
107
+ dsout = self.ds(x_dsin + x_s, embedding)
108
+ x_dmin = self.conv_dstdm(self._upsample(dsout, x_m))
109
+ dmout = self.dm(x_dmin + x_m, embedding)
110
+ x_dlin = self.conv_dmtdl(self._upsample(dmout, x_l))
111
+ dlout = self.dl(x_dlin + x_l, embedding)
112
+
113
+ return dlout
114
+
115
+
116
+ class ResidualBlock(nn.Module): # Edge-oriented Residual Convolution Block 面向边缘的残差网络块 解决梯度消失的问题
117
+ def __init__(self, channel, norm=False):
118
+ super(ResidualBlock, self).__init__()
119
+
120
+ self.el = TransformerBlock(channel, num_heads=8, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias')
121
+
122
+ def forward(self, x,embedding):
123
+ return self.el(x,embedding)
124
+
125
+ def to_3d(x):
126
+ return rearrange(x, 'b c h w -> b (h w) c')
127
+
128
+ def to_4d(x, h, w):
129
+ return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
130
+
131
+
132
+ class BiasFree_LayerNorm(nn.Module):
133
+ def __init__(self, normalized_shape):
134
+ super(BiasFree_LayerNorm, self).__init__()
135
+ if isinstance(normalized_shape, numbers.Integral):
136
+ normalized_shape = (normalized_shape,)
137
+ normalized_shape = torch.Size(normalized_shape)
138
+ assert len(normalized_shape) == 1
139
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
140
+ self.normalized_shape = normalized_shape
141
+
142
+ def forward(self, x):
143
+ sigma = x.var(-1, keepdim=True, unbiased=False)
144
+ return x / torch.sqrt(sigma + 1e-5) * self.weight
145
+
146
+ class WithBias_LayerNorm(nn.Module):
147
+ def __init__(self, normalized_shape):
148
+ super(WithBias_LayerNorm, self).__init__()
149
+ if isinstance(normalized_shape, numbers.Integral):
150
+ normalized_shape = (normalized_shape,)
151
+ normalized_shape = torch.Size(normalized_shape)
152
+ assert len(normalized_shape) == 1
153
+
154
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
155
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
156
+ self.normalized_shape = normalized_shape
157
+
158
+ def forward(self, x):
159
+ mu = x.mean(-1, keepdim=True)
160
+ sigma = x.var(-1, keepdim=True, unbiased=False)
161
+ return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
162
+
163
+ class LayerNorm(nn.Module):
164
+ def __init__(self, dim, LayerNorm_type):
165
+ super(LayerNorm, self).__init__()
166
+ if LayerNorm_type == 'BiasFree':
167
+ self.body = BiasFree_LayerNorm(dim)
168
+ else:
169
+ self.body = WithBias_LayerNorm(dim)
170
+
171
+ def forward(self, x):
172
+ h, w = x.shape[-2:]
173
+ return to_4d(self.body(to_3d(x)), h, w)
174
+
175
+ class Cross_Attention(nn.Module):
176
+ def __init__(self,
177
+ dim,
178
+ num_heads,
179
+ bias,
180
+ q_dim = 324):
181
+ super(Cross_Attention, self).__init__()
182
+ self.dim = dim
183
+ self.num_heads = num_heads
184
+ sqrt_q_dim = int(math.sqrt(q_dim))
185
+ self.resize = transforms.Resize([sqrt_q_dim, sqrt_q_dim])
186
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
187
+
188
+ self.q = nn.Linear(q_dim, q_dim, bias=bias)
189
+
190
+ self.kv = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias)
191
+ self.kv_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias)
192
+
193
+ self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
194
+ def forward(self, x, query):
195
+ b,c,h,w = x.shape
196
+
197
+ q = self.q(query)
198
+ k, v = self.kv_dwconv(self.kv(x)).chunk(2, dim=1)
199
+ k = self.resize(k)
200
+
201
+ q = repeat(q, 'b l -> b head c l', head=self.num_heads, c=self.dim//self.num_heads)
202
+ k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
203
+ v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
204
+
205
+ q = torch.nn.functional.normalize(q, dim=-1)
206
+ k = torch.nn.functional.normalize(k, dim=-1)
207
+
208
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
209
+ attn = attn.softmax(dim=-1)
210
+
211
+ out = (attn @ v)
212
+
213
+ out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
214
+
215
+ out = self.project_out(out)
216
+ return out
217
+
218
+ class Self_Attention(nn.Module):
219
+ def __init__(self,
220
+ dim,
221
+ num_heads,
222
+ bias):
223
+ super(Self_Attention, self).__init__()
224
+ self.num_heads = num_heads
225
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
226
+
227
+ self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
228
+ self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
229
+ self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
230
+ def forward(self, x):
231
+ b,c,h,w = x.shape
232
+
233
+ qkv = self.qkv_dwconv(self.qkv(x))
234
+ q,k,v = qkv.chunk(3, dim=1)
235
+
236
+ q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
237
+ k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
238
+ v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
239
+
240
+ q = torch.nn.functional.normalize(q, dim=-1)
241
+ k = torch.nn.functional.normalize(k, dim=-1)
242
+
243
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
244
+ attn = attn.softmax(dim=-1)
245
+
246
+ out = (attn @ v)
247
+
248
+ out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
249
+
250
+ out = self.project_out(out)
251
+ return out
252
+
253
+ class FeedForward(nn.Module):
254
+ def __init__(self,
255
+ dim,
256
+ ffn_expansion_factor,
257
+ bias):
258
+ super(FeedForward, self).__init__()
259
+
260
+ hidden_features = int(dim * ffn_expansion_factor)
261
+
262
+ self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
263
+
264
+ self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1,
265
+ groups=hidden_features * 2, bias=bias)
266
+
267
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
268
+
269
+ def forward(self, x):
270
+ x = self.project_in(x)
271
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
272
+ x = F.gelu(x1) * x2
273
+ x = self.project_out(x)
274
+ return x
275
+
276
+ class TransformerBlock(nn.Module):
277
+ def __init__(self,
278
+ dim,
279
+ num_heads=8,
280
+ ffn_expansion_factor=2.66,
281
+ bias=False,
282
+ LayerNorm_type='WithBias'):
283
+ super(TransformerBlock, self).__init__()
284
+ self.norm1 = LayerNorm(dim, LayerNorm_type)
285
+ self.cross_attn = Cross_Attention(dim, num_heads, bias)
286
+ self.norm2 = LayerNorm(dim, LayerNorm_type)
287
+ self.self_attn = Self_Attention(dim, num_heads, bias)
288
+ self.norm3 = LayerNorm(dim, LayerNorm_type)
289
+ self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
290
+
291
+ def forward(self, x, query):
292
+ x = x + self.cross_attn(self.norm1(x),query)
293
+ x = x + self.self_attn(self.norm2(x))
294
+ x = x + self.ffn(self.norm3(x))
295
+ return x
296
+
297
+ if __name__ == '__main__':
298
+ net = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu")
299
+ # x = torch.Tensor(np.random.random((2,3,256,256))).to("cuda" if torch.cuda.is_available() else "cpu")
300
+ # query = torch.Tensor(np.random.random((2, 324))).to("cuda" if torch.cuda.is_available() else "cpu")
301
+ # out = net(x, query)
302
+ # print(out.shape)
303
+ input = torch.randn(1, 3, 512, 512).to("cuda" if torch.cuda.is_available() else "cpu")
304
+ query = torch.Tensor(np.random.random((1, 324))).to("cuda" if torch.cuda.is_available() else "cpu")
305
+ macs, _ = profile(net, inputs=(input, query))
306
+ total = sum([param.nelement() for param in net.parameters()])
307
+ print('Macs = ' + str(macs/1000**3) + 'G')
308
+ print('Params = ' + str(total/1e6) + 'M')
309
+
310
+ from fvcore.nn import FlopCountAnalysis, parameter_count_table
311
+ flops = FlopCountAnalysis(net, (input, query))
312
+ print("FLOPs", flops.total()/1000**3)
313
+
314
+
model/loss.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.autograd import Variable
4
+ import torch.nn.functional as F
5
+ import cv2 as cv
6
+ import numpy as np
7
+ from matplotlib import pyplot as plt
8
+ from math import exp
9
+ from torchvision import transforms
10
+ from torchvision.models import vgg16
11
+ import torchvision
12
+ '''
13
+ MS-SSIM Loss
14
+ '''
15
+
16
+ def gaussian(window_size, sigma):
17
+ gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
18
+ return gauss/gauss.sum()
19
+
20
+
21
+ def create_window(window_size, channel=1):
22
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
23
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
24
+ window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
25
+ return window
26
+
27
+
28
+ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
29
+ # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
30
+ if val_range is None:
31
+ if torch.max(img1) > 128:
32
+ max_val = 255
33
+ else:
34
+ max_val = 1
35
+
36
+ if torch.min(img1) < -0.5:
37
+ min_val = -1
38
+ else:
39
+ min_val = 0
40
+ L = max_val - min_val
41
+ else:
42
+ L = val_range
43
+
44
+ padd = 0
45
+ (_, channel, height, width) = img1.size()
46
+ if window is None:
47
+ real_size = min(window_size, height, width)
48
+ window = create_window(real_size, channel=channel).to(img1.device)
49
+
50
+ mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
51
+ mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
52
+
53
+ mu1_sq = mu1.pow(2)
54
+ mu2_sq = mu2.pow(2)
55
+ mu1_mu2 = mu1 * mu2
56
+
57
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
58
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
59
+ sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
60
+
61
+ C1 = (0.01 * L) ** 2
62
+ C2 = (0.03 * L) ** 2
63
+
64
+ v1 = 2.0 * sigma12 + C2
65
+ v2 = sigma1_sq + sigma2_sq + C2
66
+ cs = torch.mean(v1 / v2) # contrast sensitivity
67
+
68
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
69
+
70
+ if size_average:
71
+ ret = ssim_map.mean()
72
+ else:
73
+ ret = ssim_map.mean(1).mean(1).mean(1)
74
+
75
+ if full:
76
+ return ret, cs
77
+ return ret
78
+
79
+
80
+ def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
81
+ weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(img1.device)
82
+ levels = weights.size()[0]
83
+ mssim = []
84
+ mcs = []
85
+ for _ in range(levels):
86
+ sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
87
+ mssim.append(sim)
88
+ mcs.append(cs)
89
+
90
+ img1 = F.avg_pool2d(img1, (2, 2))
91
+ img2 = F.avg_pool2d(img2, (2, 2))
92
+
93
+ mssim = torch.stack(mssim)
94
+ mcs = torch.stack(mcs)
95
+
96
+ # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
97
+ if normalize:
98
+ mssim = (mssim + 1) / 2
99
+ mcs = (mcs + 1) / 2
100
+
101
+ pow1 = mcs ** weights
102
+ pow2 = mssim ** weights
103
+ # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
104
+ output = torch.prod(pow1[:-1] * pow2[-1])
105
+ return output
106
+
107
+
108
+ # Classes to re-use window
109
+ class SSIM(torch.nn.Module):
110
+ def __init__(self, window_size=11, size_average=True, val_range=None):
111
+ super(SSIM, self).__init__()
112
+ self.window_size = window_size
113
+ self.size_average = size_average
114
+ self.val_range = val_range
115
+
116
+ # Assume 1 channel for SSIM
117
+ self.channel = 1
118
+ self.window = create_window(window_size)
119
+
120
+ def forward(self, img1, img2):
121
+ (_, channel, _, _) = img1.size()
122
+
123
+ if channel == self.channel and self.window.dtype == img1.dtype:
124
+ window = self.window
125
+ else:
126
+ window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
127
+ self.window = window
128
+ self.channel = channel
129
+
130
+ return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
131
+
132
+ class MSSSIM(torch.nn.Module):
133
+ def __init__(self, window_size=11, size_average=True, channel=3):
134
+ super(MSSSIM, self).__init__()
135
+ self.window_size = window_size
136
+ self.size_average = size_average
137
+ self.channel = channel
138
+
139
+ def forward(self, img1, img2):
140
+ # TODO: store window between calls if possible
141
+ return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
142
+
143
+ class TVLoss(nn.Module):
144
+ def __init__(self,TVLoss_weight=1):
145
+ super(TVLoss,self).__init__()
146
+ self.TVLoss_weight = TVLoss_weight
147
+
148
+ def forward(self,x):
149
+ batch_size = x.size()[0]
150
+ h_x = x.size()[2]
151
+ w_x = x.size()[3]
152
+ count_h = self._tensor_size(x[:,:,1:,:]) #算出总共求了多少次差
153
+ count_w = self._tensor_size(x[:,:,:,1:])
154
+ h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
155
+ # x[:,:,1:,:]-x[:,:,:h_x-1,:]就是对原图进行错位,分成两张像素位置差1的图片,第一张图片
156
+ # 从像素点1开始(原图从0开始),到最后一个像素点,第二张图片从像素点0开始,到倒数第二个
157
+ # 像素点,这样就实现了对原图进行错位,分成两张图的操作,做差之后就是原图中每个像素点与相
158
+ # 邻的下一个像素点的差。
159
+ w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
160
+ return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
161
+
162
+ def _tensor_size(self,t):
163
+ return t.size()[1]*t.size()[2]*t.size()[3]
164
+
165
+ def _tensor_size(self,t):
166
+ return t.size()[1]*t.size()[2]*t.size()[3]
167
+
168
+ class ContrastLoss(nn.Module):
169
+ def __init__(self):
170
+ super(ContrastLoss, self).__init__()
171
+ self.l1 = nn.L1Loss()
172
+ self.model = vgg16(weights = torchvision.models.VGG16_Weights.DEFAULT)
173
+ self.model = self.model.features[:16].to("cuda" if torch.cuda.is_available() else "cpu")
174
+ for param in self.model.parameters():
175
+ param.requires_grad = False
176
+ self.layer_name_mapping = {
177
+ '3': "relu1_2",
178
+ '8': "relu2_2",
179
+ '15': "relu3_3"
180
+ }
181
+
182
+ def gen_features(self, x):
183
+ output = []
184
+ for name, module in self.model._modules.items():
185
+ x = module(x)
186
+ if name in self.layer_name_mapping:
187
+ output.append(x)
188
+ return output
189
+ def forward(self, inp, pos, neg, out):
190
+ inp_t = inp
191
+ inp_x0 = self.gen_features(inp_t)
192
+ pos_t = pos
193
+ pos_x0 = self.gen_features(pos_t)
194
+ out_t = out
195
+ out_x0 = self.gen_features(out_t)
196
+ neg_t, neg_x0 = [],[]
197
+ for i in range(neg.shape[1]):
198
+ neg_i = neg[:,i,:,:]
199
+ neg_t.append(neg_i)
200
+ neg_x0_i = self.gen_features(neg_i)
201
+ neg_x0.append(neg_x0_i)
202
+ loss = 0
203
+ for i in range(len(pos_x0)):
204
+ pos_term = self.l1(out_x0[i], pos_x0[i].detach())
205
+ inp_term = self.l1(out_x0[i], inp_x0[i].detach())/(len(neg_x0)+1)
206
+ neg_term = sum(self.l1(out_x0[i], neg_x0[j][i].detach()) for j in range(len(neg_x0)))/(len(neg_x0)+1)
207
+ loss = loss + pos_term / (inp_term+neg_term+1e-7)
208
+ return loss / len(pos_x0)
209
+
210
+ class Total_loss(nn.Module):
211
+ def __init__(self, args):
212
+ super(Total_loss, self).__init__()
213
+ self.con_loss = ContrastLoss()
214
+ self.weight_sl1, self.weight_msssim, self.weight_drl = args.loss_weight
215
+
216
+ def forward(self, inp, pos, neg, out):
217
+ smooth_loss_l1 = F.smooth_l1_loss(out, pos)
218
+ msssim_loss = 1-msssim(out, pos, normalize=True)
219
+ c_loss = self.con_loss(inp[0], pos, neg, out)
220
+
221
+ total_loss = self.weight_sl1 * smooth_loss_l1 + self.weight_msssim * msssim_loss + self.weight_drl * c_loss
222
+ return total_loss
output/low_haze_rain_00469_01_lq.png ADDED

Git LFS Details

  • SHA256: 44f9a861af1f4672c1799e6c3cf20ca2759522dae5f78d9fe8b4540eefb206f6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.44 MB
output/low_haze_snow_00337_01_lq.png ADDED

Git LFS Details

  • SHA256: d007771f4541535683819733631de220abf76a3afbe01e39a69878477d3736b7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.16 MB
remove_optim.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, argparse
2
+ from model.OneRestore import OneRestore
3
+ from model.Embedder import Embedder
4
+
5
+ parser = argparse.ArgumentParser()
6
+
7
+ parser.add_argument("--type", type=str, default = 'OneRestore')
8
+ parser.add_argument("--input-file", type=str, default = './ckpts/onerestore_cdd-11.tar')
9
+ parser.add_argument("--output-file", type=str, default = './ckpts/onerestore_cdd-11.tar')
10
+
11
+ args = parser.parse_args()
12
+
13
+ if args.type == 'OneRestore':
14
+ restorer = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu")
15
+ restorer_info = torch.load(args.input_file, map_location='cuda:0')
16
+ weights_dict = {}
17
+ for k, v in restorer_info['state_dict'].items():
18
+ new_k = k.replace('module.', '') if 'module' in k else k
19
+ weights_dict[new_k] = v
20
+ restorer.load_state_dict(weights_dict)
21
+ torch.save(restorer.state_dict(), args.output_file)
22
+ elif args.type == 'Embedder':
23
+ combine_type = ['clear', 'low', 'haze', 'rain', 'snow',\
24
+ 'low_haze', 'low_rain', 'low_snow', 'haze_rain',\
25
+ 'haze_snow', 'low_haze_rain', 'low_haze_snow']
26
+ embedder = Embedder(combine_type).to("cuda" if torch.cuda.is_available() else "cpu")
27
+ embedder_info = torch.load(args.input_file)
28
+ embedder.load_state_dict(embedder_info['state_dict'])
29
+ torch.save(embedder.state_dict(), args.output_file)
30
+ else:
31
+ print('ERROR!')
32
+
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ pillow
2
+ numpy
3
+ scikit-image
4
+ pandas
5
+ einops
6
+ thop
7
+ fasttext
8
+ opencv-python
9
+ h5py
10
+ matplotlib
syn_data/data/clear/1.jpg ADDED
syn_data/data/depth_map/1.jpg ADDED
syn_data/data/light_map/1.jpg ADDED
syn_data/data/rain_mask/00001.jpg ADDED
syn_data/data/rain_mask/00002.jpg ADDED
syn_data/data/rain_mask/00003.jpg ADDED
syn_data/data/snow_mask/beautiful_smile_00001.jpg ADDED
syn_data/data/snow_mask/beautiful_smile_00006.jpg ADDED
syn_data/data/snow_mask/beautiful_smile_00008.jpg ADDED
syn_data/out/1.jpg ADDED
syn_data/syn_data.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, argparse, cv2, random
2
+ import numpy as np
3
+ from skimage import exposure
4
+
5
+ def guideFilter(I, p, winSize, eps):
6
+ mean_I = cv2.blur(I, winSize)
7
+ mean_p = cv2.blur(p, winSize)
8
+ mean_II = cv2.blur(I * I, winSize)
9
+ mean_Ip = cv2.blur(I * p, winSize)
10
+ var_I = mean_II - mean_I * mean_I
11
+ cov_Ip = mean_Ip - mean_I * mean_p
12
+ a = cov_Ip / (var_I + eps)
13
+ b = mean_p - a * mean_I
14
+ mean_a = cv2.blur(a, winSize)
15
+ mean_b = cv2.blur(b, winSize)
16
+ q = mean_a * I + mean_b
17
+ return q
18
+
19
+ def syn_low(img, light, img_gray, light_max=3,
20
+ light_min=2, noise_max=0.08, noise_min=0.03):
21
+ light = guideFilter(light, img_gray,(3,3),0.01)[:, :, np.newaxis]
22
+ n = np.random.uniform(noise_min, noise_max)
23
+ R = img / (light + 1e-7)
24
+ L = (light + 1e-7) ** np.random.uniform(light_min, light_max)
25
+ return np.clip(R * L + np.random.normal(0, n, img.shape), 0, 1)
26
+
27
+ def syn_haze(img, depth, beta_max=2.0, beta_min=1.0, A_max=0.9, A_min=0.6,
28
+ color_max=0, color_min=0):
29
+ beta = np.random.rand(1) * (beta_max - beta_min) + beta_min
30
+ t = np.exp(-np.minimum(1 - cv2.blur(depth,(22,22)),0.7) * beta)
31
+ A = np.random.rand(1) * (A_max - A_min) + A_min
32
+ A_random = np.random.rand(3) * (color_max - color_min) + color_min
33
+ A = A + A_random
34
+ return np.clip(img * t + A * (1 - t), 0, 1)
35
+
36
+ def syn_data(hq_file, light_file, depth_file, rain_file, snow_file, out_file,
37
+ low, haze, rain, snow):
38
+ file_list = os.listdir(hq_file)
39
+ rain_list = os.listdir(rain_file)
40
+ snow_list = os.listdir(snow_file)
41
+ num_rain = random.sample(range(0,len(rain_list)),len(rain_list))
42
+ num_snow = random.sample(range(0,len(snow_list)),len(snow_list))
43
+ for i in range(1, len(file_list)):
44
+ img = cv2.imread(hq_file+file_list[i])
45
+ w, h, _ = img.shape
46
+ light = cv2.cvtColor(cv2.imread(light_file + file_list[i]), cv2.COLOR_RGB2GRAY) / 255.0
47
+ depth = cv2.imread(depth_file + file_list[i]) / 255.0
48
+ rain_mask = cv2.imread(rain_file + rain_list[num_rain[i]]) / 255.0
49
+ rain_mask = cv2.resize(rain_mask,(h,w))
50
+ snow_mask = cv2.imread(snow_file + snow_list[num_snow[i]]) / 255.0
51
+ snow_mask = cv2.resize(snow_mask, (h, w))
52
+ img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)/ 255.0
53
+ lq = img.copy()/255.0
54
+ color_dis = 1
55
+
56
+ if low:
57
+ lq = syn_low(lq, light, img_gray)
58
+ if rain:
59
+ lq = lq+rain_mask
60
+ if snow:
61
+ lq = lq*(1-snow_mask)+color_dis*snow_mask
62
+ if haze:
63
+ lq = syn_haze(lq, depth)
64
+
65
+ # out = np.concatenate((lq*255.0,img),1)
66
+ out = lq*255.0
67
+ cv2.imwrite(out_file + file_list[i], out)
68
+
69
+ if __name__ == "__main__":
70
+ parser = argparse.ArgumentParser()
71
+ # load model
72
+ parser.add_argument("--hq-file", type=str, default = './data/clear/')
73
+ parser.add_argument("--light-file", type=str, default = './data/light_map/')
74
+ parser.add_argument("--depth-file", type=str, default = './data/depth_map/')
75
+ parser.add_argument("--rain-file", type=str, default = './data/rain_mask/')
76
+ parser.add_argument("--snow-file", type=str, default = './data/snow_mask/')
77
+ parser.add_argument("--out-file", type=str, default = './out/')
78
+ parser.add_argument("--low", action='store_true')
79
+ parser.add_argument("--haze", action='store_true')
80
+ parser.add_argument("--rain", action='store_true')
81
+ parser.add_argument("--snow", action='store_true')
82
+
83
+ args = parser.parse_args()
84
+
85
+ syn_data(args.hq_file, args.light_file, args.depth_file, args.rain_file,
86
+ args.snow_file, args.out_file, args.low, args.haze, args.rain, args.snow)
test.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, time, argparse
2
+ from PIL import Image
3
+ import numpy as np
4
+
5
+
6
+ import torch
7
+ from torchvision import transforms
8
+
9
+ from torchvision.utils import save_image as imwrite
10
+ from utils.utils import print_args, load_restore_ckpt, load_embedder_ckpt
11
+
12
+ transform_resize = transforms.Compose([
13
+ transforms.Resize([224,224]),
14
+ transforms.ToTensor()
15
+ ])
16
+
17
+ def main(args):
18
+
19
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
+ #train
21
+ print('> Model Initialization...')
22
+
23
+ embedder = load_embedder_ckpt(device, freeze_model=True, ckpt_name=args.embedder_model_path)
24
+ restorer = load_restore_ckpt(device, freeze_model=True, ckpt_name=args.restore_model_path)
25
+
26
+ os.makedirs(args.output,exist_ok=True)
27
+
28
+ files = os.listdir(argspar.input)
29
+ time_record = []
30
+ for i in files:
31
+ lq = Image.open(f'{argspar.input}/{i}')
32
+
33
+ with torch.no_grad():
34
+ lq_re = torch.Tensor((np.array(lq)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
35
+ lq_em = transform_resize(lq).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
36
+
37
+ start_time = time.time()
38
+
39
+ if args.prompt == None:
40
+ text_embedding, _, [text] = embedder(lq_em,'image_encoder')
41
+ print(f'This is {text} degradation estimated by visual embedder.')
42
+ else:
43
+ text_embedding, _, [text] = embedder([args.prompt],'text_encoder')
44
+ print(f'This is {text} degradation generated by input text.')
45
+
46
+ out = restorer(lq_re, text_embedding)
47
+
48
+ run_time = time.time()-start_time
49
+ time_record.append(run_time)
50
+
51
+ if args.concat:
52
+ out = torch.cat((lq_re, out), dim=3)
53
+
54
+ imwrite(out, f'{args.output}/{i}', range=(0, 1))
55
+
56
+ print(f'{i} Running Time: {run_time:.4f}.')
57
+ print(f'Average time is {np.mean(np.array(run_time))}')
58
+
59
+
60
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
61
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
62
+ if __name__ == '__main__':
63
+
64
+ parser = argparse.ArgumentParser(description = "OneRestore Running")
65
+
66
+ # load model
67
+ parser.add_argument("--embedder-model-path", type=str, default = "./ckpts/embedder_model.tar", help = 'embedder model path')
68
+ parser.add_argument("--restore-model-path", type=str, default = "./ckpts/onerestore_cdd-11.tar", help = 'restore model path')
69
+
70
+ # select model automatic (prompt=False) or manual (prompt=True, text={'clear', 'low', 'haze', 'rain', 'snow',\
71
+ # 'low_haze', 'low_rain', 'low_snow', 'haze_rain', 'haze_snow', 'low_haze_rain', 'low_haze_snow'})
72
+ parser.add_argument("--prompt", type=str, default = None, help = 'prompt')
73
+
74
+ parser.add_argument("--input", type=str, default = "./image/", help = 'image path')
75
+ parser.add_argument("--output", type=str, default = "./output/", help = 'output path')
76
+ parser.add_argument("--concat", action='store_true', help = 'output path')
77
+
78
+ argspar = parser.parse_args()
79
+
80
+ print_args(argspar)
81
+
82
+ main(argspar)
train_Embedder.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, torch, time
2
+ import torch.optim
3
+
4
+ from utils.utils import load_embedder_ckpt_with_optim, adjust_learning_rate, freeze_text_embedder, AverageMeter
5
+ from utils.utils_data import init_embedding_data
6
+
7
+
8
+
9
+ def train_embedding(cur_epoch, model, optimizer, trainloader, testloader, device, cfg_em):
10
+ torch.backends.cudnn.benchmark = False
11
+ torch.backends.cudnn.enabled = True
12
+
13
+ acc_train_meter = AverageMeter()
14
+ acc_test_meter = AverageMeter()
15
+ loss_train_meter = AverageMeter()
16
+ loss_test_meter = AverageMeter()
17
+ time_train_meter = AverageMeter()
18
+ time_test_meter = AverageMeter()
19
+
20
+ freeze_text_embedder(model)
21
+ for k,v in model.named_parameters():
22
+ print('{}: {}'.format(k, v.requires_grad))
23
+ for epoch in range(cur_epoch, cfg_em.epoch+1):
24
+
25
+ optimizer = adjust_learning_rate(optimizer, epoch-1, cfg_em.lr_decay)
26
+ lr = optimizer.param_groups[-1]['lr']
27
+
28
+ model.train()
29
+ for idx, batch in enumerate(trainloader):
30
+ for i in range(len(batch)):
31
+ batch[i] = batch[i].to("cuda" if torch.cuda.is_available() else "cpu")
32
+ time_start = time.time()
33
+ out = model(batch, 'train')
34
+ loss = out['loss_total']
35
+ acc = out['acc_type']
36
+ time_train_meter.update(time.time() - time_start)
37
+
38
+ acc_train_meter.update(acc)
39
+ loss_train_meter.update(loss)
40
+
41
+ optimizer.zero_grad()
42
+ loss.backward()
43
+ optimizer.step()
44
+
45
+ print(f'Epoch:{epoch}|Iter:{idx+1}/{len(trainloader)}|lr:{lr},'
46
+ f'Loss: {loss_train_meter.avg:.3f},'
47
+ f'Acc: {acc_train_meter.avg:.3f},'
48
+ f'Time: {time_train_meter.avg:.3f},', flush=True)
49
+
50
+ model.eval()
51
+ for idx, batch in enumerate(testloader):
52
+ for i in range(len(batch)):
53
+ batch[i] = batch[i].to("cuda" if torch.cuda.is_available() else "cpu")
54
+
55
+ time_start = time.time()
56
+ out = model(batch, 'train')
57
+ loss = out['loss_total']
58
+ acc = out['acc_type']
59
+ time_test_meter.update(time.time() - time_start)
60
+
61
+ acc_test_meter.update(acc)
62
+ loss_test_meter.update(loss)
63
+ print(f'Epoch:{epoch}|Iter:{idx+1}/{len(testloader)}|lr:{lr},'
64
+ f'Loss: {loss_test_meter.avg:.3f},'
65
+ f'Acc: {acc_test_meter.avg:.3f},'
66
+ f'Time: {time_test_meter.avg:.3f},', flush=True)
67
+
68
+ torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict()},
69
+ f'{cfg_em.check_dir}/embedder_model_epoch{epoch}_{acc_train_meter.avg:.3f}_{loss_train_meter.avg:.3f}_{acc_test_meter.avg:.3f}_{loss_test_meter.avg:.3f}.tar')
70
+ acc_train_meter.reset()
71
+ acc_test_meter.reset()
72
+ loss_train_meter.reset()
73
+ loss_test_meter.reset()
74
+ time_train_meter.reset()
75
+ time_test_meter.reset()
76
+ print('Done!')
77
+
78
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
79
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
80
+ if __name__ == "__main__":
81
+ parser = argparse.ArgumentParser()
82
+ # load model
83
+ parser.add_argument("--seed", type=int, default = 124)
84
+ parser.add_argument("--pre_weight", type=str, default = '')
85
+ parser.add_argument("--lr", type=float, default = 0.0001)
86
+ parser.add_argument("--type_name", type=list, default = ['clear', 'low', 'haze', 'rain',\
87
+ 'snow', 'low_haze', 'low_rain', 'low_snow', 'haze_rain',\
88
+ 'haze_snow', 'low_haze_rain', 'low_haze_snow'])
89
+ parser.add_argument("--train-dir", type=str, default = './data/CDD-11_train/')
90
+ parser.add_argument("--test-dir", type=str, default = './data/CDD-11_test/')
91
+ parser.add_argument("--batch", type=int, default = 128)
92
+ parser.add_argument("--num-workers", type=int, default = 0)
93
+ parser.add_argument("--epoch", type=int, default = 200)
94
+ parser.add_argument("--lr-decay", type=int, default = 50)
95
+ parser.add_argument("--check-dir", type=str, default = "./ckpts")
96
+
97
+ args = parser.parse_args()
98
+
99
+ os.makedirs(args.check_dir,exist_ok=True)
100
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
101
+
102
+ embedder, optimizer, cur_epoch, device = load_embedder_ckpt_with_optim(device, args)
103
+ trainloader, testloader = init_embedding_data(args, 'train')
104
+ train_embedding(cur_epoch, embedder, optimizer, trainloader, testloader, device, args)
train_OneRestore_multi-gpu.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, time, torch, argparse
2
+ import torch.nn.functional as F
3
+ from torch.utils.data import DataLoader
4
+ from torchvision.utils import save_image as imwrite
5
+ import numpy as np
6
+ from torchvision import transforms
7
+ from makedataset import Dataset
8
+ from utils.utils import print_args, load_restore_ckpt_with_optim, load_embedder_ckpt, adjust_learning_rate, data_process, tensor_metric, load_excel, save_checkpoint
9
+ from model.loss import Total_loss
10
+ from model.Embedder import Embedder
11
+ from model.OneRestore import OneRestore
12
+ from torch.utils.data.distributed import DistributedSampler
13
+ from PIL import Image
14
+
15
+ torch.distributed.init_process_group(backend="nccl")
16
+ local_rank = torch.distributed.get_rank()
17
+ torch.cuda.set_device(local_rank)
18
+ device = torch.device("cuda", local_rank)
19
+
20
+
21
+ transform_resize = transforms.Compose([
22
+ transforms.Resize([224,224]),
23
+ transforms.ToTensor()
24
+ ])
25
+
26
+ def main(args):
27
+
28
+
29
+ print('> Model Initialization...')
30
+ embedder = load_embedder_ckpt(device, freeze_model=True, ckpt_name=args.embedder_model_path)
31
+ restorer, optimizer, cur_epoch = load_restore_ckpt_with_optim(device, local_rank=local_rank, freeze_model=False, ckpt_name=args.restore_model_path, lr=args.lr)
32
+ loss = Total_loss(args)
33
+
34
+ print('> Loading dataset...')
35
+ data = Dataset(args.train_input)
36
+ dataset = DataLoader(dataset=data, batch_size=args.bs,
37
+ shuffle=False,
38
+ num_workers=args.num_works,
39
+ pin_memory=True,drop_last=False,
40
+ sampler=DistributedSampler(data,shuffle=True))
41
+
42
+ print('> Start training...')
43
+ start_all = time.time()
44
+ train(restorer, embedder, optimizer, loss, cur_epoch, args, dataset, device)
45
+ end_all = time.time()
46
+ print('Whloe Training Time:' +str(end_all-start_all)+'s.')
47
+
48
+ def train(restorer, embedder, optimizer, loss, cur_epoch, args, dataset, device):
49
+
50
+ metric = []
51
+ for epoch in range(cur_epoch, args.epoch):
52
+ optimizer = adjust_learning_rate(optimizer, epoch, args.adjust_lr)
53
+ learnrate = optimizer.param_groups[-1]['lr']
54
+ restorer.train()
55
+
56
+ for i, data in enumerate(dataset,0):
57
+ pos, inp, neg = data_process(data, args, device)
58
+
59
+ text_embedding,_,_ = embedder(inp[1],'text_encoder')
60
+ out = restorer(inp[0], text_embedding)
61
+
62
+ restorer.zero_grad()
63
+ total_loss = loss(inp, pos, neg, out)
64
+ total_loss.backward()
65
+ optimizer.step()
66
+
67
+ mse = tensor_metric(pos,out, 'MSE', data_range=1)
68
+ psnr = tensor_metric(pos,out, 'PSNR', data_range=1)
69
+ ssim = tensor_metric(pos,out, 'SSIM', data_range=1)
70
+
71
+ print("[epoch %d][%d/%d] lr :%f Floss: %.4f MSE: %.4f PSNR: %.4f SSIM: %.4f"%(epoch+1, i+1, \
72
+ len(dataset), learnrate, total_loss.item(), mse, psnr, ssim))
73
+
74
+
75
+ psnr_t1, ssim_t1, psnr_t2, ssim_t2 = test(args, restorer, embedder, device, epoch)
76
+ metric.append([psnr_t1, ssim_t1, psnr_t2, ssim_t2])
77
+ print("[epoch %d] Test images PSNR1: %.4f SSIM1: %.4f"%(epoch+1, psnr_t1,ssim_t1))
78
+
79
+ load_excel(metric)
80
+ save_checkpoint({'epoch': epoch + 1,'state_dict': restorer.state_dict(),'optimizer' : optimizer.state_dict()},\
81
+ args.save_model_path, epoch+1, psnr_t1,ssim_t1,psnr_t2,ssim_t2)
82
+
83
+ def test(args, restorer, embedder, device, epoch=-1):
84
+ combine_type = args.degr_type
85
+ psnr_1, psnr_2, ssim_1, ssim_2 = 0, 0, 0, 0
86
+ os.makedirs(args.output,exist_ok=True)
87
+
88
+ for i in range(len(combine_type)-1):
89
+ file_list = os.listdir(f'{args.test_input}/{combine_type[i+1]}/')
90
+ for j in range(len(file_list)):
91
+ hq = Image.open(f'{args.test_input}/{combine_type[0]}/{file_list[j]}')
92
+ lq = Image.open(f'{args.test_input}/{combine_type[i+1]}/{file_list[j]}')
93
+ restorer.eval()
94
+ with torch.no_grad():
95
+ lq_re = torch.Tensor((np.array(lq)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
96
+ lq_em = transform_resize(lq).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
97
+ hq = torch.Tensor((np.array(hq)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
98
+
99
+ starttime = time.time()
100
+
101
+ text_embedding_1,_,text_1 = embedder([combine_type[i+1]],'text_encoder')
102
+ text_embedding_2,_, text_2 = embedder(lq_em,'image_encoder')
103
+ out_1 = restorer(lq_re, text_embedding_1)
104
+ if text_1 != text_2:
105
+ print(text_1, text_2)
106
+ out_2 = restorer(lq_re, text_embedding_2)
107
+ else:
108
+ out_2 = out_1
109
+
110
+ endtime1 = time.time()
111
+
112
+ imwrite(torch.cat((lq_re, out_1, out_2, hq), dim=3), args.output \
113
+ + file_list[j][:-4] + '_' + str(epoch) + '_' + combine_type[i+1] + '.png', range=(0, 1))
114
+ # due to the vision problem, you can replace above line by
115
+ # imwrite(torch.cat((lq_re, out_1, out_2, hq), dim=3), args.output \
116
+ # + file_list[j][:-4] + '_' + str(epoch) + '_' + combine_type[i+1] + '.png')
117
+ psnr_1 += tensor_metric(hq, out_1, 'PSNR', data_range=1)
118
+ ssim_1 += tensor_metric(hq, out_1, 'SSIM', data_range=1)
119
+ psnr_2 += tensor_metric(hq, out_2, 'PSNR', data_range=1)
120
+ ssim_2 += tensor_metric(hq, out_2, 'SSIM', data_range=1)
121
+ print('The ' + file_list[j][:-4] + ' Time:' + str(endtime1 - starttime) + 's.')
122
+
123
+ return psnr_1 / (len(file_list)*len(combine_type)), ssim_1 / (len(file_list)*len(combine_type)),\
124
+ psnr_2 / (len(file_list)*len(combine_type)), ssim_2 / (len(file_list)*len(combine_type))
125
+
126
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
127
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
128
+ if __name__ == '__main__':
129
+
130
+ parser = argparse.ArgumentParser(description = "OneRestore Training")
131
+
132
+ # load model
133
+ parser.add_argument("--embedder-model-path", type=str, default = "./ckpts/embedder_model.tar", help = 'embedder model path')
134
+ parser.add_argument("--restore-model-path", type=str, default = None, help = 'restore model path')
135
+ parser.add_argument("--save-model-path", type=str, default = "./ckpts/", help = 'restore model path')
136
+
137
+ parser.add_argument("--epoch", type=int, default = 300, help = 'epoch number')
138
+ parser.add_argument("--bs", type=int, default = 4, help = 'batchsize')
139
+ parser.add_argument("--lr", type=float, default = 1e-4, help = 'learning rate')
140
+ parser.add_argument("--adjust-lr", type=int, default = 30, help = 'adjust learning rate')
141
+ parser.add_argument("--num-works", type=int, default = 4, help = 'number works')
142
+ parser.add_argument("--loss-weight", type=tuple, default = (0.6,0.3,0.1), help = 'loss weights')
143
+ parser.add_argument("--degr-type", type=list, default = ['clear', 'low', 'haze', 'rain', 'snow',\
144
+ 'low_haze', 'low_rain', 'low_snow', 'haze_rain', 'haze_snow', 'low_haze_rain', 'low_haze_snow'], help = 'degradation type')
145
+
146
+ parser.add_argument("--train-input", type=str, default = "./dataset.h5", help = 'train data')
147
+ parser.add_argument("--test-input", type=str, default = "./data/CDD-11_test", help = 'test path')
148
+ parser.add_argument("--output", type=str, default = "./result/", help = 'output path')
149
+
150
+ argspar = parser.parse_args()
151
+
152
+ print_args(argspar)
153
+ main(argspar)
train_OneRestore_single-gpu.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, time, torch, argparse
2
+ import torch.nn.functional as F
3
+ from torch.utils.data import DataLoader
4
+ from torchvision.utils import save_image as imwrite
5
+ import numpy as np
6
+ from torchvision import transforms
7
+ from makedataset import Dataset
8
+ from utils.utils import print_args, load_restore_ckpt_with_optim, load_embedder_ckpt, adjust_learning_rate, data_process, tensor_metric, load_excel, save_checkpoint
9
+ from model.loss import Total_loss
10
+
11
+ from PIL import Image
12
+
13
+ transform_resize = transforms.Compose([
14
+ transforms.Resize([224,224]),
15
+ transforms.ToTensor()
16
+ ])
17
+
18
+ def main(args):
19
+
20
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
+
22
+ print('> Model Initialization...')
23
+
24
+ embedder = load_embedder_ckpt(device, freeze_model=True, ckpt_name=args.embedder_model_path)
25
+ restorer, optimizer, cur_epoch = load_restore_ckpt_with_optim(device, freeze_model=False, ckpt_name=args.restore_model_path, lr=args.lr)
26
+ loss = Total_loss(args)
27
+
28
+ print('> Loading dataset...')
29
+ data = Dataset(args.train_input)
30
+ dataset = DataLoader(dataset=data, num_workers=args.num_works, batch_size=args.bs, shuffle=True)
31
+
32
+ print('> Start training...')
33
+ start_all = time.time()
34
+ train(restorer, embedder, optimizer, loss, cur_epoch, args, dataset, device)
35
+ end_all = time.time()
36
+ print('Whloe Training Time:' +str(end_all-start_all)+'s.')
37
+
38
+ def train(restorer, embedder, optimizer, loss, cur_epoch, args, dataset, device):
39
+
40
+ metric = []
41
+ for epoch in range(cur_epoch, args.epoch):
42
+ optimizer = adjust_learning_rate(optimizer, epoch, args.adjust_lr)
43
+ learnrate = optimizer.param_groups[-1]['lr']
44
+ restorer.train()
45
+
46
+ for i, data in enumerate(dataset,0):
47
+ pos, inp, neg = data_process(data, args, device)
48
+
49
+ text_embedding,_,_ = embedder(inp[1],'text_encoder')
50
+ out = restorer(inp[0], text_embedding)
51
+
52
+ restorer.zero_grad()
53
+ total_loss = loss(inp, pos, neg, out)
54
+ total_loss.backward()
55
+ optimizer.step()
56
+
57
+ mse = tensor_metric(pos,out, 'MSE', data_range=1)
58
+ psnr = tensor_metric(pos,out, 'PSNR', data_range=1)
59
+ ssim = tensor_metric(pos,out, 'SSIM', data_range=1)
60
+
61
+ print("[epoch %d][%d/%d] lr :%f Floss: %.4f MSE: %.4f PSNR: %.4f SSIM: %.4f"%(epoch+1, i+1, \
62
+ len(dataset), learnrate, total_loss.item(), mse, psnr, ssim))
63
+
64
+
65
+ psnr_t1, ssim_t1, psnr_t2, ssim_t2 = test(args, restorer, embedder, device, epoch)
66
+ metric.append([psnr_t1, ssim_t1, psnr_t2, ssim_t2])
67
+ print("[epoch %d] Test images PSNR1: %.4f SSIM1: %.4f"%(epoch+1, psnr_t1,ssim_t1))
68
+
69
+ load_excel(metric)
70
+ save_checkpoint({'epoch': epoch + 1,'state_dict': restorer.state_dict(),'optimizer' : optimizer.state_dict()},\
71
+ args.save_model_path, epoch+1, psnr_t1,ssim_t1,psnr_t2,ssim_t2)
72
+
73
+ def test(args, restorer, embedder, device, epoch=-1):
74
+ combine_type = args.degr_type
75
+ psnr_1, psnr_2, ssim_1, ssim_2 = 0, 0, 0, 0
76
+ os.makedirs(args.output,exist_ok=True)
77
+
78
+ for i in range(len(combine_type)-1):
79
+ file_list = os.listdir(f'{args.test_input}/{combine_type[i+1]}/')
80
+ for j in range(len(file_list)):
81
+ hq = Image.open(f'{args.test_input}/{combine_type[0]}/{file_list[j]}')
82
+ lq = Image.open(f'{args.test_input}/{combine_type[i+1]}/{file_list[j]}')
83
+ restorer.eval()
84
+ with torch.no_grad():
85
+ lq_re = torch.Tensor((np.array(lq)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
86
+ lq_em = transform_resize(lq).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
87
+ hq = torch.Tensor((np.array(hq)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
88
+
89
+ starttime = time.time()
90
+
91
+ text_embedding_1,_,text_1 = embedder([combine_type[i+1]],'text_encoder')
92
+ text_embedding_2,_, text_2 = embedder(lq_em,'image_encoder')
93
+ out_1 = restorer(lq_re, text_embedding_1)
94
+ if text_1 != text_2:
95
+ print(text_1, text_2)
96
+ out_2 = restorer(lq_re, text_embedding_2)
97
+ else:
98
+ out_2 = out_1
99
+
100
+ endtime1 = time.time()
101
+
102
+ imwrite(torch.cat((lq_re, out_1, out_2, hq), dim=3), args.output \
103
+ + file_list[j][:-4] + '_' + str(epoch) + '_' + combine_type[i+1] + '.png', range=(0, 1))
104
+ psnr_1 += tensor_metric(hq, out_1, 'PSNR', data_range=1)
105
+ ssim_1 += tensor_metric(hq, out_1, 'SSIM', data_range=1)
106
+ psnr_2 += tensor_metric(hq, out_2, 'PSNR', data_range=1)
107
+ ssim_2 += tensor_metric(hq, out_2, 'SSIM', data_range=1)
108
+ print('The ' + file_list[j][:-4] + ' Time:' + str(endtime1 - starttime) + 's.')
109
+
110
+ return psnr_1 / (len(file_list)*len(combine_type)), ssim_1 / (len(file_list)*len(combine_type)),\
111
+ psnr_2 / (len(file_list)*len(combine_type)), ssim_2 / (len(file_list)*len(combine_type))
112
+
113
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
114
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
115
+ if __name__ == '__main__':
116
+
117
+ parser = argparse.ArgumentParser(description = "OneRestore Training")
118
+
119
+ # load model
120
+ parser.add_argument("--embedder-model-path", type=str, default = "./ckpts/embedder_model.tar", help = 'embedder model path')
121
+ parser.add_argument("--restore-model-path", type=str, default = None, help = 'restore model path')
122
+ parser.add_argument("--save-model-path", type=str, default = "./ckpts/", help = 'restore model path')
123
+
124
+ parser.add_argument("--epoch", type=int, default = 300, help = 'epoch number')
125
+ parser.add_argument("--bs", type=int, default = 4, help = 'batchsize')
126
+ parser.add_argument("--lr", type=float, default = 1e-4, help = 'learning rate')
127
+ parser.add_argument("--adjust-lr", type=int, default = 30, help = 'adjust learning rate')
128
+ parser.add_argument("--num-works", type=int, default = 4, help = 'number works')
129
+ parser.add_argument("--loss-weight", type=tuple, default = (0.6,0.3,0.1), help = 'loss weights')
130
+ parser.add_argument("--degr-type", type=list, default = ['clear', 'low', 'haze', 'rain', 'snow',\
131
+ 'low_haze', 'low_rain', 'low_snow', 'haze_rain', 'haze_snow', 'low_haze_rain', 'low_haze_snow'], help = 'degradation type')
132
+
133
+ parser.add_argument("--train-input", type=str, default = "./dataset.h5", help = 'train data')
134
+ parser.add_argument("--test-input", type=str, default = "./data/CDD-11_test", help = 'test path')
135
+ parser.add_argument("--output", type=str, default = "./result/", help = 'output path')
136
+
137
+ argspar = parser.parse_args()
138
+
139
+ print_args(argspar)
140
+ main(argspar)
utils/glove.6B.300d.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ clear -0.081023 -0.29179 0.052021 -0.13324 0.028162 -0.0031446 -0.17156 0.063324 0.16568 -2.1722 -0.14127 0.087891 -0.2298 0.069017 0.21673 0.36556 -0.39979 -0.15506 0.099728 0.202 0.16989 0.14807 0.10938 -0.17141 -0.7258 -0.13189 -0.052768 -0.26383 -0.13189 -0.11408 0.081757 0.14773 -0.24342 0.0076364 -1.0992 0.13661 0.19262 -0.30012 0.031524 0.11439 -0.10854 0.21089 -0.037365 0.23449 0.054638 0.21505 0.023071 0.20918 -0.08606 -0.078589 -0.26945 -0.040802 -0.042601 -0.12093 -0.33614 0.25624 -0.35266 -0.17224 0.31018 0.6426 -0.036072 0.1558 0.26609 0.17298 -0.08158 0.0085636 0.13196 -0.11876 -0.19205 -0.32204 -0.092694 -0.19274 0.0056832 0.17194 0.24011 0.014739 0.091188 0.45903 0.0047753 -0.18136 -0.16434 0.012617 0.42791 0.075318 -0.042848 -0.055952 -0.071895 0.086806 0.078092 0.20169 -0.34189 -0.01975 -0.44579 -0.093254 0.23684 0.098079 -0.0018186 -0.13013 0.054252 -0.68408 0.21378 -0.084742 -0.12383 0.36645 -0.46434 0.56799 0.22341 0.31607 -0.23559 0.033889 0.062509 -0.31468 0.27684 -0.13729 -0.027181 0.17143 -0.35535 0.14426 0.14137 -0.27987 0.051007 0.1689 0.48614 0.43247 -0.31014 -0.2273 -0.17253 0.50221 -0.29023 -0.16833 -0.027586 0.25614 0.096051 0.19145 -0.15576 0.50767 0.0064827 -0.047304 0.47358 -0.029665 -0.095882 0.064574 0.1247 -0.3439 -0.59591 -0.17307 0.30627 0.16351 -0.21709 -0.13142 -0.029781 0.079412 0.36018 -0.068721 0.367 0.26454 0.1306 -0.34602 0.22326 0.22999 0.14122 -0.3084 0.22239 -0.13701 0.24538 0.10902 0.33084 0.052159 -0.54817 0.32921 0.33889 -0.060382 -0.16611 -0.26388 0.13997 -0.15486 -0.05013 -0.089628 -0.0080954 0.13155 -0.019735 0.25758 0.37509 -0.012096 -0.49247 0.13436 -0.21072 -0.13763 0.24047 0.13328 -0.043418 0.0070651 0.30496 -0.11184 0.68017 -0.65417 -0.39198 0.075546 -0.2043 0.041099 0.84586 -0.3361 -0.26385 -0.39417 -0.25468 -0.095349 0.19947 -0.30772 -0.53846 0.18258 -0.091379 -0.27183 0.10918 -0.042102 -0.25614 -0.039694 0.34987 -0.24526 -0.011983 -0.024231 0.62785 -0.16641 0.026109 0.029096 -0.16937 0.25329 -0.12066 0.023087 0.16152 -0.14058 0.044846 0.4533 0.34099 -0.028432 -0.39407 -0.068924 -0.29128 -0.012954 0.048176 -0.090455 -0.0098771 -0.022352 0.091535 -0.084673 -0.43955 -0.25237 0.79719 0.21526 0.0019634 -0.10022 -0.075669 -0.25113 -0.12675 0.12179 0.25892 0.026661 -0.38419 -0.18566 -0.15325 0.44484 -0.088815 0.10119 0.0060884 0.293 -0.415 0.26712 0.033683 -0.42317 0.22025 -0.027351 0.40923 -0.013339 -0.29543 0.37699 -0.019656 -0.082896 -1.5198 0.2961 0.81263 -0.18199 0.59082 0.007938 0.2309 0.23573 0.24941 -0.18754 -0.04029 0.17258 0.1948 0.131 -0.21552 0.016352 0.62256 0.41283 0.40387 -0.062911 -0.093159 -0.078137 -0.30083 -0.035913
2
+ low -0.21751 0.43389 0.149 0.14107 0.2574 -0.12448 0.0047523 0.035596 0.10741 -2.1047 0.17181 -0.15079 -0.044546 -0.090869 -0.43288 -0.13611 -0.0058198 -0.064724 0.23531 -0.36224 -0.21305 -0.075476 0.46786 -0.18465 -0.19746 -0.097471 0.39984 -0.084092 -0.53715 0.27303 -0.087786 0.24297 -0.38444 0.28854 -0.7873 0.089192 -0.26376 -0.16287 0.35911 0.30458 0.24502 0.22553 -0.0031653 0.47358 0.31146 -0.13823 0.075685 -0.10776 0.38329 -0.13762 0.51707 -0.16707 -0.037466 -0.7236 -0.4151 -0.42359 0.14354 0.046639 0.17527 0.48721 0.26708 -0.031042 0.86002 -0.3946 -0.50514 -0.51294 0.58527 0.18819 -0.29543 0.68596 -0.1035 0.22565 0.185 0.058375 0.030999 0.11929 0.12353 0.12873 0.42126 0.14188 -0.050079 -0.2683 0.12126 0.32302 0.27623 0.5414 0.074715 -0.1949 -0.47053 0.02313 0.68686 0.60158 -0.16194 -0.3651 0.41796 -0.22905 0.074734 0.17509 -0.44255 0.3518 -0.40079 -0.28305 0.39133 0.32303 -0.63198 -0.1507 -0.16894 0.17169 0.18894 0.027644 -0.36997 -0.26366 0.36344 -0.049584 0.32724 0.049712 0.051381 -0.058867 -0.2621 -0.50359 -0.21435 -0.25527 0.22161 0.66558 0.2224 0.27607 0.58587 -0.3071 0.24905 0.098802 -0.26459 0.77839 0.014585 0.86936 0.2329 -0.0027986 -0.087016 0.10863 0.18987 0.54552 0.24903 0.059293 0.30362 -0.028582 -0.6569 0.1206 -0.055416 -0.093077 -0.0012132 -0.15009 0.11192 -0.62139 -0.035773 0.1165 0.36541 0.55984 -0.19964 -0.065579 0.097118 -0.1672 0.13677 -0.95276 -0.25994 0.064799 -0.042161 0.12046 0.12391 0.0017478 0.29533 0.40176 0.057528 0.57864 -0.9973 0.13805 -0.30689 0.11015 -0.35402 -0.13434 -0.24479 0.50355 -0.18675 -0.22337 0.29573 0.21612 -0.068496 -0.60643 0.79013 -0.26975 -0.15492 0.70849 0.21372 0.62962 -0.0056421 0.53597 -0.54259 -0.34726 -0.29945 -0.51895 0.28471 -0.14973 0.54188 0.53535 -0.11233 0.19291 -0.24707 0.058424 -0.5473 -0.06426 0.47187 0.11149 0.28313 -0.23876 -0.10552 -0.051705 -0.28853 -0.13702 0.040562 -0.032269 0.10368 -0.29381 0.33416 0.038269 0.029697 -0.48604 -0.26334 0.28942 -0.0093944 0.13942 -0.29043 0.27332 0.16614 -0.028973 -0.32829 -0.034614 -0.0012628 0.062871 -0.000894 0.22467 0.16005 0.23141 -0.19918 0.16465 0.15247 0.29742 -1.0225 0.056188 0.91529 -0.47809 -0.24204 -0.3158 0.21033 -0.13616 0.10777 -0.26815 -0.44804 -0.12696 -0.43468 0.17849 -0.48101 0.026114 0.057368 0.26052 -0.030488 0.051275 -0.36344 0.11878 0.2279 -0.086855 -0.01455 0.070256 -0.16753 0.61449 -0.27428 -0.17901 -0.36261 0.093134 -1.5724 0.47192 -0.52493 -0.27512 -0.37945 0.29588 0.020506 0.08707 0.057053 0.37167 -0.056446 -0.38735 -0.31246 0.028304 -0.058202 0.067263 -0.58761 0.074556 0.49917 0.45134 -0.51433 -0.60996 0.076835 -0.078086
3
+ haze -0.0061289 -0.2702 0.16559 -0.29621 -0.66216 -0.1756 0.46686 1.0362 -0.20692 -0.36097 0.98615 0.32297 -0.55094 -0.36163 -0.27046 0.052225 -0.10079 0.22536 -0.095491 0.17188 0.058372 0.083556 -0.28255 0.12623 -0.0094164 -0.028727 -0.20589 -0.3932 -0.2935 -0.36104 1.0595 0.14423 -0.311 -0.20573 0.11827 -0.0048368 -0.8324 -0.10389 0.34491 0.34006 0.10354 0.11593 0.47379 -0.1042 0.38523 -0.57589 0.027253 -0.44913 -0.52822 -0.44094 0.71219 -0.12278 0.034288 -0.6935 -0.57852 0.33917 0.35018 -0.30193 0.55504 0.085603 -0.21189 -0.51958 -0.17589 -0.13369 0.2976 -0.26048 0.068146 0.62144 0.3416 -0.54399 -0.23937 -0.34802 -0.31469 -0.59554 -0.25011 -0.11644 0.19993 -0.1636 0.24289 -0.0022965 0.3064 -0.26188 0.27166 0.1962 0.37527 -0.22408 0.52979 0.59141 0.035196 0.10632 -0.28318 0.18766 -0.12253 0.41932 -0.64713 0.26068 0.67209 -0.23333 0.030945 -0.15135 0.61662 -0.0025061 -0.58374 0.51866 -0.89244 1.0056 0.15919 0.29183 -0.059984 0.10701 -0.32101 -1.0921 -0.050394 -0.074584 0.56258 -0.5915 0.048547 0.085668 -0.39964 -0.40997 0.093632 -0.22538 -0.83102 -0.051418 -0.31192 0.36056 -0.028854 -0.046907 0.09394 0.012504 0.34555 0.56564 0.48111 0.092143 0.82492 -0.20086 -0.27718 0.9004 0.38921 0.028667 0.78904 0.44698 -0.26892 0.073712 -0.73296 -0.46286 0.53386 0.53514 0.04207 -0.11448 0.27771 0.080703 -0.017482 0.43225 0.047742 -0.095399 -0.063173 -0.36341 0.2948 0.15311 -0.55934 -0.88294 0.62005 -0.23936 0.51953 -0.49463 0.41669 0.61169 -0.20471 -0.0056962 -0.29331 0.46269 0.084808 -0.049355 -0.64697 -0.85777 0.34718 -0.16176 0.14756 -0.65658 -0.54259 -0.13124 -0.88851 0.070637 -0.84926 -0.69345 0.4024 -0.5683 -0.68142 -0.1402 -0.36857 0.36013 -0.49769 -0.17478 0.77214 -0.23962 0.32951 1.0984 -0.00011441 0.9649 -0.13312 0.64326 -0.037091 0.35672 0.025156 0.046782 0.19764 -0.22757 -0.39887 -0.3045 -0.45283 -0.0045182 0.032546 -0.076483 0.72189 -0.038917 1.0621 -0.55688 0.56429 0.11264 0.40465 -0.53146 0.16851 0.69236 -0.24456 0.038704 0.69151 0.16591 -0.43451 0.14115 0.84069 0.29081 -0.31053 -0.6849 -0.27188 -0.32813 0.57882 0.13779 0.36621 -0.45935 0.27899 -0.32315 -0.5743 0.19837 0.0046648 0.18459 0.43369 0.22359 0.16652 -0.081114 -0.54539 -1.0103 -0.14539 0.12021 0.078636 -0.26667 -0.65403 0.4096 0.07257 0.036639 0.21757 0.25738 0.51675 -0.031326 -0.3869 0.012763 -0.45692 0.13828 -0.48614 -0.53757 0.50268 0.47865 -0.049528 -0.032281 -0.4486 0.036258 -0.12295 -0.46811 -0.019014 0.035839 -0.55749 0.018281 -0.88963 -0.024676 -0.19482 -0.19364 0.0069875 0.12679 -0.37379 -0.34094 -0.051568 0.55404 -0.29656 0.26045 0.50872 -0.37399 0.20334 0.70298 -0.3271 -0.24116
4
+ rain -0.52618 -0.54041 -0.89537 -0.35598 -0.74356 -0.66838 0.26326 0.89254 0.14362 -0.34904 0.25866 -0.11143 -0.52035 0.1436 -0.075728 -0.84569 -0.28762 0.049872 0.39234 0.52551 -0.39244 -0.2822 -0.097458 -0.12929 -0.38623 0.17261 0.7574 -0.29868 -0.691 -0.36639 0.63951 0.25255 -0.22299 0.16387 -0.83199 -0.30276 -0.32411 -0.36789 -0.073673 0.54726 0.14785 0.26259 0.086208 -0.033827 0.044403 -0.2135 0.3761 0.33816 -0.36696 -0.2096 0.025934 0.47679 0.23046 -0.44333 -0.65379 0.85762 0.62861 -0.70343 1.1284 0.2497 -0.34459 0.17005 0.27826 0.01167 -0.44087 -0.12649 0.31811 0.073688 -0.17127 -0.023486 0.34294 0.18888 -0.15694 -0.37975 -0.58313 -0.45624 -0.5968 0.09743 -0.50593 -0.64092 0.083647 0.38474 -0.15071 0.55042 -0.68742 0.14893 -0.039046 -0.19582 0.61498 -0.066786 0.63395 -0.4659 0.44123 -0.55136 -0.17711 0.97118 0.26321 -0.035901 -0.11096 -0.11161 0.353 1.026 -0.2605 -0.12231 0.31695 0.35807 0.2526 0.21803 -0.47766 -0.13033 -0.36929 -0.88388 -0.1249 0.27972 0.017521 0.19048 0.38647 -0.10236 0.26691 -0.66637 -0.66046 -0.48598 -0.5029 0.59602 -0.23975 -0.054244 0.71177 0.097479 0.18964 0.60496 -0.2421 1.261 0.5195 0.12978 0.28374 0.1499 -0.073072 -0.064345 0.041775 0.20712 -0.13972 0.021692 -0.45101 -0.077633 -0.58888 -0.0062811 0.50587 0.63067 -0.096216 -0.45549 -0.10162 -0.74026 -0.45125 0.16204 0.34589 0.2203 0.73482 -0.72055 0.019937 0.50934 -0.045864 -1.0167 0.4202 0.29336 0.057842 0.19622 0.71137 0.44455 -0.11329 -0.23249 0.3283 0.6458 -0.032498 0.58903 0.067438 -0.21519 0.24967 -0.047893 -0.12095 0.20468 -0.010392 -0.10827 0.5248 -0.013868 -0.40703 -0.2761 0.61498 -0.12118 -0.70097 -0.76415 -0.37243 0.3 -0.32852 -0.13877 0.23339 -0.58504 0.54768 -0.090521 0.30928 -0.19777 0.68883 0.043808 -0.012833 0.25696 0.017598 -0.11323 -0.76201 0.42972 -0.22032 -0.43818 -0.57085 0.23867 -0.098037 -0.4015 0.27659 -0.51578 -0.28637 -0.37785 0.83469 0.10563 1.1508 -0.67165 0.095388 -0.070545 0.039198 0.17726 0.44885 -0.045378 0.22337 -0.24957 0.93144 -0.16601 -0.095582 -0.60227 0.20068 -0.10264 -0.62696 0.048702 0.34737 -0.10634 -0.35068 0.11719 -0.79712 -0.32956 -0.60446 -0.0049038 -0.3351 -0.060065 -0.3063 -0.15462 -0.099521 -0.1788 0.098109 -0.59477 0.53245 -0.15388 0.063044 -0.47686 0.26712 -0.064799 0.2029 -0.093498 -0.44456 0.4692 -0.13718 0.035772 -0.74958 -0.51603 0.47025 -0.65103 0.027106 0.31463 -0.51519 -0.09912 -0.30605 0.2127 -1.6502 -0.34658 -0.19282 0.036578 -0.33871 0.21323 0.54172 -0.17543 -0.60187 -0.14679 0.20983 -0.084584 0.070885 -0.21752 -0.12642 0.030381 0.075461 0.86541 0.30098 0.22916 0.049217 -0.21204 0.32909 -0.021816
5
+ snow -0.6961 -0.3339 -0.66542 -0.16459 -0.70283 0.053264 0.57508 1.1246 -0.41143 -0.93335 -0.397 -0.13949 -0.21725 0.49383 -0.16481 -0.43673 -0.39998 -0.14702 0.5828 0.73123 -0.16808 0.050093 0.20341 0.093283 -0.18944 -0.0092796 0.0064213 -0.5586 0.079708 0.034177 0.503 -0.084123 -0.15241 0.042398 -0.95865 0.13482 0.10695 0.22212 0.16383 0.081416 -0.61437 0.60299 0.53843 0.33915 -0.060046 -0.12329 0.30417 0.067838 -0.058329 -0.24791 -0.28177 0.32273 -0.12639 -0.40664 -0.42578 0.71366 0.18676 -0.49576 0.56635 0.39411 -0.11876 0.62798 0.50193 -0.38534 -0.32333 -0.29613 -0.1984 0.082042 -0.63666 -0.25177 0.070225 0.23886 -0.35341 -0.30615 -0.7898 -0.014515 -0.096662 0.27064 0.37095 -0.3916 0.15589 0.40176 -0.12316 -0.0069311 -0.17538 0.29317 -0.035662 -0.062503 -0.11821 -0.26708 0.33433 -0.41039 -0.44941 -0.058539 -0.5973 -0.060833 0.014623 0.031391 0.041093 0.21223 0.54304 0.51444 -0.2447 -0.034937 -0.61583 0.24116 0.93612 0.29663 -0.01733 0.39864 -0.399 -0.69927 0.010899 0.044804 0.096444 0.20555 0.37109 0.13219 0.29942 -0.28494 -0.071103 -0.45338 -0.22126 -0.31673 -0.10643 0.040453 -0.15324 0.33191 0.27801 -0.25143 -0.41784 1.1352 0.18709 0.57932 0.14912 0.42731 -0.81353 0.35546 0.10287 -0.10858 0.13692 0.11451 -0.68607 -0.17115 -0.52708 0.28953 0.5147 0.25549 -0.23139 -0.44275 0.42679 -0.41475 0.041182 -0.2664 0.60967 0.03783 0.27371 -0.5267 0.12029 0.5208 0.59519 -1.1315 0.19505 -0.2528 0.34636 0.82065 0.63271 0.091682 0.38433 -0.81108 0.18232 0.19068 -0.13031 0.21336 0.074454 -0.094498 0.47594 -0.31026 -0.11718 0.092891 0.22067 -0.16721 0.71703 0.30143 -0.40609 -0.16231 0.31315 -0.59325 -0.53404 -0.1087 -0.23026 0.36507 0.30648 -0.75576 -0.20767 -0.46966 -0.21035 0.0091924 0.5057 0.45564 0.84145 -0.19412 0.23964 0.85852 0.05229 -0.0011899 -0.29387 0.044187 -0.23886 0.19207 -0.0079459 -0.25773 0.31145 -0.47615 -0.00056431 -0.8941 -0.38667 -0.37907 0.52821 -0.45513 0.53567 0.13216 0.39741 -0.4904 0.24118 -0.11714 0.27007 0.15184 0.42316 -0.39708 0.13827 -0.27638 0.29908 -0.76008 0.061752 -0.4452 -0.5132 0.12124 0.15792 -0.57067 -0.68793 -0.33873 -0.43291 -0.46817 -0.84667 -0.65852 -0.59116 -0.043406 -0.013031 0.11246 -0.35374 0.3923 0.1172 -0.56268 0.83477 -0.34675 0.054568 -0.48494 0.12108 -0.15504 -0.047008 -0.2665 0.024593 0.70123 0.21284 -0.077796 0.050835 0.3865 0.37534 -0.48749 -0.013739 0.57852 -0.90425 -0.0062806 -0.28674 -0.017749 -1.0189 -0.71371 -0.36557 -0.73412 -0.027371 -0.071396 0.64792 -0.057281 -0.2512 0.039567 0.076976 0.34572 0.34606 -0.38323 -0.074011 -0.14153 -0.03109 0.53137 -0.35708 -0.28263 0.098663 0.17693 -0.39297 0.27708
utils/utils.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import os
4
+ from torch.autograd import Variable
5
+ from skimage.metrics import peak_signal_noise_ratio as compare_psnr
6
+ from skimage.metrics import mean_squared_error as compare_mse
7
+ from skimage.metrics import structural_similarity as compare_ssim
8
+ import pandas as pd
9
+
10
+ from model.OneRestore import OneRestore
11
+ from model.Embedder import Embedder
12
+
13
+ def load_embedder_ckpt(device, freeze_model=False, ckpt_name=None,
14
+ combine_type = ['clear', 'low', 'haze', 'rain', 'snow',\
15
+ 'low_haze', 'low_rain', 'low_snow', 'haze_rain',\
16
+ 'haze_snow', 'low_haze_rain', 'low_haze_snow']):
17
+ if ckpt_name != None:
18
+ if torch.cuda.is_available():
19
+ model_info = torch.load(ckpt_name)
20
+ else:
21
+ model_info = torch.load(ckpt_name, map_location=torch.device('cpu'))
22
+
23
+ print('==> loading existing Embedder model:', ckpt_name)
24
+ model = Embedder(combine_type)
25
+ model.load_state_dict(model_info)
26
+ model.to("cuda" if torch.cuda.is_available() else "cpu")
27
+
28
+ else:
29
+ print('==> Initialize Embedder model.')
30
+ model = Embedder(combine_type)
31
+ model.to("cuda" if torch.cuda.is_available() else "cpu")
32
+
33
+ if freeze_model:
34
+ freeze(model)
35
+
36
+ return model
37
+
38
+ def load_restore_ckpt(device, freeze_model=False, ckpt_name=None):
39
+ if ckpt_name != None:
40
+ if torch.cuda.is_available():
41
+ model_info = torch.load(ckpt_name)
42
+ else:
43
+ model_info = torch.load(ckpt_name, map_location=torch.device('cpu'))
44
+ print('==> loading existing OneRestore model:', ckpt_name)
45
+ model = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu")
46
+ model.load_state_dict(model_info)
47
+ else:
48
+ print('==> Initialize OneRestore model.')
49
+ model = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu")
50
+ model = torch.nn.DataParallel(model).to("cuda" if torch.cuda.is_available() else "cpu")
51
+
52
+ if freeze_model:
53
+ freeze(model)
54
+ total = sum([param.nelement() for param in model.parameters()])
55
+ print("Number of OneRestore parameter: %.2fM" % (total/1e6))
56
+
57
+ return model
58
+
59
+ def load_restore_ckpt_with_optim(device, local_rank=None, freeze_model=False, ckpt_name=None, lr=None):
60
+ if ckpt_name != None:
61
+ if torch.cuda.is_available():
62
+ model_info = torch.load(ckpt_name)
63
+ else:
64
+ model_info = torch.load(ckpt_name, map_location=torch.device('cpu'))
65
+
66
+ print('==> loading existing OneRestore model:', ckpt_name)
67
+ model = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu")
68
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr) if lr != None else None
69
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) if local_rank != None else model
70
+
71
+ if local_rank != None:
72
+ model.load_state_dict(model_info['state_dict'])
73
+ else:
74
+ weights_dict = {}
75
+ for k, v in model_info['state_dict'].items():
76
+ new_k = k.replace('module.', '') if 'module' in k else k
77
+ weights_dict[new_k] = v
78
+ model.load_state_dict(weights_dict)
79
+ optimizer = torch.optim.Adam(model.parameters())
80
+ optimizer.load_state_dict(model_info['optimizer'])
81
+ cur_epoch = model_info['epoch']
82
+ else:
83
+ print('==> Initialize OneRestore model.')
84
+ model = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu")
85
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
86
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) if local_rank != None else torch.nn.DataParallel(model)
87
+ cur_epoch = 0
88
+
89
+ if freeze_model:
90
+ freeze(model)
91
+ total = sum([param.nelement() for param in model.parameters()])
92
+ print("Number of OneRestore parameter: %.2fM" % (total/1e6))
93
+
94
+ return model, optimizer, cur_epoch
95
+
96
+ def load_embedder_ckpt_with_optim(device, args, combine_type = ['clear', 'low', 'haze', 'rain', 'snow',\
97
+ 'low_haze', 'low_rain', 'low_snow', 'haze_rain', 'haze_snow', 'low_haze_rain', 'low_haze_snow']):
98
+ print('Init embedder')
99
+ # seed
100
+ if args.seed == -1:
101
+ args.seed = np.random.randint(1, 10000)
102
+ seed = args.seed
103
+ np.random.seed(seed)
104
+ torch.manual_seed(seed)
105
+ print('Training embedder seed:', seed)
106
+
107
+ # embedder model
108
+ embedder = Embedder(combine_type).to("cuda" if torch.cuda.is_available() else "cpu")
109
+
110
+ if args.pre_weight == '':
111
+ optimizer = torch.optim.Adam(embedder.parameters(), lr=args.lr)
112
+ cur_epoch = 1
113
+ else:
114
+ try:
115
+ embedder_info = torch.load(f'{args.check_dir}/{args.pre_weight}')
116
+ if torch.cuda.is_available():
117
+ embedder_info = torch.load(f'{args.check_dir}/{args.pre_weight}')
118
+ else:
119
+ embedder_info = torch.load(f'{args.check_dir}/{args.pre_weight}', map_location=torch.device('cpu'))
120
+ embedder.load_state_dict(embedder_info['state_dict'])
121
+ optimizer = torch.optim.Adam(embedder.parameters(), lr=args.lr)
122
+ optimizer.load_state_dict(embedder_info['optimizer'])
123
+ cur_epoch = embedder_info['epoch'] + 1
124
+ except:
125
+ print('Pre-trained model loading error!')
126
+ return embedder, optimizer, cur_epoch, device
127
+
128
+ def freeze_text_embedder(m):
129
+ """Freezes module m.
130
+ """
131
+ m.eval()
132
+ for name, para in m.named_parameters():
133
+ if name == 'embedder.weight' or name == 'mlp.0.weight' or name == 'mlp.0.bias':
134
+ print(name)
135
+ para.requires_grad = False
136
+ para.grad = None
137
+
138
+ class AverageMeter(object):
139
+ """Computes and stores the average and current value"""
140
+
141
+ def __init__(self):
142
+ self.reset()
143
+
144
+ def reset(self):
145
+ self.val = 0
146
+ self.avg = 0
147
+ self.sum = 0
148
+ self.count = 0
149
+
150
+ def update(self, val, n=1):
151
+ self.val = val
152
+ self.sum += val * n
153
+ self.count += n
154
+ self.avg = self.sum / self.count
155
+
156
+ def data_process(data, args, device):
157
+ combine_type = args.degr_type
158
+ b,n,c,w,h = data.size()
159
+
160
+ pos_data = data[:,0,:,:,:]
161
+
162
+ inp_data = torch.zeros((b,c,w,h))
163
+ inp_class = []
164
+
165
+ neg_data = torch.zeros((b,n-2,c,w,h))
166
+
167
+ index = np.random.randint(1, n, (b))
168
+ for i in range(b):
169
+ k = 0
170
+ for j in range(n):
171
+ if j == 0:
172
+ continue
173
+ elif index[i] == j:
174
+ inp_class.append(combine_type[index[i]])
175
+ inp_data[i, :, :, :] = data[i, index[i], :, :,:]
176
+ else:
177
+ neg_data[i,k,:,:,:] = data[i, j, :, :,:]
178
+ k=k+1
179
+ return pos_data.to("cuda" if torch.cuda.is_available() else "cpu"), [inp_data.to("cuda" if torch.cuda.is_available() else "cpu"), inp_class], neg_data.to("cuda" if torch.cuda.is_available() else "cpu")
180
+
181
+ def print_args(argspar):
182
+ print("\nParameter Print")
183
+ for p, v in zip(argspar.__dict__.keys(), argspar.__dict__.values()):
184
+ print('\t{}: {}'.format(p, v))
185
+ print('\n')
186
+
187
+ def adjust_learning_rate(optimizer, epoch, lr_update_freq):
188
+ if not epoch % lr_update_freq and epoch:
189
+ for param_group in optimizer.param_groups:
190
+ param_group['lr'] = param_group['lr'] /2
191
+ return optimizer
192
+
193
+
194
+ def tensor_metric(img, imclean, model, data_range=1):
195
+
196
+ img_cpu = img.data.cpu().numpy().astype(np.float32).transpose(0,2,3,1)
197
+ imgclean = imclean.data.cpu().numpy().astype(np.float32).transpose(0,2,3,1)
198
+
199
+ SUM = 0
200
+ for i in range(img_cpu.shape[0]):
201
+
202
+ if model == 'PSNR':
203
+ SUM += compare_psnr(imgclean[i, :, :, :], img_cpu[i, :, :, :],data_range=data_range)
204
+ elif model == 'MSE':
205
+ SUM += compare_mse(imgclean[i, :, :, :], img_cpu[i, :, :, :])
206
+ elif model == 'SSIM':
207
+ SUM += compare_ssim(imgclean[i, :, :, :], img_cpu[i, :, :, :], data_range=data_range, multichannel = True)
208
+ # due to the skimage vision problem, you can replace above line by
209
+ # SUM += compare_ssim(imgclean[i, :, :, :], img_cpu[i, :, :, :], data_range=data_range, channel_axis=-1)
210
+ else:
211
+ print('Model False!')
212
+
213
+ return SUM/img_cpu.shape[0]
214
+
215
+ def save_checkpoint(stateF, checkpoint, epoch, psnr_t1,ssim_t1,psnr_t2,ssim_t2, filename='model.tar'):
216
+ torch.save(stateF, checkpoint + 'OneRestore_model_%d_%.4f_%.4f_%.4f_%.4f.tar'%(epoch,psnr_t1,ssim_t1,psnr_t2,ssim_t2))
217
+
218
+ def load_excel(x):
219
+ data1 = pd.DataFrame(x)
220
+
221
+ writer = pd.ExcelWriter('./mertic_result.xlsx')
222
+ data1.to_excel(writer, 'PSNR-SSIM', float_format='%.5f')
223
+ # writer.save()
224
+ writer.close()
225
+
226
+ def freeze(m):
227
+ """Freezes module m.
228
+ """
229
+ m.eval()
230
+ for p in m.parameters():
231
+ p.requires_grad = False
232
+ p.grad = None