import gradio as gr import torch from Generator import Generator from torchvision.utils import save_image generator = Generator(1) generator.load_state_dict(torch.load("./generator.pth", map_location=torch.device('cpu'))) generator.eval() def generate(seed, num_img): torch.manual_seed(seed) z = torch.randn(num_img, 100, 1, 1) fake_img = generator(z) fake_img = fake_img.detach() fake_img = fake_img.squeeze() save_image(fake_img, "fake_img.png", normalize=True) return 'fake_img.png' with gr.Blocks() as demo: gr.Markdown("DCGAN model that generate fake images") image_input = [ gr.Slider(0, 1000, label='Seed'), gr.Slider(4, 64, label='Number of images', step=1), ] image_output = gr.Image() image_button = gr.Button("Generate") image_button.click(generate, inputs=image_input, outputs=image_output) demo.launch()