hysts HF staff commited on
Commit
eb5f129
β€’
1 Parent(s): 308eb31
Files changed (4) hide show
  1. .pre-commit-config.yaml +2 -12
  2. README.md +1 -1
  3. app.py +91 -120
  4. model.py +4 -3
.pre-commit-config.yaml CHANGED
@@ -21,11 +21,11 @@ repos:
21
  - id: docformatter
22
  args: ['--in-place']
23
  - repo: https://github.com/pycqa/isort
24
- rev: 5.10.1
25
  hooks:
26
  - id: isort
27
  - repo: https://github.com/pre-commit/mirrors-mypy
28
- rev: v0.812
29
  hooks:
30
  - id: mypy
31
  args: ['--ignore-missing-imports']
@@ -34,13 +34,3 @@ repos:
34
  hooks:
35
  - id: yapf
36
  args: ['--parallel', '--in-place']
37
- - repo: https://github.com/kynan/nbstripout
38
- rev: 0.5.0
39
- hooks:
40
- - id: nbstripout
41
- args: ['--extra-keys', 'metadata.interpreter metadata.kernelspec cell.metadata.pycharm']
42
- - repo: https://github.com/nbQA-dev/nbQA
43
- rev: 1.3.1
44
- hooks:
45
- - id: nbqa-isort
46
- - id: nbqa-yapf
 
21
  - id: docformatter
22
  args: ['--in-place']
23
  - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
  hooks:
26
  - id: isort
27
  - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
  hooks:
30
  - id: mypy
31
  args: ['--ignore-missing-imports']
 
34
  hooks:
35
  - id: yapf
36
  args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸ‘
4
  colorFrom: purple
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 3.0.11
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: purple
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  from __future__ import annotations
4
 
5
- import argparse
6
  import json
7
 
8
  import gradio as gr
@@ -10,24 +9,10 @@ import numpy as np
10
 
11
  from model import Model
12
 
13
- TITLE = '# StyleGAN2'
14
- DESCRIPTION = '''This is an unofficial demo for [https://github.com/NVlabs/stylegan3](https://github.com/NVlabs/stylegan3).
15
 
16
- Expected execution time on Hugging Face Spaces: 4s
17
  '''
18
- FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=hysts.stylegan2" />'
19
-
20
-
21
- def parse_args() -> argparse.Namespace:
22
- parser = argparse.ArgumentParser()
23
- parser.add_argument('--device', type=str, default='cpu')
24
- parser.add_argument('--theme', type=str)
25
- parser.add_argument('--share', action='store_true')
26
- parser.add_argument('--port', type=int)
27
- parser.add_argument('--disable-queue',
28
- dest='enable_queue',
29
- action='store_false')
30
- return parser.parse_args()
31
 
32
 
33
  def update_class_index(name: str) -> dict:
@@ -106,106 +91,92 @@ def update_class_name(model_name: str, index: int) -> dict:
106
  return gr.Textbox.update(visible=False)
107
 
108
 
109
- def main():
110
- args = parse_args()
111
- model = Model(args.device)
112
-
113
- with gr.Blocks(theme=args.theme, css='style.css') as demo:
114
- gr.Markdown(TITLE)
115
- gr.Markdown(DESCRIPTION)
116
-
117
- with gr.Tabs():
118
- with gr.TabItem('App'):
119
- with gr.Row():
120
- with gr.Column():
121
- with gr.Group():
122
- model_name = gr.Dropdown(list(
123
- model.MODEL_NAME_DICT.keys()),
124
- value='FFHQ-1024',
125
- label='Model')
126
- seed = gr.Slider(0,
127
- np.iinfo(np.uint32).max,
128
- step=1,
129
- value=0,
130
- label='Seed')
131
- psi = gr.Slider(0,
132
- 2,
133
- step=0.05,
134
- value=0.7,
135
- label='Truncation psi')
136
- class_index = gr.Slider(0,
137
- 9,
138
- step=1,
139
- value=0,
140
- label='Class Index',
141
- visible=False)
142
- class_name = gr.Textbox(
143
- value=CIFAR10_NAMES[class_index.value],
144
- label='Class Label',
145
- interactive=False,
146
- visible=False)
147
- run_button = gr.Button('Run')
148
- with gr.Column():
149
- result = gr.Image(label='Result', elem_id='result')
150
-
151
- with gr.TabItem('Sample Images'):
152
- with gr.Row():
153
- model_name2 = gr.Dropdown([
154
- 'afhq-cat',
155
- 'afhq-dog',
156
- 'afhq-wild',
157
- 'afhqv2',
158
- 'brecahad',
159
- 'celebahq',
160
- 'cifar10',
161
- 'ffhq',
162
- 'ffhq-u',
163
- 'lsun-dog',
164
- 'metfaces',
165
- 'metfaces-u',
166
- ],
167
- value='afhq-cat',
168
- label='Model')
169
- with gr.Row():
170
- text = get_sample_image_markdown(model_name2.value)
171
- sample_images = gr.Markdown(text)
172
-
173
- gr.Markdown(FOOTER)
174
-
175
- model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
176
- model_name.change(fn=update_class_index,
177
- inputs=model_name,
178
- outputs=class_index)
179
- model_name.change(fn=update_class_name,
180
- inputs=[
181
- model_name,
182
- class_index,
183
- ],
184
- outputs=class_name)
185
- class_index.change(fn=update_class_name,
186
- inputs=[
187
- model_name,
188
- class_index,
189
- ],
190
- outputs=class_name)
191
- run_button.click(fn=model.set_model_and_generate_image,
192
- inputs=[
193
- model_name,
194
- seed,
195
- psi,
196
- class_index,
197
- ],
198
- outputs=result)
199
- model_name2.change(fn=get_sample_image_markdown,
200
- inputs=model_name2,
201
- outputs=sample_images)
202
-
203
- demo.launch(
204
- enable_queue=args.enable_queue,
205
- server_port=args.port,
206
- share=args.share,
207
- )
208
-
209
-
210
- if __name__ == '__main__':
211
- main()
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import json
6
 
