|
import gradio as gr |
|
import torch |
|
from Generator import Generator |
|
from torchvision.utils import save_image |
|
|
|
|
|
generator = Generator(1) |
|
generator.load_state_dict(torch.load("./generator.pth", map_location=torch.device('cpu'))) |
|
generator.eval() |
|
|
|
|
|
def generate(seed, num_img): |
|
torch.manual_seed(seed) |
|
z = torch.randn(num_img, 100, 1, 1) |
|
fake_img = generator(z) |
|
fake_img = fake_img.detach() |
|
fake_img = fake_img.squeeze() |
|
save_image(fake_img, "fake_img.png", normalize=True) |
|
return 'fake_img.png' |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("DCGAN model that generate fake images") |
|
|
|
|
|
image_input = [ |
|
gr.Slider(0, 1000, label='Seed'), |
|
gr.Slider(4, 64, label='Number of images', step=1), |
|
] |
|
image_output = gr.Image() |
|
image_button = gr.Button("Generate") |
|
|
|
|
|
image_button.click(generate, inputs=image_input, outputs=image_output) |
|
|
|
demo.launch() |