import gradio as gr import dnnlib import numpy as np import PIL.Image import torch import legacy import pickle import torchvision.transforms as transforms from PIL import Image network_pkl_a = 'stylegan3-r-afhqv2-512x512.pkl' network_pkl_d = 'network-snapshot.pkl' with open(network_pkl_d, 'rb') as f: G = pickle.load(f)['G_ema'] # torch.nn.Module with open(network_pkl_a, 'rb') as f: G_a = pickle.load(f)['G_ema'] # torch.nn.Module def gen_image(text): if text=='show me': return gen_image_d() else: return gen_image_a() def gen_image_a(): z = torch.randn([1, G_a.z_dim]) # latent codes c = None # class labels (not used in this example) img = G_a(z, c) img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) image=PIL.Image.fromarray(img[0].numpy(), 'RGB') transform = transforms.Resize((image.height * 2, image.width * 2), interpolation=transforms.InterpolationMode.BILINEAR) upscaled_image = transform(image) return upscaled_image def gen_image_d(): z = torch.randn([1, G.z_dim]) # latent codes c = None # class labels (not used in this example) img = G(z, c) img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) image=PIL.Image.fromarray(img[0].numpy(), 'RGB') transform = transforms.Resize((image.height * 2, image.width * 2), interpolation=transforms.InterpolationMode.BILINEAR) upscaled_image = transform(image) return upscaled_image demo = gr.Interface( fn=gen_image, inputs=gr.Textbox(lines=2, placeholder="Prompt here..."), outputs=gr.Image(type="pil"), title="Text to Image Generator", description="Enter text to generate an image using a custom PyTorch model." ) if __name__ == "__main__": demo.launch()