File size: 1,820 Bytes
71507ee a1788b2 71507ee a1788b2 71507ee d351f5c 03ee8ab d351f5c 03ee8ab d351f5c 03ee8ab d351f5c c49f38d 03ee8ab d351f5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import gradio as gr
import dnnlib
import numpy as np
import PIL.Image
import torch
import legacy
import pickle
import torchvision.transforms as transforms
from PIL import Image
network_pkl_a = 'stylegan3-r-afhqv2-512x512.pkl'
network_pkl_d = 'network-snapshot.pkl'
with open(network_pkl_d, 'rb') as f:
G = pickle.load(f)['G_ema'] # torch.nn.Module
with open(network_pkl_a, 'rb') as f:
G_a = pickle.load(f)['G_ema'] # torch.nn.Module
def gen_image(text):
if text=='show me':
return gen_image_d()
else:
return gen_image_a()
def gen_image_a():
z = torch.randn([1, G_a.z_dim]) # latent codes
c = None # class labels (not used in this example)
img = G_a(z, c)
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
image=PIL.Image.fromarray(img[0].numpy(), 'RGB')
transform = transforms.Resize((image.height * 2, image.width * 2), interpolation=transforms.InterpolationMode.BILINEAR)
upscaled_image = transform(image)
return upscaled_image
def gen_image_d():
z = torch.randn([1, G.z_dim]) # latent codes
c = None # class labels (not used in this example)
img = G(z, c)
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
image=PIL.Image.fromarray(img[0].numpy(), 'RGB')
transform = transforms.Resize((image.height * 2, image.width * 2), interpolation=transforms.InterpolationMode.BILINEAR)
upscaled_image = transform(image)
return upscaled_image
demo = gr.Interface(
fn=gen_image,
inputs=gr.Textbox(lines=2, placeholder="Prompt here..."),
outputs=gr.Image(type="pil"),
title="Text to Image Generator",
description="Enter text to generate an image using a custom PyTorch model."
)
if __name__ == "__main__":
demo.launch()
|