Spaces:
Runtime error
Runtime error
yourusername
commited on
Commit
·
7b91e69
1
Parent(s):
b12f76d
Created using Colaboratory
Browse files- animegan_v2_for_videos.ipynb +239 -0
animegan_v2_for_videos.ipynb
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nbformat": 4,
|
3 |
+
"nbformat_minor": 0,
|
4 |
+
"metadata": {
|
5 |
+
"accelerator": "GPU",
|
6 |
+
"colab": {
|
7 |
+
"name": "animegan_v2_for_videos.ipynb",
|
8 |
+
"provenance": [],
|
9 |
+
"collapsed_sections": [],
|
10 |
+
"authorship_tag": "ABX9TyP/bydrfrVmE0CzRt9JBw+x",
|
11 |
+
"include_colab_link": true
|
12 |
+
},
|
13 |
+
"kernelspec": {
|
14 |
+
"display_name": "Python 3",
|
15 |
+
"name": "python3"
|
16 |
+
},
|
17 |
+
"language_info": {
|
18 |
+
"name": "python"
|
19 |
+
}
|
20 |
+
},
|
21 |
+
"cells": [
|
22 |
+
{
|
23 |
+
"cell_type": "markdown",
|
24 |
+
"metadata": {
|
25 |
+
"id": "view-in-github",
|
26 |
+
"colab_type": "text"
|
27 |
+
},
|
28 |
+
"source": [
|
29 |
+
"<a href=\"https://colab.research.google.com/github/nateraw/animegan-v2-for-videos/blob/main/animegan_v2_for_videos.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "code",
|
34 |
+
"metadata": {
|
35 |
+
"id": "dufmM-T1Helt"
|
36 |
+
},
|
37 |
+
"source": [
|
38 |
+
"%%capture\n",
|
39 |
+
"! pip install gradio encoded-video"
|
40 |
+
],
|
41 |
+
"execution_count": null,
|
42 |
+
"outputs": []
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"cell_type": "code",
|
46 |
+
"metadata": {
|
47 |
+
"id": "9CY3n8A0Lvdi"
|
48 |
+
},
|
49 |
+
"source": [
|
50 |
+
"import gc\n",
|
51 |
+
"import math\n",
|
52 |
+
"import tempfile\n",
|
53 |
+
"from PIL import Image\n",
|
54 |
+
"from io import BytesIO\n",
|
55 |
+
"\n",
|
56 |
+
"import torch\n",
|
57 |
+
"import gradio as gr\n",
|
58 |
+
"import numpy as np\n",
|
59 |
+
"from encoded_video import EncodedVideo, write_video\n",
|
60 |
+
"from torchvision.transforms.functional import to_tensor, center_crop"
|
61 |
+
],
|
62 |
+
"execution_count": null,
|
63 |
+
"outputs": []
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "code",
|
67 |
+
"metadata": {
|
68 |
+
"id": "YxdCnrTzLw5V"
|
69 |
+
},
|
70 |
+
"source": [
|
71 |
+
"model = torch.hub.load(\n",
|
72 |
+
" \"AK391/animegan2-pytorch:main\",\n",
|
73 |
+
" \"generator\",\n",
|
74 |
+
" pretrained=True,\n",
|
75 |
+
" device=\"cuda\",\n",
|
76 |
+
" progress=True,\n",
|
77 |
+
")"
|
78 |
+
],
|
79 |
+
"execution_count": null,
|
80 |
+
"outputs": []
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "code",
|
84 |
+
"metadata": {
|
85 |
+
"id": "TYAyXUP1UeOd"
|
86 |
+
},
|
87 |
+
"source": [
|
88 |
+
"! curl https://upload.wikimedia.org/wikipedia/commons/transcoded/2/29/2017-01-07_President_Obama%27s_Weekly_Address.webm/2017-01-07_President_Obama%27s_Weekly_Address.webm.360p.vp9.webm -o obama.webm"
|
89 |
+
],
|
90 |
+
"execution_count": null,
|
91 |
+
"outputs": []
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"cell_type": "code",
|
95 |
+
"metadata": {
|
96 |
+
"id": "TxT45Nlc88tD"
|
97 |
+
},
|
98 |
+
"source": [
|
99 |
+
"def face2paint(model: torch.nn.Module, img: Image.Image, size: int = 512, device: str = 'cuda'):\n",
|
100 |
+
" w, h = img.size\n",
|
101 |
+
" s = min(w, h)\n",
|
102 |
+
" img = img.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))\n",
|
103 |
+
" img = img.resize((size, size), Image.LANCZOS)\n",
|
104 |
+
"\n",
|
105 |
+
" with torch.no_grad():\n",
|
106 |
+
" input = to_tensor(img).unsqueeze(0) * 2 - 1\n",
|
107 |
+
" output = model(input.to(device)).cpu()[0]\n",
|
108 |
+
"\n",
|
109 |
+
" output = (output * 0.5 + 0.5).clip(0, 1) * 255.\n",
|
110 |
+
"\n",
|
111 |
+
" return output\n",
|
112 |
+
"\n",
|
113 |
+
"# This function is taken from pytorchvideo!\n",
|
114 |
+
"def uniform_temporal_subsample(x: torch.Tensor, num_samples: int, temporal_dim: int = -3) -> torch.Tensor:\n",
|
115 |
+
" \"\"\"\n",
|
116 |
+
" Uniformly subsamples num_samples indices from the temporal dimension of the video.\n",
|
117 |
+
" When num_samples is larger than the size of temporal dimension of the video, it\n",
|
118 |
+
" will sample frames based on nearest neighbor interpolation.\n",
|
119 |
+
" Args:\n",
|
120 |
+
" x (torch.Tensor): A video tensor with dimension larger than one with torch\n",
|
121 |
+
" tensor type includes int, long, float, complex, etc.\n",
|
122 |
+
" num_samples (int): The number of equispaced samples to be selected\n",
|
123 |
+
" temporal_dim (int): dimension of temporal to perform temporal subsample.\n",
|
124 |
+
" Returns:\n",
|
125 |
+
" An x-like Tensor with subsampled temporal dimension.\n",
|
126 |
+
" \"\"\"\n",
|
127 |
+
" t = x.shape[temporal_dim]\n",
|
128 |
+
" assert num_samples > 0 and t > 0\n",
|
129 |
+
" # Sample by nearest neighbor interpolation if num_samples > t.\n",
|
130 |
+
" indices = torch.linspace(0, t - 1, num_samples)\n",
|
131 |
+
" indices = torch.clamp(indices, 0, t - 1).long()\n",
|
132 |
+
" return torch.index_select(x, temporal_dim, indices)\n",
|
133 |
+
"\n",
|
134 |
+
"\n",
|
135 |
+
"def short_side_scale(\n",
|
136 |
+
" x: torch.Tensor,\n",
|
137 |
+
" size: int,\n",
|
138 |
+
" interpolation: str = \"bilinear\",\n",
|
139 |
+
") -> torch.Tensor:\n",
|
140 |
+
" \"\"\"\n",
|
141 |
+
" Determines the shorter spatial dim of the video (i.e. width or height) and scales\n",
|
142 |
+
" it to the given size. To maintain aspect ratio, the longer side is then scaled\n",
|
143 |
+
" accordingly.\n",
|
144 |
+
" Args:\n",
|
145 |
+
" x (torch.Tensor): A video tensor of shape (C, T, H, W) and type torch.float32.\n",
|
146 |
+
" size (int): The size the shorter side is scaled to.\n",
|
147 |
+
" interpolation (str): Algorithm used for upsampling,\n",
|
148 |
+
" options: nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'\n",
|
149 |
+
" Returns:\n",
|
150 |
+
" An x-like Tensor with scaled spatial dims.\n",
|
151 |
+
" \"\"\"\n",
|
152 |
+
" assert len(x.shape) == 4\n",
|
153 |
+
" assert x.dtype == torch.float32\n",
|
154 |
+
" c, t, h, w = x.shape\n",
|
155 |
+
" if w < h:\n",
|
156 |
+
" new_h = int(math.floor((float(h) / w) * size))\n",
|
157 |
+
" new_w = size\n",
|
158 |
+
" else:\n",
|
159 |
+
" new_h = size\n",
|
160 |
+
" new_w = int(math.floor((float(w) / h) * size))\n",
|
161 |
+
"\n",
|
162 |
+
" return torch.nn.functional.interpolate(\n",
|
163 |
+
" x, size=(new_h, new_w), mode=interpolation, align_corners=False\n",
|
164 |
+
" )\n",
|
165 |
+
"\n",
|
166 |
+
"def inference_step(vid, start_sec, duration, out_fps):\n",
|
167 |
+
" clip = vid.get_clip(start_sec, start_sec + duration)\n",
|
168 |
+
" video_arr = torch.from_numpy(clip['video']).permute(3, 0, 1, 2)\n",
|
169 |
+
" audio_arr = np.expand_dims(clip['audio'], 0)\n",
|
170 |
+
" audio_fps = None if not vid._has_audio else vid._container.streams.audio[0].sample_rate\n",
|
171 |
+
"\n",
|
172 |
+
" x = uniform_temporal_subsample(video_arr, duration * out_fps)\n",
|
173 |
+
" x = center_crop(short_side_scale(x, 512), 512)\n",
|
174 |
+
" x /= 255.\n",
|
175 |
+
" x = x.permute(1, 0, 2, 3)\n",
|
176 |
+
" with torch.no_grad():\n",
|
177 |
+
" output = model(x.to('cuda')).detach().cpu()\n",
|
178 |
+
" output = (output * 0.5 + 0.5).clip(0, 1) * 255.\n",
|
179 |
+
" output_video = output.permute(0, 2, 3, 1).numpy()\n",
|
180 |
+
" \n",
|
181 |
+
" return output_video, audio_arr, out_fps, audio_fps\n",
|
182 |
+
"\n",
|
183 |
+
"def predict_fn(filepath, start_sec, duration, out_fps):\n",
|
184 |
+
" # out_fps=12\n",
|
185 |
+
" vid = EncodedVideo.from_path(filepath)\n",
|
186 |
+
" for i in range(duration):\n",
|
187 |
+
" video, audio, fps, audio_fps = inference_step(\n",
|
188 |
+
" vid = vid,\n",
|
189 |
+
" start_sec = i + start_sec,\n",
|
190 |
+
" duration = 1,\n",
|
191 |
+
" out_fps = out_fps\n",
|
192 |
+
" )\n",
|
193 |
+
" gc.collect()\n",
|
194 |
+
" if i == 0:\n",
|
195 |
+
" video_all = video\n",
|
196 |
+
" audio_all = audio\n",
|
197 |
+
" else:\n",
|
198 |
+
" video_all = np.concatenate((video_all, video))\n",
|
199 |
+
" audio_all = np.hstack((audio_all, audio))\n",
|
200 |
+
"\n",
|
201 |
+
" write_video(\n",
|
202 |
+
" 'out.mp4',\n",
|
203 |
+
" video_all,\n",
|
204 |
+
" fps=fps,\n",
|
205 |
+
" audio_array=audio_all,\n",
|
206 |
+
" audio_fps=audio_fps,\n",
|
207 |
+
" audio_codec='aac'\n",
|
208 |
+
" )\n",
|
209 |
+
"\n",
|
210 |
+
" del video_all\n",
|
211 |
+
" del audio_all\n",
|
212 |
+
" \n",
|
213 |
+
" return 'out.mp4'\n",
|
214 |
+
"\n",
|
215 |
+
"article = \"\"\"\n",
|
216 |
+
"<p style='text-align: center'>\n",
|
217 |
+
" <a href='https://github.com/bryandlee/animegan2-pytorch' target='_blank'>Github Repo Pytorch</a>\n",
|
218 |
+
"</p>\n",
|
219 |
+
"\"\"\"\n",
|
220 |
+
"\n",
|
221 |
+
"gr.Interface(\n",
|
222 |
+
" predict_fn,\n",
|
223 |
+
" inputs=[gr.inputs.Video(), gr.inputs.Slider(minimum=0, maximum=300, step=1, default=0), gr.inputs.Slider(minimum=1, maximum=10, step=1, default=2), gr.inputs.Slider(minimum=12, maximum=30, step=6, default=24)],\n",
|
224 |
+
" outputs=gr.outputs.Video(),\n",
|
225 |
+
" title='AnimeGANV2 On Videos',\n",
|
226 |
+
" description=\"Applying AnimeGAN-V2 to frame from video clips\",\n",
|
227 |
+
" article = article,\n",
|
228 |
+
" enable_queue=True,\n",
|
229 |
+
" examples=[\n",
|
230 |
+
" ['obama.webm', 23, 10, 30],\n",
|
231 |
+
" ],\n",
|
232 |
+
" allow_flagging=False\n",
|
233 |
+
").launch(debug=True)"
|
234 |
+
],
|
235 |
+
"execution_count": null,
|
236 |
+
"outputs": []
|
237 |
+
}
|
238 |
+
]
|
239 |
+
}
|