animesh007 commited on
Commit
1b94225
1 Parent(s): e68ea0d

added app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -0
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import shutil
4
+ import torch
5
+ from PIL import Image
6
+ import argparse
7
+ import pathlib
8
+
9
+ os.system("git clone https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model")
10
+ os.chdir("Thin-Plate-Spline-Motion-Model")
11
+ os.system("mkdir checkpoints")
12
+ os.system("wget -c https://cloud.tsinghua.edu.cn/f/da8d61d012014b12a9e4/?dl=1 -O checkpoints/vox.pth.tar")
13
+
14
+
15
+
16
+ title = "# Thin-Plate Spline Motion Model for Image Animation"
17
+ DESCRIPTION = '''### Gradio demo for <b>Thin-Plate Spline Motion Model for Image Animation</b>, CVPR 2022. <a href='https://arxiv.org/abs/2203.14367'>[Paper]</a><a href='https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model'>[Github Code]</a>
18
+
19
+ <img id="overview" alt="overview" src="https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model/raw/main/assets/vox.gif" />
20
+ '''
21
+ FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.dualstylegan" />'
22
+
23
+
24
+ def get_style_image_path(style_name: str) -> str:
25
+ base_path = 'assets'
26
+ filenames = {
27
+ 'source': 'source.png',
28
+ 'driving': 'driving.mp4',
29
+ }
30
+ return f'{base_path}/{filenames[style_name]}'
31
+
32
+
33
+ def get_style_image_markdown_text(style_name: str) -> str:
34
+ url = get_style_image_path(style_name)
35
+ return f'<img id="style-image" src="{url}" alt="style image">'
36
+
37
+
38
+ def update_style_image(style_name: str) -> dict:
39
+ text = get_style_image_markdown_text(style_name)
40
+ return gr.Markdown.update(value=text)
41
+
42
+
43
+ def set_example_image(example: list) -> dict:
44
+ return gr.Image.update(value=example[0])
45
+
46
+ def set_example_video(example: list) -> dict:
47
+ return gr.Video.update(value=example[0])
48
+
49
+ def inference(img,vid):
50
+ if not os.path.exists('temp'):
51
+ os.system('mkdir temp')
52
+
53
+ img.save("temp/image.jpg", "JPEG")
54
+ os.system(f"python demo.py --config config/vox-256.yaml --checkpoint ./checkpoints/vox.pth.tar --source_image 'temp/image.jpg' --driving_video {vid} --result_video './temp/result.mp4' --cpu")
55
+ return './temp/result.mp4'
56
+
57
+
58
+
59
+ def main():
60
+ with gr.Blocks(theme="huggingface", css='style.css') as demo:
61
+ gr.Markdown(title)
62
+ gr.Markdown(DESCRIPTION)
63
+
64
+ with gr.Box():
65
+ gr.Markdown('''## Step 1 (Provide Input Face Image)
66
+ - Drop an image containing a face to the **Input Image**.
67
+ - If there are multiple faces in the image, use Edit button in the upper right corner and crop the input image beforehand.
68
+ ''')
69
+ with gr.Row():
70
+ with gr.Column():
71
+ with gr.Row():
72
+ input_image = gr.Image(label='Input Image',
73
+ type="pil")
74
+
75
+ with gr.Row():
76
+ paths = sorted(pathlib.Path('assets').glob('*.png'))
77
+ example_images = gr.Dataset(components=[input_image],
78
+ samples=[[path.as_posix()]
79
+ for path in paths])
80
+
81
+ with gr.Box():
82
+ gr.Markdown('''## Step 2 (Select Driving Video)
83
+ - Select **Style Driving Video for the face image animation**.
84
+ ''')
85
+ with gr.Row():
86
+ with gr.Column():
87
+ with gr.Row():
88
+ driving_video = gr.Video(label='Driving Video',
89
+ format="mp4")
90
+
91
+ with gr.Row():
92
+ paths = sorted(pathlib.Path('assets').glob('*.mp4'))
93
+ example_video = gr.Dataset(components=[driving_video],
94
+ samples=[[path.as_posix()]
95
+ for path in paths])
96
+
97
+ with gr.Box():
98
+ gr.Markdown('''## Step 3 (Generate Animated Image based on the Video)
99
+ - Hit the **Generate** button.
100
+ ''')
101
+ with gr.Row():
102
+ with gr.Column():
103
+ with gr.Row():
104
+ generate_button = gr.Button('Generate')
105
+
106
+ with gr.Column():
107
+ result = gr.Video(type="file", label="Output")
108
+ gr.Markdown(FOOTER)
109
+ generate_button.click(fn=inference,
110
+ inputs=[
111
+ input_image,
112
+ driving_video
113
+ ],
114
+ outputs=result)
115
+ example_images.click(fn=set_example_image,
116
+ inputs=example_images,
117
+ outputs=example_images.components)
118
+ example_video.click(fn=set_example_video,
119
+ inputs=example_video,
120
+ outputs=example_video.components)
121
+
122
+ demo.launch(
123
+ share=True,
124
+ debug=True
125
+ )
126
+
127
+
128
+ if __name__ == '__main__':
129
+ main()