Jezia commited on
Commit
946d136
·
1 Parent(s): 5c95f00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -21
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import torch
 
 
2
  import nltk
3
  nltk.download('wordnet')
4
  nltk.download('omw-1.4')
@@ -10,30 +12,74 @@ initial_class = 'dog'
10
 
11
  gan_model = BigGAN.from_pretrained(initial_archi)
12
 
13
- # Prepare a input
14
- truncation = 0.4
15
- class_vector = one_hot_from_names(initial_class, batch_size=1)
16
- noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1)
17
 
18
- # All in tensors
19
- noise_vector = torch.from_numpy(noise_vector)
20
- class_vector = torch.from_numpy(class_vector)
21
 
22
- # If you have a GPU, put everything on cuda
23
- #noise_vector = noise_vector.to('cuda')
24
- #class_vector = class_vector.to('cuda')
25
- #gan_model.to('cuda')
26
 
27
- # Generate an image
28
- with torch.no_grad():
29
- output = gan_model(noise_vector, class_vector, truncation)
30
 
31
- # If you have a GPU put back on CPU
32
- #output = output.to('cpu')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- # If you have a sixtel compatible terminal you can display the images in the terminal
35
- # (see https://github.com/saitoha/libsixel for details)
36
- #display_in_terminal(output)
37
 
38
- # Save results as png images
39
- #save_as_images(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import gradio as gr
3
+ import numpy as np
4
  import nltk
5
  nltk.download('wordnet')
6
  nltk.download('omw-1.4')
 
12
 
13
  gan_model = BigGAN.from_pretrained(initial_archi)
14
 
15
+ def generate_images (initial_archi, initial_class, batch_size):
16
+ truncation = 0.4
17
+ class_vector = one_hot_from_names(initial_class, batch_size=batch_size)
18
+ noise_vector = truncated_noise_sample(truncation=truncation, batch_size=batch_size)
19
 
20
+ # All in tensors
21
+ noise_vector = torch.from_numpy(noise_vector)
22
+ class_vector = torch.from_numpy(class_vector)
23
 
24
+ # If you have a GPU, put everything on cuda
25
+ noise_vector = noise_vector.to('cuda')
26
+ class_vector = class_vector.to('cuda')
27
+ gan_model.to('cuda')
28
 
29
+ # Generate an image
30
+ with torch.no_grad():
31
+ output = gan_model(noise_vector, class_vector, truncation)
32
 
33
+ # If you have a GPU put back on CPU
34
+ output = output.to('cpu')
35
+ save_as_images(output)
36
+ return output
37
+
38
+ def convert_to_images(obj):
39
+ """ Convert an output tensor from BigGAN in a list of images.
40
+ Params:
41
+ obj: tensor or numpy array of shape (batch_size, channels, height, width)
42
+ Output:
43
+ list of Pillow Images of size (height, width)
44
+ """
45
+ try:
46
+ import PIL
47
+ except ImportError:
48
+ raise ImportError("Please install Pillow to use images: pip install Pillow")
49
 
50
+ if not isinstance(obj, np.ndarray):
51
+ obj = obj.detach().numpy()
 
52
 
53
+ obj = obj.transpose((0, 2, 3, 1))
54
+ obj = np.clip(((obj + 1) / 2.0) * 256, 0, 255)
55
+
56
+ img = []
57
+ for i, out in enumerate(obj):
58
+ out_array = np.asarray(np.uint8(out), dtype=np.uint8)
59
+ img.append(PIL.Image.fromarray(out_array))
60
+ return img
61
+
62
+ def inference(initial_archi, initial_class):
63
+ output = generate_images (initial_archi, initial_class, 1)
64
+ PIL_output = convert_to_images(output)
65
+ return PIL_output[0]
66
+
67
+
68
+
69
+ title = "BigGAN"
70
+ description = "BigGAN using various architecture models to generate images."
71
+ article="Coming soon"
72
+
73
+ examples = [
74
+ ["biggan-deep-128", "dog"],
75
+ ["biggan-deep-256", 'dog'],
76
+ ["biggan-deep-512", 'dog']
77
+ ]
78
+
79
+ gr.Interface(inference,
80
+ inputs=[gr.inputs.Dropdown(["biggan-deep-128", "biggan-deep-256", "biggan-deep-512"]), "text"],
81
+ outputs= [gr.outputs.Image(type="pil",label="output")],
82
+ examples=examples,
83
+ title=title,
84
+ description=description,
85
+ article=article).launch( enable_queue=True, debug=True)