hysts HF staff commited on
Commit
811cb03
1 Parent(s): d5eecec
Files changed (7) hide show
  1. .pre-commit-config.yaml +59 -35
  2. .style.yapf +0 -5
  3. .vscode/settings.json +30 -0
  4. app.py +80 -92
  5. dualstylegan.py +66 -66
  6. images/README.md +0 -1
  7. style.css +1 -4
.pre-commit-config.yaml CHANGED
@@ -1,37 +1,61 @@
1
  exclude: ^patch
2
  repos:
3
- - repo: https://github.com/pre-commit/pre-commit-hooks
4
- rev: v4.2.0
5
- hooks:
6
- - id: check-executables-have-shebangs
7
- - id: check-json
8
- - id: check-merge-conflict
9
- - id: check-shebang-scripts-are-executable
10
- - id: check-toml
11
- - id: check-yaml
12
- - id: double-quote-string-fixer
13
- - id: end-of-file-fixer
14
- - id: mixed-line-ending
15
- args: ['--fix=lf']
16
- - id: requirements-txt-fixer
17
- - id: trailing-whitespace
18
- - repo: https://github.com/myint/docformatter
19
- rev: v1.4
20
- hooks:
21
- - id: docformatter
22
- args: ['--in-place']
23
- - repo: https://github.com/pycqa/isort
24
- rev: 5.12.0
25
- hooks:
26
- - id: isort
27
- - repo: https://github.com/pre-commit/mirrors-mypy
28
- rev: v0.991
29
- hooks:
30
- - id: mypy
31
- args: ['--ignore-missing-imports']
32
- additional_dependencies: ['types-python-slugify']
33
- - repo: https://github.com/google/yapf
34
- rev: v0.32.0
35
- hooks:
36
- - id: yapf
37
- args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  exclude: ^patch
2
  repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.6.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ["--fix=lf"]
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.7.5
19
+ hooks:
20
+ - id: docformatter
21
+ args: ["--in-place"]
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.13.2
24
+ hooks:
25
+ - id: isort
26
+ args: ["--profile", "black"]
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v1.10.0
29
+ hooks:
30
+ - id: mypy
31
+ args: ["--ignore-missing-imports"]
32
+ additional_dependencies:
33
+ [
34
+ "types-python-slugify",
35
+ "types-requests",
36
+ "types-PyYAML",
37
+ "types-pytz",
38
+ ]
39
+ - repo: https://github.com/psf/black
40
+ rev: 24.4.2
41
+ hooks:
42
+ - id: black
43
+ language_version: python3.10
44
+ args: ["--line-length", "119"]
45
+ - repo: https://github.com/kynan/nbstripout
46
+ rev: 0.7.1
47
+ hooks:
48
+ - id: nbstripout
49
+ args:
50
+ [
51
+ "--extra-keys",
52
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
53
+ ]
54
+ - repo: https://github.com/nbQA-dev/nbQA
55
+ rev: 1.8.5
56
+ hooks:
57
+ - id: nbqa-black
58
+ - id: nbqa-pyupgrade
59
+ args: ["--py37-plus"]
60
+ - id: nbqa-isort
61
+ args: ["--float-to-top"]
.style.yapf DELETED
@@ -1,5 +0,0 @@
1
- [style]
2
- based_on_style = pep8
3
- blank_line_before_nested_class_or_def = false
4
- spaces_before_comment = 2
5
- split_before_logical_operator = true
 
 
 
 
 
 
.vscode/settings.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "editor.formatOnSave": true,
3
+ "files.insertFinalNewline": false,
4
+ "[python]": {
5
+ "editor.defaultFormatter": "ms-python.black-formatter",
6
+ "editor.formatOnType": true,
7
+ "editor.codeActionsOnSave": {
8
+ "source.organizeImports": "explicit"
9
+ }
10
+ },
11
+ "[jupyter]": {
12
+ "files.insertFinalNewline": false
13
+ },
14
+ "black-formatter.args": [
15
+ "--line-length=119"
16
+ ],
17
+ "isort.args": ["--profile", "black"],
18
+ "flake8.args": [
19
+ "--max-line-length=119"
20
+ ],
21
+ "ruff.lint.args": [
22
+ "--line-length=119"
23
+ ],
24
+ "notebook.output.scrolling": true,
25
+ "notebook.formatOnCellExecution": true,
26
+ "notebook.formatOnSave.enabled": true,
27
+ "notebook.codeActionsOnSave": {
28
+ "source.organizeImports": "explicit"
29
+ }
30
+ }
app.py CHANGED
@@ -9,24 +9,24 @@ import gradio as gr
9
 
