Fashion_VAE / app.py
coledie
Update.
8ce78b0
raw
history blame
640 Bytes
import numpy as np
import torch
import gradio as gr
from vae import *
import matplotlib.image as mpimg
with open("vae.pt", "rb") as file:
vae = torch.load(file)
vae.eval()
def generate_image(filename):
image = mpimg.imread(filename)[:, :, 0] / 255
grayscale = vae(torch.Tensor(image))[0].reshape((28, 28))
return grayscale.detach().numpy()
examples = [f"examples/{i}.jpg" for i in range(10)]
demo = gr.Interface(generate_image,
gr.Image(type="filepath"),
"image",
examples,
title="VAE running on Fashion MNIST",
description=".",
article="...",
allow_flagging=False,
)
demo.launch()