File size: 5,763 Bytes
b7f3942
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import math
from functools import partial

import cv2
import numpy as np
import torch
from basicsr.data import degradations as degradations
from basicsr.data.transforms import augment
from basicsr.utils import img2tensor
from torch.nn.functional import interpolate
from torchvision.transforms import Compose
from utils.basicsr_custom import (
    random_mixed_kernels,
    random_add_gaussian_noise,
    random_add_jpg_compression,
)


def create_degradation(degradation):
    if degradation == 'sr_bicubic_x8_gaussian_noise_005':
        return Compose([
            partial(down_scale, scale_factor=1.0 / 8.0, mode='bicubic'),
            partial(add_gaussian_noise, std=0.05),
            partial(interpolate, scale_factor=8.0, mode='nearest-exact'),
            partial(torch.clip, min=0, max=1),
            partial(torch.squeeze, dim=0),
            lambda x: (x, None)

        ])
    elif degradation == 'gaussian_noise_035':
        return Compose([
            partial(add_gaussian_noise, std=0.35),
            partial(torch.clip, min=0, max=1),
            partial(torch.squeeze, dim=0),
            lambda x: (x, None)

        ])
    elif degradation == 'colorization_gaussian_noise_025':
        return Compose([
            lambda x: torch.mean(x, dim=0, keepdim=True),
            partial(add_gaussian_noise, std=0.25),
            partial(torch.clip, min=0, max=1),
            lambda x: (x, None)
        ])
    elif degradation == 'random_inpainting_gaussian_noise_01':
        def inpainting_dps(x):
            total = x.shape[1] ** 2
            # random pixel sampling
            l, h = [0.9, 0.9]
            prob = np.random.uniform(l, h)
            mask_vec = torch.ones([1, x.shape[1] * x.shape[1]])
            samples = np.random.choice(x.shape[1] * x.shape[1], int(total * prob), replace=False)
            mask_vec[:, samples] = 0
            mask_b = mask_vec.view(1, x.shape[1], x.shape[1])
            mask_b = mask_b.repeat(3, 1, 1)
            mask = torch.ones_like(x, device=x.device)
            mask[:, ...] = mask_b
            return add_gaussian_noise(x * mask, 0.1).clip(0, 1), None

        return inpainting_dps
    elif degradation == 'difface':
        def deg(x):
            blur_kernel_size = 41
            kernel_list = ['iso', 'aniso']
            kernel_prob = [0.5, 0.5]
            blur_sigma = [0.1, 15]
            downsample_range = [0.8, 32]
            noise_range = [0, 20]
            jpeg_range = [30, 100]
            gt_gray = True
            gray_prob = 0.01
            x = x.permute(1, 2, 0).numpy()[..., ::-1].astype(np.float32)
            # random horizontal flip
            img_gt = augment(x.copy(), hflip=True, rotation=False)
            h, w, _ = img_gt.shape

            # ------------------------ generate lq image ------------------------ #
            # blur
            kernel = degradations.random_mixed_kernels(
                kernel_list,
                kernel_prob,
                blur_kernel_size,
                blur_sigma,
                blur_sigma, [-math.pi, math.pi],
                noise_range=None)
            img_lq = cv2.filter2D(img_gt, -1, kernel)
            # downsample
            scale = np.random.uniform(downsample_range[0], downsample_range[1])
            img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
            # noise
            if noise_range is not None:
                img_lq = random_add_gaussian_noise(img_lq, noise_range)
            # jpeg compression
            if jpeg_range is not None:
                img_lq = random_add_jpg_compression(img_lq, jpeg_range)

            # resize to original size
            img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)

            # random color jitter (only for lq)
            # if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
            #     img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
            # random to gray (only for lq)
            if np.random.uniform() < gray_prob:
                img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
                img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
                if gt_gray:  # whether convert GT to gray images
                    img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
                    img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])  # repeat the color channels

            # BGR to RGB, HWC to CHW, numpy to tensor
            img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)

            # random color jitter (pytorch version) (only for lq)
            # if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
            #     brightness = self.opt.get('brightness', (0.5, 1.5))
            #     contrast = self.opt.get('contrast', (0.5, 1.5))
            #     saturation = self.opt.get('saturation', (0, 1.5))
            #     hue = self.opt.get('hue', (-0.1, 0.1))
            #     img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)

            # round and clip
            img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.

            return img_lq, img_gt.clip(0, 1)

        return deg
    else:
        raise NotImplementedError()


def down_scale(x, scale_factor, mode):
    with torch.no_grad():
        return interpolate(x.unsqueeze(0),
                           scale_factor=scale_factor,
                           mode=mode,
                           antialias=True,
                           align_corners=False).clip(0, 1)


def add_gaussian_noise(x, std):
    with torch.no_grad():
        x = x + torch.randn_like(x) * std
        return x