10
  from dualstylegan import Model
11
 
12
- DESCRIPTION = '''# Portrait Style Transfer with <a href="https://github.com/williamyang1991/DualStyleGAN">DualStyleGAN</a>
13
 
14
  <img id="overview" alt="overview" src="https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/overview.jpg" />
15
- '''
16
 
17
 
18
  def get_style_image_url(style_name: str) -> str:
19
- base_url = 'https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images'
20
  filenames = {
21
- 'cartoon': 'cartoon_overview.jpg',
22
- 'caricature': 'caricature_overview.jpg',
23
- 'anime': 'anime_overview.jpg',
24
- 'arcane': 'Reconstruction_arcane_overview.jpg',
25
- 'comic': 'Reconstruction_comic_overview.jpg',
26
- 'pixar': 'Reconstruction_pixar_overview.jpg',
27
- 'slamdunk': 'Reconstruction_slamdunk_overview.jpg',
28
  }
29
- return f'{base_url}/{filenames[style_name]}'
30
 
31
 
32
  def get_style_image_markdown_text(style_name: str) -> str:
@@ -36,13 +36,13 @@ def get_style_image_markdown_text(style_name: str) -> str:
36
 
37
  def update_slider(choice: str) -> dict:
38
  max_vals = {
39
- 'cartoon': 316,
40
- 'caricature': 198,
41
- 'anime': 173,
42
- 'arcane': 99,
43
- 'comic': 100,
44
- 'pixar': 121,
45
- 'slamdunk': 119,
46
  }
47
  return gr.Slider.update(maximum=max_vals[choice])
48
 
@@ -72,125 +72,113 @@ def set_example_weights(example: list) -> list[dict]:
72
 
73
  model = Model()
74
 
75
- with gr.Blocks(css='style.css') as demo:
76
  gr.Markdown(DESCRIPTION)
77
 
78
  with gr.Box():
79
- gr.Markdown('''## Step 1 (Preprocess Input Image)
 
80
 
81
  - Drop an image containing a near-frontal face to the **Input Image**.
82
  - If there are multiple faces in the image, hit the Edit button in the upper right corner and crop the input image beforehand.
83
  - Hit the **Detect & Align Face** button.
84
  - Hit the **Reconstruct Face** button.
85
  - The final result will be based on this **Reconstructed Face**. So, if the reconstructed image is not satisfactory, you may want to change the input image.
86
- ''')
 
87
  with gr.Row():
88
  with gr.Column():
89
  with gr.Row():
90
- input_image = gr.Image(label='Input Image',
91
- type='filepath')
92
  with gr.Row():
93
- detect_button = gr.Button('Detect & Align Face')
94
  with gr.Column():
95
  with gr.Row():
96
- aligned_face = gr.Image(label='Aligned Face',
97
- type='numpy',
98
- interactive=False)
99
  with gr.Row():
100
- reconstruct_button = gr.Button('Reconstruct Face')
101
  with gr.Column():
102
- reconstructed_face = gr.Image(label='Reconstructed Face',
103
- type='numpy')
104
  instyle = gr.Variable()
105
 
106
  with gr.Row():
107
- paths = sorted(pathlib.Path('images').glob('*.jpg'))
108
- gr.Examples(examples=[[path.as_posix()] for path in paths],
109
- inputs=input_image)
110
 
111
  with gr.Box():
112
- gr.Markdown('''## Step 2 (Select Style Image)
 
113
 
114
  - Select **Style Type**.
115
  - Select **Style Image Index** from the image table below.
116
- ''')
 
117
  with gr.Row():
118
  with gr.Column():
119
- style_type = gr.Radio(label='Style Type',
120
- choices=model.style_types)
121
- text = get_style_image_markdown_text('cartoon')
122
  style_image = gr.Markdown(value=text)
