import gradio as gr from PIL import Image import torch import torchvision.transforms as transforms import numpy as np from archs.model import UNet device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') #define some auxiliary functions pil_to_tensor = transforms.ToTensor() # define some parameters based on the run we want to make model = UNet() checkpoints = torch.load('./models/chk_6000.pt', map_location=device) model.load_state_dict(checkpoints['model_state_dict']) model = model.to(device) model.eval() def load_img (filename): img = Image.open(filename).convert("RGB") img_tensor = pil_to_tensor(img) return img_tensor def process_img(image): img = np.array(image) img = img / 255. img = img.astype(np.float32) y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device) with torch.no_grad(): x_hat = model(y) restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy() restored_img = np.clip(restored_img, 0. , 1.) restored_img = (restored_img * 255.0).round().astype(np.uint8) # float32 to uint8 return Image.fromarray(restored_img) #(image, Image.fromarray(restored_img)) title = "Efficient Low-Light Enhancement ✏️🖼️ 🤗" description = ''' ## [Inpainting for Autonomous Driving](https://github.com/cidautai) [Javier Abad Hernández](https://github.com/javierabad01) Fundación Cidaut > **Disclaimer:** please remember this is not a product, thus, you will notice some limitations. **This demo expects an image with some degradations.** Due to the GPU memory limitations, the app might crash if you feed a high-resolution image (2K, 4K).
''' examples = [['examples/inputs/1.png'], ['examples/inputs/2.png'], ['examples/inputs/3.png'], ["examples/inputs/4.png"], ["examples/inputs/5.png"]] css = """ .image-frame img, .image-container img { width: auto; height: auto; max-width: none; } """ demo = gr.Interface( fn = process_img, inputs = [ gr.Image(type = 'pil', label = 'input') ], outputs = [gr.Image(type='pil', label = 'output')], title = title, description = description, examples = examples, css = css ) if __name__ == '__main__': demo.launch()