hysts HF staff commited on
Commit
33cef83
1 Parent(s): 2db2246

Split file

Browse files
Files changed (2) hide show
  1. app.py +7 -163
  2. dualstylegan.py +167 -0
app.py CHANGED
@@ -5,27 +5,10 @@ from __future__ import annotations
5
  import argparse
6
  import os
7
  import pathlib
8
- import sys
9
- from typing import Callable
10
 
11
- import dlib
12
  import gradio as gr
13
- import huggingface_hub
14
- import numpy as np
15
- import PIL.Image
16
- import torch
17
- import torch.nn as nn
18
- import torchvision.transforms as T
19
 
20
- if os.environ.get('SYSTEM') == 'spaces':
21
- os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/fused_act.py")
22
- os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/upfirdn2d.py")
23
-
24
- sys.path.insert(0, 'DualStyleGAN')
25
-
26
- from model.dualstylegan import DualStyleGAN
27
- from model.encoder.align_all_parallel import align_face
28
- from model.encoder.psp import pSp
29
 
30
  TOKEN = os.environ['TOKEN']
31
  MODEL_REPO = 'hysts/DualStyleGAN'
@@ -43,146 +26,6 @@ def parse_args() -> argparse.Namespace:
43
  return parser.parse_args()
44
 
45
 
46
- class App:
47
-
48
- def __init__(self, device: torch.device):
49
- self.device = device
50
- self.landmark_model = self._create_dlib_landmark_model()
51
- self.encoder = self._load_encoder()
52
- self.transform = self._create_transform()
53
-
54
- self.style_types = [
55
- 'cartoon',
56
- 'caricature',
57
- 'anime',
58
- 'arcane',
59
- 'comic',
60
- 'pixar',
61
- 'slamdunk',
62
- ]
63
- self.generator_dict = {
64
- style_type: self._load_generator(style_type)
65
- for style_type in self.style_types
66
- }
67
- self.exstyle_dict = {
68
- style_type: self._load_exstylecode(style_type)
69
- for style_type in self.style_types
70
- }
71
-
72
- @staticmethod
73
- def _create_dlib_landmark_model():
74
- path = huggingface_hub.hf_hub_download(
75
- 'hysts/dlib_face_landmark_model',
76
- 'shape_predictor_68_face_landmarks.dat',
77
- use_auth_token=TOKEN)
78
- return dlib.shape_predictor(path)
79
-
80
- def _load_encoder(self) -> nn.Module:
81
- ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO,
82
- 'models/encoder.pt',
83
- use_auth_token=TOKEN)
84
- ckpt = torch.load(ckpt_path, map_location='cpu')
85
- opts = ckpt['opts']
86
- opts['device'] = self.device.type
87
- opts['checkpoint_path'] = ckpt_path
88
- opts = argparse.Namespace(**opts)
89
- model = pSp(opts)
90
- model.to(self.device)
91
- model.eval()
92
- return model
93
-
94
- @staticmethod
95
- def _create_transform() -> Callable:
96
- transform = T.Compose([
97
- T.Resize(256),
98
- T.CenterCrop(256),
99
- T.ToTensor(),
100
- T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
101
- ])
102
- return transform
103
-
104
- def _load_generator(self, style_type: str) -> nn.Module:
105
- model = DualStyleGAN(1024, 512, 8, 2, res_index=6)
106
- ckpt_path = huggingface_hub.hf_hub_download(
107
- MODEL_REPO,
108
- f'models/{style_type}/generator.pt',
109
- use_auth_token=TOKEN)
110
- ckpt = torch.load(ckpt_path, map_location='cpu')
111
- model.load_state_dict(ckpt['g_ema'])
112
- model.to(self.device)
113
- model.eval()
114
- return model
115
-
116
- @staticmethod
117
- def _load_exstylecode(style_type: str) -> dict[str, np.ndarray]:
118
- if style_type in ['cartoon', 'caricature', 'anime']:
119
- filename = 'refined_exstyle_code.npy'
120
- else:
121
- filename = 'exstyle_code.npy'
122
- path = huggingface_hub.hf_hub_download(
123
- MODEL_REPO,
124
- f'models/{style_type}/{filename}',
125
- use_auth_token=TOKEN)
126
- exstyles = np.load(path, allow_pickle=True).item()
127
- return exstyles
128
-
129
- def detect_and_align_face(self, image) -> np.ndarray:
130
- image = align_face(filepath=image.name, predictor=self.landmark_model)
131
- return image
132
-
133
- @staticmethod
134
- def denormalize(tensor: torch.Tensor) -> torch.Tensor:
135
- return torch.clamp((tensor + 1) / 2 * 255, 0, 255).to(torch.uint8)
136
-
137
- def postprocess(self, tensor: torch.Tensor) -> np.ndarray:
138
- tensor = self.denormalize(tensor)
139
- return tensor.cpu().numpy().transpose(1, 2, 0)
140
-
141
- @torch.inference_mode()
142
- def reconstruct_face(self,
143
- image: np.ndarray) -> tuple[np.ndarray, torch.Tensor]:
144
- image = PIL.Image.fromarray(image)
145
- input_data = self.transform(image).unsqueeze(0).to(self.device)
146
- img_rec, instyle = self.encoder(input_data,
147
- randomize_noise=False,
148
- return_latents=True,
149
- z_plus_latent=True,
150
- return_z_plus_latent=True,
151
- resize=False)
152
- img_rec = torch.clamp(img_rec.detach(), -1, 1)
153
- img_rec = self.postprocess(img_rec[0])
154
- return img_rec, instyle
155
-
156
- @torch.inference_mode()
157
- def generate(self, style_type: str, style_id: int, structure_weight: float,
158
- color_weight: float, structure_only: bool,
159
- instyle: torch.Tensor) -> np.ndarray:
160
- generator = self.generator_dict[style_type]
161
- exstyles = self.exstyle_dict[style_type]
162
-
163
- style_id = int(style_id)
164
- stylename = list(exstyles.keys())[style_id]
165
-
166
- latent = torch.tensor(exstyles[stylename]).to(self.device)
167
- if structure_only:
168
- latent[0, 7:18] = instyle[0, 7:18]
169
- exstyle = generator.generator.style(
170
- latent.reshape(latent.shape[0] * latent.shape[1],
171
- latent.shape[2])).reshape(latent.shape)
172
-
173
- img_gen, _ = generator([instyle],
174
- exstyle,
175
- z_plus_latent=True,
176
- truncation=0.7,
177
- truncation_latent=0,
178
- use_res=True,
179
- interp_weights=[structure_weight] * 7 +
180
- [color_weight] * 11)
181
- img_gen = torch.clamp(img_gen.detach(), -1, 1)
182
- img_gen = self.postprocess(img_gen[0])
183
- return img_gen
184
-
185
-
186
  def get_style_image_url(style_name: str) -> str:
187
  base_url = 'https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images'
188
  filenames = {
@@ -240,7 +83,7 @@ def set_example_weights(example: list) -> list[dict]:
240
 
241
  def main():
242
  args = parse_args()
243
- app = App(device=torch.device(args.device))
244
 
245
  css = '''
246
  h1#title {
@@ -304,7 +147,8 @@ img#style-image {
304
  ''')
305
  with gr.Row():
306
  with gr.Column():
307
- style_type = gr.Radio(app.style_types, label='Style Type')
 
308
  text = get_style_image_markdown_text('cartoon')
309
  style_image = gr.Markdown(value=text)
310
  style_index = gr.Slider(0,
@@ -366,10 +210,10 @@ img#style-image {
366
  '<center><img src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.dualstylegan" alt="visitor badge"/></center>'
367
  )
368
 
369
- detect_button.click(fn=app.detect_and_align_face,
370
  inputs=input_image,
371
  outputs=aligned_face)
372
- reconstruct_button.click(fn=app.reconstruct_face,
373
  inputs=aligned_face,
374
  outputs=[reconstructed_face, instyle])
375
  style_type.change(fn=update_slider,
@@ -378,7 +222,7 @@ img#style-image {
378
  style_type.change(fn=update_style_image,
379
  inputs=style_type,
380
  outputs=style_image)
381
- generate_button.click(fn=app.generate,
382
  inputs=[
383
  style_type,
384
  style_index,
 
5
  import argparse
6
  import os
7
  import pathlib
 
 
8
 
 
9
  import gradio as gr
 
 
 
 
 
 
10
 
11
+ from dualstylegan import Model
 
 
 
 
 
 
 
 
12
 
13
  TOKEN = os.environ['TOKEN']
14
  MODEL_REPO = 'hysts/DualStyleGAN'
 
26
  return parser.parse_args()
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def get_style_image_url(style_name: str) -> str:
30
  base_url = 'https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images'
31
  filenames = {
 
83
 
84
  def main():
85
  args = parse_args()
86
+ model = Model(device=args.device)
87
 
88
  css = '''
89
  h1#title {
 
147
  ''')
148
  with gr.Row():
149
  with gr.Column():
150
+ style_type = gr.Radio(model.style_types,
151
+ label='Style Type')
152
  text = get_style_image_markdown_text('cartoon')
153
  style_image = gr.Markdown(value=text)
154
  style_index = gr.Slider(0,
 
210
  '<center><img src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.dualstylegan" alt="visitor badge"/></center>'
211
  )
212
 
213
+ detect_button.click(fn=model.detect_and_align_face,
214
  inputs=input_image,
215
  outputs=aligned_face)
216
+ reconstruct_button.click(fn=model.reconstruct_face,
217
  inputs=aligned_face,
218
  outputs=[reconstructed_face, instyle])
219
  style_type.change(fn=update_slider,
 
222
  style_type.change(fn=update_style_image,
223
  inputs=style_type,
224
  outputs=style_image)
225
+ generate_button.click(fn=model.generate,
226
  inputs=[
227
  style_type,
228
  style_index,
dualstylegan.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import os
5
+ import sys
6
+ from typing import Callable, Union
7
+
8
+ import dlib
9
+ import huggingface_hub
10
+ import numpy as np
11
+ import PIL.Image
12
+ import torch
13
+ import torch.nn as nn
14
+ import torchvision.transforms as T
15
+
16
+ if os.environ.get('SYSTEM') == 'spaces':
17
+ os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/fused_act.py")
18
+ os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/upfirdn2d.py")
19
+
20
+ sys.path.insert(0, 'DualStyleGAN')
21
+
22
+ from model.dualstylegan import DualStyleGAN
23
+ from model.encoder.align_all_parallel import align_face
24
+ from model.encoder.psp import pSp
25
+
26
+ TOKEN = os.environ['TOKEN']
27
+ MODEL_REPO = 'hysts/DualStyleGAN'
28
+
29
+
30
+ class Model:
31
+
32
+ def __init__(self, device: Union[torch.device, str]):
33
+ self.device = torch.device(device)
34
+ self.landmark_model = self._create_dlib_landmark_model()
35
+ self.encoder = self._load_encoder()
36
+ self.transform = self._create_transform()
37
+
38
+ self.style_types = [
39
+ 'cartoon',
40
+ 'caricature',
41
+ 'anime',
42
+ 'arcane',
43
+ 'comic',
44
+ 'pixar',
45
+ 'slamdunk',
46
+ ]
47
+ self.generator_dict = {
48
+ style_type: self._load_generator(style_type)
49
+ for style_type in self.style_types
50
+ }
51
+ self.exstyle_dict = {
52
+ style_type: self._load_exstylecode(style_type)
53
+ for style_type in self.style_types
54
+ }
55
+
56
+ @staticmethod
57
+ def _create_dlib_landmark_model():
58
+ path = huggingface_hub.hf_hub_download(
59
+ 'hysts/dlib_face_landmark_model',
60
+ 'shape_predictor_68_face_landmarks.dat',
61
+ use_auth_token=TOKEN)
62
+ return dlib.shape_predictor(path)
63
+
64
+ def _load_encoder(self) -> nn.Module:
65
+ ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO,
66
+ 'models/encoder.pt',
67
+ use_auth_token=TOKEN)
68
+ ckpt = torch.load(ckpt_path, map_location='cpu')
69
+ opts = ckpt['opts']
70
+ opts['device'] = self.device.type
71
+ opts['checkpoint_path'] = ckpt_path
72
+ opts = argparse.Namespace(**opts)
73
+ model = pSp(opts)
74
+ model.to(self.device)
75
+ model.eval()
76
+ return model
77
+
78
+ @staticmethod
79
+ def _create_transform() -> Callable:
80
+ transform = T.Compose([
81
+ T.Resize(256),
82
+ T.CenterCrop(256),
83
+ T.ToTensor(),
84
+ T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
85
+ ])
86
+ return transform
87
+
88
+ def _load_generator(self, style_type: str) -> nn.Module:
89
+ model = DualStyleGAN(1024, 512, 8, 2, res_index=6)
90
+ ckpt_path = huggingface_hub.hf_hub_download(
91
+ MODEL_REPO,
92
+ f'models/{style_type}/generator.pt',
93
+ use_auth_token=TOKEN)
94
+ ckpt = torch.load(ckpt_path, map_location='cpu')
95
+ model.load_state_dict(ckpt['g_ema'])
96
+ model.to(self.device)
97
+ model.eval()
98
+ return model
99
+
100
+ @staticmethod
101
+ def _load_exstylecode(style_type: str) -> dict[str, np.ndarray]:
102
+ if style_type in ['cartoon', 'caricature', 'anime']:
103
+ filename = 'refined_exstyle_code.npy'
104
+ else:
105
+ filename = 'exstyle_code.npy'
106
+ path = huggingface_hub.hf_hub_download(
107
+ MODEL_REPO,
108
+ f'models/{style_type}/{filename}',
109
+ use_auth_token=TOKEN)
110
+ exstyles = np.load(path, allow_pickle=True).item()
111
+ return exstyles
112
+
113
+ def detect_and_align_face(self, image) -> np.ndarray:
114
+ image = align_face(filepath=image.name, predictor=self.landmark_model)
115
+ return image
116
+
117
+ @staticmethod
118
+ def denormalize(tensor: torch.Tensor) -> torch.Tensor:
119
+ return torch.clamp((tensor + 1) / 2 * 255, 0, 255).to(torch.uint8)
120
+
121
+ def postprocess(self, tensor: torch.Tensor) -> np.ndarray:
122
+ tensor = self.denormalize(tensor)
123
+ return tensor.cpu().numpy().transpose(1, 2, 0)
124
+
125
+ @torch.inference_mode()
126
+ def reconstruct_face(self,
127
+ image: np.ndarray) -> tuple[np.ndarray, torch.Tensor]:
128
+ image = PIL.Image.fromarray(image)
129
+ input_data = self.transform(image).unsqueeze(0).to(self.device)
130
+ img_rec, instyle = self.encoder(input_data,
131
+ randomize_noise=False,
132
+ return_latents=True,
133
+ z_plus_latent=True,
134
+ return_z_plus_latent=True,
135
+ resize=False)
136
+ img_rec = torch.clamp(img_rec.detach(), -1, 1)
137
+ img_rec = self.postprocess(img_rec[0])
138
+ return img_rec, instyle
139
+
140
+ @torch.inference_mode()
141
+ def generate(self, style_type: str, style_id: int, structure_weight: float,
142
+ color_weight: float, structure_only: bool,
143
+ instyle: torch.Tensor) -> np.ndarray:
144
+ generator = self.generator_dict[style_type]
145
+ exstyles = self.exstyle_dict[style_type]
146
+
147
+ style_id = int(style_id)
148
+ stylename = list(exstyles.keys())[style_id]
149
+
150
+ latent = torch.tensor(exstyles[stylename]).to(self.device)
151
+ if structure_only:
152
+ latent[0, 7:18] = instyle[0, 7:18]
153
+ exstyle = generator.generator.style(
154
+ latent.reshape(latent.shape[0] * latent.shape[1],
155
+ latent.shape[2])).reshape(latent.shape)
156
+
157
+ img_gen, _ = generator([instyle],
158
+ exstyle,
159
+ z_plus_latent=True,
160
+ truncation=0.7,
161
+ truncation_latent=0,
162
+ use_res=True,
163
+ interp_weights=[structure_weight] * 7 +
164
+ [color_weight] * 11)
165
+ img_gen = torch.clamp(img_gen.detach(), -1, 1)
166
+ img_gen = self.postprocess(img_gen[0])
167
+ return img_gen