123
- style_index = gr.Slider(label='Style Image Index',
124
- minimum=0,
125
- maximum=316,
126
- step=1,
127
- value=26)
128
 
129
  with gr.Row():
130
- gr.Examples(examples=[
131
- ['cartoon', 26],
132
- ['caricature', 65],
133
- ['arcane', 63],
134
- ['pixar', 80],
135
- ],
136
- inputs=[style_type, style_index])
 
 
137
 
138
  with gr.Box():
139
- gr.Markdown('''## Step 3 (Generate Style Transferred Image)
 
140
 
141
  - Adjust **Structure Weight** and **Color Weight**.
142
  - These are weights for the style image, so the larger the value, the closer the resulting image will be to the style image.
143
  - Hit the **Generate** button.
144
- ''')
 
145
  with gr.Row():
146
  with gr.Column():
147
  with gr.Row():
148
- structure_weight = gr.Slider(label='Structure Weight',
149
- minimum=0,
150
- maximum=1,
151
- step=0.1,
152
- value=0.6)
153
  with gr.Row():
154
- color_weight = gr.Slider(label='Color Weight',
155
- minimum=0,
156
- maximum=1,
157
- step=0.1,
158
- value=1)
159
  with gr.Row():
160
- structure_only = gr.Checkbox(label='Structure Only')
161
  with gr.Row():
162
- generate_button = gr.Button('Generate')
163
 
164
  with gr.Column():
165
- result = gr.Image(label='Result')
166
 
167
  with gr.Row():
168
- gr.Examples(examples=[
169
- [0.6, 1.0],
170
- [0.3, 1.0],
171
- [0.0, 1.0],
172
- [1.0, 0.0],
173
- ],
174
- inputs=[structure_weight, color_weight])
175
-
176
- detect_button.click(fn=model.detect_and_align_face,
177
- inputs=input_image,
178
- outputs=aligned_face)
179
- reconstruct_button.click(fn=model.reconstruct_face,
180
- inputs=aligned_face,
181
- outputs=[reconstructed_face, instyle])
182
  style_type.change(fn=update_slider, inputs=style_type, outputs=style_index)
183
- style_type.change(fn=update_style_image,
184
- inputs=style_type,
185
- outputs=style_image)
186
- generate_button.click(fn=model.generate,
187
- inputs=[
188
- style_type,
189
- style_index,
190
- structure_weight,
191
- color_weight,
192
- structure_only,
193
- instyle,
194
- ],
195
- outputs=result)
196
  demo.queue(max_size=10).launch()
 
9
 
10
  from dualstylegan import Model
11
 
12
+ DESCRIPTION = """# Portrait Style Transfer with <a href="https://github.com/williamyang1991/DualStyleGAN">DualStyleGAN</a>
13
 
14
  <img id="overview" alt="overview" src="https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/overview.jpg" />
15
+ """
16
 
17
 
18
  def get_style_image_url(style_name: str) -> str:
19
+ base_url = "https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images"
20
  filenames = {
21
+ "cartoon": "cartoon_overview.jpg",
22
+ "caricature": "caricature_overview.jpg",
23
+ "anime": "anime_overview.jpg",
24
+ "arcane": "Reconstruction_arcane_overview.jpg",
25
+ "comic": "Reconstruction_comic_overview.jpg",
26
+ "pixar": "Reconstruction_pixar_overview.jpg",
27
+ "slamdunk": "Reconstruction_slamdunk_overview.jpg",
28
  }
29
+ return f"{base_url}/{filenames[style_name]}"
30
 
31
 
32
  def get_style_image_markdown_text(style_name: str) -> str:
 
36
 
37
  def update_slider(choice: str) -> dict:
38
  max_vals = {
39
+ "cartoon": 316,
40
+ "caricature": 198,
41
+ "anime": 173,
42
+ "arcane": 99,
43
+ "comic": 100,
44
+ "pixar": 121,
45
+ "slamdunk": 119,
46
  }
47
  return gr.Slider.update(maximum=max_vals[choice])
48
 
 
72
 
73
  model = Model()
74
 
75
+ with gr.Blocks(css="style.css") as demo:
76
  gr.Markdown(DESCRIPTION)
77
 
78
  with gr.Box():
