SRex2 / app.py
ShahzebKhoso's picture
Added citations to original paper
c2d9412 verified
import gradio as gr
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
from PIL import Image
import numpy as np
import torch
import requests
# Load the model and processor
processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr")
model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr").to('cpu')
# Define the function for super-resolution
def super_resolve(image):
# Preprocess the input image
inputs = processor(images=image, return_tensors="pt").to('cpu')
# Perform super-resolution
with torch.no_grad():
outputs = model(**inputs)
# Get the reconstructed tensor from the outputs
reconstructed_tensor = outputs.reconstruction
# Move the tensor to the CPU and convert it to a NumPy array
image_tensor = reconstructed_tensor.squeeze().cpu()
image_np = image_tensor.permute(1, 2, 0).numpy() # Permute to make it HxWxC
# Rescale the values from [0, 1] to [0, 255]
image_np = np.clip(image_np, 0, 1)
image_np = (image_np * 255).astype(np.uint8)
# Convert the NumPy array back to an image
output_image = Image.fromarray(image_np)
return output_image
# Create the Gradio interface
inputs = gr.Image(type="pil", label="Upload an Image")
outputs = gr.Image(type="pil", label="Super-Resolved Image")
dsc = """
Upload an image to generate a high-resolution version using the SeemoRe model.
**Citations:**
1. Marcos V. Conde, Ui-Jin Choi, Maxime Burchi, Radu Timofte, *Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration*, Proceedings of the European Conference on Computer Vision (ECCV) Workshops, 2022.
2. Jingyun Liang, Jiezhang Cao, Guolei Sun, Kai Zhang, Luc Van Gool, Radu Timofte, *SwinIR: Image Restoration Using Swin Transformer*, arXiv preprint arXiv:2108.10257, 2021.
For more details, refer to the original papers.
"""
gr.Interface(fn=super_resolve, inputs=inputs, outputs=outputs, title="Image Super-Resolution with Swin2SR",
description=dsc
).launch()