|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
|
|
class SimpleSAFM(nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
|
|
self.proj = nn.Conv2d(dim, dim, 3, 1, 1, bias=False)
|
|
self.dwconv = nn.Conv2d(dim//2, dim//2, 3, 1, 1, groups=dim//2, bias=False)
|
|
self.out = nn.Conv2d(dim, dim, 1, 1, 0, bias=False)
|
|
self.act = nn.GELU()
|
|
|
|
def forward(self, x):
|
|
h, w = x.size()[-2:]
|
|
|
|
x0, x1 = self.proj(x).chunk(2, dim=1)
|
|
|
|
x2 = F.adaptive_max_pool2d(x0, (h//8, w//8))
|
|
x2 = self.dwconv(x2)
|
|
x2 = F.interpolate(x2, size=(h, w), mode='bilinear')
|
|
x2 = self.act(x2) * x0
|
|
|
|
x = torch.cat([x1, x2], dim=1)
|
|
x = self.out(self.act(x))
|
|
return x
|
|
|
|
|
|
class CCM(nn.Module):
|
|
def __init__(self, dim, ffn_scale):
|
|
super().__init__()
|
|
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(dim, int(dim*ffn_scale), 3, 1, 1, bias=False),
|
|
nn.GELU(),
|
|
nn.Conv2d(int(dim*ffn_scale), dim, 1, 1, 0, bias=False)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
class AttBlock(nn.Module):
|
|
def __init__(self, dim, ffn_scale):
|
|
super().__init__()
|
|
|
|
self.conv1 = SimpleSAFM(dim)
|
|
self.conv2 = CCM(dim, ffn_scale)
|
|
|
|
def forward(self, x):
|
|
|
|
out = self.conv1(x)
|
|
out = self.conv2(out)
|
|
return out
|
|
|
|
class SAFMNPP(nn.Module):
|
|
def __init__(self, dim=32, n_blocks=2, ffn_scale=1.5, upscaling_factor=4):
|
|
super().__init__()
|
|
self.scale = upscaling_factor
|
|
|
|
self.to_feat = nn.Conv2d(3, dim, 3, 1, 1, bias=False)
|
|
|
|
self.feats = nn.Sequential(*[AttBlock(dim, ffn_scale) for _ in range(n_blocks)])
|
|
|
|
self.to_img = nn.Sequential(
|
|
nn.Conv2d(dim, 3 * upscaling_factor**2, 3, 1, 1, bias=False),
|
|
nn.PixelShuffle(upscaling_factor)
|
|
)
|
|
|
|
def forward(self, x):
|
|
|
|
b = x.shape[0]
|
|
x = rearrange(x, 'b t c h w -> (b t) c h w')
|
|
x = self.to_feat(x)
|
|
x = self.feats(x) + x
|
|
x = self.to_img(x)
|
|
x = rearrange(x, '(b t) c h w -> b t c h w', b = b)
|
|
return x
|
|
|
|
|
|
|
|
|
|
if __name__== '__main__':
|
|
|
|
|
|
from fvcore.nn import flop_count_table, FlopCountAnalysis, ActivationCountAnalysis
|
|
from tqdm import tqdm
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
scale = 4
|
|
h, w = 3840, 2160
|
|
|
|
|
|
|
|
|
|
x = torch.randn(1, 30, 3, h// scale, w // scale)
|
|
|
|
model = SAFMNPP(upscaling_factor=scale)
|
|
model.load_state_dict(torch.load('light_safmnpp.pth')['params'], strict=True)
|
|
|
|
|
|
print(model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|