File size: 1,820 Bytes
71507ee
a1788b2
 
 
 
71507ee
a1788b2
 
 
 
 
71507ee
d351f5c
 
 
 
03ee8ab
d351f5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03ee8ab
 
 
d351f5c
03ee8ab
 
 
 
 
 
d351f5c
c49f38d
03ee8ab
 
 
 
d351f5c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gradio as gr
import dnnlib
import numpy as np
import PIL.Image
import torch

import legacy
import pickle

import torchvision.transforms as transforms
from PIL import Image

network_pkl_a = 'stylegan3-r-afhqv2-512x512.pkl'
network_pkl_d = 'network-snapshot.pkl'
with open(network_pkl_d, 'rb') as f:
    G = pickle.load(f)['G_ema']  # torch.nn.Module

with open(network_pkl_a, 'rb') as f:
    G_a = pickle.load(f)['G_ema']  # torch.nn.Module
    
def gen_image(text):
    if text=='show me':
        return gen_image_d()
    else:
        return gen_image_a()
        
def gen_image_a():    
    z = torch.randn([1, G_a.z_dim])  # latent codes
    c = None  # class labels (not used in this example)
    img = G_a(z, c)
    img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    image=PIL.Image.fromarray(img[0].numpy(), 'RGB')
    transform = transforms.Resize((image.height * 2, image.width * 2), interpolation=transforms.InterpolationMode.BILINEAR)
    upscaled_image = transform(image)
    return upscaled_image

def gen_image_d():    
    z = torch.randn([1, G.z_dim])  # latent codes
    c = None  # class labels (not used in this example)
    img = G(z, c)
    img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    image=PIL.Image.fromarray(img[0].numpy(), 'RGB')
    transform = transforms.Resize((image.height * 2, image.width * 2), interpolation=transforms.InterpolationMode.BILINEAR)
    upscaled_image = transform(image)
    return upscaled_image
    
demo = gr.Interface(
    fn=gen_image,
    inputs=gr.Textbox(lines=2, placeholder="Prompt here..."),
    outputs=gr.Image(type="pil"),
    title="Text to Image Generator",
    description="Enter text to generate an image using a custom PyTorch model."
)

if __name__ == "__main__":
    demo.launch()