79
+ gr.Markdown(
80
+ """## Step 1 (Preprocess Input Image)
81
 
82
  - Drop an image containing a near-frontal face to the **Input Image**.
83
  - If there are multiple faces in the image, hit the Edit button in the upper right corner and crop the input image beforehand.
84
  - Hit the **Detect & Align Face** button.
85
  - Hit the **Reconstruct Face** button.
86
  - The final result will be based on this **Reconstructed Face**. So, if the reconstructed image is not satisfactory, you may want to change the input image.
87
+ """
88
+ )
89
  with gr.Row():
90
  with gr.Column():
91
  with gr.Row():
92
+ input_image = gr.Image(label="Input Image", type="filepath")
 
93
  with gr.Row():
94
+ detect_button = gr.Button("Detect & Align Face")
95
  with gr.Column():
96
  with gr.Row():
97
+ aligned_face = gr.Image(label="Aligned Face", type="numpy", interactive=False)
 
 
98
  with gr.Row():
99
+ reconstruct_button = gr.Button("Reconstruct Face")
100
  with gr.Column():
101
+ reconstructed_face = gr.Image(label="Reconstructed Face", type="numpy")
 
102
  instyle = gr.Variable()
103
 
104
  with gr.Row():
105
+ paths = sorted(pathlib.Path("images").glob("*.jpg"))
106
+ gr.Examples(examples=[[path.as_posix()] for path in paths], inputs=input_image)
 
107
 
108
  with gr.Box():
109
+ gr.Markdown(
110
+ """## Step 2 (Select Style Image)
111
 
112
  - Select **Style Type**.
113
  - Select **Style Image Index** from the image table below.
114
+ """
115
+ )
116
  with gr.Row():
117
  with gr.Column():
118
+ style_type = gr.Radio(label="Style Type", choices=model.style_types)
119
+ text = get_style_image_markdown_text("cartoon")
 
120
  style_image = gr.Markdown(value=text)
121
+ style_index = gr.Slider(label="Style Image Index", minimum=0, maximum=316, step=1, value=26)
 
 
 
 
122
 
123
  with gr.Row():
124
+ gr.Examples(
125
+ examples=[
126
+ ["cartoon", 26],
127
+ ["caricature", 65],
128
+ ["arcane", 63],
129
+ ["pixar", 80],
130
+ ],
131
+ inputs=[style_type, style_index],
132
+ )
133
 
134
  with gr.Box():
135
+ gr.Markdown(
136
+ """## Step 3 (Generate Style Transferred Image)
137
 
138
  - Adjust **Structure Weight** and **Color Weight**.
139
  - These are weights for the style image, so the larger the value, the closer the resulting image will be to the style image.
140
  - Hit the **Generate** button.
141
+ """
142
+ )
143
  with gr.Row():
144
  with gr.Column():
145
  with gr.Row():
146
+ structure_weight = gr.Slider(label="Structure Weight", minimum=0, maximum=1, step=0.1, value=0.6)
 
 
 
 
147
  with gr.Row():
148
+ color_weight = gr.Slider(label="Color Weight", minimum=0, maximum=1, step=0.1, value=1)
 
 
 
 
149
  with gr.Row():
150
+ structure_only = gr.Checkbox(label="Structure Only")
151
  with gr.Row():
152
+ generate_button = gr.Button("Generate")
153
 
154
  with gr.Column():
155
+ result = gr.Image(label="Result")
156
 
157
  with gr.Row():
158
+ gr.Examples(
159
+ examples=[
160
+ [0.6, 1.0],
161
+ [0.3, 1.0],
162
+ [0.0, 1.0],
163
+ [1.0, 0.0],
164
+ ],
165
+ inputs=[structure_weight, color_weight],
166
+ )
167
+
168
+ detect_button.click(fn=model.detect_and_align_face, inputs=input_image, outputs=aligned_face)
169
+ reconstruct_button.click(fn=model.reconstruct_face, inputs=aligned_face, outputs=[reconstructed_face, instyle])
 
 
170
  style_type.change(fn=update_slider, inputs=style_type, outputs=style_index)
