hysts HF staff commited on
Commit
ce50ac7
1 Parent(s): 4049f95
Files changed (5) hide show
  1. .pre-commit-config.yaml +4 -13
  2. README.md +4 -1
  3. app.py +108 -148
  4. dualstylegan.py +15 -22
  5. requirements.txt +7 -7
.pre-commit-config.yaml CHANGED
@@ -1,4 +1,4 @@
1
- exclude: ^(DualStyleGAN|patch)
2
  repos:
3
  - repo: https://github.com/pre-commit/pre-commit-hooks
4
  rev: v4.2.0
@@ -21,26 +21,17 @@ repos:
21
  - id: docformatter
22
  args: ['--in-place']
23
  - repo: https://github.com/pycqa/isort
24
- rev: 5.10.1
25
  hooks:
26
  - id: isort
27
  - repo: https://github.com/pre-commit/mirrors-mypy
28
- rev: v0.812
29
  hooks:
30
  - id: mypy
31
  args: ['--ignore-missing-imports']
 
32
  - repo: https://github.com/google/yapf
33
  rev: v0.32.0
34
  hooks:
35
  - id: yapf
36
  args: ['--parallel', '--in-place']
37
- - repo: https://github.com/kynan/nbstripout
38
- rev: 0.5.0
39
- hooks:
40
- - id: nbstripout
41
- args: ['--extra-keys', 'metadata.interpreter metadata.kernelspec cell.metadata.pycharm']
42
- - repo: https://github.com/nbQA-dev/nbQA
43
- rev: 1.3.1
44
- hooks:
45
- - id: nbqa-isort
46
- - id: nbqa-yapf
 
1
+ exclude: ^patch
2
  repos:
3
  - repo: https://github.com/pre-commit/pre-commit-hooks
4
  rev: v4.2.0
 
21
  - id: docformatter
22
  args: ['--in-place']
23
  - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
  hooks:
26
  - id: isort
27
  - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
  hooks:
30
  - id: mypy
31
  args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
  - repo: https://github.com/google/yapf
34
  rev: v0.32.0
35
  hooks:
36
  - id: yapf
37
  args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -4,9 +4,12 @@ emoji: 😻
4
  colorFrom: purple
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.0.17
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
 
 
4
  colorFrom: purple
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 3.35.2
8
  app_file: app.py
9
  pinned: false
10
+ suggested_hardware: t4-small
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
14
+
15
+ https://arxiv.org/abs/2203.13248
app.py CHANGED
@@ -13,19 +13,6 @@ DESCRIPTION = '''# Portrait Style Transfer with <a href="https://github.com/will
13
 
14
  <img id="overview" alt="overview" src="https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/overview.jpg" />
15
  '''
16
- FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.dualstylegan" />'
17
-
18
-
19
- def parse_args() -> argparse.Namespace:
20
- parser = argparse.ArgumentParser()
21
- parser.add_argument('--device', type=str, default='cpu')
22
- parser.add_argument('--theme', type=str)
23
- parser.add_argument('--share', action='store_true')
24
- parser.add_argument('--port', type=int)
25
- parser.add_argument('--disable-queue',
26
- dest='enable_queue',
27
- action='store_false')
28
- return parser.parse_args()
29
 
30
 
31
  def get_style_image_url(style_name: str) -> str:
@@ -83,154 +70,127 @@ def set_example_weights(example: list) -> list[dict]:
83
  ]
84
 
85
 
86
- def main():
87
- args = parse_args()
88
- model = Model(device=args.device)
89
 
90
- with gr.Blocks(theme=args.theme, css='style.css') as demo:
91
- gr.Markdown(DESCRIPTION)
92
 
93
- with gr.Box():
94
- gr.Markdown('''## Step 1 (Preprocess Input Image)
95
 
96
  - Drop an image containing a near-frontal face to the **Input Image**.
97
- - If there are multiple faces in the image, hit the Edit button in the upper right corner and crop the input image beforehand.
98
  - Hit the **Detect & Align Face** button.
99
  - Hit the **Reconstruct Face** button.
100
- - The final result will be based on this **Reconstructed Face**. So, if the reconstructed image is not satisfactory, you may want to change the input image.
101
  ''')
