SuperResolution / app.py
Hu
change PIL load image
e831a7b
raw
history blame
4.57 kB
import numpy as np
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from PIL import Image
title = "Super Resolution with CNN"
description = """
Your low resolution image will be reconstructed to high resolution with a scale of 2 with a convolutional neural network!<br>
Detailed training and dataset can be found on my [github repo](https://github.com/susuhu/super-resolution).<br>
"""
article = """
<div style='margin:20px auto;'>
<p>Sources:<p>
<p>πŸ“œ <a href="https://arxiv.org/abs/1501.00092">Image Super-Resolution Using Deep Convolutional Networks</a></p>
<p>πŸ“¦ Dataset <a href="https://github.com/eugenesiow/super-image-data">this GitHub repo</a></p>
</div>
"""
examples = [
["LR_image.png"],
["barbara.png"],
]
class SRCNNModel(nn.Module):
def __init__(self):
super(SRCNNModel, self).__init__()
self.conv1 = nn.Conv2d(1, 64, 9, padding=4)
self.conv2 = nn.Conv2d(64, 32, 1, padding=0)
self.conv3 = nn.Conv2d(32, 1, 5, padding=2)
def forward(self, x):
out = F.relu(self.conv1(x))
out = F.relu(self.conv2(out))
out = self.conv3(out)
return out
def pred_SRCNN(model, image, device, scale_factor=2):
"""
model: SRCNN model
image: low resolution image PILLOW image
scale_factor: scale factor for resolution
device: cuda or cpu
"""
model.to(device)
model.eval()
# open image, gradio opens image as nparray
image = Image.fromarray(image)
# split channels
y, cb, cr = image.convert("YCbCr").split()
# size will be used in image transform
original_size = y.size
# bicubic interpolate it to the original size
y_bicubic = transforms.Resize(
(original_size[1] * scale_factor, original_size[0] * scale_factor),
interpolation=transforms.InterpolationMode.BICUBIC,
)(y)
cb_bicubic = transforms.Resize(
(original_size[1] * scale_factor, original_size[0] * scale_factor),
interpolation=transforms.InterpolationMode.BICUBIC,
)(cb)
cr_bicubic = transforms.Resize(
(original_size[1] * scale_factor, original_size[0] * scale_factor),
interpolation=transforms.InterpolationMode.BICUBIC,
)(cr)
# turn it into tensor and add batch dimension
y_bicubic = transforms.ToTensor()(y_bicubic).to(device).unsqueeze(0)
# get the y channel SRCNN prediction
y_pred = model(y_bicubic)
# convert it to numpy image
y_pred = y_pred[0].cpu().detach().numpy()
# convert it into regular image pixel values
y_pred = y_pred * 255
y_pred.clip(0, 255)
# conver y channel from array to PIL image format for merging
y_pred_PIL = Image.fromarray(np.uint8(y_pred[0]), mode="L")
# merge the SRCNN y channel with cb cr channels
out_final = Image.merge("YCbCr", [y_pred_PIL, cb_bicubic, cr_bicubic]).convert(
"RGB"
)
image_bicubic = transforms.Resize(
(original_size[1] * scale_factor, original_size[0] * scale_factor),
interpolation=transforms.InterpolationMode.BICUBIC,
)(image)
return out_final, image_bicubic, image
# load model
# print("Loading SRCNN model...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SRCNNModel().to(device)
model.load_state_dict(
torch.load("SRCNNmodel_trained.pt", map_location=torch.device(device))
)
model.eval()
# print("SRCNN model loaded!")
# def image_grid(imgs, rows, cols):
# '''
# imgs:list of PILImage
# '''
# assert len(imgs) == rows*cols
# w, h = imgs[0].size
# grid = Image.new('RGB', size=(cols*w, rows*h))
# grid_w, grid_h = grid.size
# for i, img in enumerate(imgs):
# grid.paste(img, box=(i%cols*w, i//cols*h))
# return grid
def super_reso(image):
# gradio open image as np array
#image_array = np.asarray(image_path)
#image = Image.fromarray(image_array, mode="RGB")
# prediction
with torch.no_grad():
out_final, image_bicubic, image = pred_SRCNN(
model=model, image=image, device=device
)
# grid = image_grid([out_final,image_bicubic],1,2)
return out_final, image_bicubic
gr.Interface(
fn=super_reso,
inputs=gr.inputs.Image(label="Upload image"),
outputs=[
gr.outputs.Image(label="Convolutional neural network"),
gr.outputs.Image(label="Bicubic interpoloation"),
],
title=title,
description=description,
article=article,
examples=examples,
).launch()