Spaces:
Runtime error
Runtime error
File size: 4,639 Bytes
33ac1eb 2d5af56 33ac1eb 840ed11 c373d06 33ac1eb c373d06 a58728b d5c8480 a58728b 33ac1eb aa0b24b a5d83a9 aa0b24b 97ecea7 aa0b24b 97ecea7 aa0b24b 97ecea7 aa0b24b 97ecea7 aa0b24b ebadb7a aa0b24b 33ac1eb a58728b 64c514e 33ac1eb 64c514e 33ac1eb a58728b 33ac1eb 840ed11 33ac1eb 840ed11 64c514e 840ed11 64c514e 33ac1eb ef80f7e 33ac1eb a5d83a9 a58728b ebadb7a ef80f7e a58728b 840ed11 64c514e 33ac1eb 0723193 a9e21ff 64c514e a9e21ff 64c514e 058c298 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import numpy as np
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from PIL import Image
title = "Super Resolution with CNN"
description = """
Your low resolution image will be reconstructed to high resolution with a scale of 2 with a convolutional neural network!<br>
Detailed training and dataset can be found on my [github repo](https://github.com/susuhu/super-resolution).<br>
"""
article = """
<div style='margin:20px auto;'>
<p>Sources:<p>
<p>📜 <a href="https://arxiv.org/abs/1501.00092">Image Super-Resolution Using Deep Convolutional Networks</a></p>
<p>📦 Dataset <a href="https://github.com/eugenesiow/super-image-data">this GitHub repo</a></p>
</div>
"""
examples = [
["peperoni.png"],
["barbara.png"],
]
class SRCNNModel(nn.Module):
def __init__(self):
super(SRCNNModel, self).__init__()
self.conv1 = nn.Conv2d(1, 64, 9, padding=4)
self.conv2 = nn.Conv2d(64, 32, 1, padding=0)
self.conv3 = nn.Conv2d(32, 1, 5, padding=2)
def forward(self, x):
out = F.relu(self.conv1(x))
out = F.relu(self.conv2(out))
out = self.conv3(out)
return out
def pred_SRCNN(model, image, device, scale_factor=2):
"""
model: SRCNN model
image: low resolution image PILLOW image
scale_factor: scale factor for resolution
device: cuda or cpu
"""
model.to(device)
model.eval()
# open image, gradio opens image as nparray
image = Image.fromarray(image)
# split channels
y, cb, cr = image.convert("YCbCr").split()
# size will be used in image transform
original_size = y.size
# bicubic interpolate it to the original size
y_bicubic = transforms.Resize(
(original_size[1] * scale_factor, original_size[0] * scale_factor),
interpolation=transforms.InterpolationMode.BICUBIC,
)(y)
cb_bicubic = transforms.Resize(
(original_size[1] * scale_factor, original_size[0] * scale_factor),
interpolation=transforms.InterpolationMode.BICUBIC,
)(cb)
cr_bicubic = transforms.Resize(
(original_size[1] * scale_factor, original_size[0] * scale_factor),
interpolation=transforms.InterpolationMode.BICUBIC,
)(cr)
# turn it into tensor and add batch dimension
y_bicubic = transforms.ToTensor()(y_bicubic).to(device).unsqueeze(0)
# get the y channel SRCNN prediction
y_pred = model(y_bicubic)
# convert it to numpy image
y_pred = y_pred[0].cpu().detach().numpy()
# convert it into regular image pixel values
y_pred = y_pred * 255
y_pred.clip(0, 255)
# conver y channel from array to PIL image format for merging
y_pred_PIL = Image.fromarray(np.uint8(y_pred[0]), mode="L")
# merge the SRCNN y channel with cb cr channels
out_final = Image.merge("YCbCr", [y_pred_PIL, cb_bicubic, cr_bicubic]).convert(
"RGB"
)
image_bicubic = transforms.Resize(
(original_size[1] * scale_factor, original_size[0] * scale_factor),
interpolation=transforms.InterpolationMode.BICUBIC,
)(image)
return out_final, image_bicubic
# load model
# print("Loading SRCNN model...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SRCNNModel().to(device)
model.load_state_dict(
torch.load("SRCNNmodel_trained.pt", map_location=torch.device(device))
)
model.eval()
# print("SRCNN model loaded!")
# def image_grid(imgs, rows, cols):
# '''
# imgs:list of PILImage
# '''
# assert len(imgs) == rows*cols
# w, h = imgs[0].size
# grid = Image.new('RGB', size=(cols*w, rows*h))
# grid_w, grid_h = grid.size
# for i, img in enumerate(imgs):
# grid.paste(img, box=(i%cols*w, i//cols*h))
# return grid
def super_reso(input_image):
# gradio open image as np array
#image_array = np.asarray(image_path)
#image = Image.fromarray(image_array, mode="RGB")
# prediction
with torch.no_grad():
out_final, image_bicubic = pred_SRCNN(
model=model, image=input_image, device=device
)
# grid = image_grid([out_final,image_bicubic],1,2)
return out_final, image_bicubic
gr.Interface(
fn=super_reso,
inputs=gr.Image(label="Upload image"),
outputs=[
gr.Image(label="Convolutional neural network"),
gr.Image(label="Bicubic interpoloation"),
],
title=title,
description=description,
article=article,
examples=examples,
).launch()
# TypeError: AsyncConnectionPool.__init__() got an unexpected keyword argument 'socket_options' |