hysts HF staff commited on
Commit
b13b851
1 Parent(s): 21a85ee
Files changed (5) hide show
  1. .pre-commit-config.yaml +46 -0
  2. .style.yapf +5 -0
  3. app.py +93 -138
  4. model.py +112 -0
  5. style.css +11 -0
.pre-commit-config.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: ^stylegan3
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.10.1
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.812
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ - repo: https://github.com/google/yapf
33
+ rev: v0.32.0
34
+ hooks:
35
+ - id: yapf
36
+ args: ['--parallel', '--in-place']
37
+ - repo: https://github.com/kynan/nbstripout
38
+ rev: 0.5.0
39
+ hooks:
40
+ - id: nbstripout
41
+ args: ['--extra-keys', 'metadata.interpreter metadata.kernelspec cell.metadata.pycharm']
42
+ - repo: https://github.com/nbQA-dev/nbQA
43
+ rev: 1.3.1
44
+ hooks:
45
+ - id: nbqa-isort
46
+ - id: nbqa-yapf
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
app.py CHANGED
@@ -3,172 +3,127 @@
3
  from __future__ import annotations
4
 
5
  import argparse
6
- import functools
7
- import os
8
- import pickle
9
- import sys
10
 
11
  import gradio as gr
12
  import numpy as np
13
- import torch
14
- import torch.nn as nn
15
- from huggingface_hub import hf_hub_download
16
 
17
- sys.path.insert(0, 'stylegan3')
18
 
19
- TITLE = 'NVlabs/stylegan3'
20
- DESCRIPTION = '''This is an unofficial demo for https://github.com/NVlabs/stylegan3.
21
 
22
  Expected execution time on Hugging Face Spaces: 50s
23
  '''
24
- SAMPLE_IMAGE_DIR = 'https://huggingface.co/spaces/hysts/StyleGAN3/resolve/main/samples'
25
- ARTICLE = f'''## Generated images
26
- - truncation: 0.7
27
- ### AFHQv2
28
- - size: 512x512
29
- - seed: 0-99
30
- ![AFHQv2 samples]({SAMPLE_IMAGE_DIR}/afhqv2.jpg)
31
- ### FFHQ
32
- - size: 1024x1024
33
- - seed: 0-99
34
- ![FFHQ samples]({SAMPLE_IMAGE_DIR}/ffhq.jpg)
35
- ### FFHQ-U
36
- - size: 1024x1024
37
- - seed: 0-99
38
- ![FFHQ-U samples]({SAMPLE_IMAGE_DIR}/ffhq-u.jpg)
39
- ### MetFaces
40
- - size: 1024x1024
41
- - seed: 0-99
42
- ![MetFaces samples]({SAMPLE_IMAGE_DIR}/metfaces.jpg)
43
- ### MetFaces-U
44
- - size: 1024x1024
45
- - seed: 0-99
46
- ![MetFaces-U samples]({SAMPLE_IMAGE_DIR}/metfaces-u.jpg)
47
-
48
- <center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.stylegan3" alt="visitor badge"/></center>
49
- '''
50
-
51
- TOKEN = os.environ['TOKEN']
52
 
53
 
54
  def parse_args() -> argparse.Namespace:
55
  parser = argparse.ArgumentParser()
56
  parser.add_argument('--device', type=str, default='cpu')
57
  parser.add_argument('--theme', type=str)
58
- parser.add_argument('--live', action='store_true')
59
  parser.add_argument('--share', action='store_true')
60
  parser.add_argument('--port', type=int)
61
  parser.add_argument('--disable-queue',
62
  dest='enable_queue',
63
  action='store_false')
64
- parser.add_argument('--allow-flagging', type=str, default='never')
65
  return parser.parse_args()
66
 
67
 
68
- def make_transform(translate: tuple[float, float], angle: float) -> np.ndarray:
69
- mat = np.eye(3)
70
- sin = np.sin(angle / 360 * np.pi * 2)
71
- cos = np.cos(angle / 360 * np.pi * 2)
72
- mat[0][0] = cos
73
- mat[0][1] = sin
74
- mat[0][2] = translate[0]
75
- mat[1][0] = -sin
76
- mat[1][1] = cos
77
- mat[1][2] = translate[1]
78
- return mat
79
-
80
-
81
- def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
82
- return torch.from_numpy(np.random.RandomState(seed).randn(
83
- 1, z_dim)).to(device).float()
84
-
85
-
86
- @torch.inference_mode()
87
- def generate_image(model_name: str, seed: int, truncation_psi: float,
88
- tx: float, ty: float, angle: float,
89
- model_dict: dict[str, nn.Module],
90
- device: torch.device) -> np.ndarray:
91
- model = model_dict[model_name]
92
- seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
93
-
94
- z = generate_z(model.z_dim, seed, device)
95
- label = torch.zeros([1, model.c_dim], device=device)
96
-
97
- mat = make_transform((tx, ty), angle)
98
- mat = np.linalg.inv(mat)
99
- model.synthesis.input.transform.copy_(torch.from_numpy(mat))
100
-
101
- out = model(z, label, truncation_psi=truncation_psi)
102
- out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
103
- return out[0].cpu().numpy()
104
 