102
- with gr.Row():
103
- with gr.Column():
104
- with gr.Row():
105
- input_image = gr.Image(label='Input Image',
106
- type='file')
107
- with gr.Row():
108
- detect_button = gr.Button('Detect & Align Face')
109
- with gr.Column():
110
- with gr.Row():
111
- aligned_face = gr.Image(label='Aligned Face',
112
- type='numpy',
113
- interactive=False)
114
- with gr.Row():
115
- reconstruct_button = gr.Button('Reconstruct Face')
116
- with gr.Column():
117
- reconstructed_face = gr.Image(label='Reconstructed Face',
118
- type='numpy')
119
- instyle = gr.Variable()
120
-
121
- with gr.Row():
122
- paths = sorted(pathlib.Path('images').glob('*.jpg'))
123
- example_images = gr.Dataset(components=[input_image],
124
- samples=[[path.as_posix()]
125
- for path in paths])
126
-
127
- with gr.Box():
128
- gr.Markdown('''## Step 2 (Select Style Image)
129
 
130
  - Select **Style Type**.
131
  - Select **Style Image Index** from the image table below.
132
  ''')
133
- with gr.Row():
134
- with gr.Column():
135
- style_type = gr.Radio(model.style_types,
136
- label='Style Type')
137
- text = get_style_image_markdown_text('cartoon')
138
- style_image = gr.Markdown(value=text)
139
- style_index = gr.Slider(0,
140
- 316,
141
- value=26,
142
- step=1,
143
- label='Style Image Index')
144
-
145
- with gr.Row():
146
- example_styles = gr.Dataset(
147
- components=[style_type, style_index],
148
- samples=[
149
- ['cartoon', 26],
150
- ['caricature', 65],
151
- ['arcane', 63],
152
- ['pixar', 80],
153
- ])
154
-
155
- with gr.Box():
156
- gr.Markdown('''## Step 3 (Generate Style Transferred Image)
157
 
158
  - Adjust **Structure Weight** and **Color Weight**.
159
- - These are weights for the style image, so the larger the value, the closer the resulting image will be to the style image.
160
  - Hit the **Generate** button.
161
  ''')
162
- with gr.Row():
163
- with gr.Column():
164
- with gr.Row():
165
- structure_weight = gr.Slider(0,
166
- 1,
167
- value=0.6,
168
- step=0.1,
169
- label='Structure Weight')
170
- with gr.Row():
171
- color_weight = gr.Slider(0,
172
- 1,
173
- value=1,
174
  step=0.1,
175
- label='Color Weight')
176
- with gr.Row():
177
- structure_only = gr.Checkbox(label='Structure Only')
178
- with gr.Row():
179
- generate_button = gr.Button('Generate')
180
-
181
- with gr.Column():
182
- result = gr.Image(label='Result')
183
-
184
- with gr.Row():
185
- example_weights = gr.Dataset(
186
- components=[structure_weight, color_weight],
187
- samples=[
188
- [0.6, 1.0],
189
- [0.3, 1.0],
190
- [0.0, 1.0],
191
- [1.0, 0.0],
192
- ])
193
-
194
- gr.Markdown(FOOTER)
195
-
196
- detect_button.click(fn=model.detect_and_align_face,
197
- inputs=input_image,
198
- outputs=aligned_face)
199
- reconstruct_button.click(fn=model.reconstruct_face,
200
- inputs=aligned_face,
201
- outputs=[reconstructed_face, instyle])
202
- style_type.change(fn=update_slider,
203
- inputs=style_type,
204
- outputs=style_index)
205
- style_type.change(fn=update_style_image,
206
- inputs=style_type,
207
- outputs=style_image)
208
- generate_button.click(fn=model.generate,
209
- inputs=[
210
- style_type,
211
- style_index,
212
- structure_weight,
213
- color_weight,
214
- structure_only,
215
- instyle,
216
- ],
217
- outputs=result)
218
- example_images.click(fn=set_example_image,
219
- inputs=example_images,
220
- outputs=example_images.components)
221
- example_styles.click(fn=set_example_styles,
222
- inputs=example_styles,
223
- outputs=example_styles.components)
224
- example_weights.click(fn=set_example_weights,
225
- inputs=example_weights,
226
- outputs=example_weights.components)
227
-
228
- demo.launch(
229
- enable_queue=args.enable_queue,
230
- server_port=args.port,
231
- share=args.share,
232
- )
233
-
234
-
235
- if __name__ == '__main__':
236
- main()
 
