Meloo's picture
Upload 4 files
188d68e verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class SimpleSAFM(nn.Module):
def __init__(self, dim):
super().__init__()
self.proj = nn.Conv2d(dim, dim, 3, 1, 1, bias=False)
self.dwconv = nn.Conv2d(dim//2, dim//2, 3, 1, 1, groups=dim//2, bias=False)
self.out = nn.Conv2d(dim, dim, 1, 1, 0, bias=False)
self.act = nn.GELU()
def forward(self, x):
h, w = x.size()[-2:]
x0, x1 = self.proj(x).chunk(2, dim=1)
x2 = F.adaptive_max_pool2d(x0, (h//8, w//8))
x2 = self.dwconv(x2)
x2 = F.interpolate(x2, size=(h, w), mode='bilinear')
x2 = self.act(x2) * x0
x = torch.cat([x1, x2], dim=1)
x = self.out(self.act(x))
return x
class CCM(nn.Module):
def __init__(self, dim, ffn_scale):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(dim, int(dim*ffn_scale), 3, 1, 1, bias=False),
nn.GELU(),
nn.Conv2d(int(dim*ffn_scale), dim, 1, 1, 0, bias=False)
)
def forward(self, x):
return self.conv(x)
class AttBlock(nn.Module):
def __init__(self, dim, ffn_scale):
super().__init__()
self.conv1 = SimpleSAFM(dim)
self.conv2 = CCM(dim, ffn_scale)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
return out
class SAFMNPP(nn.Module):
def __init__(self, dim=32, n_blocks=2, ffn_scale=1.5, upscaling_factor=4):
super().__init__()
self.scale = upscaling_factor
self.to_feat = nn.Conv2d(3, dim, 3, 1, 1, bias=False)
self.feats = nn.Sequential(*[AttBlock(dim, ffn_scale) for _ in range(n_blocks)])
self.to_img = nn.Sequential(
nn.Conv2d(dim, 3 * upscaling_factor**2, 3, 1, 1, bias=False),
nn.PixelShuffle(upscaling_factor)
)
def forward(self, x):
b = x.shape[0]
x = rearrange(x, 'b t c h w -> (b t) c h w')
x = self.to_feat(x)
x = self.feats(x) + x
x = self.to_img(x)
x = rearrange(x, '(b t) c h w -> b t c h w', b = b)
return x
if __name__== '__main__':
#############Test Model Complexity #############
# import time
from fvcore.nn import flop_count_table, FlopCountAnalysis, ActivationCountAnalysis
from tqdm import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
scale = 4
h, w = 3840, 2160
# scale = 3
# h, w = 1920, 1080
x = torch.randn(1, 30, 3, h// scale, w // scale)
model = SAFMNPP(upscaling_factor=scale)
model.load_state_dict(torch.load('light_safmnpp.pth')['params'], strict=True)
# output = model(x)
print(model)
# print(flop_count_table(FlopCountAnalysis(model, x), activations=ActivationCountAnalysis(model, x)))
# print(output.shape)
# num_frame = 30
# clip = 5
# torch.cuda.current_device()
# torch.cuda.empty_cache()
# torch.backends.cudnn.benchmark = False
# start = torch.cuda.Event(enable_timing=True)
# end = torch.cuda.Event(enable_timing=True)
# runtime = 0
# dummy_input = torch.randn((1, num_frame, 3, h // scale, w // scale)).to(device)
# # warm_up
# model.eval().to(device)
# with torch.no_grad():
# for _ in tqdm(range(clip)):
# _ = model(dummy_input)
# for _ in tqdm(range(clip)):
# start.record()
# _ = model(dummy_input)
# end.record()
# torch.cuda.synchronize()
# runtime += start.elapsed_time(end)
# per_frame_time = runtime / (num_frame * clip)
# print(f'{model.__class__.__name__} {num_frame * clip} Number Frames x{scale}SR Per Frame Time: {per_frame_time:.6f} ms')
# print(f'{model.__class__.__name__} x{scale}SR FPS: {(1000 / per_frame_time):.6f} FPS')