Upload 202 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +84 -0
- basicsr/__init__.py +12 -0
- basicsr/__pycache__/__init__.cpython-39.pyc +0 -0
- basicsr/__pycache__/test.cpython-39.pyc +0 -0
- basicsr/__pycache__/train.cpython-39.pyc +0 -0
- basicsr/__pycache__/version.cpython-39.pyc +0 -0
- basicsr/archs/__init__.py +24 -0
- basicsr/archs/__pycache__/__init__.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/arch_util.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/basicvsr_arch.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/basicvsrpp_arch.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/dfdnet_arch.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/dfdnet_util.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/discriminator_arch.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/duf_arch.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/ecbsr_arch.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/edsr_arch.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/edvr_arch.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/hifacegan_arch.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/hifacegan_util.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/rcan_arch.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/ridnet_arch.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/rrdbnet_arch.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/spynet_arch.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/srresnet_arch.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/srvgg_arch.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/stylegan2_arch.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/stylegan2_bilinear_arch.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/swinir_arch.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/tof_arch.cpython-39.pyc +0 -0
- basicsr/archs/__pycache__/vgg_arch.cpython-39.pyc +0 -0
- basicsr/archs/arch_util.py +313 -0
- basicsr/archs/basicvsr_arch.py +336 -0
- basicsr/archs/basicvsrpp_arch.py +417 -0
- basicsr/archs/dfdnet_arch.py +169 -0
- basicsr/archs/dfdnet_util.py +162 -0
- basicsr/archs/discriminator_arch.py +150 -0
- basicsr/archs/duf_arch.py +276 -0
- basicsr/archs/ecbsr_arch.py +275 -0
- basicsr/archs/edsr_arch.py +61 -0
- basicsr/archs/edvr_arch.py +382 -0
- basicsr/archs/hifacegan_arch.py +260 -0
- basicsr/archs/hifacegan_util.py +255 -0
- basicsr/archs/inception.py +307 -0
- basicsr/archs/rcan_arch.py +135 -0
- basicsr/archs/ridnet_arch.py +180 -0
- basicsr/archs/rrdbnet_arch.py +119 -0
- basicsr/archs/spynet_arch.py +96 -0
- basicsr/archs/srresnet_arch.py +65 -0
- basicsr/archs/srvgg_arch.py +70 -0
app.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import insightface
|
2 |
+
import os
|
3 |
+
import onnxruntime
|
4 |
+
import cv2
|
5 |
+
import gfpgan
|
6 |
+
import tempfile
|
7 |
+
import time
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
|
11 |
+
class Predictor:
|
12 |
+
def __init__(self):
|
13 |
+
self.setup()
|
14 |
+
|
15 |
+
def setup(self):
|
16 |
+
os.makedirs('models', exist_ok=True)
|
17 |
+
os.chdir('models')
|
18 |
+
if not os.path.exists('GFPGANv1.4.pth'):
|
19 |
+
os.system(
|
20 |
+
'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
|
21 |
+
)
|
22 |
+
if not os.path.exists('inswapper_128.onnx'):
|
23 |
+
os.system(
|
24 |
+
'wget https://huggingface.co/ashleykleynhans/inswapper/resolve/main/inswapper_128.onnx'
|
25 |
+
)
|
26 |
+
os.chdir('..')
|
27 |
+
|
28 |
+
"""Load the model into memory to make running multiple predictions efficient"""
|
29 |
+
self.face_swapper = insightface.model_zoo.get_model('models/inswapper_128.onnx',
|
30 |
+
providers=onnxruntime.get_available_providers())
|
31 |
+
self.face_enhancer = gfpgan.GFPGANer(model_path='models/GFPGANv1.4.pth', upscale=1)
|
32 |
+
self.face_analyser = insightface.app.FaceAnalysis(name='buffalo_l')
|
33 |
+
self.face_analyser.prepare(ctx_id=0, det_size=(640, 640))
|
34 |
+
|
35 |
+
def get_face(self, img_data):
|
36 |
+
analysed = self.face_analyser.get(img_data)
|
37 |
+
try:
|
38 |
+
largest = max(analysed, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))
|
39 |
+
return largest
|
40 |
+
except:
|
41 |
+
print("No face found")
|
42 |
+
return None
|
43 |
+
|
44 |
+
def predict(self, input_image, swap_image):
|
45 |
+
"""Run a single prediction on the model"""
|
46 |
+
try:
|
47 |
+
frame = cv2.imread(input_image.name)
|
48 |
+
face = self.get_face(frame)
|
49 |
+
source_face = self.get_face(cv2.imread(swap_image.name))
|
50 |
+
try:
|
51 |
+
print(frame.shape, face.shape, source_face.shape)
|
52 |
+
except:
|
53 |
+
print("printing shapes failed.")
|
54 |
+
result = self.face_swapper.get(frame, face, source_face, paste_back=True)
|
55 |
+
|
56 |
+
_, _, result = self.face_enhancer.enhance(
|
57 |
+
result,
|
58 |
+
paste_back=True
|
59 |
+
)
|
60 |
+
out_path = tempfile.mkdtemp() + f"/{str(int(time.time()))}.jpg"
|
61 |
+
cv2.imwrite(out_path, result)
|
62 |
+
return out_path
|
63 |
+
except Exception as e:
|
64 |
+
print(f"{e}")
|
65 |
+
return None
|
66 |
+
|
67 |
+
|
68 |
+
# Instantiate the Predictor class
|
69 |
+
predictor = Predictor()
|
70 |
+
title = "Swap Faces Using Our Model!!!"
|
71 |
+
|
72 |
+
# Create Gradio Interface
|
73 |
+
iface = gr.Interface(
|
74 |
+
fn=predictor.predict,
|
75 |
+
inputs=[
|
76 |
+
gr.inputs.Image(type="file", label="Target Image"),
|
77 |
+
gr.inputs.Image(type="file", label="Swap Image")
|
78 |
+
],
|
79 |
+
outputs=gr.outputs.Image(type="file", label="Result"),
|
80 |
+
title=title
|
81 |
+
)
|
82 |
+
|
83 |
+
# Launch the Gradio Interface
|
84 |
+
iface.launch()
|
basicsr/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/xinntao/BasicSR
|
2 |
+
# flake8: noqa
|
3 |
+
from .archs import *
|
4 |
+
from .data import *
|
5 |
+
from .losses import *
|
6 |
+
from .metrics import *
|
7 |
+
from .models import *
|
8 |
+
from .ops import *
|
9 |
+
from .test import *
|
10 |
+
from .train import *
|
11 |
+
from .utils import *
|
12 |
+
from .version import __gitsha__, __version__
|
basicsr/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (345 Bytes). View file
|
|
basicsr/__pycache__/test.cpython-39.pyc
ADDED
Binary file (1.58 kB). View file
|
|
basicsr/__pycache__/train.cpython-39.pyc
ADDED
Binary file (6.36 kB). View file
|
|
basicsr/__pycache__/version.cpython-39.pyc
ADDED
Binary file (207 Bytes). View file
|
|
basicsr/archs/__init__.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from copy import deepcopy
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
from basicsr.utils import get_root_logger, scandir
|
6 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
7 |
+
|
8 |
+
__all__ = ['build_network']
|
9 |
+
|
10 |
+
# automatically scan and import arch modules for registry
|
11 |
+
# scan all the files under the 'archs' folder and collect files ending with '_arch.py'
|
12 |
+
arch_folder = osp.dirname(osp.abspath(__file__))
|
13 |
+
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
14 |
+
# import all the arch modules
|
15 |
+
_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
|
16 |
+
|
17 |
+
|
18 |
+
def build_network(opt):
|
19 |
+
opt = deepcopy(opt)
|
20 |
+
network_type = opt.pop('type')
|
21 |
+
net = ARCH_REGISTRY.get(network_type)(**opt)
|
22 |
+
logger = get_root_logger()
|
23 |
+
logger.info(f'Network [{net.__class__.__name__}] is created.')
|
24 |
+
return net
|
basicsr/archs/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (1.11 kB). View file
|
|
basicsr/archs/__pycache__/arch_util.cpython-39.pyc
ADDED
Binary file (10.7 kB). View file
|
|
basicsr/archs/__pycache__/basicvsr_arch.cpython-39.pyc
ADDED
Binary file (10.4 kB). View file
|
|
basicsr/archs/__pycache__/basicvsrpp_arch.cpython-39.pyc
ADDED
Binary file (12.9 kB). View file
|
|
basicsr/archs/__pycache__/dfdnet_arch.cpython-39.pyc
ADDED
Binary file (5.4 kB). View file
|
|
basicsr/archs/__pycache__/dfdnet_util.cpython-39.pyc
ADDED
Binary file (5.55 kB). View file
|
|
basicsr/archs/__pycache__/discriminator_arch.cpython-39.pyc
ADDED
Binary file (4.91 kB). View file
|
|
basicsr/archs/__pycache__/duf_arch.cpython-39.pyc
ADDED
Binary file (9.21 kB). View file
|
|
basicsr/archs/__pycache__/ecbsr_arch.cpython-39.pyc
ADDED
Binary file (8.34 kB). View file
|
|
basicsr/archs/__pycache__/edsr_arch.cpython-39.pyc
ADDED
Binary file (2.28 kB). View file
|
|
basicsr/archs/__pycache__/edvr_arch.cpython-39.pyc
ADDED
Binary file (11.3 kB). View file
|
|
basicsr/archs/__pycache__/hifacegan_arch.cpython-39.pyc
ADDED
Binary file (7.54 kB). View file
|
|
basicsr/archs/__pycache__/hifacegan_util.cpython-39.pyc
ADDED
Binary file (8.43 kB). View file
|
|
basicsr/archs/__pycache__/rcan_arch.cpython-39.pyc
ADDED
Binary file (4.96 kB). View file
|
|
basicsr/archs/__pycache__/ridnet_arch.cpython-39.pyc
ADDED
Binary file (6.48 kB). View file
|
|
basicsr/archs/__pycache__/rrdbnet_arch.cpython-39.pyc
ADDED
Binary file (4.41 kB). View file
|
|
basicsr/archs/__pycache__/spynet_arch.cpython-39.pyc
ADDED
Binary file (3.87 kB). View file
|
|
basicsr/archs/__pycache__/srresnet_arch.cpython-39.pyc
ADDED
Binary file (2.48 kB). View file
|
|
basicsr/archs/__pycache__/srvgg_arch.cpython-39.pyc
ADDED
Binary file (2.38 kB). View file
|
|
basicsr/archs/__pycache__/stylegan2_arch.cpython-39.pyc
ADDED
Binary file (25.3 kB). View file
|
|
basicsr/archs/__pycache__/stylegan2_bilinear_arch.cpython-39.pyc
ADDED
Binary file (18 kB). View file
|
|
basicsr/archs/__pycache__/swinir_arch.cpython-39.pyc
ADDED
Binary file (28.6 kB). View file
|
|
basicsr/archs/__pycache__/tof_arch.cpython-39.pyc
ADDED
Binary file (6.24 kB). View file
|
|
basicsr/archs/__pycache__/vgg_arch.cpython-39.pyc
ADDED
Binary file (4.8 kB). View file
|
|
basicsr/archs/arch_util.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections.abc
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torchvision
|
5 |
+
import warnings
|
6 |
+
from distutils.version import LooseVersion
|
7 |
+
from itertools import repeat
|
8 |
+
from torch import nn as nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch.nn import init as init
|
11 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
12 |
+
|
13 |
+
from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
|
14 |
+
from basicsr.utils import get_root_logger
|
15 |
+
|
16 |
+
|
17 |
+
@torch.no_grad()
|
18 |
+
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
|
19 |
+
"""Initialize network weights.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
|
23 |
+
scale (float): Scale initialized weights, especially for residual
|
24 |
+
blocks. Default: 1.
|
25 |
+
bias_fill (float): The value to fill bias. Default: 0
|
26 |
+
kwargs (dict): Other arguments for initialization function.
|
27 |
+
"""
|
28 |
+
if not isinstance(module_list, list):
|
29 |
+
module_list = [module_list]
|
30 |
+
for module in module_list:
|
31 |
+
for m in module.modules():
|
32 |
+
if isinstance(m, nn.Conv2d):
|
33 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
34 |
+
m.weight.data *= scale
|
35 |
+
if m.bias is not None:
|
36 |
+
m.bias.data.fill_(bias_fill)
|
37 |
+
elif isinstance(m, nn.Linear):
|
38 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
39 |
+
m.weight.data *= scale
|
40 |
+
if m.bias is not None:
|
41 |
+
m.bias.data.fill_(bias_fill)
|
42 |
+
elif isinstance(m, _BatchNorm):
|
43 |
+
init.constant_(m.weight, 1)
|
44 |
+
if m.bias is not None:
|
45 |
+
m.bias.data.fill_(bias_fill)
|
46 |
+
|
47 |
+
|
48 |
+
def make_layer(basic_block, num_basic_block, **kwarg):
|
49 |
+
"""Make layers by stacking the same blocks.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
basic_block (nn.module): nn.module class for basic block.
|
53 |
+
num_basic_block (int): number of blocks.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
nn.Sequential: Stacked blocks in nn.Sequential.
|
57 |
+
"""
|
58 |
+
layers = []
|
59 |
+
for _ in range(num_basic_block):
|
60 |
+
layers.append(basic_block(**kwarg))
|
61 |
+
return nn.Sequential(*layers)
|
62 |
+
|
63 |
+
|
64 |
+
class ResidualBlockNoBN(nn.Module):
|
65 |
+
"""Residual block without BN.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
num_feat (int): Channel number of intermediate features.
|
69 |
+
Default: 64.
|
70 |
+
res_scale (float): Residual scale. Default: 1.
|
71 |
+
pytorch_init (bool): If set to True, use pytorch default init,
|
72 |
+
otherwise, use default_init_weights. Default: False.
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
|
76 |
+
super(ResidualBlockNoBN, self).__init__()
|
77 |
+
self.res_scale = res_scale
|
78 |
+
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
79 |
+
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
80 |
+
self.relu = nn.ReLU(inplace=True)
|
81 |
+
|
82 |
+
if not pytorch_init:
|
83 |
+
default_init_weights([self.conv1, self.conv2], 0.1)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
identity = x
|
87 |
+
out = self.conv2(self.relu(self.conv1(x)))
|
88 |
+
return identity + out * self.res_scale
|
89 |
+
|
90 |
+
|
91 |
+
class Upsample(nn.Sequential):
|
92 |
+
"""Upsample module.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
96 |
+
num_feat (int): Channel number of intermediate features.
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(self, scale, num_feat):
|
100 |
+
m = []
|
101 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
102 |
+
for _ in range(int(math.log(scale, 2))):
|
103 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
104 |
+
m.append(nn.PixelShuffle(2))
|
105 |
+
elif scale == 3:
|
106 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
107 |
+
m.append(nn.PixelShuffle(3))
|
108 |
+
else:
|
109 |
+
raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
|
110 |
+
super(Upsample, self).__init__(*m)
|
111 |
+
|
112 |
+
|
113 |
+
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
|
114 |
+
"""Warp an image or feature map with optical flow.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
x (Tensor): Tensor with size (n, c, h, w).
|
118 |
+
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
|
119 |
+
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
|
120 |
+
padding_mode (str): 'zeros' or 'border' or 'reflection'.
|
121 |
+
Default: 'zeros'.
|
122 |
+
align_corners (bool): Before pytorch 1.3, the default value is
|
123 |
+
align_corners=True. After pytorch 1.3, the default value is
|
124 |
+
align_corners=False. Here, we use the True as default.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
Tensor: Warped image or feature map.
|
128 |
+
"""
|
129 |
+
assert x.size()[-2:] == flow.size()[1:3]
|
130 |
+
_, _, h, w = x.size()
|
131 |
+
# create mesh grid
|
132 |
+
grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
|
133 |
+
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
|
134 |
+
grid.requires_grad = False
|
135 |
+
|
136 |
+
vgrid = grid + flow
|
137 |
+
# scale grid to [-1,1]
|
138 |
+
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
|
139 |
+
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
|
140 |
+
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
|
141 |
+
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
|
142 |
+
|
143 |
+
# TODO, what if align_corners=False
|
144 |
+
return output
|
145 |
+
|
146 |
+
|
147 |
+
def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
|
148 |
+
"""Resize a flow according to ratio or shape.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
flow (Tensor): Precomputed flow. shape [N, 2, H, W].
|
152 |
+
size_type (str): 'ratio' or 'shape'.
|
153 |
+
sizes (list[int | float]): the ratio for resizing or the final output
|
154 |
+
shape.
|
155 |
+
1) The order of ratio should be [ratio_h, ratio_w]. For
|
156 |
+
downsampling, the ratio should be smaller than 1.0 (i.e., ratio
|
157 |
+
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
|
158 |
+
ratio > 1.0).
|
159 |
+
2) The order of output_size should be [out_h, out_w].
|
160 |
+
interp_mode (str): The mode of interpolation for resizing.
|
161 |
+
Default: 'bilinear'.
|
162 |
+
align_corners (bool): Whether align corners. Default: False.
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
Tensor: Resized flow.
|
166 |
+
"""
|
167 |
+
_, _, flow_h, flow_w = flow.size()
|
168 |
+
if size_type == 'ratio':
|
169 |
+
output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
|
170 |
+
elif size_type == 'shape':
|
171 |
+
output_h, output_w = sizes[0], sizes[1]
|
172 |
+
else:
|
173 |
+
raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
|
174 |
+
|
175 |
+
input_flow = flow.clone()
|
176 |
+
ratio_h = output_h / flow_h
|
177 |
+
ratio_w = output_w / flow_w
|
178 |
+
input_flow[:, 0, :, :] *= ratio_w
|
179 |
+
input_flow[:, 1, :, :] *= ratio_h
|
180 |
+
resized_flow = F.interpolate(
|
181 |
+
input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
|
182 |
+
return resized_flow
|
183 |
+
|
184 |
+
|
185 |
+
# TODO: may write a cpp file
|
186 |
+
def pixel_unshuffle(x, scale):
|
187 |
+
""" Pixel unshuffle.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
x (Tensor): Input feature with shape (b, c, hh, hw).
|
191 |
+
scale (int): Downsample ratio.
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
Tensor: the pixel unshuffled feature.
|
195 |
+
"""
|
196 |
+
b, c, hh, hw = x.size()
|
197 |
+
out_channel = c * (scale**2)
|
198 |
+
assert hh % scale == 0 and hw % scale == 0
|
199 |
+
h = hh // scale
|
200 |
+
w = hw // scale
|
201 |
+
x_view = x.view(b, c, h, scale, w, scale)
|
202 |
+
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
203 |
+
|
204 |
+
|
205 |
+
class DCNv2Pack(ModulatedDeformConvPack):
|
206 |
+
"""Modulated deformable conv for deformable alignment.
|
207 |
+
|
208 |
+
Different from the official DCNv2Pack, which generates offsets and masks
|
209 |
+
from the preceding features, this DCNv2Pack takes another different
|
210 |
+
features to generate offsets and masks.
|
211 |
+
|
212 |
+
``Paper: Delving Deep into Deformable Alignment in Video Super-Resolution``
|
213 |
+
"""
|
214 |
+
|
215 |
+
def forward(self, x, feat):
|
216 |
+
out = self.conv_offset(feat)
|
217 |
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
218 |
+
offset = torch.cat((o1, o2), dim=1)
|
219 |
+
mask = torch.sigmoid(mask)
|
220 |
+
|
221 |
+
offset_absmean = torch.mean(torch.abs(offset))
|
222 |
+
if offset_absmean > 50:
|
223 |
+
logger = get_root_logger()
|
224 |
+
logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
|
225 |
+
|
226 |
+
if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
|
227 |
+
return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
|
228 |
+
self.dilation, mask)
|
229 |
+
else:
|
230 |
+
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
|
231 |
+
self.dilation, self.groups, self.deformable_groups)
|
232 |
+
|
233 |
+
|
234 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
235 |
+
# From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
|
236 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
237 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
238 |
+
def norm_cdf(x):
|
239 |
+
# Computes standard normal cumulative distribution function
|
240 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
241 |
+
|
242 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
243 |
+
warnings.warn(
|
244 |
+
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
|
245 |
+
'The distribution of values may be incorrect.',
|
246 |
+
stacklevel=2)
|
247 |
+
|
248 |
+
with torch.no_grad():
|
249 |
+
# Values are generated by using a truncated uniform distribution and
|
250 |
+
# then using the inverse CDF for the normal distribution.
|
251 |
+
# Get upper and lower cdf values
|
252 |
+
low = norm_cdf((a - mean) / std)
|
253 |
+
up = norm_cdf((b - mean) / std)
|
254 |
+
|
255 |
+
# Uniformly fill tensor with values from [low, up], then translate to
|
256 |
+
# [2l-1, 2u-1].
|
257 |
+
tensor.uniform_(2 * low - 1, 2 * up - 1)
|
258 |
+
|
259 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
260 |
+
# standard normal
|
261 |
+
tensor.erfinv_()
|
262 |
+
|
263 |
+
# Transform to proper mean, std
|
264 |
+
tensor.mul_(std * math.sqrt(2.))
|
265 |
+
tensor.add_(mean)
|
266 |
+
|
267 |
+
# Clamp to ensure it's in the proper range
|
268 |
+
tensor.clamp_(min=a, max=b)
|
269 |
+
return tensor
|
270 |
+
|
271 |
+
|
272 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
273 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
274 |
+
normal distribution.
|
275 |
+
|
276 |
+
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
|
277 |
+
|
278 |
+
The values are effectively drawn from the
|
279 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
280 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
281 |
+
the bounds. The method used for generating the random values works
|
282 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
283 |
+
|
284 |
+
Args:
|
285 |
+
tensor: an n-dimensional `torch.Tensor`
|
286 |
+
mean: the mean of the normal distribution
|
287 |
+
std: the standard deviation of the normal distribution
|
288 |
+
a: the minimum cutoff value
|
289 |
+
b: the maximum cutoff value
|
290 |
+
|
291 |
+
Examples:
|
292 |
+
>>> w = torch.empty(3, 5)
|
293 |
+
>>> nn.init.trunc_normal_(w)
|
294 |
+
"""
|
295 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
296 |
+
|
297 |
+
|
298 |
+
# From PyTorch
|
299 |
+
def _ntuple(n):
|
300 |
+
|
301 |
+
def parse(x):
|
302 |
+
if isinstance(x, collections.abc.Iterable):
|
303 |
+
return x
|
304 |
+
return tuple(repeat(x, n))
|
305 |
+
|
306 |
+
return parse
|
307 |
+
|
308 |
+
|
309 |
+
to_1tuple = _ntuple(1)
|
310 |
+
to_2tuple = _ntuple(2)
|
311 |
+
to_3tuple = _ntuple(3)
|
312 |
+
to_4tuple = _ntuple(4)
|
313 |
+
to_ntuple = _ntuple
|
basicsr/archs/basicvsr_arch.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
from .arch_util import ResidualBlockNoBN, flow_warp, make_layer
|
7 |
+
from .edvr_arch import PCDAlignment, TSAFusion
|
8 |
+
from .spynet_arch import SpyNet
|
9 |
+
|
10 |
+
|
11 |
+
@ARCH_REGISTRY.register()
|
12 |
+
class BasicVSR(nn.Module):
|
13 |
+
"""A recurrent network for video SR. Now only x4 is supported.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
num_feat (int): Number of channels. Default: 64.
|
17 |
+
num_block (int): Number of residual blocks for each branch. Default: 15
|
18 |
+
spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, num_feat=64, num_block=15, spynet_path=None):
|
22 |
+
super().__init__()
|
23 |
+
self.num_feat = num_feat
|
24 |
+
|
25 |
+
# alignment
|
26 |
+
self.spynet = SpyNet(spynet_path)
|
27 |
+
|
28 |
+
# propagation
|
29 |
+
self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
|
30 |
+
self.forward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
|
31 |
+
|
32 |
+
# reconstruction
|
33 |
+
self.fusion = nn.Conv2d(num_feat * 2, num_feat, 1, 1, 0, bias=True)
|
34 |
+
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
|
35 |
+
self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
|
36 |
+
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
|
37 |
+
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
|
38 |
+
|
39 |
+
self.pixel_shuffle = nn.PixelShuffle(2)
|
40 |
+
|
41 |
+
# activation functions
|
42 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
43 |
+
|
44 |
+
def get_flow(self, x):
|
45 |
+
b, n, c, h, w = x.size()
|
46 |
+
|
47 |
+
x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
|
48 |
+
x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
|
49 |
+
|
50 |
+
flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
|
51 |
+
flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
|
52 |
+
|
53 |
+
return flows_forward, flows_backward
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
"""Forward function of BasicVSR.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
x: Input frames with shape (b, n, c, h, w). n is the temporal dimension / number of frames.
|
60 |
+
"""
|
61 |
+
flows_forward, flows_backward = self.get_flow(x)
|
62 |
+
b, n, _, h, w = x.size()
|
63 |
+
|
64 |
+
# backward branch
|
65 |
+
out_l = []
|
66 |
+
feat_prop = x.new_zeros(b, self.num_feat, h, w)
|
67 |
+
for i in range(n - 1, -1, -1):
|
68 |
+
x_i = x[:, i, :, :, :]
|
69 |
+
if i < n - 1:
|
70 |
+
flow = flows_backward[:, i, :, :, :]
|
71 |
+
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
|
72 |
+
feat_prop = torch.cat([x_i, feat_prop], dim=1)
|
73 |
+
feat_prop = self.backward_trunk(feat_prop)
|
74 |
+
out_l.insert(0, feat_prop)
|
75 |
+
|
76 |
+
# forward branch
|
77 |
+
feat_prop = torch.zeros_like(feat_prop)
|
78 |
+
for i in range(0, n):
|
79 |
+
x_i = x[:, i, :, :, :]
|
80 |
+
if i > 0:
|
81 |
+
flow = flows_forward[:, i - 1, :, :, :]
|
82 |
+
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
|
83 |
+
|
84 |
+
feat_prop = torch.cat([x_i, feat_prop], dim=1)
|
85 |
+
feat_prop = self.forward_trunk(feat_prop)
|
86 |
+
|
87 |
+
# upsample
|
88 |
+
out = torch.cat([out_l[i], feat_prop], dim=1)
|
89 |
+
out = self.lrelu(self.fusion(out))
|
90 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
|
91 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
|
92 |
+
out = self.lrelu(self.conv_hr(out))
|
93 |
+
out = self.conv_last(out)
|
94 |
+
base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
|
95 |
+
out += base
|
96 |
+
out_l[i] = out
|
97 |
+
|
98 |
+
return torch.stack(out_l, dim=1)
|
99 |
+
|
100 |
+
|
101 |
+
class ConvResidualBlocks(nn.Module):
|
102 |
+
"""Conv and residual block used in BasicVSR.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
num_in_ch (int): Number of input channels. Default: 3.
|
106 |
+
num_out_ch (int): Number of output channels. Default: 64.
|
107 |
+
num_block (int): Number of residual blocks. Default: 15.
|
108 |
+
"""
|
109 |
+
|
110 |
+
def __init__(self, num_in_ch=3, num_out_ch=64, num_block=15):
|
111 |
+
super().__init__()
|
112 |
+
self.main = nn.Sequential(
|
113 |
+
nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
114 |
+
make_layer(ResidualBlockNoBN, num_block, num_feat=num_out_ch))
|
115 |
+
|
116 |
+
def forward(self, fea):
|
117 |
+
return self.main(fea)
|
118 |
+
|
119 |
+
|
120 |
+
@ARCH_REGISTRY.register()
|
121 |
+
class IconVSR(nn.Module):
|
122 |
+
"""IconVSR, proposed also in the BasicVSR paper.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
num_feat (int): Number of channels. Default: 64.
|
126 |
+
num_block (int): Number of residual blocks for each branch. Default: 15.
|
127 |
+
keyframe_stride (int): Keyframe stride. Default: 5.
|
128 |
+
temporal_padding (int): Temporal padding. Default: 2.
|
129 |
+
spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
|
130 |
+
edvr_path (str): Path to the pretrained EDVR model. Default: None.
|
131 |
+
"""
|
132 |
+
|
133 |
+
def __init__(self,
|
134 |
+
num_feat=64,
|
135 |
+
num_block=15,
|
136 |
+
keyframe_stride=5,
|
137 |
+
temporal_padding=2,
|
138 |
+
spynet_path=None,
|
139 |
+
edvr_path=None):
|
140 |
+
super().__init__()
|
141 |
+
|
142 |
+
self.num_feat = num_feat
|
143 |
+
self.temporal_padding = temporal_padding
|
144 |
+
self.keyframe_stride = keyframe_stride
|
145 |
+
|
146 |
+
# keyframe_branch
|
147 |
+
self.edvr = EDVRFeatureExtractor(temporal_padding * 2 + 1, num_feat, edvr_path)
|
148 |
+
# alignment
|
149 |
+
self.spynet = SpyNet(spynet_path)
|
150 |
+
|
151 |
+
# propagation
|
152 |
+
self.backward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
|
153 |
+
self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
|
154 |
+
|
155 |
+
self.forward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
|
156 |
+
self.forward_trunk = ConvResidualBlocks(2 * num_feat + 3, num_feat, num_block)
|
157 |
+
|
158 |
+
# reconstruction
|
159 |
+
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
|
160 |
+
self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
|
161 |
+
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
|
162 |
+
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
|
163 |
+
|
164 |
+
self.pixel_shuffle = nn.PixelShuffle(2)
|
165 |
+
|
166 |
+
# activation functions
|
167 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
168 |
+
|
169 |
+
def pad_spatial(self, x):
|
170 |
+
"""Apply padding spatially.
|
171 |
+
|
172 |
+
Since the PCD module in EDVR requires that the resolution is a multiple
|
173 |
+
of 4, we apply padding to the input LR images if their resolution is
|
174 |
+
not divisible by 4.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
x (Tensor): Input LR sequence with shape (n, t, c, h, w).
|
178 |
+
Returns:
|
179 |
+
Tensor: Padded LR sequence with shape (n, t, c, h_pad, w_pad).
|
180 |
+
"""
|
181 |
+
n, t, c, h, w = x.size()
|
182 |
+
|
183 |
+
pad_h = (4 - h % 4) % 4
|
184 |
+
pad_w = (4 - w % 4) % 4
|
185 |
+
|
186 |
+
# padding
|
187 |
+
x = x.view(-1, c, h, w)
|
188 |
+
x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
|
189 |
+
|
190 |
+
return x.view(n, t, c, h + pad_h, w + pad_w)
|
191 |
+
|
192 |
+
def get_flow(self, x):
|
193 |
+
b, n, c, h, w = x.size()
|
194 |
+
|
195 |
+
x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
|
196 |
+
x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
|
197 |
+
|
198 |
+
flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
|
199 |
+
flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
|
200 |
+
|
201 |
+
return flows_forward, flows_backward
|
202 |
+
|
203 |
+
def get_keyframe_feature(self, x, keyframe_idx):
|
204 |
+
if self.temporal_padding == 2:
|
205 |
+
x = [x[:, [4, 3]], x, x[:, [-4, -5]]]
|
206 |
+
elif self.temporal_padding == 3:
|
207 |
+
x = [x[:, [6, 5, 4]], x, x[:, [-5, -6, -7]]]
|
208 |
+
x = torch.cat(x, dim=1)
|
209 |
+
|
210 |
+
num_frames = 2 * self.temporal_padding + 1
|
211 |
+
feats_keyframe = {}
|
212 |
+
for i in keyframe_idx:
|
213 |
+
feats_keyframe[i] = self.edvr(x[:, i:i + num_frames].contiguous())
|
214 |
+
return feats_keyframe
|
215 |
+
|
216 |
+
def forward(self, x):
|
217 |
+
b, n, _, h_input, w_input = x.size()
|
218 |
+
|
219 |
+
x = self.pad_spatial(x)
|
220 |
+
h, w = x.shape[3:]
|
221 |
+
|
222 |
+
keyframe_idx = list(range(0, n, self.keyframe_stride))
|
223 |
+
if keyframe_idx[-1] != n - 1:
|
224 |
+
keyframe_idx.append(n - 1) # last frame is a keyframe
|
225 |
+
|
226 |
+
# compute flow and keyframe features
|
227 |
+
flows_forward, flows_backward = self.get_flow(x)
|
228 |
+
feats_keyframe = self.get_keyframe_feature(x, keyframe_idx)
|
229 |
+
|
230 |
+
# backward branch
|
231 |
+
out_l = []
|
232 |
+
feat_prop = x.new_zeros(b, self.num_feat, h, w)
|
233 |
+
for i in range(n - 1, -1, -1):
|
234 |
+
x_i = x[:, i, :, :, :]
|
235 |
+
if i < n - 1:
|
236 |
+
flow = flows_backward[:, i, :, :, :]
|
237 |
+
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
|
238 |
+
if i in keyframe_idx:
|
239 |
+
feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
|
240 |
+
feat_prop = self.backward_fusion(feat_prop)
|
241 |
+
feat_prop = torch.cat([x_i, feat_prop], dim=1)
|
242 |
+
feat_prop = self.backward_trunk(feat_prop)
|
243 |
+
out_l.insert(0, feat_prop)
|
244 |
+
|
245 |
+
# forward branch
|
246 |
+
feat_prop = torch.zeros_like(feat_prop)
|
247 |
+
for i in range(0, n):
|
248 |
+
x_i = x[:, i, :, :, :]
|
249 |
+
if i > 0:
|
250 |
+
flow = flows_forward[:, i - 1, :, :, :]
|
251 |
+
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
|
252 |
+
if i in keyframe_idx:
|
253 |
+
feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
|
254 |
+
feat_prop = self.forward_fusion(feat_prop)
|
255 |
+
|
256 |
+
feat_prop = torch.cat([x_i, out_l[i], feat_prop], dim=1)
|
257 |
+
feat_prop = self.forward_trunk(feat_prop)
|
258 |
+
|
259 |
+
# upsample
|
260 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv1(feat_prop)))
|
261 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
|
262 |
+
out = self.lrelu(self.conv_hr(out))
|
263 |
+
out = self.conv_last(out)
|
264 |
+
base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
|
265 |
+
out += base
|
266 |
+
out_l[i] = out
|
267 |
+
|
268 |
+
return torch.stack(out_l, dim=1)[..., :4 * h_input, :4 * w_input]
|
269 |
+
|
270 |
+
|
271 |
+
class EDVRFeatureExtractor(nn.Module):
|
272 |
+
"""EDVR feature extractor used in IconVSR.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
num_input_frame (int): Number of input frames.
|
276 |
+
num_feat (int): Number of feature channels
|
277 |
+
load_path (str): Path to the pretrained weights of EDVR. Default: None.
|
278 |
+
"""
|
279 |
+
|
280 |
+
def __init__(self, num_input_frame, num_feat, load_path):
|
281 |
+
|
282 |
+
super(EDVRFeatureExtractor, self).__init__()
|
283 |
+
|
284 |
+
self.center_frame_idx = num_input_frame // 2
|
285 |
+
|
286 |
+
# extract pyramid features
|
287 |
+
self.conv_first = nn.Conv2d(3, num_feat, 3, 1, 1)
|
288 |
+
self.feature_extraction = make_layer(ResidualBlockNoBN, 5, num_feat=num_feat)
|
289 |
+
self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
290 |
+
self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
291 |
+
self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
292 |
+
self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
293 |
+
|
294 |
+
# pcd and tsa module
|
295 |
+
self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=8)
|
296 |
+
self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_input_frame, center_frame_idx=self.center_frame_idx)
|
297 |
+
|
298 |
+
# activation function
|
299 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
300 |
+
|
301 |
+
if load_path:
|
302 |
+
self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
|
303 |
+
|
304 |
+
def forward(self, x):
|
305 |
+
b, n, c, h, w = x.size()
|
306 |
+
|
307 |
+
# extract features for each frame
|
308 |
+
# L1
|
309 |
+
feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
|
310 |
+
feat_l1 = self.feature_extraction(feat_l1)
|
311 |
+
# L2
|
312 |
+
feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
|
313 |
+
feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
|
314 |
+
# L3
|
315 |
+
feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
|
316 |
+
feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
|
317 |
+
|
318 |
+
feat_l1 = feat_l1.view(b, n, -1, h, w)
|
319 |
+
feat_l2 = feat_l2.view(b, n, -1, h // 2, w // 2)
|
320 |
+
feat_l3 = feat_l3.view(b, n, -1, h // 4, w // 4)
|
321 |
+
|
322 |
+
# PCD alignment
|
323 |
+
ref_feat_l = [ # reference feature list
|
324 |
+
feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
|
325 |
+
feat_l3[:, self.center_frame_idx, :, :, :].clone()
|
326 |
+
]
|
327 |
+
aligned_feat = []
|
328 |
+
for i in range(n):
|
329 |
+
nbr_feat_l = [ # neighboring feature list
|
330 |
+
feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
|
331 |
+
]
|
332 |
+
aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
|
333 |
+
aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
|
334 |
+
|
335 |
+
# TSA fusion
|
336 |
+
return self.fusion(aligned_feat)
|
basicsr/archs/basicvsrpp_arch.py
ADDED
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
from basicsr.archs.arch_util import flow_warp
|
8 |
+
from basicsr.archs.basicvsr_arch import ConvResidualBlocks
|
9 |
+
from basicsr.archs.spynet_arch import SpyNet
|
10 |
+
from basicsr.ops.dcn import ModulatedDeformConvPack
|
11 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
12 |
+
|
13 |
+
|
14 |
+
@ARCH_REGISTRY.register()
|
15 |
+
class BasicVSRPlusPlus(nn.Module):
|
16 |
+
"""BasicVSR++ network structure.
|
17 |
+
|
18 |
+
Support either x4 upsampling or same size output. Since DCN is used in this
|
19 |
+
model, it can only be used with CUDA enabled. If CUDA is not enabled,
|
20 |
+
feature alignment will be skipped. Besides, we adopt the official DCN
|
21 |
+
implementation and the version of torch need to be higher than 1.9.
|
22 |
+
|
23 |
+
``Paper: BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment``
|
24 |
+
|
25 |
+
Args:
|
26 |
+
mid_channels (int, optional): Channel number of the intermediate
|
27 |
+
features. Default: 64.
|
28 |
+
num_blocks (int, optional): The number of residual blocks in each
|
29 |
+
propagation branch. Default: 7.
|
30 |
+
max_residue_magnitude (int): The maximum magnitude of the offset
|
31 |
+
residue (Eq. 6 in paper). Default: 10.
|
32 |
+
is_low_res_input (bool, optional): Whether the input is low-resolution
|
33 |
+
or not. If False, the output resolution is equal to the input
|
34 |
+
resolution. Default: True.
|
35 |
+
spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
|
36 |
+
cpu_cache_length (int, optional): When the length of sequence is larger
|
37 |
+
than this value, the intermediate features are sent to CPU. This
|
38 |
+
saves GPU memory, but slows down the inference speed. You can
|
39 |
+
increase this number if you have a GPU with large memory.
|
40 |
+
Default: 100.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self,
|
44 |
+
mid_channels=64,
|
45 |
+
num_blocks=7,
|
46 |
+
max_residue_magnitude=10,
|
47 |
+
is_low_res_input=True,
|
48 |
+
spynet_path=None,
|
49 |
+
cpu_cache_length=100):
|
50 |
+
|
51 |
+
super().__init__()
|
52 |
+
self.mid_channels = mid_channels
|
53 |
+
self.is_low_res_input = is_low_res_input
|
54 |
+
self.cpu_cache_length = cpu_cache_length
|
55 |
+
|
56 |
+
# optical flow
|
57 |
+
self.spynet = SpyNet(spynet_path)
|
58 |
+
|
59 |
+
# feature extraction module
|
60 |
+
if is_low_res_input:
|
61 |
+
self.feat_extract = ConvResidualBlocks(3, mid_channels, 5)
|
62 |
+
else:
|
63 |
+
self.feat_extract = nn.Sequential(
|
64 |
+
nn.Conv2d(3, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
65 |
+
nn.Conv2d(mid_channels, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
66 |
+
ConvResidualBlocks(mid_channels, mid_channels, 5))
|
67 |
+
|
68 |
+
# propagation branches
|
69 |
+
self.deform_align = nn.ModuleDict()
|
70 |
+
self.backbone = nn.ModuleDict()
|
71 |
+
modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2']
|
72 |
+
for i, module in enumerate(modules):
|
73 |
+
if torch.cuda.is_available():
|
74 |
+
self.deform_align[module] = SecondOrderDeformableAlignment(
|
75 |
+
2 * mid_channels,
|
76 |
+
mid_channels,
|
77 |
+
3,
|
78 |
+
padding=1,
|
79 |
+
deformable_groups=16,
|
80 |
+
max_residue_magnitude=max_residue_magnitude)
|
81 |
+
self.backbone[module] = ConvResidualBlocks((2 + i) * mid_channels, mid_channels, num_blocks)
|
82 |
+
|
83 |
+
# upsampling module
|
84 |
+
self.reconstruction = ConvResidualBlocks(5 * mid_channels, mid_channels, 5)
|
85 |
+
|
86 |
+
self.upconv1 = nn.Conv2d(mid_channels, mid_channels * 4, 3, 1, 1, bias=True)
|
87 |
+
self.upconv2 = nn.Conv2d(mid_channels, 64 * 4, 3, 1, 1, bias=True)
|
88 |
+
|
89 |
+
self.pixel_shuffle = nn.PixelShuffle(2)
|
90 |
+
|
91 |
+
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
|
92 |
+
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
|
93 |
+
self.img_upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
|
94 |
+
|
95 |
+
# activation function
|
96 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
97 |
+
|
98 |
+
# check if the sequence is augmented by flipping
|
99 |
+
self.is_mirror_extended = False
|
100 |
+
|
101 |
+
if len(self.deform_align) > 0:
|
102 |
+
self.is_with_alignment = True
|
103 |
+
else:
|
104 |
+
self.is_with_alignment = False
|
105 |
+
warnings.warn('Deformable alignment module is not added. '
|
106 |
+
'Probably your CUDA is not configured correctly. DCN can only '
|
107 |
+
'be used with CUDA enabled. Alignment is skipped now.')
|
108 |
+
|
109 |
+
def check_if_mirror_extended(self, lqs):
|
110 |
+
"""Check whether the input is a mirror-extended sequence.
|
111 |
+
|
112 |
+
If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the (t-1-i)-th frame.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
lqs (tensor): Input low quality (LQ) sequence with shape (n, t, c, h, w).
|
116 |
+
"""
|
117 |
+
|
118 |
+
if lqs.size(1) % 2 == 0:
|
119 |
+
lqs_1, lqs_2 = torch.chunk(lqs, 2, dim=1)
|
120 |
+
if torch.norm(lqs_1 - lqs_2.flip(1)) == 0:
|
121 |
+
self.is_mirror_extended = True
|
122 |
+
|
123 |
+
def compute_flow(self, lqs):
|
124 |
+
"""Compute optical flow using SPyNet for feature alignment.
|
125 |
+
|
126 |
+
Note that if the input is an mirror-extended sequence, 'flows_forward'
|
127 |
+
is not needed, since it is equal to 'flows_backward.flip(1)'.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
lqs (tensor): Input low quality (LQ) sequence with
|
131 |
+
shape (n, t, c, h, w).
|
132 |
+
|
133 |
+
Return:
|
134 |
+
tuple(Tensor): Optical flow. 'flows_forward' corresponds to the flows used for forward-time propagation \
|
135 |
+
(current to previous). 'flows_backward' corresponds to the flows used for backward-time \
|
136 |
+
propagation (current to next).
|
137 |
+
"""
|
138 |
+
|
139 |
+
n, t, c, h, w = lqs.size()
|
140 |
+
lqs_1 = lqs[:, :-1, :, :, :].reshape(-1, c, h, w)
|
141 |
+
lqs_2 = lqs[:, 1:, :, :, :].reshape(-1, c, h, w)
|
142 |
+
|
143 |
+
flows_backward = self.spynet(lqs_1, lqs_2).view(n, t - 1, 2, h, w)
|
144 |
+
|
145 |
+
if self.is_mirror_extended: # flows_forward = flows_backward.flip(1)
|
146 |
+
flows_forward = flows_backward.flip(1)
|
147 |
+
else:
|
148 |
+
flows_forward = self.spynet(lqs_2, lqs_1).view(n, t - 1, 2, h, w)
|
149 |
+
|
150 |
+
if self.cpu_cache:
|
151 |
+
flows_backward = flows_backward.cpu()
|
152 |
+
flows_forward = flows_forward.cpu()
|
153 |
+
|
154 |
+
return flows_forward, flows_backward
|
155 |
+
|
156 |
+
def propagate(self, feats, flows, module_name):
|
157 |
+
"""Propagate the latent features throughout the sequence.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
feats dict(list[tensor]): Features from previous branches. Each
|
161 |
+
component is a list of tensors with shape (n, c, h, w).
|
162 |
+
flows (tensor): Optical flows with shape (n, t - 1, 2, h, w).
|
163 |
+
module_name (str): The name of the propgation branches. Can either
|
164 |
+
be 'backward_1', 'forward_1', 'backward_2', 'forward_2'.
|
165 |
+
|
166 |
+
Return:
|
167 |
+
dict(list[tensor]): A dictionary containing all the propagated \
|
168 |
+
features. Each key in the dictionary corresponds to a \
|
169 |
+
propagation branch, which is represented by a list of tensors.
|
170 |
+
"""
|
171 |
+
|
172 |
+
n, t, _, h, w = flows.size()
|
173 |
+
|
174 |
+
frame_idx = range(0, t + 1)
|
175 |
+
flow_idx = range(-1, t)
|
176 |
+
mapping_idx = list(range(0, len(feats['spatial'])))
|
177 |
+
mapping_idx += mapping_idx[::-1]
|
178 |
+
|
179 |
+
if 'backward' in module_name:
|
180 |
+
frame_idx = frame_idx[::-1]
|
181 |
+
flow_idx = frame_idx
|
182 |
+
|
183 |
+
feat_prop = flows.new_zeros(n, self.mid_channels, h, w)
|
184 |
+
for i, idx in enumerate(frame_idx):
|
185 |
+
feat_current = feats['spatial'][mapping_idx[idx]]
|
186 |
+
if self.cpu_cache:
|
187 |
+
feat_current = feat_current.cuda()
|
188 |
+
feat_prop = feat_prop.cuda()
|
189 |
+
# second-order deformable alignment
|
190 |
+
if i > 0 and self.is_with_alignment:
|
191 |
+
flow_n1 = flows[:, flow_idx[i], :, :, :]
|
192 |
+
if self.cpu_cache:
|
193 |
+
flow_n1 = flow_n1.cuda()
|
194 |
+
|
195 |
+
cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1))
|
196 |
+
|
197 |
+
# initialize second-order features
|
198 |
+
feat_n2 = torch.zeros_like(feat_prop)
|
199 |
+
flow_n2 = torch.zeros_like(flow_n1)
|
200 |
+
cond_n2 = torch.zeros_like(cond_n1)
|
201 |
+
|
202 |
+
if i > 1: # second-order features
|
203 |
+
feat_n2 = feats[module_name][-2]
|
204 |
+
if self.cpu_cache:
|
205 |
+
feat_n2 = feat_n2.cuda()
|
206 |
+
|
207 |
+
flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
|
208 |
+
if self.cpu_cache:
|
209 |
+
flow_n2 = flow_n2.cuda()
|
210 |
+
|
211 |
+
flow_n2 = flow_n1 + flow_warp(flow_n2, flow_n1.permute(0, 2, 3, 1))
|
212 |
+
cond_n2 = flow_warp(feat_n2, flow_n2.permute(0, 2, 3, 1))
|
213 |
+
|
214 |
+
# flow-guided deformable convolution
|
215 |
+
cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1)
|
216 |
+
feat_prop = torch.cat([feat_prop, feat_n2], dim=1)
|
217 |
+
feat_prop = self.deform_align[module_name](feat_prop, cond, flow_n1, flow_n2)
|
218 |
+
|
219 |
+
# concatenate and residual blocks
|
220 |
+
feat = [feat_current] + [feats[k][idx] for k in feats if k not in ['spatial', module_name]] + [feat_prop]
|
221 |
+
if self.cpu_cache:
|
222 |
+
feat = [f.cuda() for f in feat]
|
223 |
+
|
224 |
+
feat = torch.cat(feat, dim=1)
|
225 |
+
feat_prop = feat_prop + self.backbone[module_name](feat)
|
226 |
+
feats[module_name].append(feat_prop)
|
227 |
+
|
228 |
+
if self.cpu_cache:
|
229 |
+
feats[module_name][-1] = feats[module_name][-1].cpu()
|
230 |
+
torch.cuda.empty_cache()
|
231 |
+
|
232 |
+
if 'backward' in module_name:
|
233 |
+
feats[module_name] = feats[module_name][::-1]
|
234 |
+
|
235 |
+
return feats
|
236 |
+
|
237 |
+
def upsample(self, lqs, feats):
|
238 |
+
"""Compute the output image given the features.
|
239 |
+
|
240 |
+
Args:
|
241 |
+
lqs (tensor): Input low quality (LQ) sequence with
|
242 |
+
shape (n, t, c, h, w).
|
243 |
+
feats (dict): The features from the propagation branches.
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
|
247 |
+
"""
|
248 |
+
|
249 |
+
outputs = []
|
250 |
+
num_outputs = len(feats['spatial'])
|
251 |
+
|
252 |
+
mapping_idx = list(range(0, num_outputs))
|
253 |
+
mapping_idx += mapping_idx[::-1]
|
254 |
+
|
255 |
+
for i in range(0, lqs.size(1)):
|
256 |
+
hr = [feats[k].pop(0) for k in feats if k != 'spatial']
|
257 |
+
hr.insert(0, feats['spatial'][mapping_idx[i]])
|
258 |
+
hr = torch.cat(hr, dim=1)
|
259 |
+
if self.cpu_cache:
|
260 |
+
hr = hr.cuda()
|
261 |
+
|
262 |
+
hr = self.reconstruction(hr)
|
263 |
+
hr = self.lrelu(self.pixel_shuffle(self.upconv1(hr)))
|
264 |
+
hr = self.lrelu(self.pixel_shuffle(self.upconv2(hr)))
|
265 |
+
hr = self.lrelu(self.conv_hr(hr))
|
266 |
+
hr = self.conv_last(hr)
|
267 |
+
if self.is_low_res_input:
|
268 |
+
hr += self.img_upsample(lqs[:, i, :, :, :])
|
269 |
+
else:
|
270 |
+
hr += lqs[:, i, :, :, :]
|
271 |
+
|
272 |
+
if self.cpu_cache:
|
273 |
+
hr = hr.cpu()
|
274 |
+
torch.cuda.empty_cache()
|
275 |
+
|
276 |
+
outputs.append(hr)
|
277 |
+
|
278 |
+
return torch.stack(outputs, dim=1)
|
279 |
+
|
280 |
+
def forward(self, lqs):
|
281 |
+
"""Forward function for BasicVSR++.
|
282 |
+
|
283 |
+
Args:
|
284 |
+
lqs (tensor): Input low quality (LQ) sequence with
|
285 |
+
shape (n, t, c, h, w).
|
286 |
+
|
287 |
+
Returns:
|
288 |
+
Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
|
289 |
+
"""
|
290 |
+
|
291 |
+
n, t, c, h, w = lqs.size()
|
292 |
+
|
293 |
+
# whether to cache the features in CPU
|
294 |
+
self.cpu_cache = True if t > self.cpu_cache_length else False
|
295 |
+
|
296 |
+
if self.is_low_res_input:
|
297 |
+
lqs_downsample = lqs.clone()
|
298 |
+
else:
|
299 |
+
lqs_downsample = F.interpolate(
|
300 |
+
lqs.view(-1, c, h, w), scale_factor=0.25, mode='bicubic').view(n, t, c, h // 4, w // 4)
|
301 |
+
|
302 |
+
# check whether the input is an extended sequence
|
303 |
+
self.check_if_mirror_extended(lqs)
|
304 |
+
|
305 |
+
feats = {}
|
306 |
+
# compute spatial features
|
307 |
+
if self.cpu_cache:
|
308 |
+
feats['spatial'] = []
|
309 |
+
for i in range(0, t):
|
310 |
+
feat = self.feat_extract(lqs[:, i, :, :, :]).cpu()
|
311 |
+
feats['spatial'].append(feat)
|
312 |
+
torch.cuda.empty_cache()
|
313 |
+
else:
|
314 |
+
feats_ = self.feat_extract(lqs.view(-1, c, h, w))
|
315 |
+
h, w = feats_.shape[2:]
|
316 |
+
feats_ = feats_.view(n, t, -1, h, w)
|
317 |
+
feats['spatial'] = [feats_[:, i, :, :, :] for i in range(0, t)]
|
318 |
+
|
319 |
+
# compute optical flow using the low-res inputs
|
320 |
+
assert lqs_downsample.size(3) >= 64 and lqs_downsample.size(4) >= 64, (
|
321 |
+
'The height and width of low-res inputs must be at least 64, '
|
322 |
+
f'but got {h} and {w}.')
|
323 |
+
flows_forward, flows_backward = self.compute_flow(lqs_downsample)
|
324 |
+
|
325 |
+
# feature propgation
|
326 |
+
for iter_ in [1, 2]:
|
327 |
+
for direction in ['backward', 'forward']:
|
328 |
+
module = f'{direction}_{iter_}'
|
329 |
+
|
330 |
+
feats[module] = []
|
331 |
+
|
332 |
+
if direction == 'backward':
|
333 |
+
flows = flows_backward
|
334 |
+
elif flows_forward is not None:
|
335 |
+
flows = flows_forward
|
336 |
+
else:
|
337 |
+
flows = flows_backward.flip(1)
|
338 |
+
|
339 |
+
feats = self.propagate(feats, flows, module)
|
340 |
+
if self.cpu_cache:
|
341 |
+
del flows
|
342 |
+
torch.cuda.empty_cache()
|
343 |
+
|
344 |
+
return self.upsample(lqs, feats)
|
345 |
+
|
346 |
+
|
347 |
+
class SecondOrderDeformableAlignment(ModulatedDeformConvPack):
|
348 |
+
"""Second-order deformable alignment module.
|
349 |
+
|
350 |
+
Args:
|
351 |
+
in_channels (int): Same as nn.Conv2d.
|
352 |
+
out_channels (int): Same as nn.Conv2d.
|
353 |
+
kernel_size (int or tuple[int]): Same as nn.Conv2d.
|
354 |
+
stride (int or tuple[int]): Same as nn.Conv2d.
|
355 |
+
padding (int or tuple[int]): Same as nn.Conv2d.
|
356 |
+
dilation (int or tuple[int]): Same as nn.Conv2d.
|
357 |
+
groups (int): Same as nn.Conv2d.
|
358 |
+
bias (bool or str): If specified as `auto`, it will be decided by the
|
359 |
+
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
|
360 |
+
False.
|
361 |
+
max_residue_magnitude (int): The maximum magnitude of the offset
|
362 |
+
residue (Eq. 6 in paper). Default: 10.
|
363 |
+
"""
|
364 |
+
|
365 |
+
def __init__(self, *args, **kwargs):
|
366 |
+
self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
|
367 |
+
|
368 |
+
super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
|
369 |
+
|
370 |
+
self.conv_offset = nn.Sequential(
|
371 |
+
nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
|
372 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
373 |
+
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
|
374 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
375 |
+
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
|
376 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
377 |
+
nn.Conv2d(self.out_channels, 27 * self.deformable_groups, 3, 1, 1),
|
378 |
+
)
|
379 |
+
|
380 |
+
self.init_offset()
|
381 |
+
|
382 |
+
def init_offset(self):
|
383 |
+
|
384 |
+
def _constant_init(module, val, bias=0):
|
385 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
386 |
+
nn.init.constant_(module.weight, val)
|
387 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
388 |
+
nn.init.constant_(module.bias, bias)
|
389 |
+
|
390 |
+
_constant_init(self.conv_offset[-1], val=0, bias=0)
|
391 |
+
|
392 |
+
def forward(self, x, extra_feat, flow_1, flow_2):
|
393 |
+
extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
|
394 |
+
out = self.conv_offset(extra_feat)
|
395 |
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
396 |
+
|
397 |
+
# offset
|
398 |
+
offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
|
399 |
+
offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
|
400 |
+
offset_1 = offset_1 + flow_1.flip(1).repeat(1, offset_1.size(1) // 2, 1, 1)
|
401 |
+
offset_2 = offset_2 + flow_2.flip(1).repeat(1, offset_2.size(1) // 2, 1, 1)
|
402 |
+
offset = torch.cat([offset_1, offset_2], dim=1)
|
403 |
+
|
404 |
+
# mask
|
405 |
+
mask = torch.sigmoid(mask)
|
406 |
+
|
407 |
+
return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
|
408 |
+
self.dilation, mask)
|
409 |
+
|
410 |
+
|
411 |
+
# if __name__ == '__main__':
|
412 |
+
# spynet_path = 'experiments/pretrained_models/flownet/spynet_sintel_final-3d2a1287.pth'
|
413 |
+
# model = BasicVSRPlusPlus(spynet_path=spynet_path).cuda()
|
414 |
+
# input = torch.rand(1, 2, 3, 64, 64).cuda()
|
415 |
+
# output = model(input)
|
416 |
+
# print('===================')
|
417 |
+
# print(output.shape)
|
basicsr/archs/dfdnet_arch.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn.utils.spectral_norm import spectral_norm
|
6 |
+
|
7 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
8 |
+
from .dfdnet_util import AttentionBlock, Blur, MSDilationBlock, UpResBlock, adaptive_instance_normalization
|
9 |
+
from .vgg_arch import VGGFeatureExtractor
|
10 |
+
|
11 |
+
|
12 |
+
class SFTUpBlock(nn.Module):
|
13 |
+
"""Spatial feature transform (SFT) with upsampling block.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
in_channel (int): Number of input channels.
|
17 |
+
out_channel (int): Number of output channels.
|
18 |
+
kernel_size (int): Kernel size in convolutions. Default: 3.
|
19 |
+
padding (int): Padding in convolutions. Default: 1.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, in_channel, out_channel, kernel_size=3, padding=1):
|
23 |
+
super(SFTUpBlock, self).__init__()
|
24 |
+
self.conv1 = nn.Sequential(
|
25 |
+
Blur(in_channel),
|
26 |
+
spectral_norm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
|
27 |
+
nn.LeakyReLU(0.04, True),
|
28 |
+
# The official codes use two LeakyReLU here, so 0.04 for equivalent
|
29 |
+
)
|
30 |
+
self.convup = nn.Sequential(
|
31 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
32 |
+
spectral_norm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
|
33 |
+
nn.LeakyReLU(0.2, True),
|
34 |
+
)
|
35 |
+
|
36 |
+
# for SFT scale and shift
|
37 |
+
self.scale_block = nn.Sequential(
|
38 |
+
spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
|
39 |
+
spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)))
|
40 |
+
self.shift_block = nn.Sequential(
|
41 |
+
spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
|
42 |
+
spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), nn.Sigmoid())
|
43 |
+
# The official codes use sigmoid for shift block, do not know why
|
44 |
+
|
45 |
+
def forward(self, x, updated_feat):
|
46 |
+
out = self.conv1(x)
|
47 |
+
# SFT
|
48 |
+
scale = self.scale_block(updated_feat)
|
49 |
+
shift = self.shift_block(updated_feat)
|
50 |
+
out = out * scale + shift
|
51 |
+
# upsample
|
52 |
+
out = self.convup(out)
|
53 |
+
return out
|
54 |
+
|
55 |
+
|
56 |
+
@ARCH_REGISTRY.register()
|
57 |
+
class DFDNet(nn.Module):
|
58 |
+
"""DFDNet: Deep Face Dictionary Network.
|
59 |
+
|
60 |
+
It only processes faces with 512x512 size.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
num_feat (int): Number of feature channels.
|
64 |
+
dict_path (str): Path to the facial component dictionary.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, num_feat, dict_path):
|
68 |
+
super().__init__()
|
69 |
+
self.parts = ['left_eye', 'right_eye', 'nose', 'mouth']
|
70 |
+
# part_sizes: [80, 80, 50, 110]
|
71 |
+
channel_sizes = [128, 256, 512, 512]
|
72 |
+
self.feature_sizes = np.array([256, 128, 64, 32])
|
73 |
+
self.vgg_layers = ['relu2_2', 'relu3_4', 'relu4_4', 'conv5_4']
|
74 |
+
self.flag_dict_device = False
|
75 |
+
|
76 |
+
# dict
|
77 |
+
self.dict = torch.load(dict_path)
|
78 |
+
|
79 |
+
# vgg face extractor
|
80 |
+
self.vgg_extractor = VGGFeatureExtractor(
|
81 |
+
layer_name_list=self.vgg_layers,
|
82 |
+
vgg_type='vgg19',
|
83 |
+
use_input_norm=True,
|
84 |
+
range_norm=True,
|
85 |
+
requires_grad=False)
|
86 |
+
|
87 |
+
# attention block for fusing dictionary features and input features
|
88 |
+
self.attn_blocks = nn.ModuleDict()
|
89 |
+
for idx, feat_size in enumerate(self.feature_sizes):
|
90 |
+
for name in self.parts:
|
91 |
+
self.attn_blocks[f'{name}_{feat_size}'] = AttentionBlock(channel_sizes[idx])
|
92 |
+
|
93 |
+
# multi scale dilation block
|
94 |
+
self.multi_scale_dilation = MSDilationBlock(num_feat * 8, dilation=[4, 3, 2, 1])
|
95 |
+
|
96 |
+
# upsampling and reconstruction
|
97 |
+
self.upsample0 = SFTUpBlock(num_feat * 8, num_feat * 8)
|
98 |
+
self.upsample1 = SFTUpBlock(num_feat * 8, num_feat * 4)
|
99 |
+
self.upsample2 = SFTUpBlock(num_feat * 4, num_feat * 2)
|
100 |
+
self.upsample3 = SFTUpBlock(num_feat * 2, num_feat)
|
101 |
+
self.upsample4 = nn.Sequential(
|
102 |
+
spectral_norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)), nn.LeakyReLU(0.2, True), UpResBlock(num_feat),
|
103 |
+
UpResBlock(num_feat), nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1), nn.Tanh())
|
104 |
+
|
105 |
+
def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_name, f_size):
|
106 |
+
"""swap the features from the dictionary."""
|
107 |
+
# get the original vgg features
|
108 |
+
part_feat = vgg_feat[:, :, location[1]:location[3], location[0]:location[2]].clone()
|
109 |
+
# resize original vgg features
|
110 |
+
part_resize_feat = F.interpolate(part_feat, dict_feat.size()[2:4], mode='bilinear', align_corners=False)
|
111 |
+
# use adaptive instance normalization to adjust color and illuminations
|
112 |
+
dict_feat = adaptive_instance_normalization(dict_feat, part_resize_feat)
|
113 |
+
# get similarity scores
|
114 |
+
similarity_score = F.conv2d(part_resize_feat, dict_feat)
|
115 |
+
similarity_score = F.softmax(similarity_score.view(-1), dim=0)
|
116 |
+
# select the most similar features in the dict (after norm)
|
117 |
+
select_idx = torch.argmax(similarity_score)
|
118 |
+
swap_feat = F.interpolate(dict_feat[select_idx:select_idx + 1], part_feat.size()[2:4])
|
119 |
+
# attention
|
120 |
+
attn = self.attn_blocks[f'{part_name}_' + str(f_size)](swap_feat - part_feat)
|
121 |
+
attn_feat = attn * swap_feat
|
122 |
+
# update features
|
123 |
+
updated_feat[:, :, location[1]:location[3], location[0]:location[2]] = attn_feat + part_feat
|
124 |
+
return updated_feat
|
125 |
+
|
126 |
+
def put_dict_to_device(self, x):
|
127 |
+
if self.flag_dict_device is False:
|
128 |
+
for k, v in self.dict.items():
|
129 |
+
for kk, vv in v.items():
|
130 |
+
self.dict[k][kk] = vv.to(x)
|
131 |
+
self.flag_dict_device = True
|
132 |
+
|
133 |
+
def forward(self, x, part_locations):
|
134 |
+
"""
|
135 |
+
Now only support testing with batch size = 0.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
x (Tensor): Input faces with shape (b, c, 512, 512).
|
139 |
+
part_locations (list[Tensor]): Part locations.
|
140 |
+
"""
|
141 |
+
self.put_dict_to_device(x)
|
142 |
+
# extract vggface features
|
143 |
+
vgg_features = self.vgg_extractor(x)
|
144 |
+
# update vggface features using the dictionary for each part
|
145 |
+
updated_vgg_features = []
|
146 |
+
batch = 0 # only supports testing with batch size = 0
|
147 |
+
for vgg_layer, f_size in zip(self.vgg_layers, self.feature_sizes):
|
148 |
+
dict_features = self.dict[f'{f_size}']
|
149 |
+
vgg_feat = vgg_features[vgg_layer]
|
150 |
+
updated_feat = vgg_feat.clone()
|
151 |
+
|
152 |
+
# swap features from dictionary
|
153 |
+
for part_idx, part_name in enumerate(self.parts):
|
154 |
+
location = (part_locations[part_idx][batch] // (512 / f_size)).int()
|
155 |
+
updated_feat = self.swap_feat(vgg_feat, updated_feat, dict_features[part_name], location, part_name,
|
156 |
+
f_size)
|
157 |
+
|
158 |
+
updated_vgg_features.append(updated_feat)
|
159 |
+
|
160 |
+
vgg_feat_dilation = self.multi_scale_dilation(vgg_features['conv5_4'])
|
161 |
+
# use updated vgg features to modulate the upsampled features with
|
162 |
+
# SFT (Spatial Feature Transform) scaling and shifting manner.
|
163 |
+
upsampled_feat = self.upsample0(vgg_feat_dilation, updated_vgg_features[3])
|
164 |
+
upsampled_feat = self.upsample1(upsampled_feat, updated_vgg_features[2])
|
165 |
+
upsampled_feat = self.upsample2(upsampled_feat, updated_vgg_features[1])
|
166 |
+
upsampled_feat = self.upsample3(upsampled_feat, updated_vgg_features[0])
|
167 |
+
out = self.upsample4(upsampled_feat)
|
168 |
+
|
169 |
+
return out
|
basicsr/archs/dfdnet_util.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.autograd import Function
|
5 |
+
from torch.nn.utils.spectral_norm import spectral_norm
|
6 |
+
|
7 |
+
|
8 |
+
class BlurFunctionBackward(Function):
|
9 |
+
|
10 |
+
@staticmethod
|
11 |
+
def forward(ctx, grad_output, kernel, kernel_flip):
|
12 |
+
ctx.save_for_backward(kernel, kernel_flip)
|
13 |
+
grad_input = F.conv2d(grad_output, kernel_flip, padding=1, groups=grad_output.shape[1])
|
14 |
+
return grad_input
|
15 |
+
|
16 |
+
@staticmethod
|
17 |
+
def backward(ctx, gradgrad_output):
|
18 |
+
kernel, _ = ctx.saved_tensors
|
19 |
+
grad_input = F.conv2d(gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1])
|
20 |
+
return grad_input, None, None
|
21 |
+
|
22 |
+
|
23 |
+
class BlurFunction(Function):
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def forward(ctx, x, kernel, kernel_flip):
|
27 |
+
ctx.save_for_backward(kernel, kernel_flip)
|
28 |
+
output = F.conv2d(x, kernel, padding=1, groups=x.shape[1])
|
29 |
+
return output
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
def backward(ctx, grad_output):
|
33 |
+
kernel, kernel_flip = ctx.saved_tensors
|
34 |
+
grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)
|
35 |
+
return grad_input, None, None
|
36 |
+
|
37 |
+
|
38 |
+
blur = BlurFunction.apply
|
39 |
+
|
40 |
+
|
41 |
+
class Blur(nn.Module):
|
42 |
+
|
43 |
+
def __init__(self, channel):
|
44 |
+
super().__init__()
|
45 |
+
kernel = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
|
46 |
+
kernel = kernel.view(1, 1, 3, 3)
|
47 |
+
kernel = kernel / kernel.sum()
|
48 |
+
kernel_flip = torch.flip(kernel, [2, 3])
|
49 |
+
|
50 |
+
self.kernel = kernel.repeat(channel, 1, 1, 1)
|
51 |
+
self.kernel_flip = kernel_flip.repeat(channel, 1, 1, 1)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
return blur(x, self.kernel.type_as(x), self.kernel_flip.type_as(x))
|
55 |
+
|
56 |
+
|
57 |
+
def calc_mean_std(feat, eps=1e-5):
|
58 |
+
"""Calculate mean and std for adaptive_instance_normalization.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
feat (Tensor): 4D tensor.
|
62 |
+
eps (float): A small value added to the variance to avoid
|
63 |
+
divide-by-zero. Default: 1e-5.
|
64 |
+
"""
|
65 |
+
size = feat.size()
|
66 |
+
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
67 |
+
n, c = size[:2]
|
68 |
+
feat_var = feat.view(n, c, -1).var(dim=2) + eps
|
69 |
+
feat_std = feat_var.sqrt().view(n, c, 1, 1)
|
70 |
+
feat_mean = feat.view(n, c, -1).mean(dim=2).view(n, c, 1, 1)
|
71 |
+
return feat_mean, feat_std
|
72 |
+
|
73 |
+
|
74 |
+
def adaptive_instance_normalization(content_feat, style_feat):
|
75 |
+
"""Adaptive instance normalization.
|
76 |
+
|
77 |
+
Adjust the reference features to have the similar color and illuminations
|
78 |
+
as those in the degradate features.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
content_feat (Tensor): The reference feature.
|
82 |
+
style_feat (Tensor): The degradate features.
|
83 |
+
"""
|
84 |
+
size = content_feat.size()
|
85 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
86 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
87 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
88 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
89 |
+
|
90 |
+
|
91 |
+
def AttentionBlock(in_channel):
|
92 |
+
return nn.Sequential(
|
93 |
+
spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
|
94 |
+
spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)))
|
95 |
+
|
96 |
+
|
97 |
+
def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True):
|
98 |
+
"""Conv block used in MSDilationBlock."""
|
99 |
+
|
100 |
+
return nn.Sequential(
|
101 |
+
spectral_norm(
|
102 |
+
nn.Conv2d(
|
103 |
+
in_channels,
|
104 |
+
out_channels,
|
105 |
+
kernel_size=kernel_size,
|
106 |
+
stride=stride,
|
107 |
+
dilation=dilation,
|
108 |
+
padding=((kernel_size - 1) // 2) * dilation,
|
109 |
+
bias=bias)),
|
110 |
+
nn.LeakyReLU(0.2),
|
111 |
+
spectral_norm(
|
112 |
+
nn.Conv2d(
|
113 |
+
out_channels,
|
114 |
+
out_channels,
|
115 |
+
kernel_size=kernel_size,
|
116 |
+
stride=stride,
|
117 |
+
dilation=dilation,
|
118 |
+
padding=((kernel_size - 1) // 2) * dilation,
|
119 |
+
bias=bias)),
|
120 |
+
)
|
121 |
+
|
122 |
+
|
123 |
+
class MSDilationBlock(nn.Module):
|
124 |
+
"""Multi-scale dilation block."""
|
125 |
+
|
126 |
+
def __init__(self, in_channels, kernel_size=3, dilation=(1, 1, 1, 1), bias=True):
|
127 |
+
super(MSDilationBlock, self).__init__()
|
128 |
+
|
129 |
+
self.conv_blocks = nn.ModuleList()
|
130 |
+
for i in range(4):
|
131 |
+
self.conv_blocks.append(conv_block(in_channels, in_channels, kernel_size, dilation=dilation[i], bias=bias))
|
132 |
+
self.conv_fusion = spectral_norm(
|
133 |
+
nn.Conv2d(
|
134 |
+
in_channels * 4,
|
135 |
+
in_channels,
|
136 |
+
kernel_size=kernel_size,
|
137 |
+
stride=1,
|
138 |
+
padding=(kernel_size - 1) // 2,
|
139 |
+
bias=bias))
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
out = []
|
143 |
+
for i in range(4):
|
144 |
+
out.append(self.conv_blocks[i](x))
|
145 |
+
out = torch.cat(out, 1)
|
146 |
+
out = self.conv_fusion(out) + x
|
147 |
+
return out
|
148 |
+
|
149 |
+
|
150 |
+
class UpResBlock(nn.Module):
|
151 |
+
|
152 |
+
def __init__(self, in_channel):
|
153 |
+
super(UpResBlock, self).__init__()
|
154 |
+
self.body = nn.Sequential(
|
155 |
+
nn.Conv2d(in_channel, in_channel, 3, 1, 1),
|
156 |
+
nn.LeakyReLU(0.2, True),
|
157 |
+
nn.Conv2d(in_channel, in_channel, 3, 1, 1),
|
158 |
+
)
|
159 |
+
|
160 |
+
def forward(self, x):
|
161 |
+
out = x + self.body(x)
|
162 |
+
return out
|
basicsr/archs/discriminator_arch.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn as nn
|
2 |
+
from torch.nn import functional as F
|
3 |
+
from torch.nn.utils import spectral_norm
|
4 |
+
|
5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
|
7 |
+
|
8 |
+
@ARCH_REGISTRY.register()
|
9 |
+
class VGGStyleDiscriminator(nn.Module):
|
10 |
+
"""VGG style discriminator with input size 128 x 128 or 256 x 256.
|
11 |
+
|
12 |
+
It is used to train SRGAN, ESRGAN, and VideoGAN.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
16 |
+
num_feat (int): Channel number of base intermediate features.Default: 64.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, num_in_ch, num_feat, input_size=128):
|
20 |
+
super(VGGStyleDiscriminator, self).__init__()
|
21 |
+
self.input_size = input_size
|
22 |
+
assert self.input_size == 128 or self.input_size == 256, (
|
23 |
+
f'input size must be 128 or 256, but received {input_size}')
|
24 |
+
|
25 |
+
self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
|
26 |
+
self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False)
|
27 |
+
self.bn0_1 = nn.BatchNorm2d(num_feat, affine=True)
|
28 |
+
|
29 |
+
self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1, bias=False)
|
30 |
+
self.bn1_0 = nn.BatchNorm2d(num_feat * 2, affine=True)
|
31 |
+
self.conv1_1 = nn.Conv2d(num_feat * 2, num_feat * 2, 4, 2, 1, bias=False)
|
32 |
+
self.bn1_1 = nn.BatchNorm2d(num_feat * 2, affine=True)
|
33 |
+
|
34 |
+
self.conv2_0 = nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1, bias=False)
|
35 |
+
self.bn2_0 = nn.BatchNorm2d(num_feat * 4, affine=True)
|
36 |
+
self.conv2_1 = nn.Conv2d(num_feat * 4, num_feat * 4, 4, 2, 1, bias=False)
|
37 |
+
self.bn2_1 = nn.BatchNorm2d(num_feat * 4, affine=True)
|
38 |
+
|
39 |
+
self.conv3_0 = nn.Conv2d(num_feat * 4, num_feat * 8, 3, 1, 1, bias=False)
|
40 |
+
self.bn3_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
|
41 |
+
self.conv3_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
|
42 |
+
self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
|
43 |
+
|
44 |
+
self.conv4_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
|
45 |
+
self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
|
46 |
+
self.conv4_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
|
47 |
+
self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
|
48 |
+
|
49 |
+
if self.input_size == 256:
|
50 |
+
self.conv5_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
|
51 |
+
self.bn5_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
|
52 |
+
self.conv5_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
|
53 |
+
self.bn5_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
|
54 |
+
|
55 |
+
self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100)
|
56 |
+
self.linear2 = nn.Linear(100, 1)
|
57 |
+
|
58 |
+
# activation function
|
59 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
assert x.size(2) == self.input_size, (f'Input size must be identical to input_size, but received {x.size()}.')
|
63 |
+
|
64 |
+
feat = self.lrelu(self.conv0_0(x))
|
65 |
+
feat = self.lrelu(self.bn0_1(self.conv0_1(feat))) # output spatial size: /2
|
66 |
+
|
67 |
+
feat = self.lrelu(self.bn1_0(self.conv1_0(feat)))
|
68 |
+
feat = self.lrelu(self.bn1_1(self.conv1_1(feat))) # output spatial size: /4
|
69 |
+
|
70 |
+
feat = self.lrelu(self.bn2_0(self.conv2_0(feat)))
|
71 |
+
feat = self.lrelu(self.bn2_1(self.conv2_1(feat))) # output spatial size: /8
|
72 |
+
|
73 |
+
feat = self.lrelu(self.bn3_0(self.conv3_0(feat)))
|
74 |
+
feat = self.lrelu(self.bn3_1(self.conv3_1(feat))) # output spatial size: /16
|
75 |
+
|
76 |
+
feat = self.lrelu(self.bn4_0(self.conv4_0(feat)))
|
77 |
+
feat = self.lrelu(self.bn4_1(self.conv4_1(feat))) # output spatial size: /32
|
78 |
+
|
79 |
+
if self.input_size == 256:
|
80 |
+
feat = self.lrelu(self.bn5_0(self.conv5_0(feat)))
|
81 |
+
feat = self.lrelu(self.bn5_1(self.conv5_1(feat))) # output spatial size: / 64
|
82 |
+
|
83 |
+
# spatial size: (4, 4)
|
84 |
+
feat = feat.view(feat.size(0), -1)
|
85 |
+
feat = self.lrelu(self.linear1(feat))
|
86 |
+
out = self.linear2(feat)
|
87 |
+
return out
|
88 |
+
|
89 |
+
|
90 |
+
@ARCH_REGISTRY.register(suffix='basicsr')
|
91 |
+
class UNetDiscriminatorSN(nn.Module):
|
92 |
+
"""Defines a U-Net discriminator with spectral normalization (SN)
|
93 |
+
|
94 |
+
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
95 |
+
|
96 |
+
Arg:
|
97 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
98 |
+
num_feat (int): Channel number of base intermediate features. Default: 64.
|
99 |
+
skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
|
103 |
+
super(UNetDiscriminatorSN, self).__init__()
|
104 |
+
self.skip_connection = skip_connection
|
105 |
+
norm = spectral_norm
|
106 |
+
# the first convolution
|
107 |
+
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
|
108 |
+
# downsample
|
109 |
+
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
|
110 |
+
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
|
111 |
+
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
|
112 |
+
# upsample
|
113 |
+
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
|
114 |
+
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
|
115 |
+
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
|
116 |
+
# extra convolutions
|
117 |
+
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
118 |
+
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
119 |
+
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
# downsample
|
123 |
+
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
|
124 |
+
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
|
125 |
+
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
|
126 |
+
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
|
127 |
+
|
128 |
+
# upsample
|
129 |
+
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
|
130 |
+
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
|
131 |
+
|
132 |
+
if self.skip_connection:
|
133 |
+
x4 = x4 + x2
|
134 |
+
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
135 |
+
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
|
136 |
+
|
137 |
+
if self.skip_connection:
|
138 |
+
x5 = x5 + x1
|
139 |
+
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
|
140 |
+
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
|
141 |
+
|
142 |
+
if self.skip_connection:
|
143 |
+
x6 = x6 + x0
|
144 |
+
|
145 |
+
# extra convolutions
|
146 |
+
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
|
147 |
+
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
|
148 |
+
out = self.conv9(out)
|
149 |
+
|
150 |
+
return out
|
basicsr/archs/duf_arch.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch import nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
7 |
+
|
8 |
+
|
9 |
+
class DenseBlocksTemporalReduce(nn.Module):
|
10 |
+
"""A concatenation of 3 dense blocks with reduction in temporal dimension.
|
11 |
+
|
12 |
+
Note that the output temporal dimension is 6 fewer the input temporal dimension, since there are 3 blocks.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
num_feat (int): Number of channels in the blocks. Default: 64.
|
16 |
+
num_grow_ch (int): Growing factor of the dense blocks. Default: 32
|
17 |
+
adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
|
18 |
+
Set to false if you want to train from scratch. Default: False.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, num_feat=64, num_grow_ch=32, adapt_official_weights=False):
|
22 |
+
super(DenseBlocksTemporalReduce, self).__init__()
|
23 |
+
if adapt_official_weights:
|
24 |
+
eps = 1e-3
|
25 |
+
momentum = 1e-3
|
26 |
+
else: # pytorch default values
|
27 |
+
eps = 1e-05
|
28 |
+
momentum = 0.1
|
29 |
+
|
30 |
+
self.temporal_reduce1 = nn.Sequential(
|
31 |
+
nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
|
32 |
+
nn.Conv3d(num_feat, num_feat, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True),
|
33 |
+
nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
|
34 |
+
nn.Conv3d(num_feat, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
|
35 |
+
|
36 |
+
self.temporal_reduce2 = nn.Sequential(
|
37 |
+
nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
|
38 |
+
nn.Conv3d(
|
39 |
+
num_feat + num_grow_ch,
|
40 |
+
num_feat + num_grow_ch, (1, 1, 1),
|
41 |
+
stride=(1, 1, 1),
|
42 |
+
padding=(0, 0, 0),
|
43 |
+
bias=True), nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
|
44 |
+
nn.Conv3d(num_feat + num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
|
45 |
+
|
46 |
+
self.temporal_reduce3 = nn.Sequential(
|
47 |
+
nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
|
48 |
+
nn.Conv3d(
|
49 |
+
num_feat + 2 * num_grow_ch,
|
50 |
+
num_feat + 2 * num_grow_ch, (1, 1, 1),
|
51 |
+
stride=(1, 1, 1),
|
52 |
+
padding=(0, 0, 0),
|
53 |
+
bias=True), nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum),
|
54 |
+
nn.ReLU(inplace=True),
|
55 |
+
nn.Conv3d(
|
56 |
+
num_feat + 2 * num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
"""
|
60 |
+
Args:
|
61 |
+
x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Tensor: Output with shape (b, num_feat + num_grow_ch * 3, 1, h, w).
|
65 |
+
"""
|
66 |
+
x1 = self.temporal_reduce1(x)
|
67 |
+
x1 = torch.cat((x[:, :, 1:-1, :, :], x1), 1)
|
68 |
+
|
69 |
+
x2 = self.temporal_reduce2(x1)
|
70 |
+
x2 = torch.cat((x1[:, :, 1:-1, :, :], x2), 1)
|
71 |
+
|
72 |
+
x3 = self.temporal_reduce3(x2)
|
73 |
+
x3 = torch.cat((x2[:, :, 1:-1, :, :], x3), 1)
|
74 |
+
|
75 |
+
return x3
|
76 |
+
|
77 |
+
|
78 |
+
class DenseBlocks(nn.Module):
|
79 |
+
""" A concatenation of N dense blocks.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
num_feat (int): Number of channels in the blocks. Default: 64.
|
83 |
+
num_grow_ch (int): Growing factor of the dense blocks. Default: 32.
|
84 |
+
num_block (int): Number of dense blocks. The values are:
|
85 |
+
DUF-S (16 layers): 3
|
86 |
+
DUF-M (18 layers): 9
|
87 |
+
DUF-L (52 layers): 21
|
88 |
+
adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
|
89 |
+
Set to false if you want to train from scratch. Default: False.
|
90 |
+
"""
|
91 |
+
|
92 |
+
def __init__(self, num_block, num_feat=64, num_grow_ch=16, adapt_official_weights=False):
|
93 |
+
super(DenseBlocks, self).__init__()
|
94 |
+
if adapt_official_weights:
|
95 |
+
eps = 1e-3
|
96 |
+
momentum = 1e-3
|
97 |
+
else: # pytorch default values
|
98 |
+
eps = 1e-05
|
99 |
+
momentum = 0.1
|
100 |
+
|
101 |
+
self.dense_blocks = nn.ModuleList()
|
102 |
+
for i in range(0, num_block):
|
103 |
+
self.dense_blocks.append(
|
104 |
+
nn.Sequential(
|
105 |
+
nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
|
106 |
+
nn.Conv3d(
|
107 |
+
num_feat + i * num_grow_ch,
|
108 |
+
num_feat + i * num_grow_ch, (1, 1, 1),
|
109 |
+
stride=(1, 1, 1),
|
110 |
+
padding=(0, 0, 0),
|
111 |
+
bias=True), nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum),
|
112 |
+
nn.ReLU(inplace=True),
|
113 |
+
nn.Conv3d(
|
114 |
+
num_feat + i * num_grow_ch,
|
115 |
+
num_grow_ch, (3, 3, 3),
|
116 |
+
stride=(1, 1, 1),
|
117 |
+
padding=(1, 1, 1),
|
118 |
+
bias=True)))
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
"""
|
122 |
+
Args:
|
123 |
+
x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
Tensor: Output with shape (b, num_feat + num_block * num_grow_ch, t, h, w).
|
127 |
+
"""
|
128 |
+
for i in range(0, len(self.dense_blocks)):
|
129 |
+
y = self.dense_blocks[i](x)
|
130 |
+
x = torch.cat((x, y), 1)
|
131 |
+
return x
|
132 |
+
|
133 |
+
|
134 |
+
class DynamicUpsamplingFilter(nn.Module):
|
135 |
+
"""Dynamic upsampling filter used in DUF.
|
136 |
+
|
137 |
+
Reference: https://github.com/yhjo09/VSR-DUF
|
138 |
+
|
139 |
+
It only supports input with 3 channels. And it applies the same filters to 3 channels.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
filter_size (tuple): Filter size of generated filters. The shape is (kh, kw). Default: (5, 5).
|
143 |
+
"""
|
144 |
+
|
145 |
+
def __init__(self, filter_size=(5, 5)):
|
146 |
+
super(DynamicUpsamplingFilter, self).__init__()
|
147 |
+
if not isinstance(filter_size, tuple):
|
148 |
+
raise TypeError(f'The type of filter_size must be tuple, but got type{filter_size}')
|
149 |
+
if len(filter_size) != 2:
|
150 |
+
raise ValueError(f'The length of filter size must be 2, but got {len(filter_size)}.')
|
151 |
+
# generate a local expansion filter, similar to im2col
|
152 |
+
self.filter_size = filter_size
|
153 |
+
filter_prod = np.prod(filter_size)
|
154 |
+
expansion_filter = torch.eye(int(filter_prod)).view(filter_prod, 1, *filter_size) # (kh*kw, 1, kh, kw)
|
155 |
+
self.expansion_filter = expansion_filter.repeat(3, 1, 1, 1) # repeat for all the 3 channels
|
156 |
+
|
157 |
+
def forward(self, x, filters):
|
158 |
+
"""Forward function for DynamicUpsamplingFilter.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
x (Tensor): Input image with 3 channels. The shape is (n, 3, h, w).
|
162 |
+
filters (Tensor): Generated dynamic filters. The shape is (n, filter_prod, upsampling_square, h, w).
|
163 |
+
filter_prod: prod of filter kernel size, e.g., 1*5*5=25.
|
164 |
+
upsampling_square: similar to pixel shuffle, upsampling_square = upsampling * upsampling.
|
165 |
+
e.g., for x 4 upsampling, upsampling_square= 4*4 = 16
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
Tensor: Filtered image with shape (n, 3*upsampling_square, h, w)
|
169 |
+
"""
|
170 |
+
n, filter_prod, upsampling_square, h, w = filters.size()
|
171 |
+
kh, kw = self.filter_size
|
172 |
+
expanded_input = F.conv2d(
|
173 |
+
x, self.expansion_filter.to(x), padding=(kh // 2, kw // 2), groups=3) # (n, 3*filter_prod, h, w)
|
174 |
+
expanded_input = expanded_input.view(n, 3, filter_prod, h, w).permute(0, 3, 4, 1,
|
175 |
+
2) # (n, h, w, 3, filter_prod)
|
176 |
+
filters = filters.permute(0, 3, 4, 1, 2) # (n, h, w, filter_prod, upsampling_square]
|
177 |
+
out = torch.matmul(expanded_input, filters) # (n, h, w, 3, upsampling_square)
|
178 |
+
return out.permute(0, 3, 4, 1, 2).view(n, 3 * upsampling_square, h, w)
|
179 |
+
|
180 |
+
|
181 |
+
@ARCH_REGISTRY.register()
|
182 |
+
class DUF(nn.Module):
|
183 |
+
"""Network architecture for DUF
|
184 |
+
|
185 |
+
``Paper: Deep Video Super-Resolution Network Using Dynamic Upsampling Filters Without Explicit Motion Compensation``
|
186 |
+
|
187 |
+
Reference: https://github.com/yhjo09/VSR-DUF
|
188 |
+
|
189 |
+
For all the models below, 'adapt_official_weights' is only necessary when
|
190 |
+
loading the weights converted from the official TensorFlow weights.
|
191 |
+
Please set it to False if you are training the model from scratch.
|
192 |
+
|
193 |
+
There are three models with different model size: DUF16Layers, DUF28Layers,
|
194 |
+
and DUF52Layers. This class is the base class for these models.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
scale (int): The upsampling factor. Default: 4.
|
198 |
+
num_layer (int): The number of layers. Default: 52.
|
199 |
+
adapt_official_weights_weights (bool): Whether to adapt the weights
|
200 |
+
translated from the official implementation. Set to false if you
|
201 |
+
want to train from scratch. Default: False.
|
202 |
+
"""
|
203 |
+
|
204 |
+
def __init__(self, scale=4, num_layer=52, adapt_official_weights=False):
|
205 |
+
super(DUF, self).__init__()
|
206 |
+
self.scale = scale
|
207 |
+
if adapt_official_weights:
|
208 |
+
eps = 1e-3
|
209 |
+
momentum = 1e-3
|
210 |
+
else: # pytorch default values
|
211 |
+
eps = 1e-05
|
212 |
+
momentum = 0.1
|
213 |
+
|
214 |
+
self.conv3d1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
|
215 |
+
self.dynamic_filter = DynamicUpsamplingFilter((5, 5))
|
216 |
+
|
217 |
+
if num_layer == 16:
|
218 |
+
num_block = 3
|
219 |
+
num_grow_ch = 32
|
220 |
+
elif num_layer == 28:
|
221 |
+
num_block = 9
|
222 |
+
num_grow_ch = 16
|
223 |
+
elif num_layer == 52:
|
224 |
+
num_block = 21
|
225 |
+
num_grow_ch = 16
|
226 |
+
else:
|
227 |
+
raise ValueError(f'Only supported (16, 28, 52) layers, but got {num_layer}.')
|
228 |
+
|
229 |
+
self.dense_block1 = DenseBlocks(
|
230 |
+
num_block=num_block, num_feat=64, num_grow_ch=num_grow_ch,
|
231 |
+
adapt_official_weights=adapt_official_weights) # T = 7
|
232 |
+
self.dense_block2 = DenseBlocksTemporalReduce(
|
233 |
+
64 + num_grow_ch * num_block, num_grow_ch, adapt_official_weights=adapt_official_weights) # T = 1
|
234 |
+
channels = 64 + num_grow_ch * num_block + num_grow_ch * 3
|
235 |
+
self.bn3d2 = nn.BatchNorm3d(channels, eps=eps, momentum=momentum)
|
236 |
+
self.conv3d2 = nn.Conv3d(channels, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
|
237 |
+
|
238 |
+
self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
|
239 |
+
self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
|
240 |
+
|
241 |
+
self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
|
242 |
+
self.conv3d_f2 = nn.Conv3d(
|
243 |
+
512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
|
244 |
+
|
245 |
+
def forward(self, x):
|
246 |
+
"""
|
247 |
+
Args:
|
248 |
+
x (Tensor): Input with shape (b, 7, c, h, w)
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
Tensor: Output with shape (b, c, h * scale, w * scale)
|
252 |
+
"""
|
253 |
+
num_batches, num_imgs, _, h, w = x.size()
|
254 |
+
|
255 |
+
x = x.permute(0, 2, 1, 3, 4) # (b, c, 7, h, w) for Conv3D
|
256 |
+
x_center = x[:, :, num_imgs // 2, :, :]
|
257 |
+
|
258 |
+
x = self.conv3d1(x)
|
259 |
+
x = self.dense_block1(x)
|
260 |
+
x = self.dense_block2(x)
|
261 |
+
x = F.relu(self.bn3d2(x), inplace=True)
|
262 |
+
x = F.relu(self.conv3d2(x), inplace=True)
|
263 |
+
|
264 |
+
# residual image
|
265 |
+
res = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True))
|
266 |
+
|
267 |
+
# filter
|
268 |
+
filter_ = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True))
|
269 |
+
filter_ = F.softmax(filter_.view(num_batches, 25, self.scale**2, h, w), dim=1)
|
270 |
+
|
271 |
+
# dynamic filter
|
272 |
+
out = self.dynamic_filter(x_center, filter_)
|
273 |
+
out += res.squeeze_(2)
|
274 |
+
out = F.pixel_shuffle(out, self.scale)
|
275 |
+
|
276 |
+
return out
|
basicsr/archs/ecbsr_arch.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
|
7 |
+
|
8 |
+
class SeqConv3x3(nn.Module):
|
9 |
+
"""The re-parameterizable block used in the ECBSR architecture.
|
10 |
+
|
11 |
+
``Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices``
|
12 |
+
|
13 |
+
Reference: https://github.com/xindongzhang/ECBSR
|
14 |
+
|
15 |
+
Args:
|
16 |
+
seq_type (str): Sequence type, option: conv1x1-conv3x3 | conv1x1-sobelx | conv1x1-sobely | conv1x1-laplacian.
|
17 |
+
in_channels (int): Channel number of input.
|
18 |
+
out_channels (int): Channel number of output.
|
19 |
+
depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, seq_type, in_channels, out_channels, depth_multiplier=1):
|
23 |
+
super(SeqConv3x3, self).__init__()
|
24 |
+
self.seq_type = seq_type
|
25 |
+
self.in_channels = in_channels
|
26 |
+
self.out_channels = out_channels
|
27 |
+
|
28 |
+
if self.seq_type == 'conv1x1-conv3x3':
|
29 |
+
self.mid_planes = int(out_channels * depth_multiplier)
|
30 |
+
conv0 = torch.nn.Conv2d(self.in_channels, self.mid_planes, kernel_size=1, padding=0)
|
31 |
+
self.k0 = conv0.weight
|
32 |
+
self.b0 = conv0.bias
|
33 |
+
|
34 |
+
conv1 = torch.nn.Conv2d(self.mid_planes, self.out_channels, kernel_size=3)
|
35 |
+
self.k1 = conv1.weight
|
36 |
+
self.b1 = conv1.bias
|
37 |
+
|
38 |
+
elif self.seq_type == 'conv1x1-sobelx':
|
39 |
+
conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
|
40 |
+
self.k0 = conv0.weight
|
41 |
+
self.b0 = conv0.bias
|
42 |
+
|
43 |
+
# init scale and bias
|
44 |
+
scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
|
45 |
+
self.scale = nn.Parameter(scale)
|
46 |
+
bias = torch.randn(self.out_channels) * 1e-3
|
47 |
+
bias = torch.reshape(bias, (self.out_channels, ))
|
48 |
+
self.bias = nn.Parameter(bias)
|
49 |
+
# init mask
|
50 |
+
self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
|
51 |
+
for i in range(self.out_channels):
|
52 |
+
self.mask[i, 0, 0, 0] = 1.0
|
53 |
+
self.mask[i, 0, 1, 0] = 2.0
|
54 |
+
self.mask[i, 0, 2, 0] = 1.0
|
55 |
+
self.mask[i, 0, 0, 2] = -1.0
|
56 |
+
self.mask[i, 0, 1, 2] = -2.0
|
57 |
+
self.mask[i, 0, 2, 2] = -1.0
|
58 |
+
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
|
59 |
+
|
60 |
+
elif self.seq_type == 'conv1x1-sobely':
|
61 |
+
conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
|
62 |
+
self.k0 = conv0.weight
|
63 |
+
self.b0 = conv0.bias
|
64 |
+
|
65 |
+
# init scale and bias
|
66 |
+
scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
|
67 |
+
self.scale = nn.Parameter(torch.FloatTensor(scale))
|
68 |
+
bias = torch.randn(self.out_channels) * 1e-3
|
69 |
+
bias = torch.reshape(bias, (self.out_channels, ))
|
70 |
+
self.bias = nn.Parameter(torch.FloatTensor(bias))
|
71 |
+
# init mask
|
72 |
+
self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
|
73 |
+
for i in range(self.out_channels):
|
74 |
+
self.mask[i, 0, 0, 0] = 1.0
|
75 |
+
self.mask[i, 0, 0, 1] = 2.0
|
76 |
+
self.mask[i, 0, 0, 2] = 1.0
|
77 |
+
self.mask[i, 0, 2, 0] = -1.0
|
78 |
+
self.mask[i, 0, 2, 1] = -2.0
|
79 |
+
self.mask[i, 0, 2, 2] = -1.0
|
80 |
+
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
|
81 |
+
|
82 |
+
elif self.seq_type == 'conv1x1-laplacian':
|
83 |
+
conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
|
84 |
+
self.k0 = conv0.weight
|
85 |
+
self.b0 = conv0.bias
|
86 |
+
|
87 |
+
# init scale and bias
|
88 |
+
scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
|
89 |
+
self.scale = nn.Parameter(torch.FloatTensor(scale))
|
90 |
+
bias = torch.randn(self.out_channels) * 1e-3
|
91 |
+
bias = torch.reshape(bias, (self.out_channels, ))
|
92 |
+
self.bias = nn.Parameter(torch.FloatTensor(bias))
|
93 |
+
# init mask
|
94 |
+
self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
|
95 |
+
for i in range(self.out_channels):
|
96 |
+
self.mask[i, 0, 0, 1] = 1.0
|
97 |
+
self.mask[i, 0, 1, 0] = 1.0
|
98 |
+
self.mask[i, 0, 1, 2] = 1.0
|
99 |
+
self.mask[i, 0, 2, 1] = 1.0
|
100 |
+
self.mask[i, 0, 1, 1] = -4.0
|
101 |
+
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
|
102 |
+
else:
|
103 |
+
raise ValueError('The type of seqconv is not supported!')
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
if self.seq_type == 'conv1x1-conv3x3':
|
107 |
+
# conv-1x1
|
108 |
+
y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
|
109 |
+
# explicitly padding with bias
|
110 |
+
y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
|
111 |
+
b0_pad = self.b0.view(1, -1, 1, 1)
|
112 |
+
y0[:, :, 0:1, :] = b0_pad
|
113 |
+
y0[:, :, -1:, :] = b0_pad
|
114 |
+
y0[:, :, :, 0:1] = b0_pad
|
115 |
+
y0[:, :, :, -1:] = b0_pad
|
116 |
+
# conv-3x3
|
117 |
+
y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1)
|
118 |
+
else:
|
119 |
+
y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
|
120 |
+
# explicitly padding with bias
|
121 |
+
y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
|
122 |
+
b0_pad = self.b0.view(1, -1, 1, 1)
|
123 |
+
y0[:, :, 0:1, :] = b0_pad
|
124 |
+
y0[:, :, -1:, :] = b0_pad
|
125 |
+
y0[:, :, :, 0:1] = b0_pad
|
126 |
+
y0[:, :, :, -1:] = b0_pad
|
127 |
+
# conv-3x3
|
128 |
+
y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_channels)
|
129 |
+
return y1
|
130 |
+
|
131 |
+
def rep_params(self):
|
132 |
+
device = self.k0.get_device()
|
133 |
+
if device < 0:
|
134 |
+
device = None
|
135 |
+
|
136 |
+
if self.seq_type == 'conv1x1-conv3x3':
|
137 |
+
# re-param conv kernel
|
138 |
+
rep_weight = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3))
|
139 |
+
# re-param conv bias
|
140 |
+
rep_bias = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
|
141 |
+
rep_bias = F.conv2d(input=rep_bias, weight=self.k1).view(-1, ) + self.b1
|
142 |
+
else:
|
143 |
+
tmp = self.scale * self.mask
|
144 |
+
k1 = torch.zeros((self.out_channels, self.out_channels, 3, 3), device=device)
|
145 |
+
for i in range(self.out_channels):
|
146 |
+
k1[i, i, :, :] = tmp[i, 0, :, :]
|
147 |
+
b1 = self.bias
|
148 |
+
# re-param conv kernel
|
149 |
+
rep_weight = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3))
|
150 |
+
# re-param conv bias
|
151 |
+
rep_bias = torch.ones(1, self.out_channels, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
|
152 |
+
rep_bias = F.conv2d(input=rep_bias, weight=k1).view(-1, ) + b1
|
153 |
+
return rep_weight, rep_bias
|
154 |
+
|
155 |
+
|
156 |
+
class ECB(nn.Module):
|
157 |
+
"""The ECB block used in the ECBSR architecture.
|
158 |
+
|
159 |
+
Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
|
160 |
+
Ref git repo: https://github.com/xindongzhang/ECBSR
|
161 |
+
|
162 |
+
Args:
|
163 |
+
in_channels (int): Channel number of input.
|
164 |
+
out_channels (int): Channel number of output.
|
165 |
+
depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
|
166 |
+
act_type (str): Activation type. Option: prelu | relu | rrelu | softplus | linear. Default: prelu.
|
167 |
+
with_idt (bool): Whether to use identity connection. Default: False.
|
168 |
+
"""
|
169 |
+
|
170 |
+
def __init__(self, in_channels, out_channels, depth_multiplier, act_type='prelu', with_idt=False):
|
171 |
+
super(ECB, self).__init__()
|
172 |
+
|
173 |
+
self.depth_multiplier = depth_multiplier
|
174 |
+
self.in_channels = in_channels
|
175 |
+
self.out_channels = out_channels
|
176 |
+
self.act_type = act_type
|
177 |
+
|
178 |
+
if with_idt and (self.in_channels == self.out_channels):
|
179 |
+
self.with_idt = True
|
180 |
+
else:
|
181 |
+
self.with_idt = False
|
182 |
+
|
183 |
+
self.conv3x3 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1)
|
184 |
+
self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.in_channels, self.out_channels, self.depth_multiplier)
|
185 |
+
self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.in_channels, self.out_channels)
|
186 |
+
self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.in_channels, self.out_channels)
|
187 |
+
self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.in_channels, self.out_channels)
|
188 |
+
|
189 |
+
if self.act_type == 'prelu':
|
190 |
+
self.act = nn.PReLU(num_parameters=self.out_channels)
|
191 |
+
elif self.act_type == 'relu':
|
192 |
+
self.act = nn.ReLU(inplace=True)
|
193 |
+
elif self.act_type == 'rrelu':
|
194 |
+
self.act = nn.RReLU(lower=-0.05, upper=0.05)
|
195 |
+
elif self.act_type == 'softplus':
|
196 |
+
self.act = nn.Softplus()
|
197 |
+
elif self.act_type == 'linear':
|
198 |
+
pass
|
199 |
+
else:
|
200 |
+
raise ValueError('The type of activation if not support!')
|
201 |
+
|
202 |
+
def forward(self, x):
|
203 |
+
if self.training:
|
204 |
+
y = self.conv3x3(x) + self.conv1x1_3x3(x) + self.conv1x1_sbx(x) + self.conv1x1_sby(x) + self.conv1x1_lpl(x)
|
205 |
+
if self.with_idt:
|
206 |
+
y += x
|
207 |
+
else:
|
208 |
+
rep_weight, rep_bias = self.rep_params()
|
209 |
+
y = F.conv2d(input=x, weight=rep_weight, bias=rep_bias, stride=1, padding=1)
|
210 |
+
if self.act_type != 'linear':
|
211 |
+
y = self.act(y)
|
212 |
+
return y
|
213 |
+
|
214 |
+
def rep_params(self):
|
215 |
+
weight0, bias0 = self.conv3x3.weight, self.conv3x3.bias
|
216 |
+
weight1, bias1 = self.conv1x1_3x3.rep_params()
|
217 |
+
weight2, bias2 = self.conv1x1_sbx.rep_params()
|
218 |
+
weight3, bias3 = self.conv1x1_sby.rep_params()
|
219 |
+
weight4, bias4 = self.conv1x1_lpl.rep_params()
|
220 |
+
rep_weight, rep_bias = (weight0 + weight1 + weight2 + weight3 + weight4), (
|
221 |
+
bias0 + bias1 + bias2 + bias3 + bias4)
|
222 |
+
|
223 |
+
if self.with_idt:
|
224 |
+
device = rep_weight.get_device()
|
225 |
+
if device < 0:
|
226 |
+
device = None
|
227 |
+
weight_idt = torch.zeros(self.out_channels, self.out_channels, 3, 3, device=device)
|
228 |
+
for i in range(self.out_channels):
|
229 |
+
weight_idt[i, i, 1, 1] = 1.0
|
230 |
+
bias_idt = 0.0
|
231 |
+
rep_weight, rep_bias = rep_weight + weight_idt, rep_bias + bias_idt
|
232 |
+
return rep_weight, rep_bias
|
233 |
+
|
234 |
+
|
235 |
+
@ARCH_REGISTRY.register()
|
236 |
+
class ECBSR(nn.Module):
|
237 |
+
"""ECBSR architecture.
|
238 |
+
|
239 |
+
Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
|
240 |
+
Ref git repo: https://github.com/xindongzhang/ECBSR
|
241 |
+
|
242 |
+
Args:
|
243 |
+
num_in_ch (int): Channel number of inputs.
|
244 |
+
num_out_ch (int): Channel number of outputs.
|
245 |
+
num_block (int): Block number in the trunk network.
|
246 |
+
num_channel (int): Channel number.
|
247 |
+
with_idt (bool): Whether use identity in convolution layers.
|
248 |
+
act_type (str): Activation type.
|
249 |
+
scale (int): Upsampling factor.
|
250 |
+
"""
|
251 |
+
|
252 |
+
def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_type, scale):
|
253 |
+
super(ECBSR, self).__init__()
|
254 |
+
self.num_in_ch = num_in_ch
|
255 |
+
self.scale = scale
|
256 |
+
|
257 |
+
backbone = []
|
258 |
+
backbone += [ECB(num_in_ch, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
|
259 |
+
for _ in range(num_block):
|
260 |
+
backbone += [ECB(num_channel, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
|
261 |
+
backbone += [
|
262 |
+
ECB(num_channel, num_out_ch * scale * scale, depth_multiplier=2.0, act_type='linear', with_idt=with_idt)
|
263 |
+
]
|
264 |
+
|
265 |
+
self.backbone = nn.Sequential(*backbone)
|
266 |
+
self.upsampler = nn.PixelShuffle(scale)
|
267 |
+
|
268 |
+
def forward(self, x):
|
269 |
+
if self.num_in_ch > 1:
|
270 |
+
shortcut = torch.repeat_interleave(x, self.scale * self.scale, dim=1)
|
271 |
+
else:
|
272 |
+
shortcut = x # will repeat the input in the channel dimension (repeat scale * scale times)
|
273 |
+
y = self.backbone(x) + shortcut
|
274 |
+
y = self.upsampler(y)
|
275 |
+
return y
|
basicsr/archs/edsr_arch.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
|
4 |
+
from basicsr.archs.arch_util import ResidualBlockNoBN, Upsample, make_layer
|
5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
|
7 |
+
|
8 |
+
@ARCH_REGISTRY.register()
|
9 |
+
class EDSR(nn.Module):
|
10 |
+
"""EDSR network structure.
|
11 |
+
|
12 |
+
Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution.
|
13 |
+
Ref git repo: https://github.com/thstkdgus35/EDSR-PyTorch
|
14 |
+
|
15 |
+
Args:
|
16 |
+
num_in_ch (int): Channel number of inputs.
|
17 |
+
num_out_ch (int): Channel number of outputs.
|
18 |
+
num_feat (int): Channel number of intermediate features.
|
19 |
+
Default: 64.
|
20 |
+
num_block (int): Block number in the trunk network. Default: 16.
|
21 |
+
upscale (int): Upsampling factor. Support 2^n and 3.
|
22 |
+
Default: 4.
|
23 |
+
res_scale (float): Used to scale the residual in residual block.
|
24 |
+
Default: 1.
|
25 |
+
img_range (float): Image range. Default: 255.
|
26 |
+
rgb_mean (tuple[float]): Image mean in RGB orders.
|
27 |
+
Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self,
|
31 |
+
num_in_ch,
|
32 |
+
num_out_ch,
|
33 |
+
num_feat=64,
|
34 |
+
num_block=16,
|
35 |
+
upscale=4,
|
36 |
+
res_scale=1,
|
37 |
+
img_range=255.,
|
38 |
+
rgb_mean=(0.4488, 0.4371, 0.4040)):
|
39 |
+
super(EDSR, self).__init__()
|
40 |
+
|
41 |
+
self.img_range = img_range
|
42 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
43 |
+
|
44 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
45 |
+
self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat, res_scale=res_scale, pytorch_init=True)
|
46 |
+
self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
47 |
+
self.upsample = Upsample(upscale, num_feat)
|
48 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
self.mean = self.mean.type_as(x)
|
52 |
+
|
53 |
+
x = (x - self.mean) * self.img_range
|
54 |
+
x = self.conv_first(x)
|
55 |
+
res = self.conv_after_body(self.body(x))
|
56 |
+
res += x
|
57 |
+
|
58 |
+
x = self.conv_last(self.upsample(res))
|
59 |
+
x = x / self.img_range + self.mean
|
60 |
+
|
61 |
+
return x
|
basicsr/archs/edvr_arch.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
from .arch_util import DCNv2Pack, ResidualBlockNoBN, make_layer
|
7 |
+
|
8 |
+
|
9 |
+
class PCDAlignment(nn.Module):
|
10 |
+
"""Alignment module using Pyramid, Cascading and Deformable convolution
|
11 |
+
(PCD). It is used in EDVR.
|
12 |
+
|
13 |
+
``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks``
|
14 |
+
|
15 |
+
Args:
|
16 |
+
num_feat (int): Channel number of middle features. Default: 64.
|
17 |
+
deformable_groups (int): Deformable groups. Defaults: 8.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, num_feat=64, deformable_groups=8):
|
21 |
+
super(PCDAlignment, self).__init__()
|
22 |
+
|
23 |
+
# Pyramid has three levels:
|
24 |
+
# L3: level 3, 1/4 spatial size
|
25 |
+
# L2: level 2, 1/2 spatial size
|
26 |
+
# L1: level 1, original spatial size
|
27 |
+
self.offset_conv1 = nn.ModuleDict()
|
28 |
+
self.offset_conv2 = nn.ModuleDict()
|
29 |
+
self.offset_conv3 = nn.ModuleDict()
|
30 |
+
self.dcn_pack = nn.ModuleDict()
|
31 |
+
self.feat_conv = nn.ModuleDict()
|
32 |
+
|
33 |
+
# Pyramids
|
34 |
+
for i in range(3, 0, -1):
|
35 |
+
level = f'l{i}'
|
36 |
+
self.offset_conv1[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
|
37 |
+
if i == 3:
|
38 |
+
self.offset_conv2[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
39 |
+
else:
|
40 |
+
self.offset_conv2[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
|
41 |
+
self.offset_conv3[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
42 |
+
self.dcn_pack[level] = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
|
43 |
+
|
44 |
+
if i < 3:
|
45 |
+
self.feat_conv[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
|
46 |
+
|
47 |
+
# Cascading dcn
|
48 |
+
self.cas_offset_conv1 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
|
49 |
+
self.cas_offset_conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
50 |
+
self.cas_dcnpack = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
|
51 |
+
|
52 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
53 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
54 |
+
|
55 |
+
def forward(self, nbr_feat_l, ref_feat_l):
|
56 |
+
"""Align neighboring frame features to the reference frame features.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
nbr_feat_l (list[Tensor]): Neighboring feature list. It
|
60 |
+
contains three pyramid levels (L1, L2, L3),
|
61 |
+
each with shape (b, c, h, w).
|
62 |
+
ref_feat_l (list[Tensor]): Reference feature list. It
|
63 |
+
contains three pyramid levels (L1, L2, L3),
|
64 |
+
each with shape (b, c, h, w).
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
Tensor: Aligned features.
|
68 |
+
"""
|
69 |
+
# Pyramids
|
70 |
+
upsampled_offset, upsampled_feat = None, None
|
71 |
+
for i in range(3, 0, -1):
|
72 |
+
level = f'l{i}'
|
73 |
+
offset = torch.cat([nbr_feat_l[i - 1], ref_feat_l[i - 1]], dim=1)
|
74 |
+
offset = self.lrelu(self.offset_conv1[level](offset))
|
75 |
+
if i == 3:
|
76 |
+
offset = self.lrelu(self.offset_conv2[level](offset))
|
77 |
+
else:
|
78 |
+
offset = self.lrelu(self.offset_conv2[level](torch.cat([offset, upsampled_offset], dim=1)))
|
79 |
+
offset = self.lrelu(self.offset_conv3[level](offset))
|
80 |
+
|
81 |
+
feat = self.dcn_pack[level](nbr_feat_l[i - 1], offset)
|
82 |
+
if i < 3:
|
83 |
+
feat = self.feat_conv[level](torch.cat([feat, upsampled_feat], dim=1))
|
84 |
+
if i > 1:
|
85 |
+
feat = self.lrelu(feat)
|
86 |
+
|
87 |
+
if i > 1: # upsample offset and features
|
88 |
+
# x2: when we upsample the offset, we should also enlarge
|
89 |
+
# the magnitude.
|
90 |
+
upsampled_offset = self.upsample(offset) * 2
|
91 |
+
upsampled_feat = self.upsample(feat)
|
92 |
+
|
93 |
+
# Cascading
|
94 |
+
offset = torch.cat([feat, ref_feat_l[0]], dim=1)
|
95 |
+
offset = self.lrelu(self.cas_offset_conv2(self.lrelu(self.cas_offset_conv1(offset))))
|
96 |
+
feat = self.lrelu(self.cas_dcnpack(feat, offset))
|
97 |
+
return feat
|
98 |
+
|
99 |
+
|
100 |
+
class TSAFusion(nn.Module):
|
101 |
+
"""Temporal Spatial Attention (TSA) fusion module.
|
102 |
+
|
103 |
+
Temporal: Calculate the correlation between center frame and
|
104 |
+
neighboring frames;
|
105 |
+
Spatial: It has 3 pyramid levels, the attention is similar to SFT.
|
106 |
+
(SFT: Recovering realistic texture in image super-resolution by deep
|
107 |
+
spatial feature transform.)
|
108 |
+
|
109 |
+
Args:
|
110 |
+
num_feat (int): Channel number of middle features. Default: 64.
|
111 |
+
num_frame (int): Number of frames. Default: 5.
|
112 |
+
center_frame_idx (int): The index of center frame. Default: 2.
|
113 |
+
"""
|
114 |
+
|
115 |
+
def __init__(self, num_feat=64, num_frame=5, center_frame_idx=2):
|
116 |
+
super(TSAFusion, self).__init__()
|
117 |
+
self.center_frame_idx = center_frame_idx
|
118 |
+
# temporal attention (before fusion conv)
|
119 |
+
self.temporal_attn1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
120 |
+
self.temporal_attn2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
121 |
+
self.feat_fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
|
122 |
+
|
123 |
+
# spatial attention (after fusion conv)
|
124 |
+
self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
|
125 |
+
self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)
|
126 |
+
self.spatial_attn1 = nn.Conv2d(num_frame * num_feat, num_feat, 1)
|
127 |
+
self.spatial_attn2 = nn.Conv2d(num_feat * 2, num_feat, 1)
|
128 |
+
self.spatial_attn3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
129 |
+
self.spatial_attn4 = nn.Conv2d(num_feat, num_feat, 1)
|
130 |
+
self.spatial_attn5 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
131 |
+
self.spatial_attn_l1 = nn.Conv2d(num_feat, num_feat, 1)
|
132 |
+
self.spatial_attn_l2 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
|
133 |
+
self.spatial_attn_l3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
134 |
+
self.spatial_attn_add1 = nn.Conv2d(num_feat, num_feat, 1)
|
135 |
+
self.spatial_attn_add2 = nn.Conv2d(num_feat, num_feat, 1)
|
136 |
+
|
137 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
138 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
139 |
+
|
140 |
+
def forward(self, aligned_feat):
|
141 |
+
"""
|
142 |
+
Args:
|
143 |
+
aligned_feat (Tensor): Aligned features with shape (b, t, c, h, w).
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
Tensor: Features after TSA with the shape (b, c, h, w).
|
147 |
+
"""
|
148 |
+
b, t, c, h, w = aligned_feat.size()
|
149 |
+
# temporal attention
|
150 |
+
embedding_ref = self.temporal_attn1(aligned_feat[:, self.center_frame_idx, :, :, :].clone())
|
151 |
+
embedding = self.temporal_attn2(aligned_feat.view(-1, c, h, w))
|
152 |
+
embedding = embedding.view(b, t, -1, h, w) # (b, t, c, h, w)
|
153 |
+
|
154 |
+
corr_l = [] # correlation list
|
155 |
+
for i in range(t):
|
156 |
+
emb_neighbor = embedding[:, i, :, :, :]
|
157 |
+
corr = torch.sum(emb_neighbor * embedding_ref, 1) # (b, h, w)
|
158 |
+
corr_l.append(corr.unsqueeze(1)) # (b, 1, h, w)
|
159 |
+
corr_prob = torch.sigmoid(torch.cat(corr_l, dim=1)) # (b, t, h, w)
|
160 |
+
corr_prob = corr_prob.unsqueeze(2).expand(b, t, c, h, w)
|
161 |
+
corr_prob = corr_prob.contiguous().view(b, -1, h, w) # (b, t*c, h, w)
|
162 |
+
aligned_feat = aligned_feat.view(b, -1, h, w) * corr_prob
|
163 |
+
|
164 |
+
# fusion
|
165 |
+
feat = self.lrelu(self.feat_fusion(aligned_feat))
|
166 |
+
|
167 |
+
# spatial attention
|
168 |
+
attn = self.lrelu(self.spatial_attn1(aligned_feat))
|
169 |
+
attn_max = self.max_pool(attn)
|
170 |
+
attn_avg = self.avg_pool(attn)
|
171 |
+
attn = self.lrelu(self.spatial_attn2(torch.cat([attn_max, attn_avg], dim=1)))
|
172 |
+
# pyramid levels
|
173 |
+
attn_level = self.lrelu(self.spatial_attn_l1(attn))
|
174 |
+
attn_max = self.max_pool(attn_level)
|
175 |
+
attn_avg = self.avg_pool(attn_level)
|
176 |
+
attn_level = self.lrelu(self.spatial_attn_l2(torch.cat([attn_max, attn_avg], dim=1)))
|
177 |
+
attn_level = self.lrelu(self.spatial_attn_l3(attn_level))
|
178 |
+
attn_level = self.upsample(attn_level)
|
179 |
+
|
180 |
+
attn = self.lrelu(self.spatial_attn3(attn)) + attn_level
|
181 |
+
attn = self.lrelu(self.spatial_attn4(attn))
|
182 |
+
attn = self.upsample(attn)
|
183 |
+
attn = self.spatial_attn5(attn)
|
184 |
+
attn_add = self.spatial_attn_add2(self.lrelu(self.spatial_attn_add1(attn)))
|
185 |
+
attn = torch.sigmoid(attn)
|
186 |
+
|
187 |
+
# after initialization, * 2 makes (attn * 2) to be close to 1.
|
188 |
+
feat = feat * attn * 2 + attn_add
|
189 |
+
return feat
|
190 |
+
|
191 |
+
|
192 |
+
class PredeblurModule(nn.Module):
|
193 |
+
"""Pre-dublur module.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
num_in_ch (int): Channel number of input image. Default: 3.
|
197 |
+
num_feat (int): Channel number of intermediate features. Default: 64.
|
198 |
+
hr_in (bool): Whether the input has high resolution. Default: False.
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(self, num_in_ch=3, num_feat=64, hr_in=False):
|
202 |
+
super(PredeblurModule, self).__init__()
|
203 |
+
self.hr_in = hr_in
|
204 |
+
|
205 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
206 |
+
if self.hr_in:
|
207 |
+
# downsample x4 by stride conv
|
208 |
+
self.stride_conv_hr1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
209 |
+
self.stride_conv_hr2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
210 |
+
|
211 |
+
# generate feature pyramid
|
212 |
+
self.stride_conv_l2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
213 |
+
self.stride_conv_l3 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
214 |
+
|
215 |
+
self.resblock_l3 = ResidualBlockNoBN(num_feat=num_feat)
|
216 |
+
self.resblock_l2_1 = ResidualBlockNoBN(num_feat=num_feat)
|
217 |
+
self.resblock_l2_2 = ResidualBlockNoBN(num_feat=num_feat)
|
218 |
+
self.resblock_l1 = nn.ModuleList([ResidualBlockNoBN(num_feat=num_feat) for i in range(5)])
|
219 |
+
|
220 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
221 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
222 |
+
|
223 |
+
def forward(self, x):
|
224 |
+
feat_l1 = self.lrelu(self.conv_first(x))
|
225 |
+
if self.hr_in:
|
226 |
+
feat_l1 = self.lrelu(self.stride_conv_hr1(feat_l1))
|
227 |
+
feat_l1 = self.lrelu(self.stride_conv_hr2(feat_l1))
|
228 |
+
|
229 |
+
# generate feature pyramid
|
230 |
+
feat_l2 = self.lrelu(self.stride_conv_l2(feat_l1))
|
231 |
+
feat_l3 = self.lrelu(self.stride_conv_l3(feat_l2))
|
232 |
+
|
233 |
+
feat_l3 = self.upsample(self.resblock_l3(feat_l3))
|
234 |
+
feat_l2 = self.resblock_l2_1(feat_l2) + feat_l3
|
235 |
+
feat_l2 = self.upsample(self.resblock_l2_2(feat_l2))
|
236 |
+
|
237 |
+
for i in range(2):
|
238 |
+
feat_l1 = self.resblock_l1[i](feat_l1)
|
239 |
+
feat_l1 = feat_l1 + feat_l2
|
240 |
+
for i in range(2, 5):
|
241 |
+
feat_l1 = self.resblock_l1[i](feat_l1)
|
242 |
+
return feat_l1
|
243 |
+
|
244 |
+
|
245 |
+
@ARCH_REGISTRY.register()
|
246 |
+
class EDVR(nn.Module):
|
247 |
+
"""EDVR network structure for video super-resolution.
|
248 |
+
|
249 |
+
Now only support X4 upsampling factor.
|
250 |
+
|
251 |
+
``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks``
|
252 |
+
|
253 |
+
Args:
|
254 |
+
num_in_ch (int): Channel number of input image. Default: 3.
|
255 |
+
num_out_ch (int): Channel number of output image. Default: 3.
|
256 |
+
num_feat (int): Channel number of intermediate features. Default: 64.
|
257 |
+
num_frame (int): Number of input frames. Default: 5.
|
258 |
+
deformable_groups (int): Deformable groups. Defaults: 8.
|
259 |
+
num_extract_block (int): Number of blocks for feature extraction.
|
260 |
+
Default: 5.
|
261 |
+
num_reconstruct_block (int): Number of blocks for reconstruction.
|
262 |
+
Default: 10.
|
263 |
+
center_frame_idx (int): The index of center frame. Frame counting from
|
264 |
+
0. Default: Middle of input frames.
|
265 |
+
hr_in (bool): Whether the input has high resolution. Default: False.
|
266 |
+
with_predeblur (bool): Whether has predeblur module.
|
267 |
+
Default: False.
|
268 |
+
with_tsa (bool): Whether has TSA module. Default: True.
|
269 |
+
"""
|
270 |
+
|
271 |
+
def __init__(self,
|
272 |
+
num_in_ch=3,
|
273 |
+
num_out_ch=3,
|
274 |
+
num_feat=64,
|
275 |
+
num_frame=5,
|
276 |
+
deformable_groups=8,
|
277 |
+
num_extract_block=5,
|
278 |
+
num_reconstruct_block=10,
|
279 |
+
center_frame_idx=None,
|
280 |
+
hr_in=False,
|
281 |
+
with_predeblur=False,
|
282 |
+
with_tsa=True):
|
283 |
+
super(EDVR, self).__init__()
|
284 |
+
if center_frame_idx is None:
|
285 |
+
self.center_frame_idx = num_frame // 2
|
286 |
+
else:
|
287 |
+
self.center_frame_idx = center_frame_idx
|
288 |
+
self.hr_in = hr_in
|
289 |
+
self.with_predeblur = with_predeblur
|
290 |
+
self.with_tsa = with_tsa
|
291 |
+
|
292 |
+
# extract features for each frame
|
293 |
+
if self.with_predeblur:
|
294 |
+
self.predeblur = PredeblurModule(num_feat=num_feat, hr_in=self.hr_in)
|
295 |
+
self.conv_1x1 = nn.Conv2d(num_feat, num_feat, 1, 1)
|
296 |
+
else:
|
297 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
298 |
+
|
299 |
+
# extract pyramid features
|
300 |
+
self.feature_extraction = make_layer(ResidualBlockNoBN, num_extract_block, num_feat=num_feat)
|
301 |
+
self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
302 |
+
self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
303 |
+
self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
304 |
+
self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
305 |
+
|
306 |
+
# pcd and tsa module
|
307 |
+
self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=deformable_groups)
|
308 |
+
if self.with_tsa:
|
309 |
+
self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_frame, center_frame_idx=self.center_frame_idx)
|
310 |
+
else:
|
311 |
+
self.fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
|
312 |
+
|
313 |
+
# reconstruction
|
314 |
+
self.reconstruction = make_layer(ResidualBlockNoBN, num_reconstruct_block, num_feat=num_feat)
|
315 |
+
# upsample
|
316 |
+
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
|
317 |
+
self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1)
|
318 |
+
self.pixel_shuffle = nn.PixelShuffle(2)
|
319 |
+
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
|
320 |
+
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
|
321 |
+
|
322 |
+
# activation function
|
323 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
324 |
+
|
325 |
+
def forward(self, x):
|
326 |
+
b, t, c, h, w = x.size()
|
327 |
+
if self.hr_in:
|
328 |
+
assert h % 16 == 0 and w % 16 == 0, ('The height and width must be multiple of 16.')
|
329 |
+
else:
|
330 |
+
assert h % 4 == 0 and w % 4 == 0, ('The height and width must be multiple of 4.')
|
331 |
+
|
332 |
+
x_center = x[:, self.center_frame_idx, :, :, :].contiguous()
|
333 |
+
|
334 |
+
# extract features for each frame
|
335 |
+
# L1
|
336 |
+
if self.with_predeblur:
|
337 |
+
feat_l1 = self.conv_1x1(self.predeblur(x.view(-1, c, h, w)))
|
338 |
+
if self.hr_in:
|
339 |
+
h, w = h // 4, w // 4
|
340 |
+
else:
|
341 |
+
feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
|
342 |
+
|
343 |
+
feat_l1 = self.feature_extraction(feat_l1)
|
344 |
+
# L2
|
345 |
+
feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
|
346 |
+
feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
|
347 |
+
# L3
|
348 |
+
feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
|
349 |
+
feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
|
350 |
+
|
351 |
+
feat_l1 = feat_l1.view(b, t, -1, h, w)
|
352 |
+
feat_l2 = feat_l2.view(b, t, -1, h // 2, w // 2)
|
353 |
+
feat_l3 = feat_l3.view(b, t, -1, h // 4, w // 4)
|
354 |
+
|
355 |
+
# PCD alignment
|
356 |
+
ref_feat_l = [ # reference feature list
|
357 |
+
feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
|
358 |
+
feat_l3[:, self.center_frame_idx, :, :, :].clone()
|
359 |
+
]
|
360 |
+
aligned_feat = []
|
361 |
+
for i in range(t):
|
362 |
+
nbr_feat_l = [ # neighboring feature list
|
363 |
+
feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
|
364 |
+
]
|
365 |
+
aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
|
366 |
+
aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
|
367 |
+
|
368 |
+
if not self.with_tsa:
|
369 |
+
aligned_feat = aligned_feat.view(b, -1, h, w)
|
370 |
+
feat = self.fusion(aligned_feat)
|
371 |
+
|
372 |
+
out = self.reconstruction(feat)
|
373 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
|
374 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
|
375 |
+
out = self.lrelu(self.conv_hr(out))
|
376 |
+
out = self.conv_last(out)
|
377 |
+
if self.hr_in:
|
378 |
+
base = x_center
|
379 |
+
else:
|
380 |
+
base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False)
|
381 |
+
out += base
|
382 |
+
return out
|
basicsr/archs/hifacegan_arch.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
7 |
+
from .hifacegan_util import BaseNetwork, LIPEncoder, SPADEResnetBlock, get_nonspade_norm_layer
|
8 |
+
|
9 |
+
|
10 |
+
class SPADEGenerator(BaseNetwork):
|
11 |
+
"""Generator with SPADEResBlock"""
|
12 |
+
|
13 |
+
def __init__(self,
|
14 |
+
num_in_ch=3,
|
15 |
+
num_feat=64,
|
16 |
+
use_vae=False,
|
17 |
+
z_dim=256,
|
18 |
+
crop_size=512,
|
19 |
+
norm_g='spectralspadesyncbatch3x3',
|
20 |
+
is_train=True,
|
21 |
+
init_train_phase=3): # progressive training disabled
|
22 |
+
super().__init__()
|
23 |
+
self.nf = num_feat
|
24 |
+
self.input_nc = num_in_ch
|
25 |
+
self.is_train = is_train
|
26 |
+
self.train_phase = init_train_phase
|
27 |
+
|
28 |
+
self.scale_ratio = 5 # hardcoded now
|
29 |
+
self.sw = crop_size // (2**self.scale_ratio)
|
30 |
+
self.sh = self.sw # 20210519: By default use square image, aspect_ratio = 1.0
|
31 |
+
|
32 |
+
if use_vae:
|
33 |
+
# In case of VAE, we will sample from random z vector
|
34 |
+
self.fc = nn.Linear(z_dim, 16 * self.nf * self.sw * self.sh)
|
35 |
+
else:
|
36 |
+
# Otherwise, we make the network deterministic by starting with
|
37 |
+
# downsampled segmentation map instead of random z
|
38 |
+
self.fc = nn.Conv2d(num_in_ch, 16 * self.nf, 3, padding=1)
|
39 |
+
|
40 |
+
self.head_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
|
41 |
+
|
42 |
+
self.g_middle_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
|
43 |
+
self.g_middle_1 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
|
44 |
+
|
45 |
+
self.ups = nn.ModuleList([
|
46 |
+
SPADEResnetBlock(16 * self.nf, 8 * self.nf, norm_g),
|
47 |
+
SPADEResnetBlock(8 * self.nf, 4 * self.nf, norm_g),
|
48 |
+
SPADEResnetBlock(4 * self.nf, 2 * self.nf, norm_g),
|
49 |
+
SPADEResnetBlock(2 * self.nf, 1 * self.nf, norm_g)
|
50 |
+
])
|
51 |
+
|
52 |
+
self.to_rgbs = nn.ModuleList([
|
53 |
+
nn.Conv2d(8 * self.nf, 3, 3, padding=1),
|
54 |
+
nn.Conv2d(4 * self.nf, 3, 3, padding=1),
|
55 |
+
nn.Conv2d(2 * self.nf, 3, 3, padding=1),
|
56 |
+
nn.Conv2d(1 * self.nf, 3, 3, padding=1)
|
57 |
+
])
|
58 |
+
|
59 |
+
self.up = nn.Upsample(scale_factor=2)
|
60 |
+
|
61 |
+
def encode(self, input_tensor):
|
62 |
+
"""
|
63 |
+
Encode input_tensor into feature maps, can be overridden in derived classes
|
64 |
+
Default: nearest downsampling of 2**5 = 32 times
|
65 |
+
"""
|
66 |
+
h, w = input_tensor.size()[-2:]
|
67 |
+
sh, sw = h // 2**self.scale_ratio, w // 2**self.scale_ratio
|
68 |
+
x = F.interpolate(input_tensor, size=(sh, sw))
|
69 |
+
return self.fc(x)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
# In oroginal SPADE, seg means a segmentation map, but here we use x instead.
|
73 |
+
seg = x
|
74 |
+
|
75 |
+
x = self.encode(x)
|
76 |
+
x = self.head_0(x, seg)
|
77 |
+
|
78 |
+
x = self.up(x)
|
79 |
+
x = self.g_middle_0(x, seg)
|
80 |
+
x = self.g_middle_1(x, seg)
|
81 |
+
|
82 |
+
if self.is_train:
|
83 |
+
phase = self.train_phase + 1
|
84 |
+
else:
|
85 |
+
phase = len(self.to_rgbs)
|
86 |
+
|
87 |
+
for i in range(phase):
|
88 |
+
x = self.up(x)
|
89 |
+
x = self.ups[i](x, seg)
|
90 |
+
|
91 |
+
x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
|
92 |
+
x = torch.tanh(x)
|
93 |
+
|
94 |
+
return x
|
95 |
+
|
96 |
+
def mixed_guidance_forward(self, input_x, seg=None, n=0, mode='progressive'):
|
97 |
+
"""
|
98 |
+
A helper class for subspace visualization. Input and seg are different images.
|
99 |
+
For the first n levels (including encoder) we use input, for the rest we use seg.
|
100 |
+
|
101 |
+
If mode = 'progressive', the output's like: AAABBB
|
102 |
+
If mode = 'one_plug', the output's like: AAABAA
|
103 |
+
If mode = 'one_ablate', the output's like: BBBABB
|
104 |
+
"""
|
105 |
+
|
106 |
+
if seg is None:
|
107 |
+
return self.forward(input_x)
|
108 |
+
|
109 |
+
if self.is_train:
|
110 |
+
phase = self.train_phase + 1
|
111 |
+
else:
|
112 |
+
phase = len(self.to_rgbs)
|
113 |
+
|
114 |
+
if mode == 'progressive':
|
115 |
+
n = max(min(n, 4 + phase), 0)
|
116 |
+
guide_list = [input_x] * n + [seg] * (4 + phase - n)
|
117 |
+
elif mode == 'one_plug':
|
118 |
+
n = max(min(n, 4 + phase - 1), 0)
|
119 |
+
guide_list = [seg] * (4 + phase)
|
120 |
+
guide_list[n] = input_x
|
121 |
+
elif mode == 'one_ablate':
|
122 |
+
if n > 3 + phase:
|
123 |
+
return self.forward(input_x)
|
124 |
+
guide_list = [input_x] * (4 + phase)
|
125 |
+
guide_list[n] = seg
|
126 |
+
|
127 |
+
x = self.encode(guide_list[0])
|
128 |
+
x = self.head_0(x, guide_list[1])
|
129 |
+
|
130 |
+
x = self.up(x)
|
131 |
+
x = self.g_middle_0(x, guide_list[2])
|
132 |
+
x = self.g_middle_1(x, guide_list[3])
|
133 |
+
|
134 |
+
for i in range(phase):
|
135 |
+
x = self.up(x)
|
136 |
+
x = self.ups[i](x, guide_list[4 + i])
|
137 |
+
|
138 |
+
x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
|
139 |
+
x = torch.tanh(x)
|
140 |
+
|
141 |
+
return x
|
142 |
+
|
143 |
+
|
144 |
+
@ARCH_REGISTRY.register()
|
145 |
+
class HiFaceGAN(SPADEGenerator):
|
146 |
+
"""
|
147 |
+
HiFaceGAN: SPADEGenerator with a learnable feature encoder
|
148 |
+
Current encoder design: LIPEncoder
|
149 |
+
"""
|
150 |
+
|
151 |
+
def __init__(self,
|
152 |
+
num_in_ch=3,
|
153 |
+
num_feat=64,
|
154 |
+
use_vae=False,
|
155 |
+
z_dim=256,
|
156 |
+
crop_size=512,
|
157 |
+
norm_g='spectralspadesyncbatch3x3',
|
158 |
+
is_train=True,
|
159 |
+
init_train_phase=3):
|
160 |
+
super().__init__(num_in_ch, num_feat, use_vae, z_dim, crop_size, norm_g, is_train, init_train_phase)
|
161 |
+
self.lip_encoder = LIPEncoder(num_in_ch, num_feat, self.sw, self.sh, self.scale_ratio)
|
162 |
+
|
163 |
+
def encode(self, input_tensor):
|
164 |
+
return self.lip_encoder(input_tensor)
|
165 |
+
|
166 |
+
|
167 |
+
@ARCH_REGISTRY.register()
|
168 |
+
class HiFaceGANDiscriminator(BaseNetwork):
|
169 |
+
"""
|
170 |
+
Inspired by pix2pixHD multiscale discriminator.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
174 |
+
num_out_ch (int): Channel number of outputs. Default: 3.
|
175 |
+
conditional_d (bool): Whether use conditional discriminator.
|
176 |
+
Default: True.
|
177 |
+
num_d (int): Number of Multiscale discriminators. Default: 3.
|
178 |
+
n_layers_d (int): Number of downsample layers in each D. Default: 4.
|
179 |
+
num_feat (int): Channel number of base intermediate features.
|
180 |
+
Default: 64.
|
181 |
+
norm_d (str): String to determine normalization layers in D.
|
182 |
+
Choices: [spectral][instance/batch/syncbatch]
|
183 |
+
Default: 'spectralinstance'.
|
184 |
+
keep_features (bool): Keep intermediate features for matching loss, etc.
|
185 |
+
Default: True.
|
186 |
+
"""
|
187 |
+
|
188 |
+
def __init__(self,
|
189 |
+
num_in_ch=3,
|
190 |
+
num_out_ch=3,
|
191 |
+
conditional_d=True,
|
192 |
+
num_d=2,
|
193 |
+
n_layers_d=4,
|
194 |
+
num_feat=64,
|
195 |
+
norm_d='spectralinstance',
|
196 |
+
keep_features=True):
|
197 |
+
super().__init__()
|
198 |
+
self.num_d = num_d
|
199 |
+
|
200 |
+
input_nc = num_in_ch
|
201 |
+
if conditional_d:
|
202 |
+
input_nc += num_out_ch
|
203 |
+
|
204 |
+
for i in range(num_d):
|
205 |
+
subnet_d = NLayerDiscriminator(input_nc, n_layers_d, num_feat, norm_d, keep_features)
|
206 |
+
self.add_module(f'discriminator_{i}', subnet_d)
|
207 |
+
|
208 |
+
def downsample(self, x):
|
209 |
+
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
|
210 |
+
|
211 |
+
# Returns list of lists of discriminator outputs.
|
212 |
+
# The final result is of size opt.num_d x opt.n_layers_D
|
213 |
+
def forward(self, x):
|
214 |
+
result = []
|
215 |
+
for _, _net_d in self.named_children():
|
216 |
+
out = _net_d(x)
|
217 |
+
result.append(out)
|
218 |
+
x = self.downsample(x)
|
219 |
+
|
220 |
+
return result
|
221 |
+
|
222 |
+
|
223 |
+
class NLayerDiscriminator(BaseNetwork):
|
224 |
+
"""Defines the PatchGAN discriminator with the specified arguments."""
|
225 |
+
|
226 |
+
def __init__(self, input_nc, n_layers_d, num_feat, norm_d, keep_features):
|
227 |
+
super().__init__()
|
228 |
+
kw = 4
|
229 |
+
padw = int(np.ceil((kw - 1.0) / 2))
|
230 |
+
nf = num_feat
|
231 |
+
self.keep_features = keep_features
|
232 |
+
|
233 |
+
norm_layer = get_nonspade_norm_layer(norm_d)
|
234 |
+
sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)]]
|
235 |
+
|
236 |
+
for n in range(1, n_layers_d):
|
237 |
+
nf_prev = nf
|
238 |
+
nf = min(nf * 2, 512)
|
239 |
+
stride = 1 if n == n_layers_d - 1 else 2
|
240 |
+
sequence += [[
|
241 |
+
norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=stride, padding=padw)),
|
242 |
+
nn.LeakyReLU(0.2, False)
|
243 |
+
]]
|
244 |
+
|
245 |
+
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
|
246 |
+
|
247 |
+
# We divide the layers into groups to extract intermediate layer outputs
|
248 |
+
for n in range(len(sequence)):
|
249 |
+
self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
|
250 |
+
|
251 |
+
def forward(self, x):
|
252 |
+
results = [x]
|
253 |
+
for submodel in self.children():
|
254 |
+
intermediate_output = submodel(results[-1])
|
255 |
+
results.append(intermediate_output)
|
256 |
+
|
257 |
+
if self.keep_features:
|
258 |
+
return results[1:]
|
259 |
+
else:
|
260 |
+
return results[-1]
|
basicsr/archs/hifacegan_util.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn import init
|
6 |
+
# Warning: spectral norm could be buggy
|
7 |
+
# under eval mode and multi-GPU inference
|
8 |
+
# A workaround is sticking to single-GPU inference and train mode
|
9 |
+
from torch.nn.utils import spectral_norm
|
10 |
+
|
11 |
+
|
12 |
+
class SPADE(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, config_text, norm_nc, label_nc):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
assert config_text.startswith('spade')
|
18 |
+
parsed = re.search('spade(\\D+)(\\d)x\\d', config_text)
|
19 |
+
param_free_norm_type = str(parsed.group(1))
|
20 |
+
ks = int(parsed.group(2))
|
21 |
+
|
22 |
+
if param_free_norm_type == 'instance':
|
23 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_nc)
|
24 |
+
elif param_free_norm_type == 'syncbatch':
|
25 |
+
print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
|
26 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_nc)
|
27 |
+
elif param_free_norm_type == 'batch':
|
28 |
+
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
|
29 |
+
else:
|
30 |
+
raise ValueError(f'{param_free_norm_type} is not a recognized param-free norm type in SPADE')
|
31 |
+
|
32 |
+
# The dimension of the intermediate embedding space. Yes, hardcoded.
|
33 |
+
nhidden = 128 if norm_nc > 128 else norm_nc
|
34 |
+
|
35 |
+
pw = ks // 2
|
36 |
+
self.mlp_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
|
37 |
+
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
|
38 |
+
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
|
39 |
+
|
40 |
+
def forward(self, x, segmap):
|
41 |
+
|
42 |
+
# Part 1. generate parameter-free normalized activations
|
43 |
+
normalized = self.param_free_norm(x)
|
44 |
+
|
45 |
+
# Part 2. produce scaling and bias conditioned on semantic map
|
46 |
+
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
|
47 |
+
actv = self.mlp_shared(segmap)
|
48 |
+
gamma = self.mlp_gamma(actv)
|
49 |
+
beta = self.mlp_beta(actv)
|
50 |
+
|
51 |
+
# apply scale and bias
|
52 |
+
out = normalized * gamma + beta
|
53 |
+
|
54 |
+
return out
|
55 |
+
|
56 |
+
|
57 |
+
class SPADEResnetBlock(nn.Module):
|
58 |
+
"""
|
59 |
+
ResNet block that uses SPADE. It differs from the ResNet block of pix2pixHD in that
|
60 |
+
it takes in the segmentation map as input, learns the skip connection if necessary,
|
61 |
+
and applies normalization first and then convolution.
|
62 |
+
This architecture seemed like a standard architecture for unconditional or
|
63 |
+
class-conditional GAN architecture using residual block.
|
64 |
+
The code was inspired from https://github.com/LMescheder/GAN_stability.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, fin, fout, norm_g='spectralspadesyncbatch3x3', semantic_nc=3):
|
68 |
+
super().__init__()
|
69 |
+
# Attributes
|
70 |
+
self.learned_shortcut = (fin != fout)
|
71 |
+
fmiddle = min(fin, fout)
|
72 |
+
|
73 |
+
# create conv layers
|
74 |
+
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
|
75 |
+
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
|
76 |
+
if self.learned_shortcut:
|
77 |
+
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
|
78 |
+
|
79 |
+
# apply spectral norm if specified
|
80 |
+
if 'spectral' in norm_g:
|
81 |
+
self.conv_0 = spectral_norm(self.conv_0)
|
82 |
+
self.conv_1 = spectral_norm(self.conv_1)
|
83 |
+
if self.learned_shortcut:
|
84 |
+
self.conv_s = spectral_norm(self.conv_s)
|
85 |
+
|
86 |
+
# define normalization layers
|
87 |
+
spade_config_str = norm_g.replace('spectral', '')
|
88 |
+
self.norm_0 = SPADE(spade_config_str, fin, semantic_nc)
|
89 |
+
self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc)
|
90 |
+
if self.learned_shortcut:
|
91 |
+
self.norm_s = SPADE(spade_config_str, fin, semantic_nc)
|
92 |
+
|
93 |
+
# note the resnet block with SPADE also takes in |seg|,
|
94 |
+
# the semantic segmentation map as input
|
95 |
+
def forward(self, x, seg):
|
96 |
+
x_s = self.shortcut(x, seg)
|
97 |
+
dx = self.conv_0(self.act(self.norm_0(x, seg)))
|
98 |
+
dx = self.conv_1(self.act(self.norm_1(dx, seg)))
|
99 |
+
out = x_s + dx
|
100 |
+
return out
|
101 |
+
|
102 |
+
def shortcut(self, x, seg):
|
103 |
+
if self.learned_shortcut:
|
104 |
+
x_s = self.conv_s(self.norm_s(x, seg))
|
105 |
+
else:
|
106 |
+
x_s = x
|
107 |
+
return x_s
|
108 |
+
|
109 |
+
def act(self, x):
|
110 |
+
return F.leaky_relu(x, 2e-1)
|
111 |
+
|
112 |
+
|
113 |
+
class BaseNetwork(nn.Module):
|
114 |
+
""" A basis for hifacegan archs with custom initialization """
|
115 |
+
|
116 |
+
def init_weights(self, init_type='normal', gain=0.02):
|
117 |
+
|
118 |
+
def init_func(m):
|
119 |
+
classname = m.__class__.__name__
|
120 |
+
if classname.find('BatchNorm2d') != -1:
|
121 |
+
if hasattr(m, 'weight') and m.weight is not None:
|
122 |
+
init.normal_(m.weight.data, 1.0, gain)
|
123 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
124 |
+
init.constant_(m.bias.data, 0.0)
|
125 |
+
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
126 |
+
if init_type == 'normal':
|
127 |
+
init.normal_(m.weight.data, 0.0, gain)
|
128 |
+
elif init_type == 'xavier':
|
129 |
+
init.xavier_normal_(m.weight.data, gain=gain)
|
130 |
+
elif init_type == 'xavier_uniform':
|
131 |
+
init.xavier_uniform_(m.weight.data, gain=1.0)
|
132 |
+
elif init_type == 'kaiming':
|
133 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
134 |
+
elif init_type == 'orthogonal':
|
135 |
+
init.orthogonal_(m.weight.data, gain=gain)
|
136 |
+
elif init_type == 'none': # uses pytorch's default init method
|
137 |
+
m.reset_parameters()
|
138 |
+
else:
|
139 |
+
raise NotImplementedError(f'initialization method [{init_type}] is not implemented')
|
140 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
141 |
+
init.constant_(m.bias.data, 0.0)
|
142 |
+
|
143 |
+
self.apply(init_func)
|
144 |
+
|
145 |
+
# propagate to children
|
146 |
+
for m in self.children():
|
147 |
+
if hasattr(m, 'init_weights'):
|
148 |
+
m.init_weights(init_type, gain)
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
pass
|
152 |
+
|
153 |
+
|
154 |
+
def lip2d(x, logit, kernel=3, stride=2, padding=1):
|
155 |
+
weight = logit.exp()
|
156 |
+
return F.avg_pool2d(x * weight, kernel, stride, padding) / F.avg_pool2d(weight, kernel, stride, padding)
|
157 |
+
|
158 |
+
|
159 |
+
class SoftGate(nn.Module):
|
160 |
+
COEFF = 12.0
|
161 |
+
|
162 |
+
def forward(self, x):
|
163 |
+
return torch.sigmoid(x).mul(self.COEFF)
|
164 |
+
|
165 |
+
|
166 |
+
class SimplifiedLIP(nn.Module):
|
167 |
+
|
168 |
+
def __init__(self, channels):
|
169 |
+
super(SimplifiedLIP, self).__init__()
|
170 |
+
self.logit = nn.Sequential(
|
171 |
+
nn.Conv2d(channels, channels, 3, padding=1, bias=False), nn.InstanceNorm2d(channels, affine=True),
|
172 |
+
SoftGate())
|
173 |
+
|
174 |
+
def init_layer(self):
|
175 |
+
self.logit[0].weight.data.fill_(0.0)
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
frac = lip2d(x, self.logit(x))
|
179 |
+
return frac
|
180 |
+
|
181 |
+
|
182 |
+
class LIPEncoder(BaseNetwork):
|
183 |
+
"""Local Importance-based Pooling (Ziteng Gao et.al.,ICCV 2019)"""
|
184 |
+
|
185 |
+
def __init__(self, input_nc, ngf, sw, sh, n_2xdown, norm_layer=nn.InstanceNorm2d):
|
186 |
+
super().__init__()
|
187 |
+
self.sw = sw
|
188 |
+
self.sh = sh
|
189 |
+
self.max_ratio = 16
|
190 |
+
# 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold
|
191 |
+
kw = 3
|
192 |
+
pw = (kw - 1) // 2
|
193 |
+
|
194 |
+
model = [
|
195 |
+
nn.Conv2d(input_nc, ngf, kw, stride=1, padding=pw, bias=False),
|
196 |
+
norm_layer(ngf),
|
197 |
+
nn.ReLU(),
|
198 |
+
]
|
199 |
+
cur_ratio = 1
|
200 |
+
for i in range(n_2xdown):
|
201 |
+
next_ratio = min(cur_ratio * 2, self.max_ratio)
|
202 |
+
model += [
|
203 |
+
SimplifiedLIP(ngf * cur_ratio),
|
204 |
+
nn.Conv2d(ngf * cur_ratio, ngf * next_ratio, kw, stride=1, padding=pw),
|
205 |
+
norm_layer(ngf * next_ratio),
|
206 |
+
]
|
207 |
+
cur_ratio = next_ratio
|
208 |
+
if i < n_2xdown - 1:
|
209 |
+
model += [nn.ReLU(inplace=True)]
|
210 |
+
|
211 |
+
self.model = nn.Sequential(*model)
|
212 |
+
|
213 |
+
def forward(self, x):
|
214 |
+
return self.model(x)
|
215 |
+
|
216 |
+
|
217 |
+
def get_nonspade_norm_layer(norm_type='instance'):
|
218 |
+
# helper function to get # output channels of the previous layer
|
219 |
+
def get_out_channel(layer):
|
220 |
+
if hasattr(layer, 'out_channels'):
|
221 |
+
return getattr(layer, 'out_channels')
|
222 |
+
return layer.weight.size(0)
|
223 |
+
|
224 |
+
# this function will be returned
|
225 |
+
def add_norm_layer(layer):
|
226 |
+
nonlocal norm_type
|
227 |
+
if norm_type.startswith('spectral'):
|
228 |
+
layer = spectral_norm(layer)
|
229 |
+
subnorm_type = norm_type[len('spectral'):]
|
230 |
+
|
231 |
+
if subnorm_type == 'none' or len(subnorm_type) == 0:
|
232 |
+
return layer
|
233 |
+
|
234 |
+
# remove bias in the previous layer, which is meaningless
|
235 |
+
# since it has no effect after normalization
|
236 |
+
if getattr(layer, 'bias', None) is not None:
|
237 |
+
delattr(layer, 'bias')
|
238 |
+
layer.register_parameter('bias', None)
|
239 |
+
|
240 |
+
if subnorm_type == 'batch':
|
241 |
+
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
|
242 |
+
elif subnorm_type == 'sync_batch':
|
243 |
+
print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
|
244 |
+
# norm_layer = SynchronizedBatchNorm2d(
|
245 |
+
# get_out_channel(layer), affine=True)
|
246 |
+
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
|
247 |
+
elif subnorm_type == 'instance':
|
248 |
+
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
|
249 |
+
else:
|
250 |
+
raise ValueError(f'normalization layer {subnorm_type} is not recognized')
|
251 |
+
|
252 |
+
return nn.Sequential(layer, norm_layer)
|
253 |
+
|
254 |
+
print('This is a legacy from nvlabs/SPADE, and will be removed in future versions.')
|
255 |
+
return add_norm_layer
|
basicsr/archs/inception.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/inception.py # noqa: E501
|
2 |
+
# For FID metric
|
3 |
+
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch.utils.model_zoo import load_url
|
9 |
+
from torchvision import models
|
10 |
+
|
11 |
+
# Inception weights ported to Pytorch from
|
12 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
13 |
+
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
|
14 |
+
LOCAL_FID_WEIGHTS = 'experiments/pretrained_models/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
|
15 |
+
|
16 |
+
|
17 |
+
class InceptionV3(nn.Module):
|
18 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
19 |
+
|
20 |
+
# Index of default block of inception to return,
|
21 |
+
# corresponds to output of final average pooling
|
22 |
+
DEFAULT_BLOCK_INDEX = 3
|
23 |
+
|
24 |
+
# Maps feature dimensionality to their output blocks indices
|
25 |
+
BLOCK_INDEX_BY_DIM = {
|
26 |
+
64: 0, # First max pooling features
|
27 |
+
192: 1, # Second max pooling features
|
28 |
+
768: 2, # Pre-aux classifier features
|
29 |
+
2048: 3 # Final average pooling features
|
30 |
+
}
|
31 |
+
|
32 |
+
def __init__(self,
|
33 |
+
output_blocks=(DEFAULT_BLOCK_INDEX),
|
34 |
+
resize_input=True,
|
35 |
+
normalize_input=True,
|
36 |
+
requires_grad=False,
|
37 |
+
use_fid_inception=True):
|
38 |
+
"""Build pretrained InceptionV3.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
output_blocks (list[int]): Indices of blocks to return features of.
|
42 |
+
Possible values are:
|
43 |
+
- 0: corresponds to output of first max pooling
|
44 |
+
- 1: corresponds to output of second max pooling
|
45 |
+
- 2: corresponds to output which is fed to aux classifier
|
46 |
+
- 3: corresponds to output of final average pooling
|
47 |
+
resize_input (bool): If true, bilinearly resizes input to width and
|
48 |
+
height 299 before feeding input to model. As the network
|
49 |
+
without fully connected layers is fully convolutional, it
|
50 |
+
should be able to handle inputs of arbitrary size, so resizing
|
51 |
+
might not be strictly needed. Default: True.
|
52 |
+
normalize_input (bool): If true, scales the input from range (0, 1)
|
53 |
+
to the range the pretrained Inception network expects,
|
54 |
+
namely (-1, 1). Default: True.
|
55 |
+
requires_grad (bool): If true, parameters of the model require
|
56 |
+
gradients. Possibly useful for finetuning the network.
|
57 |
+
Default: False.
|
58 |
+
use_fid_inception (bool): If true, uses the pretrained Inception
|
59 |
+
model used in Tensorflow's FID implementation.
|
60 |
+
If false, uses the pretrained Inception model available in
|
61 |
+
torchvision. The FID Inception model has different weights
|
62 |
+
and a slightly different structure from torchvision's
|
63 |
+
Inception model. If you want to compute FID scores, you are
|
64 |
+
strongly advised to set this parameter to true to get
|
65 |
+
comparable results. Default: True.
|
66 |
+
"""
|
67 |
+
super(InceptionV3, self).__init__()
|
68 |
+
|
69 |
+
self.resize_input = resize_input
|
70 |
+
self.normalize_input = normalize_input
|
71 |
+
self.output_blocks = sorted(output_blocks)
|
72 |
+
self.last_needed_block = max(output_blocks)
|
73 |
+
|
74 |
+
assert self.last_needed_block <= 3, ('Last possible output block index is 3')
|
75 |
+
|
76 |
+
self.blocks = nn.ModuleList()
|
77 |
+
|
78 |
+
if use_fid_inception:
|
79 |
+
inception = fid_inception_v3()
|
80 |
+
else:
|
81 |
+
try:
|
82 |
+
inception = models.inception_v3(pretrained=True, init_weights=False)
|
83 |
+
except TypeError:
|
84 |
+
# pytorch < 1.5 does not have init_weights for inception_v3
|
85 |
+
inception = models.inception_v3(pretrained=True)
|
86 |
+
|
87 |
+
# Block 0: input to maxpool1
|
88 |
+
block0 = [
|
89 |
+
inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3,
|
90 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
91 |
+
]
|
92 |
+
self.blocks.append(nn.Sequential(*block0))
|
93 |
+
|
94 |
+
# Block 1: maxpool1 to maxpool2
|
95 |
+
if self.last_needed_block >= 1:
|
96 |
+
block1 = [inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, nn.MaxPool2d(kernel_size=3, stride=2)]
|
97 |
+
self.blocks.append(nn.Sequential(*block1))
|
98 |
+
|
99 |
+
# Block 2: maxpool2 to aux classifier
|
100 |
+
if self.last_needed_block >= 2:
|
101 |
+
block2 = [
|
102 |
+
inception.Mixed_5b,
|
103 |
+
inception.Mixed_5c,
|
104 |
+
inception.Mixed_5d,
|
105 |
+
inception.Mixed_6a,
|
106 |
+
inception.Mixed_6b,
|
107 |
+
inception.Mixed_6c,
|
108 |
+
inception.Mixed_6d,
|
109 |
+
inception.Mixed_6e,
|
110 |
+
]
|
111 |
+
self.blocks.append(nn.Sequential(*block2))
|
112 |
+
|
113 |
+
# Block 3: aux classifier to final avgpool
|
114 |
+
if self.last_needed_block >= 3:
|
115 |
+
block3 = [
|
116 |
+
inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c,
|
117 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
118 |
+
]
|
119 |
+
self.blocks.append(nn.Sequential(*block3))
|
120 |
+
|
121 |
+
for param in self.parameters():
|
122 |
+
param.requires_grad = requires_grad
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
"""Get Inception feature maps.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
x (Tensor): Input tensor of shape (b, 3, h, w).
|
129 |
+
Values are expected to be in range (-1, 1). You can also input
|
130 |
+
(0, 1) with setting normalize_input = True.
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
list[Tensor]: Corresponding to the selected output block, sorted
|
134 |
+
ascending by index.
|
135 |
+
"""
|
136 |
+
output = []
|
137 |
+
|
138 |
+
if self.resize_input:
|
139 |
+
x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
|
140 |
+
|
141 |
+
if self.normalize_input:
|
142 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
143 |
+
|
144 |
+
for idx, block in enumerate(self.blocks):
|
145 |
+
x = block(x)
|
146 |
+
if idx in self.output_blocks:
|
147 |
+
output.append(x)
|
148 |
+
|
149 |
+
if idx == self.last_needed_block:
|
150 |
+
break
|
151 |
+
|
152 |
+
return output
|
153 |
+
|
154 |
+
|
155 |
+
def fid_inception_v3():
|
156 |
+
"""Build pretrained Inception model for FID computation.
|
157 |
+
|
158 |
+
The Inception model for FID computation uses a different set of weights
|
159 |
+
and has a slightly different structure than torchvision's Inception.
|
160 |
+
|
161 |
+
This method first constructs torchvision's Inception and then patches the
|
162 |
+
necessary parts that are different in the FID Inception model.
|
163 |
+
"""
|
164 |
+
try:
|
165 |
+
inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False, init_weights=False)
|
166 |
+
except TypeError:
|
167 |
+
# pytorch < 1.5 does not have init_weights for inception_v3
|
168 |
+
inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False)
|
169 |
+
|
170 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
171 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
172 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
173 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
174 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
175 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
176 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
177 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
178 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
179 |
+
|
180 |
+
if os.path.exists(LOCAL_FID_WEIGHTS):
|
181 |
+
state_dict = torch.load(LOCAL_FID_WEIGHTS, map_location=lambda storage, loc: storage)
|
182 |
+
else:
|
183 |
+
state_dict = load_url(FID_WEIGHTS_URL, progress=True)
|
184 |
+
|
185 |
+
inception.load_state_dict(state_dict)
|
186 |
+
return inception
|
187 |
+
|
188 |
+
|
189 |
+
class FIDInceptionA(models.inception.InceptionA):
|
190 |
+
"""InceptionA block patched for FID computation"""
|
191 |
+
|
192 |
+
def __init__(self, in_channels, pool_features):
|
193 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
194 |
+
|
195 |
+
def forward(self, x):
|
196 |
+
branch1x1 = self.branch1x1(x)
|
197 |
+
|
198 |
+
branch5x5 = self.branch5x5_1(x)
|
199 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
200 |
+
|
201 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
202 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
203 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
204 |
+
|
205 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
206 |
+
# its average calculation
|
207 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
|
208 |
+
branch_pool = self.branch_pool(branch_pool)
|
209 |
+
|
210 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
211 |
+
return torch.cat(outputs, 1)
|
212 |
+
|
213 |
+
|
214 |
+
class FIDInceptionC(models.inception.InceptionC):
|
215 |
+
"""InceptionC block patched for FID computation"""
|
216 |
+
|
217 |
+
def __init__(self, in_channels, channels_7x7):
|
218 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
219 |
+
|
220 |
+
def forward(self, x):
|
221 |
+
branch1x1 = self.branch1x1(x)
|
222 |
+
|
223 |
+
branch7x7 = self.branch7x7_1(x)
|
224 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
225 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
226 |
+
|
227 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
228 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
229 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
230 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
231 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
232 |
+
|
233 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
234 |
+
# its average calculation
|
235 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
|
236 |
+
branch_pool = self.branch_pool(branch_pool)
|
237 |
+
|
238 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
239 |
+
return torch.cat(outputs, 1)
|
240 |
+
|
241 |
+
|
242 |
+
class FIDInceptionE_1(models.inception.InceptionE):
|
243 |
+
"""First InceptionE block patched for FID computation"""
|
244 |
+
|
245 |
+
def __init__(self, in_channels):
|
246 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
247 |
+
|
248 |
+
def forward(self, x):
|
249 |
+
branch1x1 = self.branch1x1(x)
|
250 |
+
|
251 |
+
branch3x3 = self.branch3x3_1(x)
|
252 |
+
branch3x3 = [
|
253 |
+
self.branch3x3_2a(branch3x3),
|
254 |
+
self.branch3x3_2b(branch3x3),
|
255 |
+
]
|
256 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
257 |
+
|
258 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
259 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
260 |
+
branch3x3dbl = [
|
261 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
262 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
263 |
+
]
|
264 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
265 |
+
|
266 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
267 |
+
# its average calculation
|
268 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
|
269 |
+
branch_pool = self.branch_pool(branch_pool)
|
270 |
+
|
271 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
272 |
+
return torch.cat(outputs, 1)
|
273 |
+
|
274 |
+
|
275 |
+
class FIDInceptionE_2(models.inception.InceptionE):
|
276 |
+
"""Second InceptionE block patched for FID computation"""
|
277 |
+
|
278 |
+
def __init__(self, in_channels):
|
279 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
280 |
+
|
281 |
+
def forward(self, x):
|
282 |
+
branch1x1 = self.branch1x1(x)
|
283 |
+
|
284 |
+
branch3x3 = self.branch3x3_1(x)
|
285 |
+
branch3x3 = [
|
286 |
+
self.branch3x3_2a(branch3x3),
|
287 |
+
self.branch3x3_2b(branch3x3),
|
288 |
+
]
|
289 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
290 |
+
|
291 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
292 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
293 |
+
branch3x3dbl = [
|
294 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
295 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
296 |
+
]
|
297 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
298 |
+
|
299 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
300 |
+
# pooling. This is likely an error in this specific Inception
|
301 |
+
# implementation, as other Inception models use average pooling here
|
302 |
+
# (which matches the description in the paper).
|
303 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
304 |
+
branch_pool = self.branch_pool(branch_pool)
|
305 |
+
|
306 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
307 |
+
return torch.cat(outputs, 1)
|
basicsr/archs/rcan_arch.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
|
4 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
5 |
+
from .arch_util import Upsample, make_layer
|
6 |
+
|
7 |
+
|
8 |
+
class ChannelAttention(nn.Module):
|
9 |
+
"""Channel attention used in RCAN.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
num_feat (int): Channel number of intermediate features.
|
13 |
+
squeeze_factor (int): Channel squeeze factor. Default: 16.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, num_feat, squeeze_factor=16):
|
17 |
+
super(ChannelAttention, self).__init__()
|
18 |
+
self.attention = nn.Sequential(
|
19 |
+
nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
|
20 |
+
nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid())
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
y = self.attention(x)
|
24 |
+
return x * y
|
25 |
+
|
26 |
+
|
27 |
+
class RCAB(nn.Module):
|
28 |
+
"""Residual Channel Attention Block (RCAB) used in RCAN.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
num_feat (int): Channel number of intermediate features.
|
32 |
+
squeeze_factor (int): Channel squeeze factor. Default: 16.
|
33 |
+
res_scale (float): Scale the residual. Default: 1.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, num_feat, squeeze_factor=16, res_scale=1):
|
37 |
+
super(RCAB, self).__init__()
|
38 |
+
self.res_scale = res_scale
|
39 |
+
|
40 |
+
self.rcab = nn.Sequential(
|
41 |
+
nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), nn.Conv2d(num_feat, num_feat, 3, 1, 1),
|
42 |
+
ChannelAttention(num_feat, squeeze_factor))
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
res = self.rcab(x) * self.res_scale
|
46 |
+
return res + x
|
47 |
+
|
48 |
+
|
49 |
+
class ResidualGroup(nn.Module):
|
50 |
+
"""Residual Group of RCAB.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
num_feat (int): Channel number of intermediate features.
|
54 |
+
num_block (int): Block number in the body network.
|
55 |
+
squeeze_factor (int): Channel squeeze factor. Default: 16.
|
56 |
+
res_scale (float): Scale the residual. Default: 1.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1):
|
60 |
+
super(ResidualGroup, self).__init__()
|
61 |
+
|
62 |
+
self.residual_group = make_layer(
|
63 |
+
RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale)
|
64 |
+
self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
res = self.conv(self.residual_group(x))
|
68 |
+
return res + x
|
69 |
+
|
70 |
+
|
71 |
+
@ARCH_REGISTRY.register()
|
72 |
+
class RCAN(nn.Module):
|
73 |
+
"""Residual Channel Attention Networks.
|
74 |
+
|
75 |
+
``Paper: Image Super-Resolution Using Very Deep Residual Channel Attention Networks``
|
76 |
+
|
77 |
+
Reference: https://github.com/yulunzhang/RCAN
|
78 |
+
|
79 |
+
Args:
|
80 |
+
num_in_ch (int): Channel number of inputs.
|
81 |
+
num_out_ch (int): Channel number of outputs.
|
82 |
+
num_feat (int): Channel number of intermediate features.
|
83 |
+
Default: 64.
|
84 |
+
num_group (int): Number of ResidualGroup. Default: 10.
|
85 |
+
num_block (int): Number of RCAB in ResidualGroup. Default: 16.
|
86 |
+
squeeze_factor (int): Channel squeeze factor. Default: 16.
|
87 |
+
upscale (int): Upsampling factor. Support 2^n and 3.
|
88 |
+
Default: 4.
|
89 |
+
res_scale (float): Used to scale the residual in residual block.
|
90 |
+
Default: 1.
|
91 |
+
img_range (float): Image range. Default: 255.
|
92 |
+
rgb_mean (tuple[float]): Image mean in RGB orders.
|
93 |
+
Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
|
94 |
+
"""
|
95 |
+
|
96 |
+
def __init__(self,
|
97 |
+
num_in_ch,
|
98 |
+
num_out_ch,
|
99 |
+
num_feat=64,
|
100 |
+
num_group=10,
|
101 |
+
num_block=16,
|
102 |
+
squeeze_factor=16,
|
103 |
+
upscale=4,
|
104 |
+
res_scale=1,
|
105 |
+
img_range=255.,
|
106 |
+
rgb_mean=(0.4488, 0.4371, 0.4040)):
|
107 |
+
super(RCAN, self).__init__()
|
108 |
+
|
109 |
+
self.img_range = img_range
|
110 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
111 |
+
|
112 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
113 |
+
self.body = make_layer(
|
114 |
+
ResidualGroup,
|
115 |
+
num_group,
|
116 |
+
num_feat=num_feat,
|
117 |
+
num_block=num_block,
|
118 |
+
squeeze_factor=squeeze_factor,
|
119 |
+
res_scale=res_scale)
|
120 |
+
self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
121 |
+
self.upsample = Upsample(upscale, num_feat)
|
122 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
self.mean = self.mean.type_as(x)
|
126 |
+
|
127 |
+
x = (x - self.mean) * self.img_range
|
128 |
+
x = self.conv_first(x)
|
129 |
+
res = self.conv_after_body(self.body(x))
|
130 |
+
res += x
|
131 |
+
|
132 |
+
x = self.conv_last(self.upsample(res))
|
133 |
+
x = x / self.img_range + self.mean
|
134 |
+
|
135 |
+
return x
|
basicsr/archs/ridnet_arch.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
5 |
+
from .arch_util import ResidualBlockNoBN, make_layer
|
6 |
+
|
7 |
+
|
8 |
+
class MeanShift(nn.Conv2d):
|
9 |
+
""" Data normalization with mean and std.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
rgb_range (int): Maximum value of RGB.
|
13 |
+
rgb_mean (list[float]): Mean for RGB channels.
|
14 |
+
rgb_std (list[float]): Std for RGB channels.
|
15 |
+
sign (int): For subtraction, sign is -1, for addition, sign is 1.
|
16 |
+
Default: -1.
|
17 |
+
requires_grad (bool): Whether to update the self.weight and self.bias.
|
18 |
+
Default: True.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1, requires_grad=True):
|
22 |
+
super(MeanShift, self).__init__(3, 3, kernel_size=1)
|
23 |
+
std = torch.Tensor(rgb_std)
|
24 |
+
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
|
25 |
+
self.weight.data.div_(std.view(3, 1, 1, 1))
|
26 |
+
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
|
27 |
+
self.bias.data.div_(std)
|
28 |
+
self.requires_grad = requires_grad
|
29 |
+
|
30 |
+
|
31 |
+
class EResidualBlockNoBN(nn.Module):
|
32 |
+
"""Enhanced Residual block without BN.
|
33 |
+
|
34 |
+
There are three convolution layers in residual branch.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, in_channels, out_channels):
|
38 |
+
super(EResidualBlockNoBN, self).__init__()
|
39 |
+
|
40 |
+
self.body = nn.Sequential(
|
41 |
+
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
|
42 |
+
nn.ReLU(inplace=True),
|
43 |
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
|
44 |
+
nn.ReLU(inplace=True),
|
45 |
+
nn.Conv2d(out_channels, out_channels, 1, 1, 0),
|
46 |
+
)
|
47 |
+
self.relu = nn.ReLU(inplace=True)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
out = self.body(x)
|
51 |
+
out = self.relu(out + x)
|
52 |
+
return out
|
53 |
+
|
54 |
+
|
55 |
+
class MergeRun(nn.Module):
|
56 |
+
""" Merge-and-run unit.
|
57 |
+
|
58 |
+
This unit contains two branches with different dilated convolutions,
|
59 |
+
followed by a convolution to process the concatenated features.
|
60 |
+
|
61 |
+
Paper: Real Image Denoising with Feature Attention
|
62 |
+
Ref git repo: https://github.com/saeed-anwar/RIDNet
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
|
66 |
+
super(MergeRun, self).__init__()
|
67 |
+
|
68 |
+
self.dilation1 = nn.Sequential(
|
69 |
+
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True),
|
70 |
+
nn.Conv2d(out_channels, out_channels, kernel_size, stride, 2, 2), nn.ReLU(inplace=True))
|
71 |
+
self.dilation2 = nn.Sequential(
|
72 |
+
nn.Conv2d(in_channels, out_channels, kernel_size, stride, 3, 3), nn.ReLU(inplace=True),
|
73 |
+
nn.Conv2d(out_channels, out_channels, kernel_size, stride, 4, 4), nn.ReLU(inplace=True))
|
74 |
+
|
75 |
+
self.aggregation = nn.Sequential(
|
76 |
+
nn.Conv2d(out_channels * 2, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True))
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
dilation1 = self.dilation1(x)
|
80 |
+
dilation2 = self.dilation2(x)
|
81 |
+
out = torch.cat([dilation1, dilation2], dim=1)
|
82 |
+
out = self.aggregation(out)
|
83 |
+
out = out + x
|
84 |
+
return out
|
85 |
+
|
86 |
+
|
87 |
+
class ChannelAttention(nn.Module):
|
88 |
+
"""Channel attention.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
num_feat (int): Channel number of intermediate features.
|
92 |
+
squeeze_factor (int): Channel squeeze factor. Default:
|
93 |
+
"""
|
94 |
+
|
95 |
+
def __init__(self, mid_channels, squeeze_factor=16):
|
96 |
+
super(ChannelAttention, self).__init__()
|
97 |
+
self.attention = nn.Sequential(
|
98 |
+
nn.AdaptiveAvgPool2d(1), nn.Conv2d(mid_channels, mid_channels // squeeze_factor, 1, padding=0),
|
99 |
+
nn.ReLU(inplace=True), nn.Conv2d(mid_channels // squeeze_factor, mid_channels, 1, padding=0), nn.Sigmoid())
|
100 |
+
|
101 |
+
def forward(self, x):
|
102 |
+
y = self.attention(x)
|
103 |
+
return x * y
|
104 |
+
|
105 |
+
|
106 |
+
class EAM(nn.Module):
|
107 |
+
"""Enhancement attention modules (EAM) in RIDNet.
|
108 |
+
|
109 |
+
This module contains a merge-and-run unit, a residual block,
|
110 |
+
an enhanced residual block and a feature attention unit.
|
111 |
+
|
112 |
+
Attributes:
|
113 |
+
merge: The merge-and-run unit.
|
114 |
+
block1: The residual block.
|
115 |
+
block2: The enhanced residual block.
|
116 |
+
ca: The feature/channel attention unit.
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(self, in_channels, mid_channels, out_channels):
|
120 |
+
super(EAM, self).__init__()
|
121 |
+
|
122 |
+
self.merge = MergeRun(in_channels, mid_channels)
|
123 |
+
self.block1 = ResidualBlockNoBN(mid_channels)
|
124 |
+
self.block2 = EResidualBlockNoBN(mid_channels, out_channels)
|
125 |
+
self.ca = ChannelAttention(out_channels)
|
126 |
+
# The residual block in the paper contains a relu after addition.
|
127 |
+
self.relu = nn.ReLU(inplace=True)
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
out = self.merge(x)
|
131 |
+
out = self.relu(self.block1(out))
|
132 |
+
out = self.block2(out)
|
133 |
+
out = self.ca(out)
|
134 |
+
return out
|
135 |
+
|
136 |
+
|
137 |
+
@ARCH_REGISTRY.register()
|
138 |
+
class RIDNet(nn.Module):
|
139 |
+
"""RIDNet: Real Image Denoising with Feature Attention.
|
140 |
+
|
141 |
+
Ref git repo: https://github.com/saeed-anwar/RIDNet
|
142 |
+
|
143 |
+
Args:
|
144 |
+
in_channels (int): Channel number of inputs.
|
145 |
+
mid_channels (int): Channel number of EAM modules.
|
146 |
+
Default: 64.
|
147 |
+
out_channels (int): Channel number of outputs.
|
148 |
+
num_block (int): Number of EAM. Default: 4.
|
149 |
+
img_range (float): Image range. Default: 255.
|
150 |
+
rgb_mean (tuple[float]): Image mean in RGB orders.
|
151 |
+
Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
|
152 |
+
"""
|
153 |
+
|
154 |
+
def __init__(self,
|
155 |
+
in_channels,
|
156 |
+
mid_channels,
|
157 |
+
out_channels,
|
158 |
+
num_block=4,
|
159 |
+
img_range=255.,
|
160 |
+
rgb_mean=(0.4488, 0.4371, 0.4040),
|
161 |
+
rgb_std=(1.0, 1.0, 1.0)):
|
162 |
+
super(RIDNet, self).__init__()
|
163 |
+
|
164 |
+
self.sub_mean = MeanShift(img_range, rgb_mean, rgb_std)
|
165 |
+
self.add_mean = MeanShift(img_range, rgb_mean, rgb_std, 1)
|
166 |
+
|
167 |
+
self.head = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
|
168 |
+
self.body = make_layer(
|
169 |
+
EAM, num_block, in_channels=mid_channels, mid_channels=mid_channels, out_channels=mid_channels)
|
170 |
+
self.tail = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
|
171 |
+
|
172 |
+
self.relu = nn.ReLU(inplace=True)
|
173 |
+
|
174 |
+
def forward(self, x):
|
175 |
+
res = self.sub_mean(x)
|
176 |
+
res = self.tail(self.body(self.relu(self.head(res))))
|
177 |
+
res = self.add_mean(res)
|
178 |
+
|
179 |
+
out = x + res
|
180 |
+
return out
|
basicsr/archs/rrdbnet_arch.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
from .arch_util import default_init_weights, make_layer, pixel_unshuffle
|
7 |
+
|
8 |
+
|
9 |
+
class ResidualDenseBlock(nn.Module):
|
10 |
+
"""Residual Dense Block.
|
11 |
+
|
12 |
+
Used in RRDB block in ESRGAN.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
num_feat (int): Channel number of intermediate features.
|
16 |
+
num_grow_ch (int): Channels for each growth.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, num_feat=64, num_grow_ch=32):
|
20 |
+
super(ResidualDenseBlock, self).__init__()
|
21 |
+
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
22 |
+
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
23 |
+
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
24 |
+
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
25 |
+
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
26 |
+
|
27 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
28 |
+
|
29 |
+
# initialization
|
30 |
+
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x1 = self.lrelu(self.conv1(x))
|
34 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
35 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
36 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
37 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
38 |
+
# Empirically, we use 0.2 to scale the residual for better performance
|
39 |
+
return x5 * 0.2 + x
|
40 |
+
|
41 |
+
|
42 |
+
class RRDB(nn.Module):
|
43 |
+
"""Residual in Residual Dense Block.
|
44 |
+
|
45 |
+
Used in RRDB-Net in ESRGAN.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
num_feat (int): Channel number of intermediate features.
|
49 |
+
num_grow_ch (int): Channels for each growth.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, num_feat, num_grow_ch=32):
|
53 |
+
super(RRDB, self).__init__()
|
54 |
+
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
55 |
+
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
56 |
+
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
out = self.rdb1(x)
|
60 |
+
out = self.rdb2(out)
|
61 |
+
out = self.rdb3(out)
|
62 |
+
# Empirically, we use 0.2 to scale the residual for better performance
|
63 |
+
return out * 0.2 + x
|
64 |
+
|
65 |
+
|
66 |
+
@ARCH_REGISTRY.register()
|
67 |
+
class RRDBNet(nn.Module):
|
68 |
+
"""Networks consisting of Residual in Residual Dense Block, which is used
|
69 |
+
in ESRGAN.
|
70 |
+
|
71 |
+
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
|
72 |
+
|
73 |
+
We extend ESRGAN for scale x2 and scale x1.
|
74 |
+
Note: This is one option for scale 1, scale 2 in RRDBNet.
|
75 |
+
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
|
76 |
+
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
num_in_ch (int): Channel number of inputs.
|
80 |
+
num_out_ch (int): Channel number of outputs.
|
81 |
+
num_feat (int): Channel number of intermediate features.
|
82 |
+
Default: 64
|
83 |
+
num_block (int): Block number in the trunk network. Defaults: 23
|
84 |
+
num_grow_ch (int): Channels for each growth. Default: 32.
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
|
88 |
+
super(RRDBNet, self).__init__()
|
89 |
+
self.scale = scale
|
90 |
+
if scale == 2:
|
91 |
+
num_in_ch = num_in_ch * 4
|
92 |
+
elif scale == 1:
|
93 |
+
num_in_ch = num_in_ch * 16
|
94 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
95 |
+
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
|
96 |
+
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
97 |
+
# upsample
|
98 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
99 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
100 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
101 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
102 |
+
|
103 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
if self.scale == 2:
|
107 |
+
feat = pixel_unshuffle(x, scale=2)
|
108 |
+
elif self.scale == 1:
|
109 |
+
feat = pixel_unshuffle(x, scale=4)
|
110 |
+
else:
|
111 |
+
feat = x
|
112 |
+
feat = self.conv_first(feat)
|
113 |
+
body_feat = self.conv_body(self.body(feat))
|
114 |
+
feat = feat + body_feat
|
115 |
+
# upsample
|
116 |
+
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
117 |
+
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
118 |
+
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
119 |
+
return out
|
basicsr/archs/spynet_arch.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
7 |
+
from .arch_util import flow_warp
|
8 |
+
|
9 |
+
|
10 |
+
class BasicModule(nn.Module):
|
11 |
+
"""Basic Module for SpyNet.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
super(BasicModule, self).__init__()
|
16 |
+
|
17 |
+
self.basic_module = nn.Sequential(
|
18 |
+
nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
|
19 |
+
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
|
20 |
+
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
|
21 |
+
nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
|
22 |
+
nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
|
23 |
+
|
24 |
+
def forward(self, tensor_input):
|
25 |
+
return self.basic_module(tensor_input)
|
26 |
+
|
27 |
+
|
28 |
+
@ARCH_REGISTRY.register()
|
29 |
+
class SpyNet(nn.Module):
|
30 |
+
"""SpyNet architecture.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
load_path (str): path for pretrained SpyNet. Default: None.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, load_path=None):
|
37 |
+
super(SpyNet, self).__init__()
|
38 |
+
self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)])
|
39 |
+
if load_path:
|
40 |
+
self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
|
41 |
+
|
42 |
+
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
43 |
+
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
44 |
+
|
45 |
+
def preprocess(self, tensor_input):
|
46 |
+
tensor_output = (tensor_input - self.mean) / self.std
|
47 |
+
return tensor_output
|
48 |
+
|
49 |
+
def process(self, ref, supp):
|
50 |
+
flow = []
|
51 |
+
|
52 |
+
ref = [self.preprocess(ref)]
|
53 |
+
supp = [self.preprocess(supp)]
|
54 |
+
|
55 |
+
for level in range(5):
|
56 |
+
ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
|
57 |
+
supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
|
58 |
+
|
59 |
+
flow = ref[0].new_zeros(
|
60 |
+
[ref[0].size(0), 2,
|
61 |
+
int(math.floor(ref[0].size(2) / 2.0)),
|
62 |
+
int(math.floor(ref[0].size(3) / 2.0))])
|
63 |
+
|
64 |
+
for level in range(len(ref)):
|
65 |
+
upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
|
66 |
+
|
67 |
+
if upsampled_flow.size(2) != ref[level].size(2):
|
68 |
+
upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate')
|
69 |
+
if upsampled_flow.size(3) != ref[level].size(3):
|
70 |
+
upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate')
|
71 |
+
|
72 |
+
flow = self.basic_module[level](torch.cat([
|
73 |
+
ref[level],
|
74 |
+
flow_warp(
|
75 |
+
supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'),
|
76 |
+
upsampled_flow
|
77 |
+
], 1)) + upsampled_flow
|
78 |
+
|
79 |
+
return flow
|
80 |
+
|
81 |
+
def forward(self, ref, supp):
|
82 |
+
assert ref.size() == supp.size()
|
83 |
+
|
84 |
+
h, w = ref.size(2), ref.size(3)
|
85 |
+
w_floor = math.floor(math.ceil(w / 32.0) * 32.0)
|
86 |
+
h_floor = math.floor(math.ceil(h / 32.0) * 32.0)
|
87 |
+
|
88 |
+
ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
|
89 |
+
supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
|
90 |
+
|
91 |
+
flow = F.interpolate(input=self.process(ref, supp), size=(h, w), mode='bilinear', align_corners=False)
|
92 |
+
|
93 |
+
flow[:, 0, :, :] *= float(w) / float(w_floor)
|
94 |
+
flow[:, 1, :, :] *= float(h) / float(h_floor)
|
95 |
+
|
96 |
+
return flow
|
basicsr/archs/srresnet_arch.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn as nn
|
2 |
+
from torch.nn import functional as F
|
3 |
+
|
4 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
5 |
+
from .arch_util import ResidualBlockNoBN, default_init_weights, make_layer
|
6 |
+
|
7 |
+
|
8 |
+
@ARCH_REGISTRY.register()
|
9 |
+
class MSRResNet(nn.Module):
|
10 |
+
"""Modified SRResNet.
|
11 |
+
|
12 |
+
A compacted version modified from SRResNet in
|
13 |
+
"Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"
|
14 |
+
It uses residual blocks without BN, similar to EDSR.
|
15 |
+
Currently, it supports x2, x3 and x4 upsampling scale factor.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
19 |
+
num_out_ch (int): Channel number of outputs. Default: 3.
|
20 |
+
num_feat (int): Channel number of intermediate features. Default: 64.
|
21 |
+
num_block (int): Block number in the body network. Default: 16.
|
22 |
+
upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16, upscale=4):
|
26 |
+
super(MSRResNet, self).__init__()
|
27 |
+
self.upscale = upscale
|
28 |
+
|
29 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
30 |
+
self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat)
|
31 |
+
|
32 |
+
# upsampling
|
33 |
+
if self.upscale in [2, 3]:
|
34 |
+
self.upconv1 = nn.Conv2d(num_feat, num_feat * self.upscale * self.upscale, 3, 1, 1)
|
35 |
+
self.pixel_shuffle = nn.PixelShuffle(self.upscale)
|
36 |
+
elif self.upscale == 4:
|
37 |
+
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
|
38 |
+
self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
|
39 |
+
self.pixel_shuffle = nn.PixelShuffle(2)
|
40 |
+
|
41 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
42 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
43 |
+
|
44 |
+
# activation function
|
45 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
46 |
+
|
47 |
+
# initialization
|
48 |
+
default_init_weights([self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1)
|
49 |
+
if self.upscale == 4:
|
50 |
+
default_init_weights(self.upconv2, 0.1)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
feat = self.lrelu(self.conv_first(x))
|
54 |
+
out = self.body(feat)
|
55 |
+
|
56 |
+
if self.upscale == 4:
|
57 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
|
58 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
|
59 |
+
elif self.upscale in [2, 3]:
|
60 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
|
61 |
+
|
62 |
+
out = self.conv_last(self.lrelu(self.conv_hr(out)))
|
63 |
+
base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
|
64 |
+
out += base
|
65 |
+
return out
|
basicsr/archs/srvgg_arch.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn as nn
|
2 |
+
from torch.nn import functional as F
|
3 |
+
|
4 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
5 |
+
|
6 |
+
|
7 |
+
@ARCH_REGISTRY.register(suffix='basicsr')
|
8 |
+
class SRVGGNetCompact(nn.Module):
|
9 |
+
"""A compact VGG-style network structure for super-resolution.
|
10 |
+
|
11 |
+
It is a compact network structure, which performs upsampling in the last layer and no convolution is
|
12 |
+
conducted on the HR feature space.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
16 |
+
num_out_ch (int): Channel number of outputs. Default: 3.
|
17 |
+
num_feat (int): Channel number of intermediate features. Default: 64.
|
18 |
+
num_conv (int): Number of convolution layers in the body network. Default: 16.
|
19 |
+
upscale (int): Upsampling factor. Default: 4.
|
20 |
+
act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
24 |
+
super(SRVGGNetCompact, self).__init__()
|
25 |
+
self.num_in_ch = num_in_ch
|
26 |
+
self.num_out_ch = num_out_ch
|
27 |
+
self.num_feat = num_feat
|
28 |
+
self.num_conv = num_conv
|
29 |
+
self.upscale = upscale
|
30 |
+
self.act_type = act_type
|
31 |
+
|
32 |
+
self.body = nn.ModuleList()
|
33 |
+
# the first conv
|
34 |
+
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
35 |
+
# the first activation
|
36 |
+
if act_type == 'relu':
|
37 |
+
activation = nn.ReLU(inplace=True)
|
38 |
+
elif act_type == 'prelu':
|
39 |
+
activation = nn.PReLU(num_parameters=num_feat)
|
40 |
+
elif act_type == 'leakyrelu':
|
41 |
+
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
42 |
+
self.body.append(activation)
|
43 |
+
|
44 |
+
# the body structure
|
45 |
+
for _ in range(num_conv):
|
46 |
+
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
47 |
+
# activation
|
48 |
+
if act_type == 'relu':
|
49 |
+
activation = nn.ReLU(inplace=True)
|
50 |
+
elif act_type == 'prelu':
|
51 |
+
activation = nn.PReLU(num_parameters=num_feat)
|
52 |
+
elif act_type == 'leakyrelu':
|
53 |
+
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
54 |
+
self.body.append(activation)
|
55 |
+
|
56 |
+
# the last conv
|
57 |
+
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
58 |
+
# upsample
|
59 |
+
self.upsampler = nn.PixelShuffle(upscale)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
out = x
|
63 |
+
for i in range(0, len(self.body)):
|
64 |
+
out = self.body[i](out)
|
65 |
+
|
66 |
+
out = self.upsampler(out)
|
67 |
+
# add the nearest upsampled image, so that the network learns the residual
|
68 |
+
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
69 |
+
out += base
|
70 |
+
return out
|