swin2sr / app.py
jjourney1125's picture
Init commit
19bb687
raw
history blame
1.72 kB
import os
import cv2
import gradio as gr
from PIL import Image
import torch
model_path = 'experiments/pretrained_models/Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR.pth'
if os.path.exists(model_path):
print(f'loading model from {model_path}')
else:
os.makedirs(os.path.dirname(model_path), exist_ok=True)
url = 'https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/{}'.format(os.path.basename(model_path))
r = requests.get(url, allow_redirects=True)
print(f'downloading model {model_path}')
open(model_path, 'wb').write(r.content)
os.makedirs("test", exist_ok=True)
def inference(img):
cv2.imwrite("test/1.png", cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
# basewidth = 256
# wpercent = (basewidth/float(img.size[0]))
# hsize = int((float(img.size[1])*float(wpercent)))
# img = img.resize((basewidth,hsize), Image.ANTIALIAS)
#img.save("test/1.jpg", "JPEG")
os.system('python main_test_swin2sr.py --task real_sr --model_path experiments/pretrained_models/Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR.pth --folder_lq test --scale 4')
return 'results/swin2sr_real_sr_x4/1_Swin2SR.png'
title = "Swin2SR"
description = "Gradio demo for Swin2SR."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2209.11345' target='_blank'>Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration></a> | <a href='https://github.com/mv-lab/swin2sr' target='_blank'>Github Repo</a></p>"
examples=[['butterflyx4.png']]
gr.Interface(
inference,
"image",
"image",
title=title,
description=description,
article=article,
examples=examples,
).launch(enable_queue=True,
share=True)