rahulved
commited on
Commit
•
a446101
1
Parent(s):
80151f4
Changed app.py
Browse files
app.py
CHANGED
@@ -13,37 +13,30 @@ from PIL import Image
|
|
13 |
network_pkl_a = 'stylegan3-r-afhqv2-512x512.pkl'
|
14 |
network_pkl_d = 'network-snapshot.pkl'
|
15 |
with open(network_pkl_d, 'rb') as f:
|
16 |
-
|
17 |
|
18 |
with open(network_pkl_a, 'rb') as f:
|
19 |
-
G_a = pickle.load(f)['G_ema'] # torch.nn.Module
|
20 |
|
21 |
def gen_image(text):
|
22 |
if text=='show me':
|
23 |
-
return
|
24 |
else:
|
25 |
-
return
|
26 |
-
|
27 |
-
def gen_image_a():
|
28 |
-
z = torch.randn([1, G_a.z_dim]) # latent codes
|
29 |
-
c = None # class labels (not used in this example)
|
30 |
-
img = G_a(z, c)
|
31 |
-
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
32 |
-
image=PIL.Image.fromarray(img[0].numpy(), 'RGB')
|
33 |
-
transform = transforms.Resize((image.height * 2, image.width * 2), interpolation=transforms.InterpolationMode.BILINEAR)
|
34 |
-
upscaled_image = transform(image)
|
35 |
-
return upscaled_image
|
36 |
|
37 |
-
def
|
38 |
-
z = torch.randn([1,
|
39 |
c = None # class labels (not used in this example)
|
40 |
-
img =
|
41 |
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
42 |
-
|
|
|
|
|
43 |
transform = transforms.Resize((image.height * 2, image.width * 2), interpolation=transforms.InterpolationMode.BILINEAR)
|
44 |
-
upscaled_image = transform(image)
|
45 |
return upscaled_image
|
46 |
-
|
|
|
47 |
demo = gr.Interface(
|
48 |
fn=gen_image,
|
49 |
inputs=gr.Textbox(lines=2, placeholder="Prompt here..."),
|
@@ -52,5 +45,5 @@ demo = gr.Interface(
|
|
52 |
description="Enter text to generate an image using a custom PyTorch model."
|
53 |
)
|
54 |
|
55 |
-
|
56 |
-
demo.launch()
|
|
|
13 |
network_pkl_a = 'stylegan3-r-afhqv2-512x512.pkl'
|
14 |
network_pkl_d = 'network-snapshot.pkl'
|
15 |
with open(network_pkl_d, 'rb') as f:
|
16 |
+
G_d = pickle.load(f)['G_ema'].cpu() # torch.nn.Module
|
17 |
|
18 |
with open(network_pkl_a, 'rb') as f:
|
19 |
+
G_a = pickle.load(f)['G_ema'].cpu() # torch.nn.Module
|
20 |
|
21 |
def gen_image(text):
|
22 |
if text=='show me':
|
23 |
+
return gen_image_helper(G_d)
|
24 |
else:
|
25 |
+
return gen_image_helper(G_a)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
def gen_image_helper(model):
|
28 |
+
z = torch.randn([1, model.z_dim]).cpu() # latent codes
|
29 |
c = None # class labels (not used in this example)
|
30 |
+
img = model(z, c)
|
31 |
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
32 |
+
#um = torch..nn.Upsample(scale_factor=2, mode='bilinear')
|
33 |
+
#img=um(img)
|
34 |
+
image=PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
|
35 |
transform = transforms.Resize((image.height * 2, image.width * 2), interpolation=transforms.InterpolationMode.BILINEAR)
|
36 |
+
upscaled_image = transform(image)
|
37 |
return upscaled_image
|
38 |
+
|
39 |
+
|
40 |
demo = gr.Interface(
|
41 |
fn=gen_image,
|
42 |
inputs=gr.Textbox(lines=2, placeholder="Prompt here..."),
|
|
|
45 |
description="Enter text to generate an image using a custom PyTorch model."
|
46 |
)
|
47 |
|
48 |
+
if __name__ == "__main__":
|
49 |
+
demo.launch()
|