yourusername commited on
Commit
7b91e69
·
1 Parent(s): b12f76d

Created using Colaboratory

Browse files
Files changed (1) hide show
  1. animegan_v2_for_videos.ipynb +239 -0
animegan_v2_for_videos.ipynb ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "accelerator": "GPU",
6
+ "colab": {
7
+ "name": "animegan_v2_for_videos.ipynb",
8
+ "provenance": [],
9
+ "collapsed_sections": [],
10
+ "authorship_tag": "ABX9TyP/bydrfrVmE0CzRt9JBw+x",
11
+ "include_colab_link": true
12
+ },
13
+ "kernelspec": {
14
+ "display_name": "Python 3",
15
+ "name": "python3"
16
+ },
17
+ "language_info": {
18
+ "name": "python"
19
+ }
20
+ },
21
+ "cells": [
22
+ {
23
+ "cell_type": "markdown",
24
+ "metadata": {
25
+ "id": "view-in-github",
26
+ "colab_type": "text"
27
+ },
28
+ "source": [
29
+ "<a href=\"https://colab.research.google.com/github/nateraw/animegan-v2-for-videos/blob/main/animegan_v2_for_videos.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "metadata": {
35
+ "id": "dufmM-T1Helt"
36
+ },
37
+ "source": [
38
+ "%%capture\n",
39
+ "! pip install gradio encoded-video"
40
+ ],
41
+ "execution_count": null,
42
+ "outputs": []
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "metadata": {
47
+ "id": "9CY3n8A0Lvdi"
48
+ },
49
+ "source": [
50
+ "import gc\n",
51
+ "import math\n",
52
+ "import tempfile\n",
53
+ "from PIL import Image\n",
54
+ "from io import BytesIO\n",
55
+ "\n",
56
+ "import torch\n",
57
+ "import gradio as gr\n",
58
+ "import numpy as np\n",
59
+ "from encoded_video import EncodedVideo, write_video\n",
60
+ "from torchvision.transforms.functional import to_tensor, center_crop"
61
+ ],
62
+ "execution_count": null,
63
+ "outputs": []
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "metadata": {
68
+ "id": "YxdCnrTzLw5V"
69
+ },
70
+ "source": [
71
+ "model = torch.hub.load(\n",
72
+ " \"AK391/animegan2-pytorch:main\",\n",
73
+ " \"generator\",\n",
74
+ " pretrained=True,\n",
75
+ " device=\"cuda\",\n",
76
+ " progress=True,\n",
77
+ ")"
78
+ ],
79
+ "execution_count": null,
80
+ "outputs": []
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "metadata": {
85
+ "id": "TYAyXUP1UeOd"
86
+ },
87
+ "source": [
88
+ "! curl https://upload.wikimedia.org/wikipedia/commons/transcoded/2/29/2017-01-07_President_Obama%27s_Weekly_Address.webm/2017-01-07_President_Obama%27s_Weekly_Address.webm.360p.vp9.webm -o obama.webm"
89
+ ],
90
+ "execution_count": null,
91
+ "outputs": []
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "metadata": {
96
+ "id": "TxT45Nlc88tD"
97
+ },
98
+ "source": [
99
+ "def face2paint(model: torch.nn.Module, img: Image.Image, size: int = 512, device: str = 'cuda'):\n",
100
+ " w, h = img.size\n",
101
+ " s = min(w, h)\n",
102
+ " img = img.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))\n",
103
+ " img = img.resize((size, size), Image.LANCZOS)\n",
104
+ "\n",
105
+ " with torch.no_grad():\n",
106
+ " input = to_tensor(img).unsqueeze(0) * 2 - 1\n",
107
+ " output = model(input.to(device)).cpu()[0]\n",
108
+ "\n",
109
+ " output = (output * 0.5 + 0.5).clip(0, 1) * 255.\n",
110
+ "\n",
111
+ " return output\n",
112
+ "\n",
113
+ "# This function is taken from pytorchvideo!\n",
114
+ "def uniform_temporal_subsample(x: torch.Tensor, num_samples: int, temporal_dim: int = -3) -> torch.Tensor:\n",
115
+ " \"\"\"\n",
116
+ " Uniformly subsamples num_samples indices from the temporal dimension of the video.\n",
117
+ " When num_samples is larger than the size of temporal dimension of the video, it\n",
118
+ " will sample frames based on nearest neighbor interpolation.\n",
119
+ " Args:\n",
120
+ " x (torch.Tensor): A video tensor with dimension larger than one with torch\n",
121
+ " tensor type includes int, long, float, complex, etc.\n",
122
+ " num_samples (int): The number of equispaced samples to be selected\n",
123
+ " temporal_dim (int): dimension of temporal to perform temporal subsample.\n",
124
+ " Returns:\n",
125
+ " An x-like Tensor with subsampled temporal dimension.\n",
126
+ " \"\"\"\n",
127
+ " t = x.shape[temporal_dim]\n",
128
+ " assert num_samples > 0 and t > 0\n",
129
+ " # Sample by nearest neighbor interpolation if num_samples > t.\n",
130
+ " indices = torch.linspace(0, t - 1, num_samples)\n",
131
+ " indices = torch.clamp(indices, 0, t - 1).long()\n",
132
+ " return torch.index_select(x, temporal_dim, indices)\n",
133
+ "\n",
134
+ "\n",
135
+ "def short_side_scale(\n",
136
+ " x: torch.Tensor,\n",
137
+ " size: int,\n",
138
+ " interpolation: str = \"bilinear\",\n",
139
+ ") -> torch.Tensor:\n",
140
+ " \"\"\"\n",
141
+ " Determines the shorter spatial dim of the video (i.e. width or height) and scales\n",
142
+ " it to the given size. To maintain aspect ratio, the longer side is then scaled\n",
143
+ " accordingly.\n",
144
+ " Args:\n",
145
+ " x (torch.Tensor): A video tensor of shape (C, T, H, W) and type torch.float32.\n",
146
+ " size (int): The size the shorter side is scaled to.\n",
147
+ " interpolation (str): Algorithm used for upsampling,\n",
148
+ " options: nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'\n",
149
+ " Returns:\n",
150
+ " An x-like Tensor with scaled spatial dims.\n",
151
+ " \"\"\"\n",
152
+ " assert len(x.shape) == 4\n",
153
+ " assert x.dtype == torch.float32\n",
154
+ " c, t, h, w = x.shape\n",
155
+ " if w < h:\n",
156
+ " new_h = int(math.floor((float(h) / w) * size))\n",
157
+ " new_w = size\n",
158
+ " else:\n",
159
+ " new_h = size\n",
160
+ " new_w = int(math.floor((float(w) / h) * size))\n",
161
+ "\n",
162
+ " return torch.nn.functional.interpolate(\n",
163
+ " x, size=(new_h, new_w), mode=interpolation, align_corners=False\n",
164
+ " )\n",
165
+ "\n",
166
+ "def inference_step(vid, start_sec, duration, out_fps):\n",
167
+ " clip = vid.get_clip(start_sec, start_sec + duration)\n",
168
+ " video_arr = torch.from_numpy(clip['video']).permute(3, 0, 1, 2)\n",
169
+ " audio_arr = np.expand_dims(clip['audio'], 0)\n",
170
+ " audio_fps = None if not vid._has_audio else vid._container.streams.audio[0].sample_rate\n",
171
+ "\n",
172
+ " x = uniform_temporal_subsample(video_arr, duration * out_fps)\n",
173
+ " x = center_crop(short_side_scale(x, 512), 512)\n",
174
+ " x /= 255.\n",
175
+ " x = x.permute(1, 0, 2, 3)\n",
176
+ " with torch.no_grad():\n",
177
+ " output = model(x.to('cuda')).detach().cpu()\n",
178
+ " output = (output * 0.5 + 0.5).clip(0, 1) * 255.\n",
179
+ " output_video = output.permute(0, 2, 3, 1).numpy()\n",
180
+ " \n",
181
+ " return output_video, audio_arr, out_fps, audio_fps\n",
182
+ "\n",
183
+ "def predict_fn(filepath, start_sec, duration, out_fps):\n",
184
+ " # out_fps=12\n",
185
+ " vid = EncodedVideo.from_path(filepath)\n",
186
+ " for i in range(duration):\n",
187
+ " video, audio, fps, audio_fps = inference_step(\n",
188
+ " vid = vid,\n",
189
+ " start_sec = i + start_sec,\n",
190
+ " duration = 1,\n",
191
+ " out_fps = out_fps\n",
192
+ " )\n",
193
+ " gc.collect()\n",
194
+ " if i == 0:\n",
195
+ " video_all = video\n",
196
+ " audio_all = audio\n",
197
+ " else:\n",
198
+ " video_all = np.concatenate((video_all, video))\n",
199
+ " audio_all = np.hstack((audio_all, audio))\n",
200
+ "\n",
201
+ " write_video(\n",
202
+ " 'out.mp4',\n",
203
+ " video_all,\n",
204
+ " fps=fps,\n",
205
+ " audio_array=audio_all,\n",
206
+ " audio_fps=audio_fps,\n",
207
+ " audio_codec='aac'\n",
208
+ " )\n",
209
+ "\n",
210
+ " del video_all\n",
211
+ " del audio_all\n",
212
+ " \n",
213
+ " return 'out.mp4'\n",
214
+ "\n",
215
+ "article = \"\"\"\n",
216
+ "<p style='text-align: center'>\n",
217
+ " <a href='https://github.com/bryandlee/animegan2-pytorch' target='_blank'>Github Repo Pytorch</a>\n",
218
+ "</p>\n",
219
+ "\"\"\"\n",
220
+ "\n",
221
+ "gr.Interface(\n",
222
+ " predict_fn,\n",
223
+ " inputs=[gr.inputs.Video(), gr.inputs.Slider(minimum=0, maximum=300, step=1, default=0), gr.inputs.Slider(minimum=1, maximum=10, step=1, default=2), gr.inputs.Slider(minimum=12, maximum=30, step=6, default=24)],\n",
224
+ " outputs=gr.outputs.Video(),\n",
225
+ " title='AnimeGANV2 On Videos',\n",
226
+ " description=\"Applying AnimeGAN-V2 to frame from video clips\",\n",
227
+ " article = article,\n",
228
+ " enable_queue=True,\n",
229
+ " examples=[\n",
230
+ " ['obama.webm', 23, 10, 30],\n",
231
+ " ],\n",
232
+ " allow_flagging=False\n",
233
+ ").launch(debug=True)"
234
+ ],
235
+ "execution_count": null,
236
+ "outputs": []
237
+ }
238
+ ]
239
+ }