171
+ style_type.change(fn=update_style_image, inputs=style_type, outputs=style_image)
172
+ generate_button.click(
173
+ fn=model.generate,
174
+ inputs=[
175
+ style_type,
176
+ style_index,
177
+ structure_weight,
178
+ color_weight,
179
+ structure_only,
180
+ instyle,
181
+ ],
182
+ outputs=result,
183
+ )
184
  demo.queue(max_size=10).launch()
dualstylegan.py CHANGED
@@ -16,12 +16,12 @@ import torch
16
  import torch.nn as nn
17
  import torchvision.transforms as T
18
 
19
- if os.getenv('SYSTEM') == 'spaces' and not torch.cuda.is_available():
20
- with open('patch') as f:
21
- subprocess.run(shlex.split('patch -p1'), cwd='DualStyleGAN', stdin=f)
22
 
23
  app_dir = pathlib.Path(__file__).parent
24
- submodule_dir = app_dir / 'DualStyleGAN'
25
  sys.path.insert(0, submodule_dir.as_posix())
26
 
27
  from model.dualstylegan import DualStyleGAN
@@ -31,44 +31,36 @@ from model.encoder.psp import pSp
31
 
32
  class Model:
33
  def __init__(self):
34
- self.device = torch.device(
35
- 'cuda:0' if torch.cuda.is_available() else 'cpu')
36
  self.landmark_model = self._create_dlib_landmark_model()
37
  self.encoder = self._load_encoder()
38
  self.transform = self._create_transform()
39
 
40
  self.style_types = [
41
- 'cartoon',
42
- 'caricature',
43
- 'anime',
44
- 'arcane',
45
- 'comic',
46
- 'pixar',
47
- 'slamdunk',
48
  ]
49
- self.generator_dict = {
50
- style_type: self._load_generator(style_type)
51
- for style_type in self.style_types
52
- }
53
- self.exstyle_dict = {
54
- style_type: self._load_exstylecode(style_type)
55
- for style_type in self.style_types
56
- }
57
 
58
  @staticmethod
59
  def _create_dlib_landmark_model():
60
  path = huggingface_hub.hf_hub_download(
61
- 'public-data/dlib_face_landmark_model',
62
- 'shape_predictor_68_face_landmarks.dat')
63
  return dlib.shape_predictor(path)
64
 
65
  def _load_encoder(self) -> nn.Module:
66
- ckpt_path = huggingface_hub.hf_hub_download('public-data/DualStyleGAN',
67
- 'models/encoder.pt')
68
- ckpt = torch.load(ckpt_path, map_location='cpu')
69
- opts = ckpt['opts']
70
- opts['device'] = self.device.type
71
- opts['checkpoint_path'] = ckpt_path
72
  opts = argparse.Namespace(**opts)
73
  model = pSp(opts)
74
  model.to(self.device)
@@ -77,32 +69,32 @@ class Model:
77
 
78
  @staticmethod
79
  def _create_transform() -> Callable:
80
- transform = T.Compose([
81
- T.Resize(256),
82
- T.CenterCrop(256),
83
- T.ToTensor(),
84
- T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
85
- ])
 
 
86
  return transform
87
 
88
  def _load_generator(self, style_type: str) -> nn.Module:
89
  model = DualStyleGAN(1024, 512, 8, 2, res_index=6)
90
- ckpt_path = huggingface_hub.hf_hub_download(
91
- 'public-data/DualStyleGAN', f'models/{style_type}/generator.pt')
92
- ckpt = torch.load(ckpt_path, map_location='cpu')
93
- model.load_state_dict(ckpt['g_ema'])
94
  model.to(self.device)
95
  model.eval()
96
  return model
97
 
98
  @staticmethod
99
  def _load_exstylecode(style_type: str) -> dict[str, np.ndarray]:
100
- if style_type in ['cartoon', 'caricature', 'anime']:
101
- filename = 'refined_exstyle_code.npy'
102
  else:
103
- filename = 'exstyle_code.npy'
104
- path = huggingface_hub.hf_hub_download(
105
- 'public-data/DualStyleGAN', f'models/{style_type}/{filename}')
106
  exstyles = np.load(path, allow_pickle=True).item()
107
  return exstyles
108
 
@@ -119,24 +111,31 @@ class Model:
119
  return tensor.cpu().numpy().transpose(1, 2, 0)
120
 
121
  @torch.inference_mode()
