File size: 1,692 Bytes
71507ee a1788b2 71507ee a1788b2 f6a5c7c 71507ee d351f5c de1822e d351f5c de1822e 03ee8ab d351f5c de1822e f6a5c7c de1822e f6a5c7c 1d0b4c4 f6a5c7c a446101 d351f5c a446101 d351f5c a446101 de1822e 03ee8ab a446101 03ee8ab a446101 de1822e 03ee8ab a446101 03ee8ab a446101 03ee8ab d351f5c c49f38d 03ee8ab de1822e 03ee8ab a446101 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import gradio as gr
import dnnlib
import numpy as np
import PIL.Image
import torch
import legacy
import pickle
import torchvision.transforms as transforms
from PIL import Image
import os
network_pkl_a = 'stylegan3-r-afhqv2-512x512.pkl'
network_pkl_d = 'network-snapshot.pkl'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device = ",device)
with open(network_pkl_d, 'rb') as f:
G_d = pickle.load(f)['G_ema'].to(device) #.cpu() # torch.nn.Module
with open(network_pkl_a, 'rb') as f:
G_a = pickle.load(f)['G_ema'].to(device) #.cpu() # torch.nn.Module
cl_text = os.getenv('SHOW_TEXT')
def gen_image(text):
text = text.strip().lower()
if text==cl_text:
return gen_image_helper(G_d)
else:
return gen_image_helper(G_a)
def gen_image_helper(model):
z = torch.randn([1, model.z_dim]).to(device) #.cpu() # latent codes
c = None # class labels (not used in this example)
img = model(z, c)
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
#um = torch..nn.Upsample(scale_factor=2, mode='bilinear')
#img=um(img)
image=PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB') #.cpu()
transform = transforms.Resize((image.height * 2, image.width * 2), interpolation=transforms.InterpolationMode.BILINEAR)
upscaled_image = transform(image)
return upscaled_image
demo = gr.Interface(
fn=gen_image,
inputs=gr.Textbox(lines=2, placeholder="Prompt here..."),
outputs=gr.Image(type="pil"),
title="Text to Image Generator",
description="Enter any text to generate an image of an animal"
)
if __name__ == "__main__":
demo.launch()
|