13
 
14
  <img id="overview" alt="overview" src="https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/overview.jpg" />
15
  '''
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  def get_style_image_url(style_name: str) -> str:
 
70
  ]
71
 
72
 
73
+ model = Model()
 
 
74
 
75
+ with gr.Blocks(css='style.css') as demo:
76
+ gr.Markdown(DESCRIPTION)
77
 
78
+ with gr.Box():
79
+ gr.Markdown('''## Step 1 (Preprocess Input Image)
80
 
81
  - Drop an image containing a near-frontal face to the **Input Image**.
82
+ - If there are multiple faces in the image, hit the Edit button in the upper right corner and crop the input image beforehand.
83
  - Hit the **Detect & Align Face** button.
84
  - Hit the **Reconstruct Face** button.
85
+ - The final result will be based on this **Reconstructed Face**. So, if the reconstructed image is not satisfactory, you may want to change the input image.
86
  ''')
87
+ with gr.Row():
88
+ with gr.Column():
89
+ with gr.Row():
90
+ input_image = gr.Image(label='Input Image',
91
+ type='filepath')
92
+ with gr.Row():
93
+ detect_button = gr.Button('Detect & Align Face')
94
+ with gr.Column():
95
+ with gr.Row():
96
+ aligned_face = gr.Image(label='Aligned Face',
97
+ type='numpy',
98
+ interactive=False)
99
+ with gr.Row():
100
+ reconstruct_button = gr.Button('Reconstruct Face')
101
+ with gr.Column():
102
+ reconstructed_face = gr.Image(label='Reconstructed Face',
103
+ type='numpy')
104
+ instyle = gr.Variable()
105
+
106
+ with gr.Row():
107
+ paths = sorted(pathlib.Path('images').glob('*.jpg'))
108
+ gr.Examples(examples=[[path.as_posix()] for path in paths],
109
+ inputs=input_image)
110
+
111
+ with gr.Box():
112
+ gr.Markdown('''## Step 2 (Select Style Image)
 
113
 
114
  - Select **Style Type**.
115
  - Select **Style Image Index** from the image table below.
116
  ''')
117
+ with gr.Row():
118
+ with gr.Column():
119
+ style_type = gr.Radio(label='Style Type',
120
+ choices=model.style_types)
121
+ text = get_style_image_markdown_text('cartoon')
122
+ style_image = gr.Markdown(value=text)
123
+ style_index = gr.Slider(label='Style Image Index',
124
+ minimum=0,
125
+ maximum=316,
126
+ step=1,
127
+ value=26)
128
+
129
+ with gr.Row():
130
+ gr.Examples(examples=[
131
+ ['cartoon', 26],
132
+ ['caricature', 65],
133
+ ['arcane', 63],
134
+ ['pixar', 80],
135
+ ],
136
+ inputs=[style_type, style_index])
137
+
138
+ with gr.Box():
139
+ gr.Markdown('''## Step 3 (Generate Style Transferred Image)
 
140
 
141
  - Adjust **Structure Weight** and **Color Weight**.
142
+ - These are weights for the style image, so the larger the value, the closer the resulting image will be to the style image.
143
  - Hit the **Generate** button.
