d_proj / app.py
rahulved
Fixed app.py
d351f5c
raw
history blame
1.82 kB
import gradio as gr
import dnnlib
import numpy as np
import PIL.Image
import torch
import legacy
import pickle
import torchvision.transforms as transforms
from PIL import Image
network_pkl_a = 'stylegan3-r-afhqv2-512x512.pkl'
network_pkl_d = 'network-snapshot.pkl'
with open(network_pkl_d, 'rb') as f:
G = pickle.load(f)['G_ema'] # torch.nn.Module
with open(network_pkl_a, 'rb') as f:
G_a = pickle.load(f)['G_ema'] # torch.nn.Module
def gen_image(text):
if text=='show me':
return gen_image_d()
else:
return gen_image_a()
def gen_image_a():
z = torch.randn([1, G_a.z_dim]) # latent codes
c = None # class labels (not used in this example)
img = G_a(z, c)
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
image=PIL.Image.fromarray(img[0].numpy(), 'RGB')
transform = transforms.Resize((image.height * 2, image.width * 2), interpolation=transforms.InterpolationMode.BILINEAR)
upscaled_image = transform(image)
return upscaled_image
def gen_image_d():
z = torch.randn([1, G.z_dim]) # latent codes
c = None # class labels (not used in this example)
img = G(z, c)
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
image=PIL.Image.fromarray(img[0].numpy(), 'RGB')
transform = transforms.Resize((image.height * 2, image.width * 2), interpolation=transforms.InterpolationMode.BILINEAR)
upscaled_image = transform(image)
return upscaled_image
demo = gr.Interface(
fn=gen_image,
inputs=gr.Textbox(lines=2, placeholder="Prompt here..."),
outputs=gr.Image(type="pil"),
title="Text to Image Generator",
description="Enter text to generate an image using a custom PyTorch model."
)
if __name__ == "__main__":
demo.launch()