7
  import gradio as gr
 
9
 
10
  from model import Model
11
 
12
+ DESCRIPTION = '''# StyleGAN2
 
13
 
14
+ This is an unofficial demo for [https://github.com/NVlabs/stylegan3](https://github.com/NVlabs/stylegan3).
15
  '''
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  def update_class_index(name: str) -> dict:
 
91
  return gr.Textbox.update(visible=False)
92
 
93
 
94
+ model = Model()
95
+
96
+ with gr.Blocks(css='style.css') as demo:
97
+ gr.Markdown(DESCRIPTION)
98
+
99
+ with gr.Tabs():
100
+ with gr.TabItem('App'):
101
+ with gr.Row():
102
+ with gr.Column():
103
+ model_name = gr.Dropdown(list(
104
+ model.MODEL_NAME_DICT.keys()),
105
+ value='FFHQ-1024',
106
+ label='Model')
107
+ seed = gr.Slider(0,
108
+ np.iinfo(np.uint32).max,
109
+ step=1,
110
+ value=0,
111
+ label='Seed')
112
+ psi = gr.Slider(0,
113
+ 2,
114
+ step=0.05,
115
+ value=0.7,
116
+ label='Truncation psi')
117
+ class_index = gr.Slider(0,
118
+ 9,
119
+ step=1,
120
+ value=0,
121
+ label='Class Index',
122
+ visible=False)
123
+ class_name = gr.Textbox(
124
+ value=CIFAR10_NAMES[class_index.value],
125
+ label='Class Label',
126
+ interactive=False,
127
+ visible=False)
128
+ run_button = gr.Button('Run')
129
+ with gr.Column():
130
+ result = gr.Image(label='Result', elem_id='result')
131
+
132
+ with gr.TabItem('Sample Images'):
133
+ with gr.Row():
134
+ model_name2 = gr.Dropdown([
135
+ 'afhq-cat',
136
+ 'afhq-dog',
137
+ 'afhq-wild',
138
+ 'afhqv2',
139
+ 'brecahad',
140
+ 'celebahq',
141
+ 'cifar10',
142
+ 'ffhq',
143
+ 'ffhq-u',
144
+ 'lsun-dog',
145
+ 'metfaces',
146
+ 'metfaces-u',
147
+ ],
148
+ value='afhq-cat',
149
+ label='Model')
150
+ with gr.Row():
151
+ text = get_sample_image_markdown(model_name2.value)
152
+ sample_images = gr.Markdown(text)
153
+
154
+ model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
155
+ model_name.change(fn=update_class_index,
156
+ inputs=model_name,
157
+ outputs=class_index)
158
+ model_name.change(fn=update_class_name,
159
+ inputs=[
160
+ model_name,
161
+ class_index,
162
+ ],
163
+ outputs=class_name)
164
+ class_index.change(fn=update_class_name,
165
+ inputs=[
166
+ model_name,
167
+ class_index,
168
+ ],
169
+ outputs=class_name)
170
+ run_button.click(fn=model.set_model_and_generate_image,
171
+ inputs=[
172
+ model_name,
173
+ seed,
174
+ psi,
175
+ class_index,
176
+ ],
177
+ outputs=result)
178
+ model_name2.change(fn=get_sample_image_markdown,
179
+ inputs=model_name2,
180
+ outputs=sample_images)
181
+
182
+ demo.queue().launch(show_api=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.py CHANGED
@@ -14,7 +14,7 @@ current_dir = pathlib.Path(__file__).parent
14
  submodule_dir = current_dir / 'stylegan3'
15
  sys.path.insert(0, submodule_dir.as_posix())
16
 
17
- HF_TOKEN = os.environ['HF_TOKEN']
18
 
19
 
20
  class Model:
@@ -36,8 +36,9 @@ class Model:
36
  'MetFaces-U-1024': 'stylegan2-metfacesu-1024x1024.pkl',
37
  }
38
 
39
- def __init__(self, device: str | torch.device):
40
- self.device = torch.device(device)
 
41
  self._download_all_models()
42
  self.model_name = 'FFHQ-1024'
43
  self.model = self._load_model(self.model_name)
 
14
  submodule_dir = current_dir / 'stylegan3'
15
  sys.path.insert(0, submodule_dir.as_posix())
16
 
17
+ HF_TOKEN = os.getenv('HF_TOKEN')
18
 
19
 
20
  class Model:
 
36
  'MetFaces-U-1024': 'stylegan2-metfacesu-1024x1024.pkl',
37
  }
38
 
39
+ def __init__(self):
40
+ self.device = torch.device(
41
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
42
  self._download_all_models()
43
  self.model_name = 'FFHQ-1024'
44
  self.model = self._load_model(self.model_name)