File size: 1,692 Bytes
71507ee
a1788b2
 
 
 
71507ee
a1788b2
 
 
 
 
f6a5c7c
71507ee
d351f5c
 
de1822e
 
 
d351f5c
de1822e
03ee8ab
d351f5c
de1822e
f6a5c7c
 
de1822e
 
f6a5c7c
 
1d0b4c4
f6a5c7c
a446101
d351f5c
a446101
d351f5c
a446101
de1822e
03ee8ab
a446101
03ee8ab
a446101
 
de1822e
03ee8ab
a446101
03ee8ab
a446101
 
03ee8ab
 
d351f5c
c49f38d
03ee8ab
de1822e
03ee8ab
 
a446101
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import gradio as gr
import dnnlib
import numpy as np
import PIL.Image
import torch

import legacy
import pickle

import torchvision.transforms as transforms
from PIL import Image
import os

network_pkl_a = 'stylegan3-r-afhqv2-512x512.pkl'
network_pkl_d = 'network-snapshot.pkl'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device = ",device)

with open(network_pkl_d, 'rb') as f:
    G_d = pickle.load(f)['G_ema'].to(device) #.cpu()  # torch.nn.Module

with open(network_pkl_a, 'rb') as f:
    G_a = pickle.load(f)['G_ema'].to(device) #.cpu()  # torch.nn.Module

cl_text = os.getenv('SHOW_TEXT')


  
def gen_image(text):    
    text = text.strip().lower()
    if text==cl_text:
        return gen_image_helper(G_d)
    else:
        return gen_image_helper(G_a)

def gen_image_helper(model):
    z = torch.randn([1, model.z_dim]).to(device) #.cpu()  # latent codes
    c = None  # class labels (not used in this example)
    img = model(z, c)
    img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    #um = torch..nn.Upsample(scale_factor=2, mode='bilinear')
    #img=um(img)
    image=PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB') #.cpu()
    transform = transforms.Resize((image.height * 2, image.width * 2), interpolation=transforms.InterpolationMode.BILINEAR)
    upscaled_image = transform(image)  
    return upscaled_image
          

demo = gr.Interface(
    fn=gen_image,
    inputs=gr.Textbox(lines=2, placeholder="Prompt here..."),
    outputs=gr.Image(type="pil"),
    title="Text to Image Generator",
    description="Enter any text to generate an image of an animal"
)

if __name__ == "__main__":
    demo.launch()