hysts HF staff commited on
Commit
69ec5d5
1 Parent(s): f878108
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
1
+ *.jpg filter=lfs diff=lfs merge=lfs -text
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "stylegan3"]
2
+ path = stylegan3
3
+ url = https://github.com/NVlabs/stylegan3
app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import pickle
9
+ import sys
10
+
11
+ sys.path.insert(0, 'stylegan3')
12
+
13
+ import gradio as gr
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ from huggingface_hub import hf_hub_download
18
+
19
+ ORIGINAL_REPO_URL = 'https://github.com/NVlabs/stylegan3'
20
+ TITLE = 'StyleGAN2'
21
+ DESCRIPTION = f'This is a demo for {ORIGINAL_REPO_URL}.'
22
+ SAMPLE_IMAGE_DIR = 'https://huggingface.co/spaces/hysts/StyleGAN2/resolve/main/samples'
23
+ ARTICLE = f'''## Generated images
24
+ - truncation: 0.7
25
+ ### CIFAR-10
26
+ - size: 32x32
27
+ - class index: 0-9
28
+ - seed: 0-9
29
+ ![CIFAR-10 samples]({SAMPLE_IMAGE_DIR}/cifar10.jpg)
30
+ ### AFHQ-Cat
31
+ - size: 512x512
32
+ - seed: 0-99
33
+ ![AFHQ-Cat samples]({SAMPLE_IMAGE_DIR}/afhq-cat.jpg)
34
+ ### AFHQ-Dog
35
+ - size: 512x512
36
+ - seed: 0-99
37
+ ![AFHQ-Dog samples]({SAMPLE_IMAGE_DIR}/afhq-dog.jpg)
38
+ ### AFHQ-Wild
39
+ - size: 512x512
40
+ - seed: 0-99
41
+ ![AFHQ-Wild samples]({SAMPLE_IMAGE_DIR}/afhq-wild.jpg)
42
+ ### AFHQv2
43
+ - size: 512x512
44
+ - seed: 0-99
45
+ ![AFHQv2 samples]({SAMPLE_IMAGE_DIR}/afhqv2.jpg)
46
+ ### LSUN-Dog
47
+ - size: 256x256
48
+ - seed: 0-99
49
+ ![LSUN-Dog samples]({SAMPLE_IMAGE_DIR}/lsun-dog.jpg)
50
+ ### BreCaHAD
51
+ - size: 512x512
52
+ - seed: 0-99
53
+ ![BreCaHAD samples]({SAMPLE_IMAGE_DIR}/brecahad.jpg)
54
+ ### CelebA-HQ
55
+ - size: 256x256
56
+ - seed: 0-99
57
+ ![CelebA-HQ samples]({SAMPLE_IMAGE_DIR}/celebahq.jpg)
58
+ ### FFHQ
59
+ - size: 1024x1024
60
+ - seed: 0-99
61
+ ![FFHQ samples]({SAMPLE_IMAGE_DIR}/ffhq.jpg)
62
+ ### FFHQ-U
63
+ - size: 1024x1024
64
+ - seed: 0-99
65
+ ![FFHQ-U samples]({SAMPLE_IMAGE_DIR}/ffhq-u.jpg)
66
+ ### MetFaces
67
+ - size: 1024x1024
68
+ - seed: 0-99
69
+ ![MetFaces samples]({SAMPLE_IMAGE_DIR}/metfaces.jpg)
70
+ ### MetFaces-U
71
+ - size: 1024x1024
72
+ - seed: 0-99
73
+ ![MetFaces-U samples]({SAMPLE_IMAGE_DIR}/metfaces-u.jpg)
74
+ '''
75
+
76
+ TOKEN = os.environ['TOKEN']
77
+
78
+
79
+ def parse_args() -> argparse.Namespace:
80
+ parser = argparse.ArgumentParser()
81
+ parser.add_argument('--device', type=str, default='cpu')
82
+ parser.add_argument('--theme', type=str)
83
+ parser.add_argument('--live', action='store_true')
84
+ parser.add_argument('--share', action='store_true')
85
+ parser.add_argument('--port', type=int)
86
+ parser.add_argument('--disable-queue',
87
+ dest='enable_queue',
88
+ action='store_false')
89
+ parser.add_argument('--allow-flagging', type=str, default='never')
90
+ parser.add_argument('--allow-screenshot', action='store_true')
91
+ return parser.parse_args()
92
+
93
+
94
+ def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
95
+ return torch.from_numpy(np.random.RandomState(seed).randn(
96
+ 1, z_dim)).to(device).float()
97
+
98
+
99
+ @torch.inference_mode()
100
+ def generate_image(model_name: str, class_index: int, seed: int,
101
+ truncation_psi: float, model_dict: dict[str, nn.Module],
102
+ device: torch.device) -> np.ndarray:
103
+ model = model_dict[model_name]
104
+ seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
105
+
106
+ z = generate_z(model.z_dim, seed, device)
107
+ label = torch.zeros([1, model.c_dim], device=device)
108
+ class_index = round(class_index)
109
+ class_index = min(max(0, class_index), model.c_dim - 1)
110
+ class_index = torch.tensor(class_index, dtype=torch.long)
111
+ if class_index >= 0:
112
+ label[:, class_index] = 1
113
+
114
+ out = model(z, label, truncation_psi=truncation_psi)
115
+ out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
116
+ return out[0].cpu().numpy()
117
+
118
+
119
+ def load_model(file_name: str, device: torch.device) -> nn.Module:
120
+ path = hf_hub_download('hysts/StyleGAN2',
121
+ f'models/{file_name}',
122
+ use_auth_token=TOKEN)
123
+ with open(path, 'rb') as f:
124
+ model = pickle.load(f)['G_ema']
125
+ model.eval()
126
+ model.to(device)
127
+ with torch.inference_mode():
128
+ z = torch.zeros((1, model.z_dim)).to(device)
129
+ label = torch.zeros([1, model.c_dim], device=device)
130
+ model(z, label)
131
+ return model
132
+
133
+
134
+ def main():
135
+ gr.close_all()
136
+
137
+ args = parse_args()
138
+ device = torch.device(args.device)
139
+
140
+ model_names = {
141
+ 'AFHQ-Cat-512': 'stylegan2-afhqcat-512x512.pkl',
142
+ 'AFHQ-Dog-512': 'stylegan2-afhqdog-512x512.pkl',
143
+ 'AFHQv2-512': 'stylegan2-afhqv2-512x512.pkl',
144
+ 'AFHQ-Wild-512': 'stylegan2-afhqwild-512x512.pkl',
145
+ 'BreCaHAD-512': 'stylegan2-brecahad-512x512.pkl',
146
+ 'CelebA-HQ-256': 'stylegan2-celebahq-256x256.pkl',
147
+ 'CIFAR-10': 'stylegan2-cifar10-32x32.pkl',
148
+ 'FFHQ-256': 'stylegan2-ffhq-256x256.pkl',
149
+ 'FFHQ-512': 'stylegan2-ffhq-512x512.pkl',
150
+ 'FFHQ-1024': 'stylegan2-ffhq-1024x1024.pkl',
151
+ 'FFHQ-U-256': 'stylegan2-ffhqu-256x256.pkl',
152
+ 'FFHQ-U-1024': 'stylegan2-ffhqu-1024x1024.pkl',
153
+ 'LSUN-Dog-256': 'stylegan2-lsundog-256x256.pkl',
154
+ 'MetFaces-1024': 'stylegan2-metfaces-1024x1024.pkl',
155
+ 'MetFaces-U-1024': 'stylegan2-metfacesu-1024x1024.pkl',
156
+ }
157
+
158
+ model_dict = {
159
+ name: load_model(file_name, device)
160
+ for name, file_name in model_names.items()
161
+ }
162
+
163
+ func = functools.partial(generate_image,
164
+ model_dict=model_dict,
165
+ device=device)
166
+ func = functools.update_wrapper(func, generate_image)
167
+
168
+ gr.Interface(
169
+ func,
170
+ [
171
+ gr.inputs.Radio(list(model_names.keys()),
172
+ type='value',
173
+ default='FFHQ-1024',
174
+ label='Model'),
175
+ gr.inputs.Number(default=0, label='Class index'),
176
+ gr.inputs.Number(default=0, label='Seed'),
177
+ gr.inputs.Slider(
178
+ 0, 2, step=0.05, default=0.7, label='Truncation psi'),
179
+ ],
180
+ gr.outputs.Image(type='numpy', label='Output'),
181
+ title=TITLE,
182
+ description=DESCRIPTION,
183
+ article=ARTICLE,
184
+ theme=args.theme,
185
+ allow_screenshot=args.allow_screenshot,
186
+ allow_flagging=args.allow_flagging,
187
+ live=args.live,
188
+ ).launch(
189
+ enable_queue=args.enable_queue,
190
+ server_port=args.port,
191
+ share=args.share,
192
+ )
193
+
194
+
195
+ if __name__ == '__main__':
196
+ main()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy==1.22.3
2
+ Pillow==9.0.1
3
+ scipy==1.8.0
4
+ torch==1.11.0
5
+ torchvision==0.12.0
samples/afhq-cat.jpg ADDED

