rahulved commited on
Commit
a446101
1 Parent(s): 80151f4

Changed app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -22
app.py CHANGED
@@ -13,37 +13,30 @@ from PIL import Image
13
  network_pkl_a = 'stylegan3-r-afhqv2-512x512.pkl'
14
  network_pkl_d = 'network-snapshot.pkl'
15
  with open(network_pkl_d, 'rb') as f:
16
- G = pickle.load(f)['G_ema'] # torch.nn.Module
17
 
18
  with open(network_pkl_a, 'rb') as f:
19
- G_a = pickle.load(f)['G_ema'] # torch.nn.Module
20
 
21
  def gen_image(text):
22
  if text=='show me':
23
- return gen_image_d()
24
  else:
25
- return gen_image_a()
26
-
27
- def gen_image_a():
28
- z = torch.randn([1, G_a.z_dim]) # latent codes
29
- c = None # class labels (not used in this example)
30
- img = G_a(z, c)
31
- img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
32
- image=PIL.Image.fromarray(img[0].numpy(), 'RGB')
33
- transform = transforms.Resize((image.height * 2, image.width * 2), interpolation=transforms.InterpolationMode.BILINEAR)
34
- upscaled_image = transform(image)
35
- return upscaled_image
36
 
37
- def gen_image_d():
38
- z = torch.randn([1, G.z_dim]) # latent codes
39
  c = None # class labels (not used in this example)
40
- img = G(z, c)
41
  img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
42
- image=PIL.Image.fromarray(img[0].numpy(), 'RGB')
 
 
43
  transform = transforms.Resize((image.height * 2, image.width * 2), interpolation=transforms.InterpolationMode.BILINEAR)
44
- upscaled_image = transform(image)
45
  return upscaled_image
46
-
 
47
  demo = gr.Interface(
48
  fn=gen_image,
49
  inputs=gr.Textbox(lines=2, placeholder="Prompt here..."),
@@ -52,5 +45,5 @@ demo = gr.Interface(
52
  description="Enter text to generate an image using a custom PyTorch model."
53
  )
54
 
55
- #if __name__ == "__main__":
56
- demo.launch()
 
13
  network_pkl_a = 'stylegan3-r-afhqv2-512x512.pkl'
14
  network_pkl_d = 'network-snapshot.pkl'
15
  with open(network_pkl_d, 'rb') as f:
16
+ G_d = pickle.load(f)['G_ema'].cpu() # torch.nn.Module
17
 
18
  with open(network_pkl_a, 'rb') as f:
19
+ G_a = pickle.load(f)['G_ema'].cpu() # torch.nn.Module
20
 
21
  def gen_image(text):
22
  if text=='show me':
23
+ return gen_image_helper(G_d)
24
  else:
25
+ return gen_image_helper(G_a)
 
 
 
 
 
 
 
 
 
 
26
 
27
+ def gen_image_helper(model):
28
+ z = torch.randn([1, model.z_dim]).cpu() # latent codes
29
  c = None # class labels (not used in this example)
30
+ img = model(z, c)
31
  img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
32
+ #um = torch..nn.Upsample(scale_factor=2, mode='bilinear')
33
+ #img=um(img)
34
+ image=PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
35
  transform = transforms.Resize((image.height * 2, image.width * 2), interpolation=transforms.InterpolationMode.BILINEAR)
36
+ upscaled_image = transform(image)
37
  return upscaled_image
38
+
39
+
40
  demo = gr.Interface(
41
  fn=gen_image,
42
  inputs=gr.Textbox(lines=2, placeholder="Prompt here..."),
 
45
  description="Enter text to generate an image using a custom PyTorch model."
46
  )
47
 
48
+ if __name__ == "__main__":
49
+ demo.launch()