ohayonguy commited on
Commit
b7f3942
1 Parent(s): 1b8b226

first commit fixed

Browse files
app.py CHANGED
@@ -24,17 +24,12 @@ if not os.path.exists(realesr_model_path):
24
  os.system(
25
  "wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -O experiments/pretrained_models/RealESRGAN_x4plus.pth")
26
 
27
- pmrf_model_path = 'blind_face_restoration_pmrf.ckpt'
28
-
29
  # background enhancer with RealESRGAN
30
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
31
  half = True if torch.cuda.is_available() else False
32
  upsampler = RealESRGANer(scale=4, model_path=realesr_model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
33
 
34
- pmrf = MMSERectifiedFlow.load_from_checkpoint('./blind_face_restoration_pmrf.ckpt',
35
- mmse_model_arch='swinir_L',
36
- mmse_model_ckpt_path=None,
37
- map_location='cpu').to(device)
38
 
39
  os.makedirs('output', exist_ok=True)
40
 
 
24
  os.system(
25
  "wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -O experiments/pretrained_models/RealESRGAN_x4plus.pth")
26
 
 
 
27
  # background enhancer with RealESRGAN
28
  model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
29
  half = True if torch.cuda.is_available() else False
30
  upsampler = RealESRGANer(scale=4, model_path=realesr_model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
31
 
32
+ pmrf = MMSERectifiedFlow.from_pretrained('ohayonguy/PMRF_blind_face_image_restoration').to(device)
 
 
 
33
 
34
  os.makedirs('output', exist_ok=True)
35
 
arch/hourglass/__init__.py ADDED
File without changes
arch/hourglass/axial_rope.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """k-diffusion transformer diffusion models, version 2.
2
+ Codes adopted from https://github.com/crowsonkb/k-diffusion
3
+ """
4
+
5
+ import math
6
+
7
+ import torch
8
+ import torch._dynamo
9
+ from torch import nn
10
+
11
+ from . import flags
12
+
13
+ if flags.get_use_compile():
14
+ torch._dynamo.config.suppress_errors = True
15
+
16
+
17
+ def rotate_half(x):
18
+ x1, x2 = x[..., 0::2], x[..., 1::2]
19
+ x = torch.stack((-x2, x1), dim=-1)
20
+ *shape, d, r = x.shape
21
+ return x.view(*shape, d * r)
22
+
23
+
24
+ @flags.compile_wrap
25
+ def apply_rotary_emb(freqs, t, start_index=0, scale=1.0):
26
+ freqs = freqs.to(t)
27
+ rot_dim = freqs.shape[-1]
28
+ end_index = start_index + rot_dim
29
+ assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
30
+ t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
31
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
32
+ return torch.cat((t_left, t, t_right), dim=-1)
33
+
34
+
35
+ def centers(start, stop, num, dtype=None, device=None):
36
+ edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device)
37
+ return (edges[:-1] + edges[1:]) / 2
38
+
39
+
40
+ def make_grid(h_pos, w_pos):
41
+ grid = torch.stack(torch.meshgrid(h_pos, w_pos, indexing='ij'), dim=-1)
42
+ h, w, d = grid.shape
43
+ return grid.view(h * w, d)
44
+
45
+
46
+ def bounding_box(h, w, pixel_aspect_ratio=1.0):
47
+ # Adjusted dimensions
48
+ w_adj = w
49
+ h_adj = h * pixel_aspect_ratio
50
+
51
+ # Adjusted aspect ratio
52
+ ar_adj = w_adj / h_adj
53
+
54
+ # Determine bounding box based on the adjusted aspect ratio
55
+ y_min, y_max, x_min, x_max = -1.0, 1.0, -1.0, 1.0
56
+ if ar_adj > 1:
57
+ y_min, y_max = -1 / ar_adj, 1 / ar_adj
58
+ elif ar_adj < 1:
59
+ x_min, x_max = -ar_adj, ar_adj
60
+
61
+ return y_min, y_max, x_min, x_max
62
+
63
+
64
+ def make_axial_pos(h, w, pixel_aspect_ratio=1.0, align_corners=False, dtype=None, device=None):
65
+ y_min, y_max, x_min, x_max = bounding_box(h, w, pixel_aspect_ratio)
66
+ if align_corners:
67
+ h_pos = torch.linspace(y_min, y_max, h, dtype=dtype, device=device)
68
+ w_pos = torch.linspace(x_min, x_max, w, dtype=dtype, device=device)
69
+ else:
70
+ h_pos = centers(y_min, y_max, h, dtype=dtype, device=device)
71
+ w_pos = centers(x_min, x_max, w, dtype=dtype, device=device)
72
+ return make_grid(h_pos, w_pos)
73
+
74
+
75
+ def freqs_pixel(max_freq=10.0):
76
+ def init(shape):
77
+ freqs = torch.linspace(1.0, max_freq / 2, shape[-1]) * math.pi
78
+ return freqs.log().expand(shape)
79
+ return init
80
+
81
+
82
+ def freqs_pixel_log(max_freq=10.0):
83
+ def init(shape):
84
+ log_min = math.log(math.pi)
85
+ log_max = math.log(max_freq * math.pi / 2)
86
+ return torch.linspace(log_min, log_max, shape[-1]).expand(shape)
87
+ return init
88
+
89
+
90
+ class AxialRoPE(nn.Module):
91
+ def __init__(self, dim, n_heads, start_index=0, freqs_init=freqs_pixel_log(max_freq=10.0)):
92
+ super().__init__()
93
+ self.n_heads = n_heads
94
+ self.start_index = start_index
95
+ log_freqs = freqs_init((n_heads, dim // 4))
96
+ self.freqs_h = nn.Parameter(log_freqs.clone())
97
+ self.freqs_w = nn.Parameter(log_freqs.clone())
98
+
99
+ def extra_repr(self):
100
+ dim = (self.freqs_h.shape[-1] + self.freqs_w.shape[-1]) * 2
101
+ return f"dim={dim}, n_heads={self.n_heads}, start_index={self.start_index}"
102
+
103
+ def get_freqs(self, pos):
104
+ if pos.shape[-1] != 2:
105
+ raise ValueError("input shape must be (..., 2)")
106
+ freqs_h = pos[..., None, None, 0] * self.freqs_h.exp()
107
+ freqs_w = pos[..., None, None, 1] * self.freqs_w.exp()
108
+ freqs = torch.cat((freqs_h, freqs_w), dim=-1).repeat_interleave(2, dim=-1)
109
+ return freqs.transpose(-2, -3)
110
+
111
+ def forward(self, x, pos):
112
+ freqs = self.get_freqs(pos)
113
+ return apply_rotary_emb(freqs, x, self.start_index)
arch/hourglass/flags.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """k-diffusion transformer diffusion models, version 2.
2
+ Codes adopted from https://github.com/crowsonkb/k-diffusion
3
+ """
4
+
5
+ from contextlib import contextmanager
6
+ from functools import update_wrapper
7
+ import os
8
+ import threading
9
+
10
+ import torch
11
+
12
+
13
+ def get_use_compile():
14
+ return os.environ.get("K_DIFFUSION_USE_COMPILE", "1") == "1"
15
+
16
+
17
+ def get_use_flash_attention_2():
18
+ return os.environ.get("K_DIFFUSION_USE_FLASH_2", "1") == "1"
19
+
20
+
21
+ state = threading.local()
22
+ state.checkpointing = False
23
+
24
+
25
+ @contextmanager
26
+ def checkpointing(enable=True):
27
+ try:
28
+ old_checkpointing, state.checkpointing = state.checkpointing, enable
29
+ yield
30
+ finally:
31
+ state.checkpointing = old_checkpointing
32
+
33
+
34
+ def get_checkpointing():
35
+ return getattr(state, "checkpointing", False)
36
+
37
+
38
+ class compile_wrap:
39
+ def __init__(self, function, *args, **kwargs):
40
+ self.function = function
41
+ self.args = args
42
+ self.kwargs = kwargs
43
+ self._compiled_function = None
44
+ update_wrapper(self, function)
45
+
46
+ @property
47
+ def compiled_function(self):
48
+ if self._compiled_function is not None:
49
+ return self._compiled_function
50
+ if get_use_compile():
51
+ try:
52
+ self._compiled_function = torch.compile(self.function, *self.args, **self.kwargs)
53
+ except RuntimeError:
54
+ self._compiled_function = self.function
55
+ else:
56
+ self._compiled_function = self.function
57
+ return self._compiled_function
58
+
59
+ def __call__(self, *args, **kwargs):
60
+ return self.compiled_function(*args, **kwargs)
arch/hourglass/flops.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """k-diffusion transformer diffusion models, version 2.
2
+ Codes adopted from https://github.com/crowsonkb/k-diffusion
3
+ """
4
+
5
+ from contextlib import contextmanager
6
+ import math
7
+ import threading
8
+
9
+
10
+ state = threading.local()
11
+ state.flop_counter = None
12
+
13
+
14
+ @contextmanager
15
+ def flop_counter(enable=True):
16
+ try:
17
+ old_flop_counter = state.flop_counter
18
+ state.flop_counter = FlopCounter() if enable else None
19
+ yield state.flop_counter
20
+ finally:
21
+ state.flop_counter = old_flop_counter
22
+
23
+
24
+ class FlopCounter:
25
+ def __init__(self):
26
+ self.ops = []
27
+
28
+ def op(self, op, *args, **kwargs):
29
+ self.ops.append((op, args, kwargs))
30
+
31
+ @property
32
+ def flops(self):
33
+ flops = 0
34
+ for op, args, kwargs in self.ops:
35
+ flops += op(*args, **kwargs)
36
+ return flops
37
+
38
+
39
+ def op(op, *args, **kwargs):
40
+ if getattr(state, "flop_counter", None):
41
+ state.flop_counter.op(op, *args, **kwargs)
42
+
43
+
44
+ def op_linear(x, weight):
45
+ return math.prod(x) * weight[0]
46
+
47
+
48
+ def op_attention(q, k, v):
49
+ *b, s_q, d_q = q
50
+ *b, s_k, d_k = k
51
+ *b, s_v, d_v = v
52
+ return math.prod(b) * s_q * s_k * (d_q + d_v)
53
+
54
+
55
+ def op_natten(q, k, v, kernel_size):
56
+ *q_rest, d_q = q
57
+ *_, d_v = v
58
+ return math.prod(q_rest) * (d_q + d_v) * kernel_size**2
arch/hourglass/image_transformer_v2.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """k-diffusion transformer diffusion models, version 2.
2
+ Codes adopted from https://github.com/crowsonkb/k-diffusion
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+ from functools import lru_cache, reduce
7
+ import math
8
+ from typing import Union
9
+
10
+ from einops import rearrange
11
+ import torch
12
+ from torch import nn
13
+ import torch._dynamo
14
+ from torch.nn import functional as F
15
+
16
+ from . import flags, flops
17
+ from .axial_rope import make_axial_pos
18
+
19
+
20
+ try:
21
+ import natten
22
+ except ImportError:
23
+ natten = None
24
+
25
+ try:
26
+ import flash_attn
27
+ except ImportError:
28
+ flash_attn = None
29
+
30
+
31
+ if flags.get_use_compile():
32
+ torch._dynamo.config.cache_size_limit = max(64, torch._dynamo.config.cache_size_limit)
33
+ torch._dynamo.config.suppress_errors = True
34
+
35
+
36
+ # Helpers
37
+
38
+ def zero_init(layer):
39
+ nn.init.zeros_(layer.weight)
40
+ if layer.bias is not None:
41
+ nn.init.zeros_(layer.bias)
42
+ return layer
43
+
44
+
45
+ def checkpoint(function, *args, **kwargs):
46
+ if flags.get_checkpointing():
47
+ kwargs.setdefault("use_reentrant", True)
48
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
49
+ else:
50
+ return function(*args, **kwargs)
51
+
52
+
53
+ def downscale_pos(pos):
54
+ pos = rearrange(pos, "... (h nh) (w nw) e -> ... h w (nh nw) e", nh=2, nw=2)
55
+ return torch.mean(pos, dim=-2)
56
+
57
+
58
+ # Param tags
59
+
60
+ def tag_param(param, tag):
61
+ if not hasattr(param, "_tags"):
62
+ param._tags = set([tag])
63
+ else:
64
+ param._tags.add(tag)
65
+ return param
66
+
67
+
68
+ def tag_module(module, tag):
69
+ for param in module.parameters():
70
+ tag_param(param, tag)
71
+ return module
72
+
73
+
74
+ def apply_wd(module):
75
+ for name, param in module.named_parameters():
76
+ if name.endswith("weight"):
77
+ tag_param(param, "wd")
78
+ return module
79
+
80
+
81
+ def filter_params(function, module):
82
+ for param in module.parameters():
83
+ tags = getattr(param, "_tags", set())
84
+ if function(tags):
85
+ yield param
86
+
87
+
88
+ # Kernels
89
+
90
+ @flags.compile_wrap
91
+ def linear_geglu(x, weight, bias=None):
92
+ x = x @ weight.mT
93
+ if bias is not None:
94
+ x = x + bias
95
+ x, gate = x.chunk(2, dim=-1)
96
+ return x * F.gelu(gate)
97
+
98
+
99
+ @flags.compile_wrap
100
+ def rms_norm(x, scale, eps):
101
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
102
+ mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
103
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
104
+ return x * scale.to(x.dtype)
105
+
106
+
107
+ @flags.compile_wrap
108
+ def scale_for_cosine_sim(q, k, scale, eps):
109
+ dtype = reduce(torch.promote_types, (q.dtype, k.dtype, scale.dtype, torch.float32))
110
+ sum_sq_q = torch.sum(q.to(dtype)**2, dim=-1, keepdim=True)
111
+ sum_sq_k = torch.sum(k.to(dtype)**2, dim=-1, keepdim=True)
112
+ sqrt_scale = torch.sqrt(scale.to(dtype))
113
+ scale_q = sqrt_scale * torch.rsqrt(sum_sq_q + eps)
114
+ scale_k = sqrt_scale * torch.rsqrt(sum_sq_k + eps)
115
+ return q * scale_q.to(q.dtype), k * scale_k.to(k.dtype)
116
+
117
+
118
+ @flags.compile_wrap
119
+ def scale_for_cosine_sim_qkv(qkv, scale, eps):
120
+ q, k, v = qkv.unbind(2)
121
+ q, k = scale_for_cosine_sim(q, k, scale[:, None], eps)
122
+ return torch.stack((q, k, v), dim=2)
123
+
124
+
125
+ # Layers
126
+
127
+ class Linear(nn.Linear):
128
+ def forward(self, x):
129
+ flops.op(flops.op_linear, x.shape, self.weight.shape)
130
+ return super().forward(x)
131
+
132
+
133
+ class LinearGEGLU(nn.Linear):
134
+ def __init__(self, in_features, out_features, bias=True):
135
+ super().__init__(in_features, out_features * 2, bias=bias)
136
+ self.out_features = out_features
137
+
138
+ def forward(self, x):
139
+ flops.op(flops.op_linear, x.shape, self.weight.shape)
140
+ return linear_geglu(x, self.weight, self.bias)
141
+
142
+
143
+ class FourierFeatures(nn.Module):
144
+ def __init__(self, in_features, out_features, std=1.):
145
+ super().__init__()
146
+ assert out_features % 2 == 0
147
+ self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std)
148
+
149
+ def forward(self, input):
150
+ f = 2 * math.pi * input @ self.weight.T
151
+ return torch.cat([f.cos(), f.sin()], dim=-1)
152
+
153
+ class RMSNorm(nn.Module):
154
+ def __init__(self, shape, eps=1e-6):
155
+ super().__init__()
156
+ self.eps = eps
157
+ self.scale = nn.Parameter(torch.ones(shape))
158
+
159
+ def extra_repr(self):
160
+ return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
161
+
162
+ def forward(self, x):
163
+ return rms_norm(x, self.scale, self.eps)
164
+
165
+
166
+ class AdaRMSNorm(nn.Module):
167
+ def __init__(self, features, cond_features, eps=1e-6):
168
+ super().__init__()
169
+ self.eps = eps
170
+ self.linear = apply_wd(zero_init(Linear(cond_features, features, bias=False)))
171
+ tag_module(self.linear, "mapping")
172
+
173
+ def extra_repr(self):
174
+ return f"eps={self.eps},"
175
+
176
+ def forward(self, x, cond):
177
+ return rms_norm(x, self.linear(cond)[:, None, None, :] + 1, self.eps)
178
+
179
+
180
+ # Rotary position embeddings
181
+
182
+ @flags.compile_wrap
183
+ def apply_rotary_emb(x, theta, conj=False):
184
+ out_dtype = x.dtype
185
+ dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32))
186
+ d = theta.shape[-1]
187
+ assert d * 2 <= x.shape[-1]
188
+ x1, x2, x3 = x[..., :d], x[..., d : d * 2], x[..., d * 2 :]
189
+ x1, x2, theta = x1.to(dtype), x2.to(dtype), theta.to(dtype)
190
+ cos, sin = torch.cos(theta), torch.sin(theta)
191
+ sin = -sin if conj else sin
192
+ y1 = x1 * cos - x2 * sin
193
+ y2 = x2 * cos + x1 * sin
194
+ y1, y2 = y1.to(out_dtype), y2.to(out_dtype)
195
+ return torch.cat((y1, y2, x3), dim=-1)
196
+
197
+
198
+ @flags.compile_wrap
199
+ def _apply_rotary_emb_inplace(x, theta, conj):
200
+ dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32))
201
+ d = theta.shape[-1]
202
+ assert d * 2 <= x.shape[-1]
203
+ x1, x2 = x[..., :d], x[..., d : d * 2]
204
+ x1_, x2_, theta = x1.to(dtype), x2.to(dtype), theta.to(dtype)
205
+ cos, sin = torch.cos(theta), torch.sin(theta)
206
+ sin = -sin if conj else sin
207
+ y1 = x1_ * cos - x2_ * sin
208
+ y2 = x2_ * cos + x1_ * sin
209
+ x1.copy_(y1)
210
+ x2.copy_(y2)
211
+
212
+
213
+ class ApplyRotaryEmbeddingInplace(torch.autograd.Function):
214
+ @staticmethod
215
+ def forward(x, theta, conj):
216
+ _apply_rotary_emb_inplace(x, theta, conj=conj)
217
+ return x
218
+
219
+ @staticmethod
220
+ def setup_context(ctx, inputs, output):
221
+ _, theta, conj = inputs
222
+ ctx.save_for_backward(theta)
223
+ ctx.conj = conj
224
+
225
+ @staticmethod
226
+ def backward(ctx, grad_output):
227
+ theta, = ctx.saved_tensors
228
+ _apply_rotary_emb_inplace(grad_output, theta, conj=not ctx.conj)
229
+ return grad_output, None, None
230
+
231
+
232
+ def apply_rotary_emb_(x, theta):
233
+ return ApplyRotaryEmbeddingInplace.apply(x, theta, False)
234
+
235
+
236
+ class AxialRoPE(nn.Module):
237
+ def __init__(self, dim, n_heads):
238
+ super().__init__()
239
+ log_min = math.log(math.pi)
240
+ log_max = math.log(10.0 * math.pi)
241
+ freqs = torch.linspace(log_min, log_max, n_heads * dim // 4 + 1)[:-1].exp()
242
+ self.register_buffer("freqs", freqs.view(dim // 4, n_heads).T.contiguous())
243
+
244
+ def extra_repr(self):
245
+ return f"dim={self.freqs.shape[1] * 4}, n_heads={self.freqs.shape[0]}"
246
+
247
+ def forward(self, pos):
248
+ theta_h = pos[..., None, 0:1] * self.freqs.to(pos.dtype)
249
+ theta_w = pos[..., None, 1:2] * self.freqs.to(pos.dtype)
250
+ return torch.cat((theta_h, theta_w), dim=-1)
251
+
252
+
253
+ # Shifted window attention
254
+
255
+ def window(window_size, x):
256
+ *b, h, w, c = x.shape
257
+ x = torch.reshape(
258
+ x,
259
+ (*b, h // window_size, window_size, w // window_size, window_size, c),
260
+ )
261
+ x = torch.permute(
262
+ x,
263
+ (*range(len(b)), -5, -3, -4, -2, -1),
264
+ )
265
+ return x
266
+
267
+
268
+ def unwindow(x):
269
+ *b, h, w, wh, ww, c = x.shape
270
+ x = torch.permute(x, (*range(len(b)), -5, -3, -4, -2, -1))
271
+ x = torch.reshape(x, (*b, h * wh, w * ww, c))
272
+ return x
273
+
274
+
275
+ def shifted_window(window_size, window_shift, x):
276
+ x = torch.roll(x, shifts=(window_shift, window_shift), dims=(-2, -3))
277
+ windows = window(window_size, x)
278
+ return windows
279
+
280
+
281
+ def shifted_unwindow(window_shift, x):
282
+ x = unwindow(x)
283
+ x = torch.roll(x, shifts=(-window_shift, -window_shift), dims=(-2, -3))
284
+ return x
285
+
286
+
287
+ @lru_cache
288
+ def make_shifted_window_masks(n_h_w, n_w_w, w_h, w_w, shift, device=None):
289
+ ph_coords = torch.arange(n_h_w, device=device)
290
+ pw_coords = torch.arange(n_w_w, device=device)
291
+ h_coords = torch.arange(w_h, device=device)
292
+ w_coords = torch.arange(w_w, device=device)
293
+ patch_h, patch_w, q_h, q_w, k_h, k_w = torch.meshgrid(
294
+ ph_coords,
295
+ pw_coords,
296
+ h_coords,
297
+ w_coords,
298
+ h_coords,
299
+ w_coords,
300
+ indexing="ij",
301
+ )
302
+ is_top_patch = patch_h == 0
303
+ is_left_patch = patch_w == 0
304
+ q_above_shift = q_h < shift
305
+ k_above_shift = k_h < shift
306
+ q_left_of_shift = q_w < shift
307
+ k_left_of_shift = k_w < shift
308
+ m_corner = (
309
+ is_left_patch
310
+ & is_top_patch
311
+ & (q_left_of_shift == k_left_of_shift)
312
+ & (q_above_shift == k_above_shift)
313
+ )
314
+ m_left = is_left_patch & ~is_top_patch & (q_left_of_shift == k_left_of_shift)
315
+ m_top = ~is_left_patch & is_top_patch & (q_above_shift == k_above_shift)
316
+ m_rest = ~is_left_patch & ~is_top_patch
317
+ m = m_corner | m_left | m_top | m_rest
318
+ return m
319
+
320
+
321
+ def apply_window_attention(window_size, window_shift, q, k, v, scale=None):
322
+ # prep windows and masks
323
+ q_windows = shifted_window(window_size, window_shift, q)
324
+ k_windows = shifted_window(window_size, window_shift, k)
325
+ v_windows = shifted_window(window_size, window_shift, v)
326
+ b, heads, h, w, wh, ww, d_head = q_windows.shape
327
+ mask = make_shifted_window_masks(h, w, wh, ww, window_shift, device=q.device)
328
+ q_seqs = torch.reshape(q_windows, (b, heads, h, w, wh * ww, d_head))
329
+ k_seqs = torch.reshape(k_windows, (b, heads, h, w, wh * ww, d_head))
330
+ v_seqs = torch.reshape(v_windows, (b, heads, h, w, wh * ww, d_head))
331
+ mask = torch.reshape(mask, (h, w, wh * ww, wh * ww))
332
+
333
+ # do the attention here
334
+ flops.op(flops.op_attention, q_seqs.shape, k_seqs.shape, v_seqs.shape)
335
+ qkv = F.scaled_dot_product_attention(q_seqs, k_seqs, v_seqs, mask, scale=scale)
336
+
337
+ # unwindow
338
+ qkv = torch.reshape(qkv, (b, heads, h, w, wh, ww, d_head))
339
+ return shifted_unwindow(window_shift, qkv)
340
+
341
+
342
+ # Transformer layers
343
+
344
+
345
+ def use_flash_2(x):
346
+ if not flags.get_use_flash_attention_2():
347
+ return False
348
+ if flash_attn is None:
349
+ return False
350
+ if x.device.type != "cuda":
351
+ return False
352
+ if x.dtype not in (torch.float16, torch.bfloat16):
353
+ return False
354
+ return True
355
+
356
+
357
+ class SelfAttentionBlock(nn.Module):
358
+ def __init__(self, d_model, d_head, cond_features, dropout=0.0):
359
+ super().__init__()
360
+ self.d_head = d_head
361
+ self.n_heads = d_model // d_head
362
+ self.norm = AdaRMSNorm(d_model, cond_features)
363
+ self.qkv_proj = apply_wd(Linear(d_model, d_model * 3, bias=False))
364
+ self.scale = nn.Parameter(torch.full([self.n_heads], 10.0))
365
+ self.pos_emb = AxialRoPE(d_head // 2, self.n_heads)
366
+ self.dropout = nn.Dropout(dropout)
367
+ self.out_proj = apply_wd(zero_init(Linear(d_model, d_model, bias=False)))
368
+
369
+ def extra_repr(self):
370
+ return f"d_head={self.d_head},"
371
+
372
+ def forward(self, x, pos, cond):
373
+ skip = x
374
+ x = self.norm(x, cond)
375
+ qkv = self.qkv_proj(x)
376
+ pos = rearrange(pos, "... h w e -> ... (h w) e").to(qkv.dtype)
377
+ theta = self.pos_emb(pos)
378
+ if use_flash_2(qkv):
379
+ qkv = rearrange(qkv, "n h w (t nh e) -> n (h w) t nh e", t=3, e=self.d_head)
380
+ qkv = scale_for_cosine_sim_qkv(qkv, self.scale, 1e-6)
381
+ theta = torch.stack((theta, theta, torch.zeros_like(theta)), dim=-3)
382
+ qkv = apply_rotary_emb_(qkv, theta)
383
+ flops_shape = qkv.shape[-5], qkv.shape[-2], qkv.shape[-4], qkv.shape[-1]
384
+ flops.op(flops.op_attention, flops_shape, flops_shape, flops_shape)
385
+ x = flash_attn.flash_attn_qkvpacked_func(qkv, softmax_scale=1.0)
386
+ x = rearrange(x, "n (h w) nh e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2])
387
+ else:
388
+ q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh (h w) e", t=3, e=self.d_head)
389
+ q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None], 1e-6)
390
+ theta = theta.movedim(-2, -3)
391
+ q = apply_rotary_emb_(q, theta)
392
+ k = apply_rotary_emb_(k, theta)
393
+ flops.op(flops.op_attention, q.shape, k.shape, v.shape)
394
+ x = F.scaled_dot_product_attention(q, k, v, scale=1.0)
395
+ x = rearrange(x, "n nh (h w) e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2])
396
+ x = self.dropout(x)
397
+ x = self.out_proj(x)
398
+ return x + skip
399
+
400
+
401
+ class NeighborhoodSelfAttentionBlock(nn.Module):
402
+ def __init__(self, d_model, d_head, cond_features, kernel_size, dropout=0.0):
403
+ super().__init__()
404
+ self.d_head = d_head
405
+ self.n_heads = d_model // d_head
406
+ self.kernel_size = kernel_size
407
+ self.norm = AdaRMSNorm(d_model, cond_features)
408
+ self.qkv_proj = apply_wd(Linear(d_model, d_model * 3, bias=False))
409
+ self.scale = nn.Parameter(torch.full([self.n_heads], 10.0))
410
+ self.pos_emb = AxialRoPE(d_head // 2, self.n_heads)
411
+ self.dropout = nn.Dropout(dropout)
412
+ self.out_proj = apply_wd(zero_init(Linear(d_model, d_model, bias=False)))
413
+
414
+ def extra_repr(self):
415
+ return f"d_head={self.d_head}, kernel_size={self.kernel_size}"
416
+
417
+ def forward(self, x, pos, cond):
418
+ skip = x
419
+ x = self.norm(x, cond)
420
+ qkv = self.qkv_proj(x)
421
+ if natten is None:
422
+ raise ModuleNotFoundError("natten is required for neighborhood attention")
423
+ if natten.has_fused_na():
424
+ q, k, v = rearrange(qkv, "n h w (t nh e) -> t n h w nh e", t=3, e=self.d_head)
425
+ q, k = scale_for_cosine_sim(q, k, self.scale[:, None], 1e-6)
426
+ theta = self.pos_emb(pos)
427
+ q = apply_rotary_emb_(q, theta)
428
+ k = apply_rotary_emb_(k, theta)
429
+ flops.op(flops.op_natten, q.shape, k.shape, v.shape, self.kernel_size)
430
+ x = natten.functional.na2d(q, k, v, self.kernel_size, scale=1.0)
431
+ x = rearrange(x, "n h w nh e -> n h w (nh e)")
432
+ else:
433
+ q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head)
434
+ q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6)
435
+ theta = self.pos_emb(pos).movedim(-2, -4)
436
+ q = apply_rotary_emb_(q, theta)
437
+ k = apply_rotary_emb_(k, theta)
438
+ flops.op(flops.op_natten, q.shape, k.shape, v.shape, self.kernel_size)
439
+ qk = natten.functional.na2d_qk(q, k, self.kernel_size)
440
+ a = torch.softmax(qk, dim=-1).to(v.dtype)
441
+ x = natten.functional.na2d_av(a, v, self.kernel_size)
442
+ x = rearrange(x, "n nh h w e -> n h w (nh e)")
443
+ x = self.dropout(x)
444
+ x = self.out_proj(x)
445
+ return x + skip
446
+
447
+
448
+ class ShiftedWindowSelfAttentionBlock(nn.Module):
449
+ def __init__(self, d_model, d_head, cond_features, window_size, window_shift, dropout=0.0):
450
+ super().__init__()
451
+ self.d_head = d_head
452
+ self.n_heads = d_model // d_head
453
+ self.window_size = window_size
454
+ self.window_shift = window_shift
455
+ self.norm = AdaRMSNorm(d_model, cond_features)
456
+ self.qkv_proj = apply_wd(Linear(d_model, d_model * 3, bias=False))
457
+ self.scale = nn.Parameter(torch.full([self.n_heads], 10.0))
458
+ self.pos_emb = AxialRoPE(d_head // 2, self.n_heads)
459
+ self.dropout = nn.Dropout(dropout)
460
+ self.out_proj = apply_wd(zero_init(Linear(d_model, d_model, bias=False)))
461
+
462
+ def extra_repr(self):
463
+ return f"d_head={self.d_head}, window_size={self.window_size}, window_shift={self.window_shift}"
464
+
465
+ def forward(self, x, pos, cond):
466
+ skip = x
467
+ x = self.norm(x, cond)
468
+ qkv = self.qkv_proj(x)
469
+ q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head)
470
+ q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6)
471
+ theta = self.pos_emb(pos).movedim(-2, -4)
472
+ q = apply_rotary_emb_(q, theta)
473
+ k = apply_rotary_emb_(k, theta)
474
+ x = apply_window_attention(self.window_size, self.window_shift, q, k, v, scale=1.0)
475
+ x = rearrange(x, "n nh h w e -> n h w (nh e)")
476
+ x = self.dropout(x)
477
+ x = self.out_proj(x)
478
+ return x + skip
479
+
480
+
481
+ class FeedForwardBlock(nn.Module):
482
+ def __init__(self, d_model, d_ff, cond_features, dropout=0.0):
483
+ super().__init__()
484
+ self.norm = AdaRMSNorm(d_model, cond_features)
485
+ self.up_proj = apply_wd(LinearGEGLU(d_model, d_ff, bias=False))
486
+ self.dropout = nn.Dropout(dropout)
487
+ self.down_proj = apply_wd(zero_init(Linear(d_ff, d_model, bias=False)))
488
+
489
+ def forward(self, x, cond):
490
+ skip = x
491
+ x = self.norm(x, cond)
492
+ x = self.up_proj(x)
493
+ x = self.dropout(x)
494
+ x = self.down_proj(x)
495
+ return x + skip
496
+
497
+
498
+ class GlobalTransformerLayer(nn.Module):
499
+ def __init__(self, d_model, d_ff, d_head, cond_features, dropout=0.0):
500
+ super().__init__()
501
+ self.self_attn = SelfAttentionBlock(d_model, d_head, cond_features, dropout=dropout)
502
+ self.ff = FeedForwardBlock(d_model, d_ff, cond_features, dropout=dropout)
503
+
504
+ def forward(self, x, pos, cond):
505
+ x = checkpoint(self.self_attn, x, pos, cond)
506
+ x = checkpoint(self.ff, x, cond)
507
+ return x
508
+
509
+
510
+ class NeighborhoodTransformerLayer(nn.Module):
511
+ def __init__(self, d_model, d_ff, d_head, cond_features, kernel_size, dropout=0.0):
512
+ super().__init__()
513
+ self.self_attn = NeighborhoodSelfAttentionBlock(d_model, d_head, cond_features, kernel_size, dropout=dropout)
514
+ self.ff = FeedForwardBlock(d_model, d_ff, cond_features, dropout=dropout)
515
+
516
+ def forward(self, x, pos, cond):
517
+ x = checkpoint(self.self_attn, x, pos, cond)
518
+ x = checkpoint(self.ff, x, cond)
519
+ return x
520
+
521
+
522
+ class ShiftedWindowTransformerLayer(nn.Module):
523
+ def __init__(self, d_model, d_ff, d_head, cond_features, window_size, index, dropout=0.0):
524
+ super().__init__()
525
+ window_shift = window_size // 2 if index % 2 == 1 else 0
526
+ self.self_attn = ShiftedWindowSelfAttentionBlock(d_model, d_head, cond_features, window_size, window_shift, dropout=dropout)
527
+ self.ff = FeedForwardBlock(d_model, d_ff, cond_features, dropout=dropout)
528
+
529
+ def forward(self, x, pos, cond):
530
+ x = checkpoint(self.self_attn, x, pos, cond)
531
+ x = checkpoint(self.ff, x, cond)
532
+ return x
533
+
534
+
535
+ class NoAttentionTransformerLayer(nn.Module):
536
+ def __init__(self, d_model, d_ff, cond_features, dropout=0.0):
537
+ super().__init__()
538
+ self.ff = FeedForwardBlock(d_model, d_ff, cond_features, dropout=dropout)
539
+
540
+ def forward(self, x, pos, cond):
541
+ x = checkpoint(self.ff, x, cond)
542
+ return x
543
+
544
+
545
+ class Level(nn.ModuleList):
546
+ def forward(self, x, *args, **kwargs):
547
+ for layer in self:
548
+ x = layer(x, *args, **kwargs)
549
+ return x
550
+
551
+
552
+ # Mapping network
553
+
554
+ class MappingFeedForwardBlock(nn.Module):
555
+ def __init__(self, d_model, d_ff, dropout=0.0):
556
+ super().__init__()
557
+ self.norm = RMSNorm(d_model)
558
+ self.up_proj = apply_wd(LinearGEGLU(d_model, d_ff, bias=False))
559
+ self.dropout = nn.Dropout(dropout)
560
+ self.down_proj = apply_wd(zero_init(Linear(d_ff, d_model, bias=False)))
561
+
562
+ def forward(self, x):
563
+ skip = x
564
+ x = self.norm(x)
565
+ x = self.up_proj(x)
566
+ x = self.dropout(x)
567
+ x = self.down_proj(x)
568
+ return x + skip
569
+
570
+
571
+ class MappingNetwork(nn.Module):
572
+ def __init__(self, n_layers, d_model, d_ff, dropout=0.0):
573
+ super().__init__()
574
+ self.in_norm = RMSNorm(d_model)
575
+ self.blocks = nn.ModuleList([MappingFeedForwardBlock(d_model, d_ff, dropout=dropout) for _ in range(n_layers)])
576
+ self.out_norm = RMSNorm(d_model)
577
+
578
+ def forward(self, x):
579
+ x = self.in_norm(x)
580
+ for block in self.blocks:
581
+ x = block(x)
582
+ x = self.out_norm(x)
583
+ return x
584
+
585
+
586
+ # Token merging and splitting
587
+
588
+ class TokenMerge(nn.Module):
589
+ def __init__(self, in_features, out_features, patch_size=(2, 2)):
590
+ super().__init__()
591
+ self.h = patch_size[0]
592
+ self.w = patch_size[1]
593
+ self.proj = apply_wd(Linear(in_features * self.h * self.w, out_features, bias=False))
594
+
595
+ def forward(self, x):
596
+ x = rearrange(x, "... (h nh) (w nw) e -> ... h w (nh nw e)", nh=self.h, nw=self.w)
597
+ return self.proj(x)
598
+
599
+
600
+ class TokenSplitWithoutSkip(nn.Module):
601
+ def __init__(self, in_features, out_features, patch_size=(2, 2)):
602
+ super().__init__()
603
+ self.h = patch_size[0]
604
+ self.w = patch_size[1]
605
+ self.proj = apply_wd(Linear(in_features, out_features * self.h * self.w, bias=False))
606
+
607
+ def forward(self, x):
608
+ x = self.proj(x)
609
+ return rearrange(x, "... h w (nh nw e) -> ... (h nh) (w nw) e", nh=self.h, nw=self.w)
610
+
611
+
612
+ class TokenSplit(nn.Module):
613
+ def __init__(self, in_features, out_features, patch_size=(2, 2)):
614
+ super().__init__()
615
+ self.h = patch_size[0]
616
+ self.w = patch_size[1]
617
+ self.proj = apply_wd(Linear(in_features, out_features * self.h * self.w, bias=False))
618
+ self.fac = nn.Parameter(torch.ones(1) * 0.5)
619
+
620
+ def forward(self, x, skip):
621
+ x = self.proj(x)
622
+ x = rearrange(x, "... h w (nh nw e) -> ... (h nh) (w nw) e", nh=self.h, nw=self.w)
623
+ return torch.lerp(skip, x, self.fac.to(x.dtype))
624
+
625
+
626
+ # Configuration
627
+
628
+ @dataclass
629
+ class GlobalAttentionSpec:
630
+ d_head: int
631
+
632
+
633
+ @dataclass
634
+ class NeighborhoodAttentionSpec:
635
+ d_head: int
636
+ kernel_size: int
637
+
638
+
639
+ @dataclass
640
+ class ShiftedWindowAttentionSpec:
641
+ d_head: int
642
+ window_size: int
643
+
644
+
645
+ @dataclass
646
+ class NoAttentionSpec:
647
+ pass
648
+
649
+
650
+ @dataclass
651
+ class LevelSpec:
652
+ depth: int
653
+ width: int
654
+ d_ff: int
655
+ self_attn: Union[GlobalAttentionSpec, NeighborhoodAttentionSpec, ShiftedWindowAttentionSpec, NoAttentionSpec]
656
+ dropout: float
657
+
658
+
659
+ @dataclass
660
+ class MappingSpec:
661
+ depth: int
662
+ width: int
663
+ d_ff: int
664
+ dropout: float
665
+
666
+
667
+ # Model class
668
+
669
+ class ImageTransformerDenoiserModelV2(nn.Module):
670
+ def __init__(self, levels, mapping, in_channels, out_channels, patch_size, num_classes=0, mapping_cond_dim=0, degradation_params_dim=None):
671
+ super().__init__()
672
+ self.num_classes = num_classes
673
+ self.patch_in = TokenMerge(in_channels, levels[0].width, patch_size)
674
+ self.mapping_width = mapping.width
675
+ self.time_emb = FourierFeatures(1, mapping.width)
676
+ self.time_in_proj = Linear(mapping.width, mapping.width, bias=False)
677
+ self.aug_emb = FourierFeatures(9, mapping.width)
678
+ self.aug_in_proj = Linear(mapping.width, mapping.width, bias=False)
679
+ self.degradation_proj = Linear(degradation_params_dim, mapping.width, bias=False) if degradation_params_dim else None
680
+ self.class_emb = nn.Embedding(num_classes, mapping.width) if num_classes else None
681
+ self.mapping_cond_in_proj = Linear(mapping_cond_dim, mapping.width, bias=False) if mapping_cond_dim else None
682
+ self.mapping = tag_module(MappingNetwork(mapping.depth, mapping.width, mapping.d_ff, dropout=mapping.dropout), "mapping")
683
+
684
+ self.down_levels, self.up_levels = nn.ModuleList(), nn.ModuleList()
685
+ for i, spec in enumerate(levels):
686
+ if isinstance(spec.self_attn, GlobalAttentionSpec):
687
+ layer_factory = lambda _: GlobalTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, dropout=spec.dropout)
688
+ elif isinstance(spec.self_attn, NeighborhoodAttentionSpec):
689
+ layer_factory = lambda _: NeighborhoodTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.kernel_size, dropout=spec.dropout)
690
+ elif isinstance(spec.self_attn, ShiftedWindowAttentionSpec):
691
+ layer_factory = lambda i: ShiftedWindowTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.window_size, i, dropout=spec.dropout)
692
+ elif isinstance(spec.self_attn, NoAttentionSpec):
693
+ layer_factory = lambda _: NoAttentionTransformerLayer(spec.width, spec.d_ff, mapping.width, dropout=spec.dropout)
694
+ else:
695
+ raise ValueError(f"unsupported self attention spec {spec.self_attn}")
696
+
697
+ if i < len(levels) - 1:
698
+ self.down_levels.append(Level([layer_factory(i) for i in range(spec.depth)]))
699
+ self.up_levels.append(Level([layer_factory(i + spec.depth) for i in range(spec.depth)]))
700
+ else:
701
+ self.mid_level = Level([layer_factory(i) for i in range(spec.depth)])
702
+
703
+ self.merges = nn.ModuleList([TokenMerge(spec_1.width, spec_2.width) for spec_1, spec_2 in zip(levels[:-1], levels[1:])])
704
+ self.splits = nn.ModuleList([TokenSplit(spec_2.width, spec_1.width) for spec_1, spec_2 in zip(levels[:-1], levels[1:])])
705
+
706
+ self.out_norm = RMSNorm(levels[0].width)
707
+ self.patch_out = TokenSplitWithoutSkip(levels[0].width, out_channels, patch_size)
708
+ nn.init.zeros_(self.patch_out.proj.weight)
709
+
710
+ def param_groups(self, base_lr=5e-4, mapping_lr_scale=1 / 3):
711
+ wd = filter_params(lambda tags: "wd" in tags and "mapping" not in tags, self)
712
+ no_wd = filter_params(lambda tags: "wd" not in tags and "mapping" not in tags, self)
713
+ mapping_wd = filter_params(lambda tags: "wd" in tags and "mapping" in tags, self)
714
+ mapping_no_wd = filter_params(lambda tags: "wd" not in tags and "mapping" in tags, self)
715
+ groups = [
716
+ {"params": list(wd), "lr": base_lr},
717
+ {"params": list(no_wd), "lr": base_lr, "weight_decay": 0.0},
718
+ {"params": list(mapping_wd), "lr": base_lr * mapping_lr_scale},
719
+ {"params": list(mapping_no_wd), "lr": base_lr * mapping_lr_scale, "weight_decay": 0.0}
720
+ ]
721
+ return groups
722
+
723
+ def forward(self, x, sigma=None, aug_cond=None, class_cond=None, mapping_cond=None, degradation_params=None):
724
+ # Patching
725
+ x = x.movedim(-3, -1)
726
+ x = self.patch_in(x)
727
+ # TODO: pixel aspect ratio for nonsquare patches
728
+ pos = make_axial_pos(x.shape[-3], x.shape[-2], device=x.device).view(x.shape[-3], x.shape[-2], 2)
729
+
730
+ # Mapping network
731
+ if class_cond is None and self.class_emb is not None:
732
+ raise ValueError("class_cond must be specified if num_classes > 0")
733
+ if mapping_cond is None and self.mapping_cond_in_proj is not None:
734
+ raise ValueError("mapping_cond must be specified if mapping_cond_dim > 0")
735
+
736
+ # c_noise = torch.log(sigma) / 4
737
+ # c_noise = (sigma * 2.0 - 1.0)
738
+ # c_noise = sigma * 2 - 1
739
+ if sigma is not None:
740
+ time_emb = self.time_in_proj(self.time_emb(sigma[..., None]))
741
+ else:
742
+ time_emb = self.time_in_proj(torch.ones(1, 1, device=x.device, dtype=x.dtype).expand(x.shape[0], self.mapping_width))
743
+ # time_emb = self.time_in_proj(sigma[..., None])
744
+
745
+ aug_cond = x.new_zeros([x.shape[0], 9]) if aug_cond is None else aug_cond
746
+ aug_emb = self.aug_in_proj(self.aug_emb(aug_cond))
747
+ class_emb = self.class_emb(class_cond) if self.class_emb is not None else 0
748
+ mapping_emb = self.mapping_cond_in_proj(mapping_cond) if self.mapping_cond_in_proj is not None else 0
749
+ degradation_emb = self.degradation_proj(degradation_params) if degradation_params is not None else 0
750
+ cond = self.mapping(time_emb + aug_emb + class_emb + mapping_emb + degradation_emb)
751
+
752
+ # Hourglass transformer
753
+ skips, poses = [], []
754
+ for down_level, merge in zip(self.down_levels, self.merges):
755
+ x = down_level(x, pos, cond)
756
+ skips.append(x)
757
+ poses.append(pos)
758
+ x = merge(x)
759
+ pos = downscale_pos(pos)
760
+
761
+ x = self.mid_level(x, pos, cond)
762
+
763
+ for up_level, split, skip, pos in reversed(list(zip(self.up_levels, self.splits, skips, poses))):
764
+ x = split(x, skip)
765
+ x = up_level(x, pos, cond)
766
+
767
+ # Unpatching
768
+ x = self.out_norm(x)
769
+ x = self.patch_out(x)
770
+ x = x.movedim(-1, -3)
771
+
772
+ return x
arch/swinir/__init__.py ADDED
File without changes
arch/swinir/swinir.py ADDED
@@ -0,0 +1,904 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------------
2
+ # SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
3
+ # Originally Written by Ze Liu, Modified by Jingyun Liang.
4
+ # -----------------------------------------------------------------------------------
5
+ # Borrowed from DifFace (https://github.com/zsyOAOA/DifFace/blob/master/models/swinir.py)
6
+
7
+ import math
8
+ from typing import Set
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint as checkpoint
14
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
15
+
16
+
17
+ class Mlp(nn.Module):
18
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
19
+ super().__init__()
20
+ out_features = out_features or in_features
21
+ hidden_features = hidden_features or in_features
22
+ self.fc1 = nn.Linear(in_features, hidden_features)
23
+ self.act = act_layer()
24
+ self.fc2 = nn.Linear(hidden_features, out_features)
25
+ self.drop = nn.Dropout(drop)
26
+
27
+ def forward(self, x):
28
+ x = self.fc1(x)
29
+ x = self.act(x)
30
+ x = self.drop(x)
31
+ x = self.fc2(x)
32
+ x = self.drop(x)
33
+ return x
34
+
35
+
36
+ def window_partition(x, window_size):
37
+ """
38
+ Args:
39
+ x: (B, H, W, C)
40
+ window_size (int): window size
41
+
42
+ Returns:
43
+ windows: (num_windows*B, window_size, window_size, C)
44
+ """
45
+ B, H, W, C = x.shape
46
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
47
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
48
+ return windows
49
+
50
+
51
+ def window_reverse(windows, window_size, H, W):
52
+ """
53
+ Args:
54
+ windows: (num_windows*B, window_size, window_size, C)
55
+ window_size (int): Window size
56
+ H (int): Height of image
57
+ W (int): Width of image
58
+
59
+ Returns:
60
+ x: (B, H, W, C)
61
+ """
62
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
63
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
64
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
65
+ return x
66
+
67
+
68
+ class WindowAttention(nn.Module):
69
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
70
+ It supports both of shifted and non-shifted window.
71
+
72
+ Args:
73
+ dim (int): Number of input channels.
74
+ window_size (tuple[int]): The height and width of the window.
75
+ num_heads (int): Number of attention heads.
76
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
77
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
78
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
79
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
80
+ """
81
+
82
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
83
+
84
+ super().__init__()
85
+ self.dim = dim
86
+ self.window_size = window_size # Wh, Ww
87
+ self.num_heads = num_heads
88
+ head_dim = dim // num_heads
89
+ self.scale = qk_scale or head_dim ** -0.5
90
+
91
+ # define a parameter table of relative position bias
92
+ self.relative_position_bias_table = nn.Parameter(
93
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
94
+
95
+ # get pair-wise relative position index for each token inside the window
96
+ coords_h = torch.arange(self.window_size[0])
97
+ coords_w = torch.arange(self.window_size[1])
98
+ # coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
99
+ # Fix: Pass indexing="ij" to avoid warning
100
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
101
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
102
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
103
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
104
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
105
+ relative_coords[:, :, 1] += self.window_size[1] - 1
106
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
107
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
108
+ self.register_buffer("relative_position_index", relative_position_index)
109
+
110
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
111
+ self.attn_drop = nn.Dropout(attn_drop)
112
+ self.proj = nn.Linear(dim, dim)
113
+
114
+ self.proj_drop = nn.Dropout(proj_drop)
115
+
116
+ trunc_normal_(self.relative_position_bias_table, std=.02)
117
+ self.softmax = nn.Softmax(dim=-1)
118
+
119
+ def forward(self, x, mask=None):
120
+ """
121
+ Args:
122
+ x: input features with shape of (num_windows*B, N, C)
123
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
124
+ """
125
+ B_, N, C = x.shape
126
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
127
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
128
+
129
+ q = q * self.scale
130
+ attn = (q @ k.transpose(-2, -1))
131
+
132
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
133
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
134
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
135
+ attn = attn + relative_position_bias.unsqueeze(0)
136
+
137
+ if mask is not None:
138
+ nW = mask.shape[0]
139
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
140
+ attn = attn.view(-1, self.num_heads, N, N)
141
+ attn = self.softmax(attn)
142
+ else:
143
+ attn = self.softmax(attn)
144
+
145
+ attn = self.attn_drop(attn)
146
+
147
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
148
+ x = self.proj(x)
149
+ x = self.proj_drop(x)
150
+ return x
151
+
152
+ def extra_repr(self) -> str:
153
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
154
+
155
+ def flops(self, N):
156
+ # calculate flops for 1 window with token length of N
157
+ flops = 0
158
+ # qkv = self.qkv(x)
159
+ flops += N * self.dim * 3 * self.dim
160
+ # attn = (q @ k.transpose(-2, -1))
161
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
162
+ # x = (attn @ v)
163
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
164
+ # x = self.proj(x)
165
+ flops += N * self.dim * self.dim
166
+ return flops
167
+
168
+
169
+ class SwinTransformerBlock(nn.Module):
170
+ r""" Swin Transformer Block.
171
+
172
+ Args:
173
+ dim (int): Number of input channels.
174
+ input_resolution (tuple[int]): Input resulotion.
175
+ num_heads (int): Number of attention heads.
176
+ window_size (int): Window size.
177
+ shift_size (int): Shift size for SW-MSA.
178
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
179
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
180
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
181
+ drop (float, optional): Dropout rate. Default: 0.0
182
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
183
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
184
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
185
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
186
+ """
187
+
188
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
189
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
190
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
191
+ super().__init__()
192
+ self.dim = dim
193
+ self.input_resolution = input_resolution
194
+ self.num_heads = num_heads
195
+ self.window_size = window_size
196
+ self.shift_size = shift_size
197
+ self.mlp_ratio = mlp_ratio
198
+ if min(self.input_resolution) <= self.window_size:
199
+ # if window size is larger than input resolution, we don't partition windows
200
+ self.shift_size = 0
201
+ self.window_size = min(self.input_resolution)
202
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
203
+
204
+ self.norm1 = norm_layer(dim)
205
+ self.attn = WindowAttention(
206
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
207
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
208
+
209
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
210
+ self.norm2 = norm_layer(dim)
211
+ mlp_hidden_dim = int(dim * mlp_ratio)
212
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
213
+
214
+ if self.shift_size > 0:
215
+ attn_mask = self.calculate_mask(self.input_resolution)
216
+ else:
217
+ attn_mask = None
218
+
219
+ self.register_buffer("attn_mask", attn_mask)
220
+
221
+ def calculate_mask(self, x_size):
222
+ # calculate attention mask for SW-MSA
223
+ H, W = x_size
224
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
225
+ h_slices = (slice(0, -self.window_size),
226
+ slice(-self.window_size, -self.shift_size),
227
+ slice(-self.shift_size, None))
228
+ w_slices = (slice(0, -self.window_size),
229
+ slice(-self.window_size, -self.shift_size),
230
+ slice(-self.shift_size, None))
231
+ cnt = 0
232
+ for h in h_slices:
233
+ for w in w_slices:
234
+ img_mask[:, h, w, :] = cnt
235
+ cnt += 1
236
+
237
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
238
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
239
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
240
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
241
+
242
+ return attn_mask
243
+
244
+ def forward(self, x, x_size):
245
+ H, W = x_size
246
+ B, L, C = x.shape
247
+ # assert L == H * W, "input feature has wrong size"
248
+
249
+ shortcut = x
250
+ x = self.norm1(x)
251
+ x = x.view(B, H, W, C)
252
+
253
+ # cyclic shift
254
+ if self.shift_size > 0:
255
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
256
+ else:
257
+ shifted_x = x
258
+
259
+ # partition windows
260
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
261
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
262
+
263
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
264
+ if self.input_resolution == x_size:
265
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
266
+ else:
267
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
268
+
269
+ # merge windows
270
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
271
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
272
+
273
+ # reverse cyclic shift
274
+ if self.shift_size > 0:
275
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
276
+ else:
277
+ x = shifted_x
278
+ x = x.view(B, H * W, C)
279
+
280
+ # FFN
281
+ x = shortcut + self.drop_path(x)
282
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
283
+
284
+ return x
285
+
286
+ def extra_repr(self) -> str:
287
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
288
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
289
+
290
+ def flops(self):
291
+ flops = 0
292
+ H, W = self.input_resolution
293
+ # norm1
294
+ flops += self.dim * H * W
295
+ # W-MSA/SW-MSA
296
+ nW = H * W / self.window_size / self.window_size
297
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
298
+ # mlp
299
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
300
+ # norm2
301
+ flops += self.dim * H * W
302
+ return flops
303
+
304
+
305
+ class PatchMerging(nn.Module):
306
+ r""" Patch Merging Layer.
307
+
308
+ Args:
309
+ input_resolution (tuple[int]): Resolution of input feature.
310
+ dim (int): Number of input channels.
311
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
312
+ """
313
+
314
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
315
+ super().__init__()
316
+ self.input_resolution = input_resolution
317
+ self.dim = dim
318
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
319
+ self.norm = norm_layer(4 * dim)
320
+
321
+ def forward(self, x):
322
+ """
323
+ x: B, H*W, C
324
+ """
325
+ H, W = self.input_resolution
326
+ B, L, C = x.shape
327
+ assert L == H * W, "input feature has wrong size"
328
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
329
+
330
+ x = x.view(B, H, W, C)
331
+
332
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
333
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
334
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
335
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
336
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
337
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
338
+
339
+ x = self.norm(x)
340
+ x = self.reduction(x)
341
+
342
+ return x
343
+
344
+ def extra_repr(self) -> str:
345
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
346
+
347
+ def flops(self):
348
+ H, W = self.input_resolution
349
+ flops = H * W * self.dim
350
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
351
+ return flops
352
+
353
+
354
+ class BasicLayer(nn.Module):
355
+ """ A basic Swin Transformer layer for one stage.
356
+
357
+ Args:
358
+ dim (int): Number of input channels.
359
+ input_resolution (tuple[int]): Input resolution.
360
+ depth (int): Number of blocks.
361
+ num_heads (int): Number of attention heads.
362
+ window_size (int): Local window size.
363
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
364
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
365
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
366
+ drop (float, optional): Dropout rate. Default: 0.0
367
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
368
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
369
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
370
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
371
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
372
+ """
373
+
374
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
375
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
376
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
377
+
378
+ super().__init__()
379
+ self.dim = dim
380
+ self.input_resolution = input_resolution
381
+ self.depth = depth
382
+ self.use_checkpoint = use_checkpoint
383
+
384
+ # build blocks
385
+ self.blocks = nn.ModuleList([
386
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
387
+ num_heads=num_heads, window_size=window_size,
388
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
389
+ mlp_ratio=mlp_ratio,
390
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
391
+ drop=drop, attn_drop=attn_drop,
392
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
393
+ norm_layer=norm_layer)
394
+ for i in range(depth)])
395
+
396
+ # patch merging layer
397
+ if downsample is not None:
398
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
399
+ else:
400
+ self.downsample = None
401
+
402
+ def forward(self, x, x_size):
403
+ for blk in self.blocks:
404
+ if self.use_checkpoint:
405
+ x = checkpoint.checkpoint(blk, x, x_size)
406
+ else:
407
+ x = blk(x, x_size)
408
+ if self.downsample is not None:
409
+ x = self.downsample(x)
410
+ return x
411
+
412
+ def extra_repr(self) -> str:
413
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
414
+
415
+ def flops(self):
416
+ flops = 0
417
+ for blk in self.blocks:
418
+ flops += blk.flops()
419
+ if self.downsample is not None:
420
+ flops += self.downsample.flops()
421
+ return flops
422
+
423
+
424
+ class RSTB(nn.Module):
425
+ """Residual Swin Transformer Block (RSTB).
426
+
427
+ Args:
428
+ dim (int): Number of input channels.
429
+ input_resolution (tuple[int]): Input resolution.
430
+ depth (int): Number of blocks.
431
+ num_heads (int): Number of attention heads.
432
+ window_size (int): Local window size.
433
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
434
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
435
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
436
+ drop (float, optional): Dropout rate. Default: 0.0
437
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
438
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
439
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
440
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
441
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
442
+ img_size: Input image size.
443
+ patch_size: Patch size.
444
+ resi_connection: The convolutional block before residual connection.
445
+ """
446
+
447
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
448
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
449
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
450
+ img_size=224, patch_size=4, resi_connection='1conv'):
451
+ super(RSTB, self).__init__()
452
+
453
+ self.dim = dim
454
+ self.input_resolution = input_resolution
455
+
456
+ self.residual_group = BasicLayer(dim=dim,
457
+ input_resolution=input_resolution,
458
+ depth=depth,
459
+ num_heads=num_heads,
460
+ window_size=window_size,
461
+ mlp_ratio=mlp_ratio,
462
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
463
+ drop=drop, attn_drop=attn_drop,
464
+ drop_path=drop_path,
465
+ norm_layer=norm_layer,
466
+ downsample=downsample,
467
+ use_checkpoint=use_checkpoint)
468
+
469
+ if resi_connection == '1conv':
470
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
471
+ elif resi_connection == '3conv':
472
+ # to save parameters and memory
473
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
474
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
475
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
476
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
477
+
478
+ self.patch_embed = PatchEmbed(
479
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
480
+ norm_layer=None)
481
+
482
+ self.patch_unembed = PatchUnEmbed(
483
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
484
+ norm_layer=None)
485
+
486
+ def forward(self, x, x_size):
487
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
488
+
489
+ def flops(self):
490
+ flops = 0
491
+ flops += self.residual_group.flops()
492
+ H, W = self.input_resolution
493
+ flops += H * W * self.dim * self.dim * 9
494
+ flops += self.patch_embed.flops()
495
+ flops += self.patch_unembed.flops()
496
+
497
+ return flops
498
+
499
+
500
+ class PatchEmbed(nn.Module):
501
+ r""" Image to Patch Embedding
502
+
503
+ Args:
504
+ img_size (int): Image size. Default: 224.
505
+ patch_size (int): Patch token size. Default: 4.
506
+ in_chans (int): Number of input image channels. Default: 3.
507
+ embed_dim (int): Number of linear projection output channels. Default: 96.
508
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
509
+ """
510
+
511
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
512
+ super().__init__()
513
+ img_size = to_2tuple(img_size)
514
+ patch_size = to_2tuple(patch_size)
515
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
516
+ self.img_size = img_size
517
+ self.patch_size = patch_size
518
+ self.patches_resolution = patches_resolution
519
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
520
+
521
+ self.in_chans = in_chans
522
+ self.embed_dim = embed_dim
523
+
524
+ if norm_layer is not None:
525
+ self.norm = norm_layer(embed_dim)
526
+ else:
527
+ self.norm = None
528
+
529
+ def forward(self, x):
530
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
531
+ if self.norm is not None:
532
+ x = self.norm(x)
533
+ return x
534
+
535
+ def flops(self):
536
+ flops = 0
537
+ H, W = self.img_size
538
+ if self.norm is not None:
539
+ flops += H * W * self.embed_dim
540
+ return flops
541
+
542
+
543
+ class PatchUnEmbed(nn.Module):
544
+ r""" Image to Patch Unembedding
545
+
546
+ Args:
547
+ img_size (int): Image size. Default: 224.
548
+ patch_size (int): Patch token size. Default: 4.
549
+ in_chans (int): Number of input image channels. Default: 3.
550
+ embed_dim (int): Number of linear projection output channels. Default: 96.
551
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
552
+ """
553
+
554
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
555
+ super().__init__()
556
+ img_size = to_2tuple(img_size)
557
+ patch_size = to_2tuple(patch_size)
558
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
559
+ self.img_size = img_size
560
+ self.patch_size = patch_size
561
+ self.patches_resolution = patches_resolution
562
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
563
+
564
+ self.in_chans = in_chans
565
+ self.embed_dim = embed_dim
566
+
567
+ def forward(self, x, x_size):
568
+ B, HW, C = x.shape
569
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
570
+ return x
571
+
572
+ def flops(self):
573
+ flops = 0
574
+ return flops
575
+
576
+
577
+ class Upsample(nn.Sequential):
578
+ """Upsample module.
579
+
580
+ Args:
581
+ scale (int): Scale factor. Supported scales: 2^n and 3.
582
+ num_feat (int): Channel number of intermediate features.
583
+ """
584
+
585
+ def __init__(self, scale, num_feat):
586
+ m = []
587
+ if (scale & (scale - 1)) == 0: # scale = 2^n
588
+ for _ in range(int(math.log(scale, 2))):
589
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
590
+ m.append(nn.PixelShuffle(2))
591
+ elif scale == 3:
592
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
593
+ m.append(nn.PixelShuffle(3))
594
+ else:
595
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
596
+ super(Upsample, self).__init__(*m)
597
+
598
+
599
+ class UpsampleOneStep(nn.Sequential):
600
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
601
+ Used in lightweight SR to save parameters.
602
+
603
+ Args:
604
+ scale (int): Scale factor. Supported scales: 2^n and 3.
605
+ num_feat (int): Channel number of intermediate features.
606
+
607
+ """
608
+
609
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
610
+ self.num_feat = num_feat
611
+ self.input_resolution = input_resolution
612
+ m = []
613
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
614
+ m.append(nn.PixelShuffle(scale))
615
+ super(UpsampleOneStep, self).__init__(*m)
616
+
617
+ def flops(self):
618
+ H, W = self.input_resolution
619
+ flops = H * W * self.num_feat * 3 * 9
620
+ return flops
621
+
622
+
623
+ class SwinIR(nn.Module):
624
+ r""" SwinIR
625
+ A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
626
+
627
+ Args:
628
+ img_size (int | tuple(int)): Input image size. Default 64
629
+ patch_size (int | tuple(int)): Patch size. Default: 1
630
+ in_chans (int): Number of input image channels. Default: 3
631
+ embed_dim (int): Patch embedding dimension. Default: 96
632
+ depths (tuple(int)): Depth of each Swin Transformer layer.
633
+ num_heads (tuple(int)): Number of attention heads in different layers.
634
+ window_size (int): Window size. Default: 7
635
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
636
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
637
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
638
+ drop_rate (float): Dropout rate. Default: 0
639
+ attn_drop_rate (float): Attention dropout rate. Default: 0
640
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
641
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
642
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
643
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
644
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
645
+ sf: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
646
+ img_range: Image range. 1. or 255.
647
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
648
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
649
+ """
650
+
651
+ def __init__(
652
+ self,
653
+ img_size=64,
654
+ patch_size=1,
655
+ in_chans=3,
656
+ num_out_ch=3,
657
+ embed_dim=96,
658
+ depths=[6, 6, 6, 6],
659
+ num_heads=[6, 6, 6, 6],
660
+ window_size=7,
661
+ mlp_ratio=4.,
662
+ qkv_bias=True,
663
+ qk_scale=None,
664
+ drop_rate=0.,
665
+ attn_drop_rate=0.,
666
+ drop_path_rate=0.1,
667
+ norm_layer=nn.LayerNorm,
668
+ ape=False,
669
+ patch_norm=True,
670
+ use_checkpoint=False,
671
+ sf=4,
672
+ img_range=1.,
673
+ upsampler='',
674
+ resi_connection='1conv',
675
+ unshuffle=False,
676
+ unshuffle_scale=None,
677
+ hq_key: str = "jpg",
678
+ lq_key: str = "hint",
679
+ learning_rate: float = None,
680
+ weight_decay: float = None
681
+ ) -> "SwinIR":
682
+ super(SwinIR, self).__init__()
683
+ num_in_ch = in_chans * (unshuffle_scale ** 2) if unshuffle else in_chans
684
+ num_feat = 64
685
+ self.img_range = img_range
686
+ if in_chans == 3:
687
+ rgb_mean = (0.4488, 0.4371, 0.4040)
688
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
689
+ else:
690
+ self.mean = torch.zeros(1, 1, 1, 1)
691
+ self.upscale = sf
692
+ self.upsampler = upsampler
693
+ self.window_size = window_size
694
+ self.unshuffle_scale = unshuffle_scale
695
+ self.unshuffle = unshuffle
696
+
697
+ #####################################################################################################
698
+ ################################### 1, shallow feature extraction ###################################
699
+ if unshuffle:
700
+ assert unshuffle_scale is not None
701
+ self.conv_first = nn.Sequential(
702
+ nn.PixelUnshuffle(sf),
703
+ nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1),
704
+ )
705
+ else:
706
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
707
+
708
+ #####################################################################################################
709
+ ################################### 2, deep feature extraction ######################################
710
+ self.num_layers = len(depths)
711
+ self.embed_dim = embed_dim
712
+ self.ape = ape
713
+ self.patch_norm = patch_norm
714
+ self.num_features = embed_dim
715
+ self.mlp_ratio = mlp_ratio
716
+
717
+ # split image into non-overlapping patches
718
+ self.patch_embed = PatchEmbed(
719
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
720
+ norm_layer=norm_layer if self.patch_norm else None
721
+ )
722
+ num_patches = self.patch_embed.num_patches
723
+ patches_resolution = self.patch_embed.patches_resolution
724
+ self.patches_resolution = patches_resolution
725
+
726
+ # merge non-overlapping patches into image
727
+ self.patch_unembed = PatchUnEmbed(
728
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
729
+ norm_layer=norm_layer if self.patch_norm else None
730
+ )
731
+
732
+ # absolute position embedding
733
+ if self.ape:
734
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
735
+ trunc_normal_(self.absolute_pos_embed, std=.02)
736
+
737
+ self.pos_drop = nn.Dropout(p=drop_rate)
738
+
739
+ # stochastic depth
740
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
741
+
742
+ # build Residual Swin Transformer blocks (RSTB)
743
+ self.layers = nn.ModuleList()
744
+ for i_layer in range(self.num_layers):
745
+ layer = RSTB(
746
+ dim=embed_dim,
747
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
748
+ depth=depths[i_layer],
749
+ num_heads=num_heads[i_layer],
750
+ window_size=window_size,
751
+ mlp_ratio=self.mlp_ratio,
752
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
753
+ drop=drop_rate, attn_drop=attn_drop_rate,
754
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
755
+ norm_layer=norm_layer,
756
+ downsample=None,
757
+ use_checkpoint=use_checkpoint,
758
+ img_size=img_size,
759
+ patch_size=patch_size,
760
+ resi_connection=resi_connection
761
+ )
762
+ self.layers.append(layer)
763
+ self.norm = norm_layer(self.num_features)
764
+
765
+ # build the last conv layer in deep feature extraction
766
+ if resi_connection == '1conv':
767
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
768
+ elif resi_connection == '3conv':
769
+ # to save parameters and memory
770
+ self.conv_after_body = nn.Sequential(
771
+ nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
772
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
773
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
774
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
775
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)
776
+ )
777
+
778
+ #####################################################################################################
779
+ ################################ 3, high quality image reconstruction ################################
780
+ if self.upsampler == 'pixelshuffle':
781
+ # for classical SR
782
+ self.conv_before_upsample = nn.Sequential(
783
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
784
+ nn.LeakyReLU(inplace=True)
785
+ )
786
+ self.upsample = Upsample(sf, num_feat)
787
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
788
+ elif self.upsampler == 'pixelshuffledirect':
789
+ # for lightweight SR (to save parameters)
790
+ self.upsample = UpsampleOneStep(
791
+ sf, embed_dim, num_out_ch,
792
+ (patches_resolution[0], patches_resolution[1])
793
+ )
794
+ elif self.upsampler == 'nearest+conv':
795
+ # for real-world SR (less artifacts)
796
+ self.conv_before_upsample = nn.Sequential(
797
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
798
+ nn.LeakyReLU(inplace=True)
799
+ )
800
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
801
+ if self.upscale == 4:
802
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
803
+ elif self.upscale == 8:
804
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
805
+ self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
806
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
807
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
808
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
809
+ else:
810
+ # for image denoising and JPEG compression artifact reduction
811
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
812
+
813
+ self.apply(self._init_weights)
814
+
815
+ def _init_weights(self, m: nn.Module) -> None:
816
+ if isinstance(m, nn.Linear):
817
+ trunc_normal_(m.weight, std=.02)
818
+ if isinstance(m, nn.Linear) and m.bias is not None:
819
+ nn.init.constant_(m.bias, 0)
820
+ elif isinstance(m, nn.LayerNorm):
821
+ nn.init.constant_(m.bias, 0)
822
+ nn.init.constant_(m.weight, 1.0)
823
+
824
+ # TODO: What's this ?
825
+ @torch.jit.ignore
826
+ def no_weight_decay(self) -> Set[str]:
827
+ return {'absolute_pos_embed'}
828
+
829
+ @torch.jit.ignore
830
+ def no_weight_decay_keywords(self) -> Set[str]:
831
+ return {'relative_position_bias_table'}
832
+
833
+ def check_image_size(self, x: torch.Tensor) -> torch.Tensor:
834
+ _, _, h, w = x.size()
835
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
836
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
837
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
838
+ return x
839
+
840
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
841
+ x_size = (x.shape[2], x.shape[3])
842
+ x = self.patch_embed(x)
843
+ if self.ape:
844
+ x = x + self.absolute_pos_embed
845
+ x = self.pos_drop(x)
846
+
847
+ for layer in self.layers:
848
+ x = layer(x, x_size)
849
+
850
+ x = self.norm(x) # B L C
851
+ x = self.patch_unembed(x, x_size)
852
+
853
+ return x
854
+
855
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
856
+ H, W = x.shape[2:]
857
+ x = self.check_image_size(x)
858
+
859
+ self.mean = self.mean.type_as(x)
860
+ x = (x - self.mean) * self.img_range
861
+
862
+ if self.upsampler == 'pixelshuffle':
863
+ # for classical SR
864
+ x = self.conv_first(x)
865
+ x = self.conv_after_body(self.forward_features(x)) + x
866
+ x = self.conv_before_upsample(x)
867
+ x = self.conv_last(self.upsample(x))
868
+ elif self.upsampler == 'pixelshuffledirect':
869
+ # for lightweight SR
870
+ x = self.conv_first(x)
871
+ x = self.conv_after_body(self.forward_features(x)) + x
872
+ x = self.upsample(x)
873
+ elif self.upsampler == 'nearest+conv':
874
+ # for real-world SR
875
+ x = self.conv_first(x)
876
+ x = self.conv_after_body(self.forward_features(x)) + x
877
+ x = self.conv_before_upsample(x)
878
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
879
+ if self.upscale == 4:
880
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
881
+ elif self.upscale == 8:
882
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
883
+ x = self.lrelu(self.conv_up3(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
884
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
885
+ else:
886
+ # for image denoising and JPEG compression artifact reduction
887
+ x_first = self.conv_first(x)
888
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
889
+ x = x + self.conv_last(res)
890
+
891
+ x = x / self.img_range + self.mean
892
+
893
+ return x[:, :, :H * self.upscale, :W * self.upscale]
894
+
895
+ def flops(self) -> int:
896
+ flops = 0
897
+ H, W = self.patches_resolution
898
+ flops += H * W * 3 * self.embed_dim * 9
899
+ flops += self.patch_embed.flops()
900
+ for i, layer in enumerate(self.layers):
901
+ flops += layer.flops()
902
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
903
+ flops += self.upsample.flops()
904
+ return flops
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ffmpeg
2
+ libsm6
3
+ libxext6
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.2.2
2
+ facexlib==0.2.5
3
+ realesrgan==0.2.5
4
+ numpy
5
+ opencv-python
6
+ torchvision
7
+ pytorch-lightning==2.4.0
8
+ scipy
9
+ tqdm
10
+ lmdb
11
+ pyyaml
12
+ basicsr==1.4.2
13
+ yapf
14
+ dctorch
15
+ einops
16
+ torch-ema==0.3
17
+ huggingface_hub==0.24.5
18
+ natten==0.17.1
19
+ wandb
20
+ timm
21
+ huggingface_hub==0.24.5
utils/__init__.py ADDED
File without changes
utils/basicsr_custom.py ADDED
@@ -0,0 +1,954 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/data/degradations.py
2
+ # Copyright (c) OpenMMLab. All rights reserved.
3
+ # https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
4
+
5
+ import math
6
+ import random
7
+ import re
8
+ from abc import ABCMeta, abstractmethod
9
+ from pathlib import Path
10
+ from typing import List, Dict
11
+ from typing import Mapping, Any
12
+ from typing import Optional, Union
13
+
14
+ import cv2
15
+ import numpy as np
16
+ import torch
17
+ from PIL import Image
18
+ from scipy import special
19
+ from scipy.stats import multivariate_normal
20
+ from torch import Tensor
21
+ # from torchvision.transforms.functional_tensor import rgb_to_grayscale
22
+ from torchvision.transforms._functional_tensor import rgb_to_grayscale
23
+
24
+
25
+ # -------------------------------------------------------------------- #
26
+ # --------------------------- blur kernels --------------------------- #
27
+ # -------------------------------------------------------------------- #
28
+
29
+
30
+ # --------------------------- util functions --------------------------- #
31
+ def sigma_matrix2(sig_x, sig_y, theta):
32
+ """Calculate the rotated sigma matrix (two dimensional matrix).
33
+
34
+ Args:
35
+ sig_x (float):
36
+ sig_y (float):
37
+ theta (float): Radian measurement.
38
+
39
+ Returns:
40
+ ndarray: Rotated sigma matrix.
41
+ """
42
+ d_matrix = np.array([[sig_x ** 2, 0], [0, sig_y ** 2]])
43
+ u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
44
+ return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
45
+
46
+
47
+ def mesh_grid(kernel_size):
48
+ """Generate the mesh grid, centering at zero.
49
+
50
+ Args:
51
+ kernel_size (int):
52
+
53
+ Returns:
54
+ xy (ndarray): with the shape (kernel_size, kernel_size, 2)
55
+ xx (ndarray): with the shape (kernel_size, kernel_size)
56
+ yy (ndarray): with the shape (kernel_size, kernel_size)
57
+ """
58
+ ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
59
+ xx, yy = np.meshgrid(ax, ax)
60
+ xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
61
+ 1))).reshape(kernel_size, kernel_size, 2)
62
+ return xy, xx, yy
63
+
64
+
65
+ def pdf2(sigma_matrix, grid):
66
+ """Calculate PDF of the bivariate Gaussian distribution.
67
+
68
+ Args:
69
+ sigma_matrix (ndarray): with the shape (2, 2)
70
+ grid (ndarray): generated by :func:`mesh_grid`,
71
+ with the shape (K, K, 2), K is the kernel size.
72
+
73
+ Returns:
74
+ kernel (ndarrray): un-normalized kernel.
75
+ """
76
+ inverse_sigma = np.linalg.inv(sigma_matrix)
77
+ kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
78
+ return kernel
79
+
80
+
81
+ def cdf2(d_matrix, grid):
82
+ """Calculate the CDF of the standard bivariate Gaussian distribution.
83
+ Used in skewed Gaussian distribution.
84
+
85
+ Args:
86
+ d_matrix (ndarrasy): skew matrix.
87
+ grid (ndarray): generated by :func:`mesh_grid`,
88
+ with the shape (K, K, 2), K is the kernel size.
89
+
90
+ Returns:
91
+ cdf (ndarray): skewed cdf.
92
+ """
93
+ rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
94
+ grid = np.dot(grid, d_matrix)
95
+ cdf = rv.cdf(grid)
96
+ return cdf
97
+
98
+
99
+ def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
100
+ """Generate a bivariate isotropic or anisotropic Gaussian kernel.
101
+
102
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
103
+
104
+ Args:
105
+ kernel_size (int):
106
+ sig_x (float):
107
+ sig_y (float):
108
+ theta (float): Radian measurement.
109
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
110
+ with the shape (K, K, 2), K is the kernel size. Default: None
111
+ isotropic (bool):
112
+
113
+ Returns:
114
+ kernel (ndarray): normalized kernel.
115
+ """
116
+ if grid is None:
117
+ grid, _, _ = mesh_grid(kernel_size)
118
+ if isotropic:
119
+ sigma_matrix = np.array([[sig_x ** 2, 0], [0, sig_x ** 2]])
120
+ else:
121
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
122
+ kernel = pdf2(sigma_matrix, grid)
123
+ kernel = kernel / np.sum(kernel)
124
+ return kernel
125
+
126
+
127
+ def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
128
+ """Generate a bivariate generalized Gaussian kernel.
129
+
130
+ ``Paper: Parameter Estimation For Multivariate Generalized Gaussian Distributions``
131
+
132
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
133
+
134
+ Args:
135
+ kernel_size (int):
136
+ sig_x (float):
137
+ sig_y (float):
138
+ theta (float): Radian measurement.
139
+ beta (float): shape parameter, beta = 1 is the normal distribution.
140
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
141
+ with the shape (K, K, 2), K is the kernel size. Default: None
142
+
143
+ Returns:
144
+ kernel (ndarray): normalized kernel.
145
+ """
146
+ if grid is None:
147
+ grid, _, _ = mesh_grid(kernel_size)
148
+ if isotropic:
149
+ sigma_matrix = np.array([[sig_x ** 2, 0], [0, sig_x ** 2]])
150
+ else:
151
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
152
+ inverse_sigma = np.linalg.inv(sigma_matrix)
153
+ kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
154
+ kernel = kernel / np.sum(kernel)
155
+ return kernel
156
+
157
+
158
+ def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
159
+ """Generate a plateau-like anisotropic kernel.
160
+
161
+ 1 / (1+x^(beta))
162
+
163
+ Reference: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
164
+
165
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
166
+
167
+ Args:
168
+ kernel_size (int):
169
+ sig_x (float):
170
+ sig_y (float):
171
+ theta (float): Radian measurement.
172
+ beta (float): shape parameter, beta = 1 is the normal distribution.
173
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
174
+ with the shape (K, K, 2), K is the kernel size. Default: None
175
+
176
+ Returns:
177
+ kernel (ndarray): normalized kernel.
178
+ """
179
+ if grid is None:
180
+ grid, _, _ = mesh_grid(kernel_size)
181
+ if isotropic:
182
+ sigma_matrix = np.array([[sig_x ** 2, 0], [0, sig_x ** 2]])
183
+ else:
184
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
185
+ inverse_sigma = np.linalg.inv(sigma_matrix)
186
+ kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
187
+ kernel = kernel / np.sum(kernel)
188
+ return kernel
189
+
190
+
191
+ def random_bivariate_Gaussian(kernel_size,
192
+ sigma_x_range,
193
+ sigma_y_range,
194
+ rotation_range,
195
+ noise_range=None,
196
+ isotropic=True):
197
+ """Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
198
+
199
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
200
+
201
+ Args:
202
+ kernel_size (int):
203
+ sigma_x_range (tuple): [0.6, 5]
204
+ sigma_y_range (tuple): [0.6, 5]
205
+ rotation range (tuple): [-math.pi, math.pi]
206
+ noise_range(tuple, optional): multiplicative kernel noise,
207
+ [0.75, 1.25]. Default: None
208
+
209
+ Returns:
210
+ kernel (ndarray):
211
+ """
212
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
213
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
214
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
215
+ if isotropic is False:
216
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
217
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
218
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
219
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
220
+ else:
221
+ sigma_y = sigma_x
222
+ rotation = 0
223
+
224
+ kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
225
+
226
+ # add multiplicative noise
227
+ if noise_range is not None:
228
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
229
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
230
+ kernel = kernel * noise
231
+ kernel = kernel / np.sum(kernel)
232
+ return kernel
233
+
234
+
235
+ def random_bivariate_generalized_Gaussian(kernel_size,
236
+ sigma_x_range,
237
+ sigma_y_range,
238
+ rotation_range,
239
+ beta_range,
240
+ noise_range=None,
241
+ isotropic=True):
242
+ """Randomly generate bivariate generalized Gaussian kernels.
243
+
244
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
245
+
246
+ Args:
247
+ kernel_size (int):
248
+ sigma_x_range (tuple): [0.6, 5]
249
+ sigma_y_range (tuple): [0.6, 5]
250
+ rotation range (tuple): [-math.pi, math.pi]
251
+ beta_range (tuple): [0.5, 8]
252
+ noise_range(tuple, optional): multiplicative kernel noise,
253
+ [0.75, 1.25]. Default: None
254
+
255
+ Returns:
256
+ kernel (ndarray):
257
+ """
258
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
259
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
260
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
261
+ if isotropic is False:
262
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
263
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
264
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
265
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
266
+ else:
267
+ sigma_y = sigma_x
268
+ rotation = 0
269
+
270
+ # assume beta_range[0] < 1 < beta_range[1]
271
+ if np.random.uniform() < 0.5:
272
+ beta = np.random.uniform(beta_range[0], 1)
273
+ else:
274
+ beta = np.random.uniform(1, beta_range[1])
275
+
276
+ kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
277
+
278
+ # add multiplicative noise
279
+ if noise_range is not None:
280
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
281
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
282
+ kernel = kernel * noise
283
+ kernel = kernel / np.sum(kernel)
284
+ return kernel
285
+
286
+
287
+ def random_bivariate_plateau(kernel_size,
288
+ sigma_x_range,
289
+ sigma_y_range,
290
+ rotation_range,
291
+ beta_range,
292
+ noise_range=None,
293
+ isotropic=True):
294
+ """Randomly generate bivariate plateau kernels.
295
+
296
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
297
+
298
+ Args:
299
+ kernel_size (int):
300
+ sigma_x_range (tuple): [0.6, 5]
301
+ sigma_y_range (tuple): [0.6, 5]
302
+ rotation range (tuple): [-math.pi/2, math.pi/2]
303
+ beta_range (tuple): [1, 4]
304
+ noise_range(tuple, optional): multiplicative kernel noise,
305
+ [0.75, 1.25]. Default: None
306
+
307
+ Returns:
308
+ kernel (ndarray):
309
+ """
310
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
311
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
312
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
313
+ if isotropic is False:
314
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
315
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
316
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
317
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
318
+ else:
319
+ sigma_y = sigma_x
320
+ rotation = 0
321
+
322
+ # TODO: this may be not proper
323
+ if np.random.uniform() < 0.5:
324
+ beta = np.random.uniform(beta_range[0], 1)
325
+ else:
326
+ beta = np.random.uniform(1, beta_range[1])
327
+
328
+ kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
329
+ # add multiplicative noise
330
+ if noise_range is not None:
331
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
332
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
333
+ kernel = kernel * noise
334
+ kernel = kernel / np.sum(kernel)
335
+
336
+ return kernel
337
+
338
+
339
+ def random_mixed_kernels(kernel_list,
340
+ kernel_prob,
341
+ kernel_size=21,
342
+ sigma_x_range=(0.6, 5),
343
+ sigma_y_range=(0.6, 5),
344
+ rotation_range=(-math.pi, math.pi),
345
+ betag_range=(0.5, 8),
346
+ betap_range=(0.5, 8),
347
+ noise_range=None):
348
+ """Randomly generate mixed kernels.
349
+
350
+ Args:
351
+ kernel_list (tuple): a list name of kernel types,
352
+ support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
353
+ 'plateau_aniso']
354
+ kernel_prob (tuple): corresponding kernel probability for each
355
+ kernel type
356
+ kernel_size (int):
357
+ sigma_x_range (tuple): [0.6, 5]
358
+ sigma_y_range (tuple): [0.6, 5]
359
+ rotation range (tuple): [-math.pi, math.pi]
360
+ beta_range (tuple): [0.5, 8]
361
+ noise_range(tuple, optional): multiplicative kernel noise,
362
+ [0.75, 1.25]. Default: None
363
+
364
+ Returns:
365
+ kernel (ndarray):
366
+ """
367
+ kernel_type = random.choices(kernel_list, kernel_prob)[0]
368
+ if kernel_type == 'iso':
369
+ kernel = random_bivariate_Gaussian(
370
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
371
+ elif kernel_type == 'aniso':
372
+ kernel = random_bivariate_Gaussian(
373
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
374
+ elif kernel_type == 'generalized_iso':
375
+ kernel = random_bivariate_generalized_Gaussian(
376
+ kernel_size,
377
+ sigma_x_range,
378
+ sigma_y_range,
379
+ rotation_range,
380
+ betag_range,
381
+ noise_range=noise_range,
382
+ isotropic=True)
383
+ elif kernel_type == 'generalized_aniso':
384
+ kernel = random_bivariate_generalized_Gaussian(
385
+ kernel_size,
386
+ sigma_x_range,
387
+ sigma_y_range,
388
+ rotation_range,
389
+ betag_range,
390
+ noise_range=noise_range,
391
+ isotropic=False)
392
+ elif kernel_type == 'plateau_iso':
393
+ kernel = random_bivariate_plateau(
394
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
395
+ elif kernel_type == 'plateau_aniso':
396
+ kernel = random_bivariate_plateau(
397
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
398
+ return kernel
399
+
400
+
401
+ np.seterr(divide='ignore', invalid='ignore')
402
+
403
+
404
+ def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
405
+ """2D sinc filter
406
+
407
+ Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
408
+
409
+ Args:
410
+ cutoff (float): cutoff frequency in radians (pi is max)
411
+ kernel_size (int): horizontal and vertical size, must be odd.
412
+ pad_to (int): pad kernel size to desired size, must be odd or zero.
413
+ """
414
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
415
+ kernel = np.fromfunction(
416
+ lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
417
+ (x - (kernel_size - 1) / 2) ** 2 + (y - (kernel_size - 1) / 2) ** 2)) / (2 * np.pi * np.sqrt(
418
+ (x - (kernel_size - 1) / 2) ** 2 + (y - (kernel_size - 1) / 2) ** 2)), [kernel_size, kernel_size])
419
+ kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff ** 2 / (4 * np.pi)
420
+ kernel = kernel / np.sum(kernel)
421
+ if pad_to > kernel_size:
422
+ pad_size = (pad_to - kernel_size) // 2
423
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
424
+ return kernel
425
+
426
+
427
+ # ------------------------------------------------------------- #
428
+ # --------------------------- noise --------------------------- #
429
+ # ------------------------------------------------------------- #
430
+
431
+ # ----------------------- Gaussian Noise ----------------------- #
432
+
433
+ def instantiate_from_config(config: Mapping[str, Any]) -> Any:
434
+ if not "target" in config:
435
+ raise KeyError("Expected key `target` to instantiate.")
436
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
437
+
438
+
439
+ class BaseStorageBackend(metaclass=ABCMeta):
440
+ """Abstract class of storage backends.
441
+
442
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
443
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
444
+ as texts.
445
+ """
446
+
447
+ @property
448
+ def name(self) -> str:
449
+ return self.__class__.__name__
450
+
451
+ @abstractmethod
452
+ def get(self, filepath: str) -> bytes:
453
+ pass
454
+
455
+
456
+ class PetrelBackend(BaseStorageBackend):
457
+ """Petrel storage backend (for internal use).
458
+
459
+ PetrelBackend supports reading and writing data to multiple clusters.
460
+ If the file path contains the cluster name, PetrelBackend will read data
461
+ from specified cluster or write data to it. Otherwise, PetrelBackend will
462
+ access the default cluster.
463
+
464
+ Args:
465
+ path_mapping (dict, optional): Path mapping dict from local path to
466
+ Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in
467
+ ``filepath`` will be replaced by ``dst``. Default: None.
468
+ enable_mc (bool, optional): Whether to enable memcached support.
469
+ Default: True.
470
+ conf_path (str, optional): Config path of Petrel client. Default: None.
471
+ `New in version 1.7.1`.
472
+
473
+ Examples:
474
+ >>> filepath1 = 's3://path/of/file'
475
+ >>> filepath2 = 'cluster-name:s3://path/of/file'
476
+ >>> client = PetrelBackend()
477
+ >>> client.get(filepath1) # get data from default cluster
478
+ >>> client.get(filepath2) # get data from 'cluster-name' cluster
479
+ """
480
+
481
+ def __init__(self,
482
+ path_mapping: Optional[dict] = None,
483
+ enable_mc: bool = False,
484
+ conf_path: str = None):
485
+ try:
486
+ from petrel_client import client
487
+ except ImportError:
488
+ raise ImportError('Please install petrel_client to enable '
489
+ 'PetrelBackend.')
490
+
491
+ self._client = client.Client(conf_path=conf_path, enable_mc=enable_mc)
492
+ assert isinstance(path_mapping, dict) or path_mapping is None
493
+ self.path_mapping = path_mapping
494
+
495
+ def _map_path(self, filepath: Union[str, Path]) -> str:
496
+ """Map ``filepath`` to a string path whose prefix will be replaced by
497
+ :attr:`self.path_mapping`.
498
+
499
+ Args:
500
+ filepath (str): Path to be mapped.
501
+ """
502
+ filepath = str(filepath)
503
+ if self.path_mapping is not None:
504
+ for k, v in self.path_mapping.items():
505
+ filepath = filepath.replace(k, v, 1)
506
+ return filepath
507
+
508
+ def _format_path(self, filepath: str) -> str:
509
+ """Convert a ``filepath`` to standard format of petrel oss.
510
+
511
+ If the ``filepath`` is concatenated by ``os.path.join``, in a Windows
512
+ environment, the ``filepath`` will be the format of
513
+ 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the
514
+ above ``filepath`` will be converted to 's3://bucket_name/image.jpg'.
515
+
516
+ Args:
517
+ filepath (str): Path to be formatted.
518
+ """
519
+ return re.sub(r'\\+', '/', filepath)
520
+
521
+ def get(self, filepath: Union[str, Path]) -> bytes:
522
+ """Read data from a given ``filepath`` with 'rb' mode.
523
+
524
+ Args:
525
+ filepath (str or Path): Path to read data.
526
+
527
+ Returns:
528
+ bytes: The loaded bytes.
529
+ """
530
+ filepath = self._map_path(filepath)
531
+ filepath = self._format_path(filepath)
532
+ value = self._client.Get(filepath)
533
+ return value
534
+
535
+
536
+ class HardDiskBackend(BaseStorageBackend):
537
+ """Raw hard disks storage backend."""
538
+
539
+ def get(self, filepath: Union[str, Path]) -> bytes:
540
+ """Read data from a given ``filepath`` with 'rb' mode.
541
+
542
+ Args:
543
+ filepath (str or Path): Path to read data.
544
+
545
+ Returns:
546
+ bytes: Expected bytes object.
547
+ """
548
+ with open(filepath, 'rb') as f:
549
+ value_buf = f.read()
550
+ return value_buf
551
+
552
+
553
+ def generate_gaussian_noise(img, sigma=10, gray_noise=False):
554
+ """Generate Gaussian noise.
555
+
556
+ Args:
557
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
558
+ sigma (float): Noise scale (measured in range 255). Default: 10.
559
+
560
+ Returns:
561
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
562
+ float32.
563
+ """
564
+ if gray_noise:
565
+ noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
566
+ noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
567
+ else:
568
+ noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
569
+ return noise
570
+
571
+
572
+ def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
573
+ """Add Gaussian noise.
574
+
575
+ Args:
576
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
577
+ sigma (float): Noise scale (measured in range 255). Default: 10.
578
+
579
+ Returns:
580
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
581
+ float32.
582
+ """
583
+ noise = generate_gaussian_noise(img, sigma, gray_noise)
584
+ out = img + noise
585
+ if clip and rounds:
586
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
587
+ elif clip:
588
+ out = np.clip(out, 0, 1)
589
+ elif rounds:
590
+ out = (out * 255.0).round() / 255.
591
+ return out
592
+
593
+
594
+ def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
595
+ """Add Gaussian noise (PyTorch version).
596
+
597
+ Args:
598
+ img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
599
+ scale (float | Tensor): Noise scale. Default: 1.0.
600
+
601
+ Returns:
602
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
603
+ float32.
604
+ """
605
+ b, _, h, w = img.size()
606
+ if not isinstance(sigma, (float, int)):
607
+ sigma = sigma.view(img.size(0), 1, 1, 1)
608
+ if isinstance(gray_noise, (float, int)):
609
+ cal_gray_noise = gray_noise > 0
610
+ else:
611
+ gray_noise = gray_noise.view(b, 1, 1, 1)
612
+ cal_gray_noise = torch.sum(gray_noise) > 0
613
+
614
+ if cal_gray_noise:
615
+ noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
616
+ noise_gray = noise_gray.view(b, 1, h, w)
617
+
618
+ # always calculate color noise
619
+ noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
620
+
621
+ if cal_gray_noise:
622
+ noise = noise * (1 - gray_noise) + noise_gray * gray_noise
623
+ return noise
624
+
625
+
626
+ def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
627
+ """Add Gaussian noise (PyTorch version).
628
+
629
+ Args:
630
+ img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
631
+ scale (float | Tensor): Noise scale. Default: 1.0.
632
+
633
+ Returns:
634
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
635
+ float32.
636
+ """
637
+ noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
638
+ out = img + noise
639
+ if clip and rounds:
640
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
641
+ elif clip:
642
+ out = torch.clamp(out, 0, 1)
643
+ elif rounds:
644
+ out = (out * 255.0).round() / 255.
645
+ return out
646
+
647
+
648
+ # ----------------------- Random Gaussian Noise ----------------------- #
649
+ def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
650
+ sigma = np.random.uniform(sigma_range[0], sigma_range[1])
651
+ if np.random.uniform() < gray_prob:
652
+ gray_noise = True
653
+ else:
654
+ gray_noise = False
655
+ return generate_gaussian_noise(img, sigma, gray_noise)
656
+
657
+
658
+ def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
659
+ noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
660
+ out = img + noise
661
+ if clip and rounds:
662
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
663
+ elif clip:
664
+ out = np.clip(out, 0, 1)
665
+ elif rounds:
666
+ out = (out * 255.0).round() / 255.
667
+ return out
668
+
669
+
670
+ def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
671
+ sigma = torch.rand(
672
+ img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
673
+ gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
674
+ gray_noise = (gray_noise < gray_prob).float()
675
+ return generate_gaussian_noise_pt(img, sigma, gray_noise)
676
+
677
+
678
+ def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
679
+ noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
680
+ out = img + noise
681
+ if clip and rounds:
682
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
683
+ elif clip:
684
+ out = torch.clamp(out, 0, 1)
685
+ elif rounds:
686
+ out = (out * 255.0).round() / 255.
687
+ return out
688
+
689
+
690
+ # ----------------------- Poisson (Shot) Noise ----------------------- #
691
+
692
+
693
+ def generate_poisson_noise(img, scale=1.0, gray_noise=False):
694
+ """Generate poisson noise.
695
+
696
+ Reference: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
697
+
698
+ Args:
699
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
700
+ scale (float): Noise scale. Default: 1.0.
701
+ gray_noise (bool): Whether generate gray noise. Default: False.
702
+
703
+ Returns:
704
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
705
+ float32.
706
+ """
707
+ if gray_noise:
708
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
709
+ # round and clip image for counting vals correctly
710
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
711
+ vals = len(np.unique(img))
712
+ vals = 2 ** np.ceil(np.log2(vals))
713
+ out = np.float32(np.random.poisson(img * vals) / float(vals))
714
+ noise = out - img
715
+ if gray_noise:
716
+ noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
717
+ return noise * scale
718
+
719
+
720
+ def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
721
+ """Add poisson noise.
722
+
723
+ Args:
724
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
725
+ scale (float): Noise scale. Default: 1.0.
726
+ gray_noise (bool): Whether generate gray noise. Default: False.
727
+
728
+ Returns:
729
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
730
+ float32.
731
+ """
732
+ noise = generate_poisson_noise(img, scale, gray_noise)
733
+ out = img + noise
734
+ if clip and rounds:
735
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
736
+ elif clip:
737
+ out = np.clip(out, 0, 1)
738
+ elif rounds:
739
+ out = (out * 255.0).round() / 255.
740
+ return out
741
+
742
+
743
+ def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
744
+ """Generate a batch of poisson noise (PyTorch version)
745
+
746
+ Args:
747
+ img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
748
+ scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
749
+ Default: 1.0.
750
+ gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
751
+ 0 for False, 1 for True. Default: 0.
752
+
753
+ Returns:
754
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
755
+ float32.
756
+ """
757
+ b, _, h, w = img.size()
758
+ if isinstance(gray_noise, (float, int)):
759
+ cal_gray_noise = gray_noise > 0
760
+ else:
761
+ gray_noise = gray_noise.view(b, 1, 1, 1)
762
+ cal_gray_noise = torch.sum(gray_noise) > 0
763
+ if cal_gray_noise:
764
+ img_gray = rgb_to_grayscale(img, num_output_channels=1)
765
+ # round and clip image for counting vals correctly
766
+ img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
767
+ # use for-loop to get the unique values for each sample
768
+ vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
769
+ vals_list = [2 ** np.ceil(np.log2(vals)) for vals in vals_list]
770
+ vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
771
+ out = torch.poisson(img_gray * vals) / vals
772
+ noise_gray = out - img_gray
773
+ noise_gray = noise_gray.expand(b, 3, h, w)
774
+
775
+ # always calculate color noise
776
+ # round and clip image for counting vals correctly
777
+ img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
778
+ # use for-loop to get the unique values for each sample
779
+ vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
780
+ vals_list = [2 ** np.ceil(np.log2(vals)) for vals in vals_list]
781
+ vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
782
+ out = torch.poisson(img * vals) / vals
783
+ noise = out - img
784
+ if cal_gray_noise:
785
+ noise = noise * (1 - gray_noise) + noise_gray * gray_noise
786
+ if not isinstance(scale, (float, int)):
787
+ scale = scale.view(b, 1, 1, 1)
788
+ return noise * scale
789
+
790
+
791
+ def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
792
+ """Add poisson noise to a batch of images (PyTorch version).
793
+
794
+ Args:
795
+ img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
796
+ scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
797
+ Default: 1.0.
798
+ gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
799
+ 0 for False, 1 for True. Default: 0.
800
+
801
+ Returns:
802
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
803
+ float32.
804
+ """
805
+ noise = generate_poisson_noise_pt(img, scale, gray_noise)
806
+ out = img + noise
807
+ if clip and rounds:
808
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
809
+ elif clip:
810
+ out = torch.clamp(out, 0, 1)
811
+ elif rounds:
812
+ out = (out * 255.0).round() / 255.
813
+ return out
814
+
815
+
816
+ # ----------------------- Random Poisson (Shot) Noise ----------------------- #
817
+
818
+
819
+ def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
820
+ scale = np.random.uniform(scale_range[0], scale_range[1])
821
+ if np.random.uniform() < gray_prob:
822
+ gray_noise = True
823
+ else:
824
+ gray_noise = False
825
+ return generate_poisson_noise(img, scale, gray_noise)
826
+
827
+
828
+ def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
829
+ noise = random_generate_poisson_noise(img, scale_range, gray_prob)
830
+ out = img + noise
831
+ if clip and rounds:
832
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
833
+ elif clip:
834
+ out = np.clip(out, 0, 1)
835
+ elif rounds:
836
+ out = (out * 255.0).round() / 255.
837
+ return out
838
+
839
+
840
+ def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
841
+ scale = torch.rand(
842
+ img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
843
+ gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
844
+ gray_noise = (gray_noise < gray_prob).float()
845
+ return generate_poisson_noise_pt(img, scale, gray_noise)
846
+
847
+
848
+ def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
849
+ noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
850
+ out = img + noise
851
+ if clip and rounds:
852
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
853
+ elif clip:
854
+ out = torch.clamp(out, 0, 1)
855
+ elif rounds:
856
+ out = (out * 255.0).round() / 255.
857
+ return out
858
+
859
+
860
+ # ------------------------------------------------------------------------ #
861
+ # --------------------------- JPEG compression --------------------------- #
862
+ # ------------------------------------------------------------------------ #
863
+
864
+
865
+ def add_jpg_compression(img, quality=90):
866
+ """Add JPG compression artifacts.
867
+
868
+ Args:
869
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
870
+ quality (float): JPG compression quality. 0 for lowest quality, 100 for
871
+ best quality. Default: 90.
872
+
873
+ Returns:
874
+ (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
875
+ float32.
876
+ """
877
+ img = np.clip(img, 0, 1)
878
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
879
+ _, encimg = cv2.imencode('.jpg', img * 255., encode_param)
880
+ img = np.float32(cv2.imdecode(encimg, 1)) / 255.
881
+ return img
882
+
883
+
884
+ def random_add_jpg_compression(img, quality_range=(90, 100)):
885
+ """Randomly add JPG compression artifacts.
886
+
887
+ Args:
888
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
889
+ quality_range (tuple[float] | list[float]): JPG compression quality
890
+ range. 0 for lowest quality, 100 for best quality.
891
+ Default: (90, 100).
892
+
893
+ Returns:
894
+ (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
895
+ float32.
896
+ """
897
+ quality = np.random.uniform(quality_range[0], quality_range[1])
898
+ return add_jpg_compression(img, int(quality))
899
+
900
+
901
+ def load_file_list(file_list_path: str) -> List[Dict[str, str]]:
902
+ files = []
903
+ with open(file_list_path, "r") as fin:
904
+ for line in fin:
905
+ path = line.strip()
906
+ if path:
907
+ files.append({"image_path": path, "prompt": ""})
908
+ return files
909
+
910
+
911
+ # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/image_datasets.py
912
+ def center_crop_arr(pil_image, image_size):
913
+ # We are not on a new enough PIL to support the `reducing_gap`
914
+ # argument, which uses BOX downsampling at powers of two first.
915
+ # Thus, we do it by hand to improve downsample quality.
916
+ while min(*pil_image.size) >= 2 * image_size:
917
+ pil_image = pil_image.resize(
918
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
919
+ )
920
+
921
+ scale = image_size / min(*pil_image.size)
922
+ pil_image = pil_image.resize(
923
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
924
+ )
925
+
926
+ arr = np.array(pil_image)
927
+ crop_y = (arr.shape[0] - image_size) // 2
928
+ crop_x = (arr.shape[1] - image_size) // 2
929
+ return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
930
+
931
+
932
+ # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/image_datasets.py
933
+ def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
934
+ min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
935
+ max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
936
+ smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
937
+
938
+ # We are not on a new enough PIL to support the `reducing_gap`
939
+ # argument, which uses BOX downsampling at powers of two first.
940
+ # Thus, we do it by hand to improve downsample quality.
941
+ while min(*pil_image.size) >= 2 * smaller_dim_size:
942
+ pil_image = pil_image.resize(
943
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
944
+ )
945
+
946
+ scale = smaller_dim_size / min(*pil_image.size)
947
+ pil_image = pil_image.resize(
948
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
949
+ )
950
+
951
+ arr = np.array(pil_image)
952
+ crop_y = random.randrange(arr.shape[0] - image_size + 1)
953
+ crop_x = random.randrange(arr.shape[1] - image_size + 1)
954
+ return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
utils/create_arch.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from arch.hourglass import image_transformer_v2 as itv2
2
+ from arch.hourglass.image_transformer_v2 import ImageTransformerDenoiserModelV2
3
+ from arch.swinir.swinir import SwinIR
4
+
5
+
6
+ def create_arch(arch, condition_channels=0):
7
+ # arch should be, e.g., swinir_XL, or hdit_XL
8
+ arch_name, arch_size = arch.split('_')
9
+ arch_config = arch_configs[arch_name][arch_size].copy()
10
+ arch_config['in_channels'] += condition_channels
11
+ return arch_name_to_object[arch_name](**arch_config)
12
+
13
+
14
+ arch_configs = {
15
+ 'hdit': {
16
+ "ImageNet256Sp4": {
17
+ 'in_channels': 3,
18
+ 'out_channels': 3,
19
+ 'widths': [256, 512, 1024],
20
+ 'depths': [2, 2, 8],
21
+ 'patch_size': [4, 4],
22
+ 'self_attns': [
23
+ {"type": "neighborhood", "d_head": 64, "kernel_size": 7},
24
+ {"type": "neighborhood", "d_head": 64, "kernel_size": 7},
25
+ {"type": "global", "d_head": 64}
26
+ ],
27
+ 'mapping_depth': 2,
28
+ 'mapping_width': 768,
29
+ 'dropout_rate': [0, 0, 0],
30
+ 'mapping_dropout_rate': 0.0
31
+ },
32
+ "XL2": {
33
+ 'in_channels': 3,
34
+ 'out_channels': 3,
35
+ 'widths': [384, 768],
36
+ 'depths': [2, 11],
37
+ 'patch_size': [4, 4],
38
+ 'self_attns': [
39
+ {"type": "neighborhood", "d_head": 64, "kernel_size": 7},
40
+ {"type": "global", "d_head": 64}
41
+ ],
42
+ 'mapping_depth': 2,
43
+ 'mapping_width': 768,
44
+ 'dropout_rate': [0, 0],
45
+ 'mapping_dropout_rate': 0.0
46
+ }
47
+
48
+ },
49
+ 'swinir': {
50
+ "M": {
51
+ 'in_channels': 3,
52
+ 'out_channels': 3,
53
+ 'embed_dim': 120,
54
+ 'depths': [6, 6, 6, 6, 6],
55
+ 'num_heads': [6, 6, 6, 6, 6],
56
+ 'resi_connection': '1conv',
57
+ 'sf': 8
58
+
59
+ },
60
+ "L": {
61
+ 'in_channels': 3,
62
+ 'out_channels': 3,
63
+ 'embed_dim': 180,
64
+ 'depths': [6, 6, 6, 6, 6, 6, 6, 6],
65
+ 'num_heads': [6, 6, 6, 6, 6, 6, 6, 6],
66
+ 'resi_connection': '1conv',
67
+ 'sf': 8
68
+ },
69
+ },
70
+ }
71
+
72
+
73
+ def create_swinir_model(in_channels, out_channels, embed_dim, depths, num_heads, resi_connection,
74
+ sf):
75
+ return SwinIR(
76
+ img_size=64,
77
+ patch_size=1,
78
+ in_chans=in_channels,
79
+ num_out_ch=out_channels,
80
+ embed_dim=embed_dim,
81
+ depths=depths,
82
+ num_heads=num_heads,
83
+ window_size=8,
84
+ mlp_ratio=2,
85
+ sf=sf,
86
+ img_range=1.0,
87
+ upsampler="nearest+conv",
88
+ resi_connection=resi_connection,
89
+ unshuffle=True,
90
+ unshuffle_scale=8
91
+ )
92
+
93
+
94
+ def create_hdit_model(widths,
95
+ depths,
96
+ self_attns,
97
+ dropout_rate,
98
+ mapping_depth,
99
+ mapping_width,
100
+ mapping_dropout_rate,
101
+ in_channels,
102
+ out_channels,
103
+ patch_size
104
+ ):
105
+ assert len(widths) == len(depths)
106
+ assert len(widths) == len(self_attns)
107
+ assert len(widths) == len(dropout_rate)
108
+ mapping_d_ff = mapping_width * 3
109
+ d_ffs = []
110
+ for width in widths:
111
+ d_ffs.append(width * 3)
112
+
113
+ levels = []
114
+ for depth, width, d_ff, self_attn, dropout in zip(depths, widths, d_ffs, self_attns, dropout_rate):
115
+ if self_attn['type'] == 'global':
116
+ self_attn = itv2.GlobalAttentionSpec(self_attn.get('d_head', 64))
117
+ elif self_attn['type'] == 'neighborhood':
118
+ self_attn = itv2.NeighborhoodAttentionSpec(self_attn.get('d_head', 64), self_attn.get('kernel_size', 7))
119
+ elif self_attn['type'] == 'shifted-window':
120
+ self_attn = itv2.ShiftedWindowAttentionSpec(self_attn.get('d_head', 64), self_attn['window_size'])
121
+ elif self_attn['type'] == 'none':
122
+ self_attn = itv2.NoAttentionSpec()
123
+ else:
124
+ raise ValueError(f'unsupported self attention type {self_attn["type"]}')
125
+ levels.append(itv2.LevelSpec(depth, width, d_ff, self_attn, dropout))
126
+ mapping = itv2.MappingSpec(mapping_depth, mapping_width, mapping_d_ff, mapping_dropout_rate)
127
+ model = ImageTransformerDenoiserModelV2(
128
+ levels=levels,
129
+ mapping=mapping,
130
+ in_channels=in_channels,
131
+ out_channels=out_channels,
132
+ patch_size=patch_size,
133
+ num_classes=0,
134
+ mapping_cond_dim=0,
135
+ )
136
+
137
+ return model
138
+
139
+
140
+ arch_name_to_object = {
141
+ 'hdit': create_hdit_model,
142
+ 'swinir': create_swinir_model,
143
+ }
utils/create_degradation.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import partial
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from basicsr.data import degradations as degradations
8
+ from basicsr.data.transforms import augment
9
+ from basicsr.utils import img2tensor
10
+ from torch.nn.functional import interpolate
11
+ from torchvision.transforms import Compose
12
+ from utils.basicsr_custom import (
13
+ random_mixed_kernels,
14
+ random_add_gaussian_noise,
15
+ random_add_jpg_compression,
16
+ )
17
+
18
+
19
+ def create_degradation(degradation):
20
+ if degradation == 'sr_bicubic_x8_gaussian_noise_005':
21
+ return Compose([
22
+ partial(down_scale, scale_factor=1.0 / 8.0, mode='bicubic'),
23
+ partial(add_gaussian_noise, std=0.05),
24
+ partial(interpolate, scale_factor=8.0, mode='nearest-exact'),
25
+ partial(torch.clip, min=0, max=1),
26
+ partial(torch.squeeze, dim=0),
27
+ lambda x: (x, None)
28
+
29
+ ])
30
+ elif degradation == 'gaussian_noise_035':
31
+ return Compose([
32
+ partial(add_gaussian_noise, std=0.35),
33
+ partial(torch.clip, min=0, max=1),
34
+ partial(torch.squeeze, dim=0),
35
+ lambda x: (x, None)
36
+
37
+ ])
38
+ elif degradation == 'colorization_gaussian_noise_025':
39
+ return Compose([
40
+ lambda x: torch.mean(x, dim=0, keepdim=True),
41
+ partial(add_gaussian_noise, std=0.25),
42
+ partial(torch.clip, min=0, max=1),
43
+ lambda x: (x, None)
44
+ ])
45
+ elif degradation == 'random_inpainting_gaussian_noise_01':
46
+ def inpainting_dps(x):
47
+ total = x.shape[1] ** 2
48
+ # random pixel sampling
49
+ l, h = [0.9, 0.9]
50
+ prob = np.random.uniform(l, h)
51
+ mask_vec = torch.ones([1, x.shape[1] * x.shape[1]])
52
+ samples = np.random.choice(x.shape[1] * x.shape[1], int(total * prob), replace=False)
53
+ mask_vec[:, samples] = 0
54
+ mask_b = mask_vec.view(1, x.shape[1], x.shape[1])
55
+ mask_b = mask_b.repeat(3, 1, 1)
56
+ mask = torch.ones_like(x, device=x.device)
57
+ mask[:, ...] = mask_b
58
+ return add_gaussian_noise(x * mask, 0.1).clip(0, 1), None
59
+
60
+ return inpainting_dps
61
+ elif degradation == 'difface':
62
+ def deg(x):
63
+ blur_kernel_size = 41
64
+ kernel_list = ['iso', 'aniso']
65
+ kernel_prob = [0.5, 0.5]
66
+ blur_sigma = [0.1, 15]
67
+ downsample_range = [0.8, 32]
68
+ noise_range = [0, 20]
69
+ jpeg_range = [30, 100]
70
+ gt_gray = True
71
+ gray_prob = 0.01
72
+ x = x.permute(1, 2, 0).numpy()[..., ::-1].astype(np.float32)
73
+ # random horizontal flip
74
+ img_gt = augment(x.copy(), hflip=True, rotation=False)
75
+ h, w, _ = img_gt.shape
76
+
77
+ # ------------------------ generate lq image ------------------------ #
78
+ # blur
79
+ kernel = degradations.random_mixed_kernels(
80
+ kernel_list,
81
+ kernel_prob,
82
+ blur_kernel_size,
83
+ blur_sigma,
84
+ blur_sigma, [-math.pi, math.pi],
85
+ noise_range=None)
86
+ img_lq = cv2.filter2D(img_gt, -1, kernel)
87
+ # downsample
88
+ scale = np.random.uniform(downsample_range[0], downsample_range[1])
89
+ img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
90
+ # noise
91
+ if noise_range is not None:
92
+ img_lq = random_add_gaussian_noise(img_lq, noise_range)
93
+ # jpeg compression
94
+ if jpeg_range is not None:
95
+ img_lq = random_add_jpg_compression(img_lq, jpeg_range)
96
+
97
+ # resize to original size
98
+ img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
99
+
100
+ # random color jitter (only for lq)
101
+ # if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
102
+ # img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
103
+ # random to gray (only for lq)
104
+ if np.random.uniform() < gray_prob:
105
+ img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
106
+ img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
107
+ if gt_gray: # whether convert GT to gray images
108
+ img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
109
+ img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
110
+
111
+ # BGR to RGB, HWC to CHW, numpy to tensor
112
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
113
+
114
+ # random color jitter (pytorch version) (only for lq)
115
+ # if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
116
+ # brightness = self.opt.get('brightness', (0.5, 1.5))
117
+ # contrast = self.opt.get('contrast', (0.5, 1.5))
118
+ # saturation = self.opt.get('saturation', (0, 1.5))
119
+ # hue = self.opt.get('hue', (-0.1, 0.1))
120
+ # img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
121
+
122
+ # round and clip
123
+ img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
124
+
125
+ return img_lq, img_gt.clip(0, 1)
126
+
127
+ return deg
128
+ else:
129
+ raise NotImplementedError()
130
+
131
+
132
+ def down_scale(x, scale_factor, mode):
133
+ with torch.no_grad():
134
+ return interpolate(x.unsqueeze(0),
135
+ scale_factor=scale_factor,
136
+ mode=mode,
137
+ antialias=True,
138
+ align_corners=False).clip(0, 1)
139
+
140
+
141
+ def add_gaussian_noise(x, std):
142
+ with torch.no_grad():
143
+ x = x + torch.randn_like(x) * std
144
+ return x
utils/img_utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from torchvision.utils import make_grid
2
+
3
+
4
+ def create_grid(img, normalize=False, num_images=5):
5
+ return make_grid(img[:num_images], padding=0, normalize=normalize, nrow=16)