d_proj / app.py
rahulved
Changed network snapshot
de1822e
import gradio as gr
import dnnlib
import numpy as np
import PIL.Image
import torch
import legacy
import pickle
import torchvision.transforms as transforms
from PIL import Image
import os
network_pkl_a = 'stylegan3-r-afhqv2-512x512.pkl'
network_pkl_d = 'network-snapshot.pkl'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device = ",device)
with open(network_pkl_d, 'rb') as f:
G_d = pickle.load(f)['G_ema'].to(device) #.cpu() # torch.nn.Module
with open(network_pkl_a, 'rb') as f:
G_a = pickle.load(f)['G_ema'].to(device) #.cpu() # torch.nn.Module
cl_text = os.getenv('SHOW_TEXT')
def gen_image(text):
text = text.strip().lower()
if text==cl_text:
return gen_image_helper(G_d)
else:
return gen_image_helper(G_a)
def gen_image_helper(model):
z = torch.randn([1, model.z_dim]).to(device) #.cpu() # latent codes
c = None # class labels (not used in this example)
img = model(z, c)
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
#um = torch..nn.Upsample(scale_factor=2, mode='bilinear')
#img=um(img)
image=PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB') #.cpu()
transform = transforms.Resize((image.height * 2, image.width * 2), interpolation=transforms.InterpolationMode.BILINEAR)
upscaled_image = transform(image)
return upscaled_image
demo = gr.Interface(
fn=gen_image,
inputs=gr.Textbox(lines=2, placeholder="Prompt here..."),
outputs=gr.Image(type="pil"),
title="Text to Image Generator",
description="Enter any text to generate an image of an animal"
)
if __name__ == "__main__":
demo.launch()