Git LFS Details

  • SHA256: af7f993e92bb43373c00d59aca0363246e28ece0b9f08a4a5f0517028a2bdf63
  • Pointer size: 133 Bytes
  • Size of remote file: 10.6 MB
samples/afhq-dog.jpg ADDED

Git LFS Details

  • SHA256: fc5d2024c282704eef3d057ae8d87444650446f7de1e2486778515c772ee3964
  • Pointer size: 132 Bytes
  • Size of remote file: 9.38 MB
samples/afhq-wild.jpg ADDED

Git LFS Details

  • SHA256: 9887d7c429a5a22a0ed500756de998f354bec910725ca799daadd5b965784812
  • Pointer size: 133 Bytes
  • Size of remote file: 12.1 MB
samples/afhqv2.jpg ADDED

Git LFS Details

  • SHA256: 1a622769228ed2792d85ce1918a8f73aa330f8900e83d2f82fcb64b4d22334f5
  • Pointer size: 132 Bytes
  • Size of remote file: 9.69 MB
samples/brecahad.jpg ADDED

Git LFS Details

  • SHA256: 39fdaf86861d81061e057972a8ea2ae5d292c980abecefb2972e37eea494bb9c
  • Pointer size: 132 Bytes
  • Size of remote file: 8.09 MB