122
- def reconstruct_face(self,
123
- image: np.ndarray) -> tuple[np.ndarray, torch.Tensor]:
124
  image = PIL.Image.fromarray(image)
125
  input_data = self.transform(image).unsqueeze(0).to(self.device)
126
- img_rec, instyle = self.encoder(input_data,
127
- randomize_noise=False,
128
- return_latents=True,
129
- z_plus_latent=True,
130
- return_z_plus_latent=True,
131
- resize=False)
 
 
132
  img_rec = torch.clamp(img_rec.detach(), -1, 1)
133
  img_rec = self.postprocess(img_rec[0])
134
  return img_rec, instyle
135
 
136
  @torch.inference_mode()
137
- def generate(self, style_type: str, style_id: int, structure_weight: float,
138
- color_weight: float, structure_only: bool,
139
- instyle: torch.Tensor) -> np.ndarray:
 
 
 
 
 
 
140
  generator = self.generator_dict[style_type]
141
  exstyles = self.exstyle_dict[style_type]
142
 
@@ -147,17 +146,18 @@ class Model:
147
  if structure_only:
148
  latent[0, 7:18] = instyle[0, 7:18]
149
  exstyle = generator.generator.style(
150
- latent.reshape(latent.shape[0] * latent.shape[1],
151
- latent.shape[2])).reshape(latent.shape)
152
-
153
- img_gen, _ = generator([instyle],
154
- exstyle,
155
- z_plus_latent=True,
156
- truncation=0.7,
157
- truncation_latent=0,
158
- use_res=True,
159
- interp_weights=[structure_weight] * 7 +
160
- [color_weight] * 11)
 
161
  img_gen = torch.clamp(img_gen.detach(), -1, 1)
162
  img_gen = self.postprocess(img_gen[0])
163
  return img_gen
 
16
  import torch.nn as nn
17
  import torchvision.transforms as T
18
 
19
+ if os.getenv("SYSTEM") == "spaces" and not torch.cuda.is_available():
20
+ with open("patch") as f:
21
+ subprocess.run(shlex.split("patch -p1"), cwd="DualStyleGAN", stdin=f)
22
 
23
  app_dir = pathlib.Path(__file__).parent
24
+ submodule_dir = app_dir / "DualStyleGAN"
25
  sys.path.insert(0, submodule_dir.as_posix())
26
 
27
  from model.dualstylegan import DualStyleGAN
 
31
 
32
  class Model:
33
  def __init__(self):
34
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
35
  self.landmark_model = self._create_dlib_landmark_model()
36
  self.encoder = self._load_encoder()
37
  self.transform = self._create_transform()
38
 
39
  self.style_types = [
40
+ "cartoon",
41
+ "caricature",
42
+ "anime",
43
+ "arcane",
44
+ "comic",
45
+ "pixar",
46
+ "slamdunk",
47
  ]
48
+ self.generator_dict = {style_type: self._load_generator(style_type) for style_type in self.style_types}
49
+ self.exstyle_dict = {style_type: self._load_exstylecode(style_type) for style_type in self.style_types}
 
 
 
 
 
 
50
 
51
  @staticmethod
52
  def _create_dlib_landmark_model():
53
  path = huggingface_hub.hf_hub_download(
54
+ "public-data/dlib_face_landmark_model", "shape_predictor_68_face_landmarks.dat"
55
+ )
56
  return dlib.shape_predictor(path)
57
 
58
  def _load_encoder(self) -> nn.Module:
59
+ ckpt_path = huggingface_hub.hf_hub_download("public-data/DualStyleGAN", "models/encoder.pt")
60
+ ckpt = torch.load(ckpt_path, map_location="cpu")
61
+ opts = ckpt["opts"]
62
+ opts["device"] = self.device.type
63
+ opts["checkpoint_path"] = ckpt_path
 
64
  opts = argparse.Namespace(**opts)
65
  model = pSp(opts)
66
  model.to(self.device)
 
69
 
70
  @staticmethod
71
  def _create_transform() -> Callable:
72
+ transform = T.Compose(
73
+ [
74
+ T.Resize(256),
75
+ T.CenterCrop(256),
76
+ T.ToTensor(),
77
+ T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
78
+ ]
79
+ )
80
  return transform
81
 
