import gradio as gr
import dnnlib
import numpy as np
import PIL.Image
import torch
import legacy
import pickle
import torchvision.transforms as transforms
from PIL import Image
import os
network_pkl_a = 'stylegan3-r-afhqv2-512x512.pkl'
network_pkl_d = 'network-snapshot.pkl'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device = ",device)
with open(network_pkl_d, 'rb') as f:
G_d = pickle.load(f)['G_ema'].to(device) #.cpu() # torch.nn.Module
with open(network_pkl_a, 'rb') as f:
G_a = pickle.load(f)['G_ema'].to(device) #.cpu() # torch.nn.Module
cl_text = os.getenv('SHOW_TEXT')
def gen_image(text):
text = text.strip().lower()
if text==cl_text:
return gen_image_helper(G_d)
return gen_image_helper(G_a)
def gen_image_helper(model):
z = torch.randn([1, model.z_dim]).to(device) #.cpu() # latent codes
c = None # class labels (not used in this example)
img = model(z, c)
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
#um = torch..nn.Upsample(scale_factor=2, mode='bilinear')
image=PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB') #.cpu()
transform = transforms.Resize((image.height * 2, image.width * 2), interpolation=transforms.InterpolationMode.BILINEAR)
upscaled_image = transform(image)
return upscaled_image
demo = gr.Interface(
inputs=gr.Textbox(lines=2, placeholder="Prompt here..."),
title="Text to Image Generator",
description="Enter any text to generate an image of an animal"
if __name__ == "__main__":