|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from typing import Any |
|
|
|
import torch |
|
from torch import Tensor |
|
from torch import nn |
|
from torch.nn import functional as F_torch |
|
from torchvision import models |
|
from torchvision import transforms |
|
from torchvision.models.feature_extraction import create_feature_extractor |
|
|
|
__all__ = [ |
|
"SRResNet", "Discriminator", |
|
"srresnet_x4", "discriminator", "content_loss", |
|
] |
|
|
|
|
|
class SRResNet(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
channels: int, |
|
num_rcb: int, |
|
upscale_factor: int |
|
) -> None: |
|
super(SRResNet, self).__init__() |
|
|
|
self.conv1 = nn.Sequential( |
|
nn.Conv2d(in_channels, channels, (9, 9), (1, 1), (4, 4)), |
|
nn.PReLU(), |
|
) |
|
|
|
|
|
trunk = [] |
|
for _ in range(num_rcb): |
|
trunk.append(_ResidualConvBlock(channels)) |
|
self.trunk = nn.Sequential(*trunk) |
|
|
|
|
|
self.conv2 = nn.Sequential( |
|
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False), |
|
nn.BatchNorm2d(channels), |
|
) |
|
|
|
|
|
upsampling = [] |
|
if upscale_factor == 2 or upscale_factor == 4 or upscale_factor == 8: |
|
for _ in range(int(math.log(upscale_factor, 2))): |
|
upsampling.append(_UpsampleBlock(channels, 2)) |
|
elif upscale_factor == 3: |
|
upsampling.append(_UpsampleBlock(channels, 3)) |
|
self.upsampling = nn.Sequential(*upsampling) |
|
|
|
|
|
self.conv3 = nn.Conv2d(channels, out_channels, (9, 9), (1, 1), (4, 4)) |
|
|
|
|
|
self._initialize_weights() |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
return self._forward_impl(x) |
|
|
|
|
|
def _forward_impl(self, x: Tensor) -> Tensor: |
|
out1 = self.conv1(x) |
|
out = self.trunk(out1) |
|
out2 = self.conv2(out) |
|
out = torch.add(out1, out2) |
|
out = self.upsampling(out) |
|
out = self.conv3(out) |
|
|
|
out = torch.clamp_(out, 0.0, 1.0) |
|
|
|
return out |
|
|
|
def _initialize_weights(self) -> None: |
|
for module in self.modules(): |
|
if isinstance(module, nn.Conv2d): |
|
nn.init.kaiming_normal_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
elif isinstance(module, nn.BatchNorm2d): |
|
nn.init.constant_(module.weight, 1) |
|
|
|
|
|
class Discriminator(nn.Module): |
|
def __init__(self) -> None: |
|
super(Discriminator, self).__init__() |
|
self.features = nn.Sequential( |
|
|
|
nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=True), |
|
nn.LeakyReLU(0.2, True), |
|
|
|
nn.Conv2d(64, 64, (3, 3), (2, 2), (1, 1), bias=False), |
|
nn.BatchNorm2d(64), |
|
nn.LeakyReLU(0.2, True), |
|
nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False), |
|
nn.BatchNorm2d(128), |
|
nn.LeakyReLU(0.2, True), |
|
|
|
nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1), bias=False), |
|
nn.BatchNorm2d(128), |
|
nn.LeakyReLU(0.2, True), |
|
nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False), |
|
nn.BatchNorm2d(256), |
|
nn.LeakyReLU(0.2, True), |
|
|
|
nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), bias=False), |
|
nn.BatchNorm2d(256), |
|
nn.LeakyReLU(0.2, True), |
|
nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False), |
|
nn.BatchNorm2d(512), |
|
nn.LeakyReLU(0.2, True), |
|
|
|
nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), bias=False), |
|
nn.BatchNorm2d(512), |
|
nn.LeakyReLU(0.2, True), |
|
) |
|
|
|
self.classifier = nn.Sequential( |
|
nn.Linear(512 * 6 * 6, 1024), |
|
nn.LeakyReLU(0.2, True), |
|
nn.Linear(1024, 1), |
|
) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
|
assert x.shape[2] == 96 and x.shape[3] == 96, "Image shape must equal 96x96" |
|
|
|
out = self.features(x) |
|
out = torch.flatten(out, 1) |
|
out = self.classifier(out) |
|
|
|
return out |
|
|
|
|
|
class _ResidualConvBlock(nn.Module): |
|
def __init__(self, channels: int) -> None: |
|
super(_ResidualConvBlock, self).__init__() |
|
self.rcb = nn.Sequential( |
|
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False), |
|
nn.BatchNorm2d(channels), |
|
nn.PReLU(), |
|
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False), |
|
nn.BatchNorm2d(channels), |
|
) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
identity = x |
|
|
|
out = self.rcb(x) |
|
|
|
out = torch.add(out, identity) |
|
|
|
return out |
|
|
|
|
|
class _UpsampleBlock(nn.Module): |
|
def __init__(self, channels: int, upscale_factor: int) -> None: |
|
super(_UpsampleBlock, self).__init__() |
|
self.upsample_block = nn.Sequential( |
|
nn.Conv2d(channels, channels * upscale_factor * upscale_factor, (3, 3), (1, 1), (1, 1)), |
|
nn.PixelShuffle(2), |
|
nn.PReLU(), |
|
) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
out = self.upsample_block(x) |
|
|
|
return out |
|
|
|
|
|
class _ContentLoss(nn.Module): |
|
"""Constructs a content loss function based on the VGG19 network. |
|
Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image. |
|
|
|
Paper reference list: |
|
-`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network <https://arxiv.org/pdf/1609.04802.pdf>` paper. |
|
-`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks <https://arxiv.org/pdf/1809.00219.pdf>` paper. |
|
-`Perceptual Extreme Super Resolution Network with Receptive Field Block <https://arxiv.org/pdf/2005.12597.pdf>` paper. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
feature_model_extractor_node: str, |
|
feature_model_normalize_mean: list, |
|
feature_model_normalize_std: list |
|
) -> None: |
|
super(_ContentLoss, self).__init__() |
|
|
|
self.feature_model_extractor_node = feature_model_extractor_node |
|
|
|
model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1) |
|
|
|
self.feature_extractor = create_feature_extractor(model, [feature_model_extractor_node]) |
|
|
|
self.feature_extractor.eval() |
|
|
|
|
|
|
|
self.normalize = transforms.Normalize(feature_model_normalize_mean, feature_model_normalize_std) |
|
|
|
|
|
for model_parameters in self.feature_extractor.parameters(): |
|
model_parameters.requires_grad = False |
|
|
|
def forward(self, sr_tensor: Tensor, gt_tensor: Tensor) -> Tensor: |
|
|
|
sr_tensor = self.normalize(sr_tensor) |
|
gt_tensor = self.normalize(gt_tensor) |
|
|
|
sr_feature = self.feature_extractor(sr_tensor)[self.feature_model_extractor_node] |
|
gt_feature = self.feature_extractor(gt_tensor)[self.feature_model_extractor_node] |
|
|
|
|
|
loss = F_torch.mse_loss(sr_feature, gt_feature) |
|
|
|
return loss |
|
|
|
|
|
def srresnet_x4(**kwargs: Any) -> SRResNet: |
|
model = SRResNet(upscale_factor=4, **kwargs) |
|
|
|
return model |
|
|
|
|
|
def discriminator() -> Discriminator: |
|
model = Discriminator() |
|
|
|
return model |
|
|
|
|
|
def content_loss(**kwargs: Any) -> _ContentLoss: |
|
content_loss = _ContentLoss(**kwargs) |
|
|
|
return content_loss |
|
|