DCGAN / app.py
AyoubMDL's picture
Add DCGAN app
4019e92
raw
history blame
916 Bytes
import gradio as gr
import torch
from Generator import Generator
from torchvision.utils import save_image
generator = Generator(1)
generator.load_state_dict(torch.load("./generator.pth", map_location=torch.device('cpu')))
generator.eval()
def generate(seed, num_img):
torch.manual_seed(seed)
z = torch.randn(num_img, 100, 1, 1)
fake_img = generator(z)
fake_img = fake_img.detach()
fake_img = fake_img.squeeze()
save_image(fake_img, "fake_img.png", normalize=True)
return 'fake_img.png'
with gr.Blocks() as demo:
gr.Markdown("DCGAN model that generate fake images")
image_input = [
gr.Slider(0, 1000, label='Seed'),
gr.Slider(4, 64, label='Number of images', step=1),
]
image_output = gr.Image()
image_button = gr.Button("Generate")
image_button.click(generate, inputs=image_input, outputs=image_output)
demo.launch(share=True)