82
  def _load_generator(self, style_type: str) -> nn.Module:
83
  model = DualStyleGAN(1024, 512, 8, 2, res_index=6)
84
+ ckpt_path = huggingface_hub.hf_hub_download("public-data/DualStyleGAN", f"models/{style_type}/generator.pt")
85
+ ckpt = torch.load(ckpt_path, map_location="cpu")
86
+ model.load_state_dict(ckpt["g_ema"])
 
87
  model.to(self.device)
88
  model.eval()
89
  return model
90
 
91
  @staticmethod
92
  def _load_exstylecode(style_type: str) -> dict[str, np.ndarray]:
93
+ if style_type in ["cartoon", "caricature", "anime"]:
94
+ filename = "refined_exstyle_code.npy"
95
  else:
96
+ filename = "exstyle_code.npy"
97
+ path = huggingface_hub.hf_hub_download("public-data/DualStyleGAN", f"models/{style_type}/{filename}")
 
98
  exstyles = np.load(path, allow_pickle=True).item()
99
  return exstyles
100
 
 
111
  return tensor.cpu().numpy().transpose(1, 2, 0)
112
 
113
  @torch.inference_mode()
114
+ def reconstruct_face(self, image: np.ndarray) -> tuple[np.ndarray, torch.Tensor]:
 
115
  image = PIL.Image.fromarray(image)
116
  input_data = self.transform(image).unsqueeze(0).to(self.device)
117
+ img_rec, instyle = self.encoder(
118
+ input_data,
119
+ randomize_noise=False,
120
+ return_latents=True,
121
+ z_plus_latent=True,
122
+ return_z_plus_latent=True,
123
+ resize=False,
124
+ )
125
  img_rec = torch.clamp(img_rec.detach(), -1, 1)
126
  img_rec = self.postprocess(img_rec[0])
127
  return img_rec, instyle
128
 
129
  @torch.inference_mode()
130
+ def generate(
131
+ self,
132
+ style_type: str,
133
+ style_id: int,
134
+ structure_weight: float,
135
+ color_weight: float,
136
+ structure_only: bool,
137
+ instyle: torch.Tensor,
138
+ ) -> np.ndarray:
139
  generator = self.generator_dict[style_type]
140
  exstyles = self.exstyle_dict[style_type]
141
 
 
146
  if structure_only:
147
  latent[0, 7:18] = instyle[0, 7:18]
148
  exstyle = generator.generator.style(
149
+ latent.reshape(latent.shape[0] * latent.shape[1], latent.shape[2])
150
+ ).reshape(latent.shape)
151
+
152
+ img_gen, _ = generator(
153
+ [instyle],
154
+ exstyle,
155
+ z_plus_latent=True,
156
+ truncation=0.7,
157
+ truncation_latent=0,
158
+ use_res=True,
159
+ interp_weights=[structure_weight] * 7 + [color_weight] * 11,
160
+ )
161
  img_gen = torch.clamp(img_gen.detach(), -1, 1)
162
  img_gen = self.postprocess(img_gen[0])
163
  return img_gen
images/README.md CHANGED
@@ -4,4 +4,3 @@ These images are freely-usable ones from [Unsplash](https://unsplash.com/).
4
  - https://unsplash.com/photos/et_78QkMMQs
5
  - https://unsplash.com/photos/ILip77SbmOE
6
  - https://unsplash.com/photos/95UF6LXe-Lo
7
-
 
4
  - https://unsplash.com/photos/et_78QkMMQs
5
  - https://unsplash.com/photos/ILip77SbmOE
6
  - https://unsplash.com/photos/95UF6LXe-Lo
 
style.css CHANGED
@@ -1,5 +1,6 @@
1
  h1 {
2
  text-align: center;
 
3
  }
4
  img#overview {
5
  max-width: 800px;
@@ -11,7 +12,3 @@ img#style-image {
11
  max-width: 1000px;
12
  max-height: 600px;
13
  }
14
- img#visitor-badge {
15
- display: block;
16
- margin: auto;
17
- }
 
1
  h1 {
2
  text-align: center;
3
+ display: block;
4
  }
5
  img#overview {
6
  max-width: 800px;
 
12
  max-width: 1000px;
13
  max-height: 600px;
14
  }