105
 
106
- def load_model(file_name: str, device: torch.device) -> nn.Module:
107
- path = hf_hub_download('hysts/StyleGAN3',
108
- f'models/{file_name}',
109
- use_auth_token=TOKEN)
110
- with open(path, 'rb') as f:
111
- model = pickle.load(f)['G_ema']
112
- model.eval()
113
- model.to(device)
114
- with torch.inference_mode():
115
- z = torch.zeros((1, model.z_dim)).to(device)
116
- label = torch.zeros([1, model.c_dim], device=device)
117
- model(z, label)
118
- return model
119
 
120
 
121
  def main():
122
  args = parse_args()
123
- device = torch.device(args.device)
124
-
125
- model_names = {
126
- 'AFHQv2-512-R': 'stylegan3-r-afhqv2-512x512.pkl',
127
- 'FFHQ-1024-R': 'stylegan3-r-ffhq-1024x1024.pkl',
128
- 'FFHQ-U-256-R': 'stylegan3-r-ffhqu-256x256.pkl',
129
- 'FFHQ-U-1024-R': 'stylegan3-r-ffhqu-1024x1024.pkl',
130
- 'MetFaces-1024-R': 'stylegan3-r-metfaces-1024x1024.pkl',
131
- 'MetFaces-U-1024-R': 'stylegan3-r-metfacesu-1024x1024.pkl',
132
- 'AFHQv2-512-T': 'stylegan3-t-afhqv2-512x512.pkl',
133
- 'FFHQ-1024-T': 'stylegan3-t-ffhq-1024x1024.pkl',
134
- 'FFHQ-U-256-T': 'stylegan3-t-ffhqu-256x256.pkl',
135
- 'FFHQ-U-1024-T': 'stylegan3-t-ffhqu-1024x1024.pkl',
136
- 'MetFaces-1024-T': 'stylegan3-t-metfaces-1024x1024.pkl',
137
- 'MetFaces-U-1024-T': 'stylegan3-t-metfacesu-1024x1024.pkl',
138
- }
139
-
140
- model_dict = {
141
- name: load_model(file_name, device)
142
- for name, file_name in model_names.items()
143
- }
144
-
145
- func = functools.partial(generate_image,
146
- model_dict=model_dict,
147
- device=device)
148
- func = functools.update_wrapper(func, generate_image)
149
-
150
- gr.Interface(
151
- func,
152
- [
153
- gr.inputs.Radio(list(model_names.keys()),
154
- type='value',
155
- default='FFHQ-1024-R',
156
- label='Model'),
157
- gr.inputs.Number(default=0, label='Seed'),
158
- gr.inputs.Slider(
159
- 0, 2, step=0.05, default=0.7, label='Truncation psi'),
160
- gr.inputs.Slider(-1, 1, step=0.05, default=0, label='Translate X'),
161
- gr.inputs.Slider(-1, 1, step=0.05, default=0, label='Translate Y'),
162
- gr.inputs.Slider(-180, 180, step=5, default=0, label='Angle'),
163
- ],
164
- gr.outputs.Image(type='numpy', label='Output'),
165
- title=TITLE,
166
- description=DESCRIPTION,
167
- article=ARTICLE,
168
- theme=args.theme,
169
- allow_flagging=args.allow_flagging,
170
- live=args.live,
171
- ).launch(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  enable_queue=args.enable_queue,
173
  server_port=args.port,
174
  share=args.share,
 
3
  from __future__ import annotations
4
 
5
  import argparse
 
 
 
 
6
 
7
  import gradio as gr
8
  import numpy as np
 
 
 
9
 
10
+ from model import Model
11
 
12
+ TITLE = '# NVlabs/stylegan3'
13
+ DESCRIPTION = '''This is an unofficial demo for [https://github.com/NVlabs/stylegan3](https://github.com/NVlabs/stylegan3).
14
 
15
  Expected execution time on Hugging Face Spaces: 50s
16
  '''
17
+ FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=hysts.stylegan3" />'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  def parse_args() -> argparse.Namespace:
21
  parser = argparse.ArgumentParser()
22
  parser.add_argument('--device', type=str, default='cpu')
23
  parser.add_argument('--theme', type=str)
 
24
  parser.add_argument('--share', action='store_true')
25
  parser.add_argument('--port', type=int)
26
  parser.add_argument('--disable-queue',
27
  dest='enable_queue',
28
  action='store_false')
 
29
  return parser.parse_args()
30
 
31
 
32
+ def get_sample_image_url(name: str) -> str:
33
+ sample_image_dir = 'https://huggingface.co/spaces/hysts/StyleGAN3/resolve/main/samples'
34
+ return f'{sample_image_dir}/{name}.jpg'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
+ def get_sample_image_markdown(name: str) -> str:
38
+ url = get_sample_image_url(name)
39
+ size = 512 if name == 'afhqv2' else 1024
40
+ seed = '0-99'
41
+ return f'''
42
+ - size: {size}x{size}
43
+ - seed: {seed}
44
+ - truncation: 0.7
45
+ ![sample images]({url})'''
 
 
 
 
46
 
47
 
48
  def main():
49
  args = parse_args()
50
+ model = Model(args.device)
51
+
52
+ with gr.Blocks(theme=args.theme, css='style.css') as demo:
53
+ gr.Markdown(TITLE)
54
+ gr.Markdown(DESCRIPTION)
55
+
56
+ with gr.Tabs():
57
+ with gr.TabItem('App'):
58
+ with gr.Row():
59
+ with gr.Column():
60
+ with gr.Group():
61
+ model_name = gr.Dropdown(list(
62
+ model.MODEL_NAME_DICT.keys()),
63
+ value='FFHQ-1024-R',
64
+ label='Model')
65
+ seed = gr.Slider(0,
66
+ np.iinfo(np.uint32).max,
67
+ step=1,
68
+ value=0,
69
+ label='Seed')
70
+ psi = gr.Slider(0,
71
+ 2,
72
+ step=0.05,
73
+ value=0.7,
74
+ label='Truncation psi')
75
+ tx = gr.Slider(-1,
76
+ 1,
77
+ step=0.05,
78
+ value=0,
79
+ label='Translate X')
80
+ ty = gr.Slider(-1,
81
+ 1,
82
+ step=0.05,
83
+ value=0,
84
+ label='Translate Y')
85
+ angle = gr.Slider(-180,
86
+ 180,
87
+ step=5,
88
+ value=0,
89
+ label='Angle')
90
+ run_button = gr.Button('Run')
91
+ with gr.Column():
92
+ result = gr.Image(label='Result', elem_id='result')
93
+
94
+ with gr.TabItem('Sample Images'):
95
+ with gr.Row():
96
+ model_name2 = gr.Dropdown([
97
+ 'afhqv2',
98
+ 'ffhq',
99
+ 'ffhq-u',
100
+ 'metfaces',
101
+ 'metfaces-u',
102
+ ],
103
+ value='afhqv2',
104
+ label='Model')
105
+ with gr.Row():
106
+ text = get_sample_image_markdown(model_name2.value)
107
+ sample_images = gr.Markdown(text)
108
+
109
+ gr.Markdown(FOOTER)
110
+
111
+ model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
112
+ run_button.click(fn=model.set_model_and_generate_image,
113
+ inputs=[
114
+ model_name,
115
+ seed,
116
+ psi,
117
+ tx,
118
+ ty,
119
+ angle,
120
+ ],
121
+ outputs=result)
122
+ model_name2.change(fn=get_sample_image_markdown,
123
+ inputs=model_name2,
124
+ outputs=sample_images)
125
+
126
+ demo.launch(
127
  enable_queue=args.enable_queue,
128
  server_port=args.port,
129
  share=args.share,
model.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import pathlib
5
+ import pickle
6
+ import sys
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ current_dir = pathlib.Path(__file__).parent
14
+ submodule_dir = current_dir / 'stylegan3'
15
+ sys.path.insert(0, submodule_dir.as_posix())
16
+
17
+ HF_TOKEN = os.environ['HF_TOKEN']
18
+
19
+
20
+ class Model:
21
+ MODEL_NAME_DICT = {
22
+ 'AFHQv2-512-R': 'stylegan3-r-afhqv2-512x512.pkl',
23
+ 'FFHQ-1024-R': 'stylegan3-r-ffhq-1024x1024.pkl',
24
+ 'FFHQ-U-256-R': 'stylegan3-r-ffhqu-256x256.pkl',
25
+ 'FFHQ-U-1024-R': 'stylegan3-r-ffhqu-1024x1024.pkl',
26
+ 'MetFaces-1024-R': 'stylegan3-r-metfaces-1024x1024.pkl',
27
+ 'MetFaces-U-1024-R': 'stylegan3-r-metfacesu-1024x1024.pkl',
28
+ 'AFHQv2-512-T': 'stylegan3-t-afhqv2-512x512.pkl',
29
+ 'FFHQ-1024-T': 'stylegan3-t-ffhq-1024x1024.pkl',
30
+ 'FFHQ-U-256-T': 'stylegan3-t-ffhqu-256x256.pkl',
31
+ 'FFHQ-U-1024-T': 'stylegan3-t-ffhqu-1024x1024.pkl',
32
+ 'MetFaces-1024-T': 'stylegan3-t-metfaces-1024x1024.pkl',
33
+ 'MetFaces-U-1024-T': 'stylegan3-t-metfacesu-1024x1024.pkl',
34
+ }
35
+
36
+ def __init__(self, device: str | torch.device):
37
+ self.device = torch.device(device)
38
+ self._download_all_models()
39
+ self.model_name = 'FFHQ-1024-R'
40
+ self.model = self._load_model(self.model_name)
41
+
42
+ def _load_model(self, model_name: str) -> nn.Module:
43
+ file_name = self.MODEL_NAME_DICT[model_name]
44
+ path = hf_hub_download('hysts/StyleGAN3',
45
+ f'models/{file_name}',
46
+ use_auth_token=HF_TOKEN)
47
+ with open(path, 'rb') as f:
48
+ model = pickle.load(f)['G_ema']
49
+ model.eval()
50
+ model.to(self.device)
51
+ return model
52
+
53
+ def set_model(self, model_name: str) -> None:
54
+ if model_name == self.model_name:
55
+ return
56
+ self.model_name = model_name
57
+ self.model = self._load_model(model_name)
58
+
59
+ def _download_all_models(self):
60
+ for name in self.MODEL_NAME_DICT.keys():
61
+ self._load_model(name)
62
+
63
+ @staticmethod
64
+ def make_transform(translate: tuple[float, float],
65
+ angle: float) -> np.ndarray:
66
+ mat = np.eye(3)
67
+ sin = np.sin(angle / 360 * np.pi * 2)
68
+ cos = np.cos(angle / 360 * np.pi * 2)
69
+ mat[0][0] = cos
70
+ mat[0][1] = sin
71
+ mat[0][2] = translate[0]
72
+ mat[1][0] = -sin
73
+ mat[1][1] = cos
74
+ mat[1][2] = translate[1]
75
+ return mat
76
+
77
+ def generate_z(self, seed: int) -> torch.Tensor:
78
+ seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
79
+ z = np.random.RandomState(seed).randn(1, self.model.z_dim)
80
+ return torch.from_numpy(z).float().to(self.device)
81
+
82
+ def postprocess(self, tensor: torch.Tensor) -> np.ndarray:
83
+ tensor = (tensor.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(
84
+ torch.uint8)
85
+ return tensor.cpu().numpy()
86
+
87
+ def set_transform(self, tx: float, ty: float, angle: float) -> None:
88
+ mat = self.make_transform((tx, ty), angle)
89
+ mat = np.linalg.inv(mat)
90
+ self.model.synthesis.input.transform.copy_(torch.from_numpy(mat))
91
+
92
+ @torch.inference_mode()
93
+ def generate(self, z: torch.Tensor, label: torch.Tensor,
94
+ truncation_psi: float) -> torch.Tensor:
95
+ return self.model(z, label, truncation_psi=truncation_psi)
96
+
97
+ def generate_image(self, seed: int, truncation_psi: float, tx: float,
98
+ ty: float, angle: float) -> np.ndarray:
99
+ self.set_transform(tx, ty, angle)
100
+
101
+ z = self.generate_z(seed)
102
+ label = torch.zeros([1, self.model.c_dim], device=self.device)
103
+
104
+ out = self.generate(z, label, truncation_psi)
105
+ out = self.postprocess(out)
106
+ return out[0]
107
+
108
+ def set_model_and_generate_image(self, model_name: str, seed: int,
109
+ truncation_psi: float, tx: float,
110
+ ty: float, angle: float) -> np.ndarray:
111
+ self.set_model(model_name)
112
+ return self.generate_image(seed, truncation_psi, tx, ty, angle)
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+ div#result {
5
+ max-width: 600px;
6
+ max-height: 600px;
7
+ }
8
+ img#visitor-badge {
9
+ display: block;
10
+ margin: auto;
11
+ }