hysts HF staff commited on
Commit
7ca00ae
1 Parent(s): 9b1f027
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +18 -14
  3. requirements.txt +3 -3
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🏢
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
10
  suggested_hardware: t4-small
 
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 4.36.0
8
  app_file: app.py
9
  pinned: false
10
  suggested_hardware: t4-small
app.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  from __future__ import annotations
4
 
5
- import functools
6
  import pickle
7
  import sys
8
 
@@ -31,19 +30,25 @@ def load_model(file_name: str, device: torch.device) -> nn.Module:
31
  return model
32
 
33
 
34
- def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
35
- return torch.from_numpy(np.random.RandomState(seed).randn(1, z_dim)).to(device).float()
 
 
 
 
36
 
37
 
38
  @torch.inference_mode()
39
  def generate_interpolated_images(
40
- seed0: int, psi0: float, seed1: int, psi1: float, num_intermediate: int, model: nn.Module, device: torch.device
41
  ) -> list[np.ndarray]:
42
  seed0 = int(np.clip(seed0, 0, np.iinfo(np.uint32).max))
43
  seed1 = int(np.clip(seed1, 0, np.iinfo(np.uint32).max))
44
 
45
- z0 = generate_z(model.z_dim, seed0, device)
46
- z1 = generate_z(model.z_dim, seed1, device)
 
 
47
  vec = z1 - z0
48
  dvec = vec / (num_intermediate + 1)
49
  zs = [z0 + dvec * i for i in range(num_intermediate + 2)]
@@ -61,12 +66,8 @@ def generate_interpolated_images(
61
  return res
62
 
63
 
64
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
65
- model = load_model("stylegan_human_v2_1024.pkl", device)
66
- fn = functools.partial(generate_interpolated_images, model=model, device=device)
67
-
68
- gr.Interface(
69
- fn=fn,
70
  inputs=[
71
  gr.Slider(label="Seed 1", minimum=0, maximum=100000, step=1, value=0, randomize=True),
72
  gr.Slider(label="Truncation psi 1", minimum=0, maximum=2, step=0.05, value=0.7),
@@ -74,7 +75,10 @@ gr.Interface(
74
  gr.Slider(label="Truncation psi 2", minimum=0, maximum=2, step=0.05, value=0.7),
75
  gr.Slider(label="Number of Intermediate Frames", minimum=0, maximum=21, step=1, value=7),
76
  ],
77
- outputs=gr.Gallery(label="Output Images", type="numpy"),
78
  title=TITLE,
79
  description=DESCRIPTION,
80
- ).queue(max_size=10).launch()
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import pickle
6
  import sys
7
 
 
30
  return model
31
 
32
 
33
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34
+ model = load_model("stylegan_human_v2_1024.pkl", device)
35
+
36
+
37
+ def generate_z(z_dim: int, seed: int) -> torch.Tensor:
38
+ return torch.from_numpy(np.random.RandomState(seed).randn(1, z_dim)).float()
39
 
40
 
41
  @torch.inference_mode()
42
  def generate_interpolated_images(
43
+ seed0: int, psi0: float, seed1: int, psi1: float, num_intermediate: int
44
  ) -> list[np.ndarray]:
45
  seed0 = int(np.clip(seed0, 0, np.iinfo(np.uint32).max))
46
  seed1 = int(np.clip(seed1, 0, np.iinfo(np.uint32).max))
47
 
48
+ z0 = generate_z(model.z_dim, seed0)
49
+ z1 = generate_z(model.z_dim, seed1)
50
+ z0 = z0.to(device)
51
+ z1 = z1.to(device)
52
  vec = z1 - z0
53
  dvec = vec / (num_intermediate + 1)
54
  zs = [z0 + dvec * i for i in range(num_intermediate + 2)]
 
66
  return res
67
 
68
 
69
+ demo = gr.Interface(
70
+ fn=generate_interpolated_images,
 
 
 
 
71
  inputs=[
72
  gr.Slider(label="Seed 1", minimum=0, maximum=100000, step=1, value=0, randomize=True),
73
  gr.Slider(label="Truncation psi 1", minimum=0, maximum=2, step=0.05, value=0.7),
 
75
  gr.Slider(label="Truncation psi 2", minimum=0, maximum=2, step=0.05, value=0.7),
76
  gr.Slider(label="Number of Intermediate Frames", minimum=0, maximum=21, step=1, value=7),
77
  ],
78
+ outputs=gr.Gallery(label="Output Images"),
79
  title=TITLE,
80
  description=DESCRIPTION,
81
+ )
82
+
83
+ if __name__ == "__main__":
84
+ demo.queue(max_size=10).launch()
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
- numpy==1.23.5
2
- Pillow==10.0.0
3
- scipy==1.10.1
4
  torch==2.0.1
5
  torchvision==0.15.2
 
1
+ numpy==1.26.4
2
+ Pillow==10.3.0
3
+ scipy==1.13.1
4
  torch==2.0.1
5
  torchvision==0.15.2