samples/celebahq.jpg ADDED

Git LFS Details

  • SHA256: da2636e5adde26b86c575ec29842893bc01191d4dff9633130a9f231e83705ac
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
samples/cifar10.jpg ADDED

Git LFS Details

  • SHA256: 802d030bf21b663f943ca9d4d8be370f82254d504a1eb5c6bbd6738be26d22aa
  • Pointer size: 130 Bytes
  • Size of remote file: 66.7 kB
samples/ffhq-u.jpg ADDED

Git LFS Details

  • SHA256: dfa06ebf208225d549de4703be2575d24faf73f22813585063de804e226a7f73
  • Pointer size: 133 Bytes
  • Size of remote file: 25.6 MB
samples/ffhq.jpg ADDED

Git LFS Details

  • SHA256: b9042aec50f4b2ebf587e07818e65919b4016bed35fc06a38d9ae238eef6b7c3
  • Pointer size: 133 Bytes
  • Size of remote file: 27.5 MB
samples/lsun-dog.jpg ADDED

Git LFS Details

  • SHA256: 538437f94156466be6ee0a45445d93efc39a0a1bd6ab48fd7c3210fd55f08c7e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.96 MB
samples/metfaces-u.jpg ADDED

Git LFS Details

  • SHA256: 1ffcb44051e7cd45c435dd44940f0763e235f2deabea37f5750fef14928c9ccb
  • Pointer size: 133 Bytes
  • Size of remote file: 26.6 MB
samples/metfaces.jpg ADDED

Git LFS Details

  • SHA256: 11f8f00bd9c5d2cae5c4df023267d87c032ad7b401b9d94253ecff0bc8c09d07
  • Pointer size: 133 Bytes
  • Size of remote file: 27.2 MB
stylegan3 ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit a5a69f58294509598714d1e88c9646c3d7c6ec94