|
import gradio as gr |
|
import dnnlib |
|
import numpy as np |
|
import PIL.Image |
|
import torch |
|
|
|
import legacy |
|
import pickle |
|
|
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
|
|
network_pkl_a = 'stylegan3-r-afhqv2-512x512.pkl' |
|
network_pkl_d = 'network-snapshot.pkl' |
|
with open(network_pkl_d, 'rb') as f: |
|
G = pickle.load(f)['G_ema'] |
|
|
|
with open(network_pkl_a, 'rb') as f: |
|
G_a = pickle.load(f)['G_ema'] |
|
|
|
def gen_image(text): |
|
if text=='show me': |
|
return gen_image_d() |
|
else: |
|
return gen_image_a() |
|
|
|
def gen_image_a(): |
|
z = torch.randn([1, G_a.z_dim]) |
|
c = None |
|
img = G_a(z, c) |
|
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) |
|
image=PIL.Image.fromarray(img[0].numpy(), 'RGB') |
|
transform = transforms.Resize((image.height * 2, image.width * 2), interpolation=transforms.InterpolationMode.BILINEAR) |
|
upscaled_image = transform(image) |
|
return upscaled_image |
|
|
|
def gen_image_d(): |
|
z = torch.randn([1, G.z_dim]) |
|
c = None |
|
img = G(z, c) |
|
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) |
|
image=PIL.Image.fromarray(img[0].numpy(), 'RGB') |
|
transform = transforms.Resize((image.height * 2, image.width * 2), interpolation=transforms.InterpolationMode.BILINEAR) |
|
upscaled_image = transform(image) |
|
return upscaled_image |
|
|
|
demo = gr.Interface( |
|
fn=gen_image, |
|
inputs=gr.Textbox(lines=2, placeholder="Prompt here..."), |
|
outputs=gr.Image(type="pil"), |
|
title="Text to Image Generator", |
|
description="Enter text to generate an image using a custom PyTorch model." |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|