SRex2 / app.py
ShahzebKhoso's picture
Update app.py
988617d verified
raw
history blame
1.63 kB
import gradio as gr
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
from PIL import Image
import numpy as np
import torch
import requests
# Load the model and processor
processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr")
model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr").to('cpu')
# Define the function for super-resolution
def super_resolve(image):
# Preprocess the input image
inputs = processor(images=image, return_tensors="pt").to('cpu')
# Perform super-resolution
with torch.no_grad():
outputs = model(**inputs)
# Get the reconstructed tensor from the outputs
reconstructed_tensor = outputs.reconstruction
# Move the tensor to the CPU and convert it to a NumPy array
image_tensor = reconstructed_tensor.squeeze().cpu()
image_np = image_tensor.permute(1, 2, 0).numpy() # Permute to make it HxWxC
# Rescale the values from [0, 1] to [0, 255]
image_np = np.clip(image_np, 0, 1)
image_np = (image_np * 255).astype(np.uint8)
# Convert the NumPy array back to an image
output_image = Image.fromarray(image_np)
return output_image
# Create the Gradio interface
inputs = gr.inputs.Image(type="pil", label="Upload an Image")
outputs = gr.outputs.Image(type="pil", label="Super-Resolved Image")
gr.Interface(fn=super_resolve, inputs=inputs, outputs=outputs, title="Image Super-Resolution with Swin2SR",
description="Upload an image to upscale it using the Swin2SR model for real-world super-resolution."
).launch()