hysts HF staff commited on
Commit
d2eda3f
1 Parent(s): c627119
Files changed (5) hide show
  1. .pre-commit-config.yaml +46 -0
  2. .style.yapf +5 -0
  3. app.py +9 -22
  4. dualstylegan.py +5 -6
  5. style.css +17 -0
.pre-commit-config.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: ^DualStyleGAN
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.10.1
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.812
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ - repo: https://github.com/google/yapf
33
+ rev: v0.32.0
34
+ hooks:
35
+ - id: yapf
36
+ args: ['--parallel', '--in-place']
37
+ - repo: https://github.com/kynan/nbstripout
38
+ rev: 0.5.0
39
+ hooks:
40
+ - id: nbstripout
41
+ args: ['--extra-keys', 'metadata.interpreter metadata.kernelspec cell.metadata.pycharm']
42
+ - repo: https://github.com/nbQA-dev/nbQA
43
+ rev: 1.3.1
44
+ hooks:
45
+ - id: nbqa-isort
46
+ - id: nbqa-yapf
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
app.py CHANGED
@@ -9,6 +9,12 @@ import gradio as gr
9
 
10
  from dualstylegan import Model
11
 
 
 
 
 
 
 
12
 
13
  def parse_args() -> argparse.Namespace:
14
  parser = argparse.ArgumentParser()
@@ -81,26 +87,9 @@ def main():
81
  args = parse_args()
82
  model = Model(device=args.device)
83
 
84
- css = '''
85
- h1#title {
86
- text-align: center;
87
- }
88
- img#overview {
89
- max-width: 800px;
90
- max-height: 600px;
91
- }
92
- img#style-image {
93
- max-width: 1000px;
94
- max-height: 600px;
95
- }
96
- '''
97
-
98
- with gr.Blocks(theme=args.theme, css=css) as demo:
99
- gr.Markdown(
100
- '''<h1 id="title">Portrait Style Transfer with <a href="https://github.com/williamyang1991/DualStyleGAN">DualStyleGAN</a></h1>
101
 
102
- <center><img id="overview" src="https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/overview.jpg" alt="overview"></center>
103
- ''')
104
  with gr.Box():
105
  gr.Markdown('''## Step 1 (Preprocess Input Image)
106
 
@@ -202,9 +191,7 @@ img#style-image {
202
  [1.0, 0.0],
203
  ])
204
 
205
- gr.Markdown(
206
- '<center><img src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.dualstylegan" alt="visitor badge"/></center>'
207
- )
208
 
209
  detect_button.click(fn=model.detect_and_align_face,
210
  inputs=input_image,
 
9
 
10
  from dualstylegan import Model
11
 
12
+ DESCRIPTION = '''# Portrait Style Transfer with <a href="https://github.com/williamyang1991/DualStyleGAN">DualStyleGAN</a>
13
+
14
+ <img id="overview" alt="overview" src="https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/overview.jpg" />
15
+ '''
16
+ FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.dualstylegan" />'
17
+
18
 
19
  def parse_args() -> argparse.Namespace:
20
  parser = argparse.ArgumentParser()
 
87
  args = parse_args()
88
  model = Model(device=args.device)
89
 
90
+ with gr.Blocks(theme=args.theme, css='style.css') as demo:
91
+ gr.Markdown(DESCRIPTION)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
 
 
93
  with gr.Box():
94
  gr.Markdown('''## Step 1 (Preprocess Input Image)
95
 
 
191
  [1.0, 0.0],
192
  ])
193
 
194
+ gr.Markdown(FOOTER)
 
 
195
 
196
  detect_button.click(fn=model.detect_and_align_face,
197
  inputs=input_image,
dualstylegan.py CHANGED
@@ -23,12 +23,11 @@ from model.dualstylegan import DualStyleGAN
23
  from model.encoder.align_all_parallel import align_face
24
  from model.encoder.psp import pSp
25
 
26
- TOKEN = os.environ['TOKEN']
27
  MODEL_REPO = 'hysts/DualStyleGAN'
28
 
29
 
30
  class Model:
31
-
32
  def __init__(self, device: Union[torch.device, str]):
33
  self.device = torch.device(device)
34
  self.landmark_model = self._create_dlib_landmark_model()
@@ -58,13 +57,13 @@ class Model:
58
  path = huggingface_hub.hf_hub_download(
59
  'hysts/dlib_face_landmark_model',
60
  'shape_predictor_68_face_landmarks.dat',
61
- use_auth_token=TOKEN)
62
  return dlib.shape_predictor(path)
63
 
64
  def _load_encoder(self) -> nn.Module:
65
  ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO,
66
  'models/encoder.pt',
67
- use_auth_token=TOKEN)
68
  ckpt = torch.load(ckpt_path, map_location='cpu')
69
  opts = ckpt['opts']
70
  opts['device'] = self.device.type
@@ -90,7 +89,7 @@ class Model:
90
  ckpt_path = huggingface_hub.hf_hub_download(
91
  MODEL_REPO,
92
  f'models/{style_type}/generator.pt',
93
- use_auth_token=TOKEN)
94
  ckpt = torch.load(ckpt_path, map_location='cpu')
95
  model.load_state_dict(ckpt['g_ema'])
96
  model.to(self.device)
@@ -106,7 +105,7 @@ class Model:
106
  path = huggingface_hub.hf_hub_download(
107
  MODEL_REPO,
108
  f'models/{style_type}/{filename}',
109
- use_auth_token=TOKEN)
110
  exstyles = np.load(path, allow_pickle=True).item()
111
  return exstyles
112
 
 
23
  from model.encoder.align_all_parallel import align_face
24
  from model.encoder.psp import pSp
25
 
26
+ HF_TOKEN = os.environ['HF_TOKEN']
27
  MODEL_REPO = 'hysts/DualStyleGAN'
28
 
29
 
30
  class Model:
 
31
  def __init__(self, device: Union[torch.device, str]):
32
  self.device = torch.device(device)
33
  self.landmark_model = self._create_dlib_landmark_model()
 
57
  path = huggingface_hub.hf_hub_download(
58
  'hysts/dlib_face_landmark_model',
59
  'shape_predictor_68_face_landmarks.dat',
60
+ use_auth_token=HF_TOKEN)
61
  return dlib.shape_predictor(path)
62
 
63
  def _load_encoder(self) -> nn.Module:
64
  ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO,
65
  'models/encoder.pt',
66
+ use_auth_token=HF_TOKEN)
67
  ckpt = torch.load(ckpt_path, map_location='cpu')
68
  opts = ckpt['opts']
69
  opts['device'] = self.device.type
 
89
  ckpt_path = huggingface_hub.hf_hub_download(
90
  MODEL_REPO,
91
  f'models/{style_type}/generator.pt',
92
+ use_auth_token=HF_TOKEN)
93
  ckpt = torch.load(ckpt_path, map_location='cpu')
94
  model.load_state_dict(ckpt['g_ema'])
95
  model.to(self.device)
 
105
  path = huggingface_hub.hf_hub_download(
106
  MODEL_REPO,
107
  f'models/{style_type}/{filename}',
108
+ use_auth_token=HF_TOKEN)
109
  exstyles = np.load(path, allow_pickle=True).item()
110
  return exstyles
111
 
style.css ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+ img#overview {
5
+ max-width: 800px;
6
+ max-height: 600px;
7
+ display: block;
8
+ margin: auto;
9
+ }
10
+ img#style-image {
11
+ max-width: 1000px;
12
+ max-height: 600px;
13
+ }
14
+ img#visitor-badge {
15
+ display: block;
16
+ margin: auto;
17
+ }