144
  ''')
145
+ with gr.Row():
146
+ with gr.Column():
147
+ with gr.Row():
148
+ structure_weight = gr.Slider(label='Structure Weight',
149
+ minimum=0,
150
+ maximum=1,
 
 
 
 
 
 
151
  step=0.1,
152
+ value=0.6)
153
+ with gr.Row():
154
+ color_weight = gr.Slider(label='Color Weight',
155
+ minimum=0,
156
+ maximum=1,
157
+ step=0.1,
158
+ value=1)
159
+ with gr.Row():
160
+ structure_only = gr.Checkbox(label='Structure Only')
161
+ with gr.Row():
162
+ generate_button = gr.Button('Generate')
163
+
164
+ with gr.Column():
165
+ result = gr.Image(label='Result')
166
+
167
+ with gr.Row():
168
+ gr.Examples(examples=[
169
+ [0.6, 1.0],
170
+ [0.3, 1.0],
171
+ [0.0, 1.0],
172
+ [1.0, 0.0],
173
+ ],
174
+ inputs=[structure_weight, color_weight])
175
+
176
+ detect_button.click(fn=model.detect_and_align_face,
177
+ inputs=input_image,
178
+ outputs=aligned_face)
179
+ reconstruct_button.click(fn=model.reconstruct_face,
180
+ inputs=aligned_face,
181
+ outputs=[reconstructed_face, instyle])
182
+ style_type.change(fn=update_slider, inputs=style_type, outputs=style_index)
183
+ style_type.change(fn=update_style_image,
184
+ inputs=style_type,
185
+ outputs=style_image)
186
+ generate_button.click(fn=model.generate,
187
+ inputs=[
188
+ style_type,
189
+ style_index,
190
+ structure_weight,
191
+ color_weight,
192
+ structure_only,
193
+ instyle,
194
+ ],
195
+ outputs=result)
196
+ demo.queue(max_size=10).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dualstylegan.py CHANGED
@@ -3,9 +3,10 @@ from __future__ import annotations
3
  import argparse
4
  import os
5
  import pathlib
 
6
  import subprocess
7
  import sys
8
- from typing import Callable, Union
9
 
10
  import dlib
11
  import huggingface_hub
@@ -15,9 +16,9 @@ import torch
15
  import torch.nn as nn
16
  import torchvision.transforms as T
17
 
18
- if os.getenv('SYSTEM') == 'spaces':
19
  with open('patch') as f:
20
- subprocess.run('patch -p1'.split(), cwd='DualStyleGAN', stdin=f)
21
 
22
  app_dir = pathlib.Path(__file__).parent
23
  submodule_dir = app_dir / 'DualStyleGAN'
@@ -27,13 +28,11 @@ from model.dualstylegan import DualStyleGAN
27
  from model.encoder.align_all_parallel import align_face
28
  from model.encoder.psp import pSp
29
 
30
- HF_TOKEN = os.environ['HF_TOKEN']
31
- MODEL_REPO = 'hysts/DualStyleGAN'
32
-
33
 
34
  class Model:
35
- def __init__(self, device: Union[torch.device, str]):
36
- self.device = torch.device(device)
 
37
  self.landmark_model = self._create_dlib_landmark_model()
38
  self.encoder = self._load_encoder()
39
  self.transform = self._create_transform()
@@ -59,15 +58,13 @@ class Model:
59
  @staticmethod
60
  def _create_dlib_landmark_model():
61
  path = huggingface_hub.hf_hub_download(
62
- 'hysts/dlib_face_landmark_model',
63
- 'shape_predictor_68_face_landmarks.dat',
64
- use_auth_token=HF_TOKEN)
65
  return dlib.shape_predictor(path)
66
 
67
  def _load_encoder(self) -> nn.Module:
68
- ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO,
69
- 'models/encoder.pt',
70
- use_auth_token=HF_TOKEN)
71
  ckpt = torch.load(ckpt_path, map_location='cpu')
72
  opts = ckpt['opts']
73
  opts['device'] = self.device.type
@@ -91,9 +88,7 @@ class Model:
91
  def _load_generator(self, style_type: str) -> nn.Module:
92
  model = DualStyleGAN(1024, 512, 8, 2, res_index=6)
93
  ckpt_path = huggingface_hub.hf_hub_download(
94
- MODEL_REPO,
95
- f'models/{style_type}/generator.pt',
96
- use_auth_token=HF_TOKEN)
97
  ckpt = torch.load(ckpt_path, map_location='cpu')
98
  model.load_state_dict(ckpt['g_ema'])
99
  model.to(self.device)
@@ -107,14 +102,12 @@ class Model:
107
  else:
108
  filename = 'exstyle_code.npy'
109
  path = huggingface_hub.hf_hub_download(
110
- MODEL_REPO,
111
- f'models/{style_type}/{filename}',
112
- use_auth_token=HF_TOKEN)
113
  exstyles = np.load(path, allow_pickle=True).item()
114
  return exstyles
115
 
116
- def detect_and_align_face(self, image) -> np.ndarray:
117
- image = align_face(filepath=image.name, predictor=self.landmark_model)
118
  return image
119
 
120
  @staticmethod
 
3
  import argparse
4
  import os
5
  import pathlib
6
+ import shlex
7
  import subprocess
8
  import sys
9
+ from typing import Callable
10
 
11
  import dlib
12
  import huggingface_hub
 
16
  import torch.nn as nn
17
  import torchvision.transforms as T
18
 
19
+ if os.getenv('SYSTEM') == 'spaces' and not torch.cuda.is_available():
20
  with open('patch') as f:
21
+ subprocess.run(shlex.split('patch -p1'), cwd='DualStyleGAN', stdin=f)
22
 
23
  app_dir = pathlib.Path(__file__).parent
24
  submodule_dir = app_dir / 'DualStyleGAN'
 
28
  from model.encoder.align_all_parallel import align_face
29
  from model.encoder.psp import pSp
30
 
 
 
 
31
 
32
  class Model:
33
+ def __init__(self):
34
+ self.device = torch.device(
35
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
36
  self.landmark_model = self._create_dlib_landmark_model()
37
  self.encoder = self._load_encoder()
38
  self.transform = self._create_transform()
 
58
  @staticmethod
59
  def _create_dlib_landmark_model():
60
  path = huggingface_hub.hf_hub_download(
61
+ 'public-data/dlib_face_landmark_model',
62
+ 'shape_predictor_68_face_landmarks.dat')
 
63
  return dlib.shape_predictor(path)
64
 
65
  def _load_encoder(self) -> nn.Module:
66
+ ckpt_path = huggingface_hub.hf_hub_download('public-data/DualStyleGAN',
67
+ 'models/encoder.pt')
 
68
  ckpt = torch.load(ckpt_path, map_location='cpu')
69
  opts = ckpt['opts']
70
  opts['device'] = self.device.type
 
88
  def _load_generator(self, style_type: str) -> nn.Module:
89
  model = DualStyleGAN(1024, 512, 8, 2, res_index=6)
90
  ckpt_path = huggingface_hub.hf_hub_download(
91
+ 'public-data/DualStyleGAN', f'models/{style_type}/generator.pt')
 
 
92
  ckpt = torch.load(ckpt_path, map_location='cpu')
93
  model.load_state_dict(ckpt['g_ema'])
94
  model.to(self.device)
 
102
  else:
103
  filename = 'exstyle_code.npy'
104
  path = huggingface_hub.hf_hub_download(
105
+ 'public-data/DualStyleGAN', f'models/{style_type}/{filename}')
 
 
106
  exstyles = np.load(path, allow_pickle=True).item()
107
  return exstyles
108
 
109
+ def detect_and_align_face(self, image: str) -> np.ndarray:
110
+ image = align_face(filepath=image, predictor=self.landmark_model)
111
  return image
112
 
113
  @staticmethod
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
- dlib==19.23.0
2
- numpy==1.22.3
3
- opencv-python-headless==4.5.5.62
4
- Pillow==9.0.1
5
- scipy==1.8.0
6
- torch==1.11.0
7
- torchvision==0.12.0
 
1
+ dlib==19.24.2
2
+ numpy==1.23.5
3
+ opencv-python-headless==4.8.0.74
4
+ Pillow==9.5.0
5
+ scipy==1.10.1
6
+ torch==2.0.1
7
+ torchvision==0.15.2