DmitrMakeev commited on
Commit
c626b55
1 Parent(s): 183ec7c

Upload 7 files

Browse files
Files changed (7) hide show
  1. .gitattributes +59 -0
  2. README.md +59 -0
  3. config.py +3 -0
  4. image_preprocess.py +57 -0
  5. phindex.json +1 -0
  6. requirements.txt +8 -0
  7. test_script.py +180 -0
.gitattributes ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ OpenFace/FaceLandmarkVidMulti filter=lfs diff=lfs merge=lfs -text
36
+ OpenFace/FeatureExtraction filter=lfs diff=lfs merge=lfs -text
37
+ OpenFace/FaceLandmarkVid filter=lfs diff=lfs merge=lfs -text
38
+ OpenFace/FaceLandmarkImg filter=lfs diff=lfs merge=lfs -text
39
+ OpenFace/model/detection_validation/validator_cnn.txt filter=lfs diff=lfs merge=lfs -text
40
+ OpenFace/model/detection_validation/validator_cnn_68.txt filter=lfs diff=lfs merge=lfs -text
41
+ OpenFace/model/model_inner/patch_experts/ccnf_patches_1.00_inner.txt filter=lfs diff=lfs merge=lfs -text
42
+ OpenFace/model/patch_experts/ccnf_patches_0.5_wild.txt filter=lfs diff=lfs merge=lfs -text
43
+ OpenFace/model/patch_experts/ccnf_patches_1_wild.txt filter=lfs diff=lfs merge=lfs -text
44
+ OpenFace/model/patch_experts/ccnf_patches_0.35_multi_pie.txt filter=lfs diff=lfs merge=lfs -text
45
+ OpenFace/model/patch_experts/ccnf_patches_0.35_wild.txt filter=lfs diff=lfs merge=lfs -text
46
+ OpenFace/model/patch_experts/ccnf_patches_0.25_wild.txt filter=lfs diff=lfs merge=lfs -text
47
+ OpenFace/model/patch_experts/ccnf_patches_0.5_multi_pie.txt filter=lfs diff=lfs merge=lfs -text
48
+ OpenFace/model/patch_experts/ccnf_patches_0.25_multi_pie.txt filter=lfs diff=lfs merge=lfs -text
49
+ OpenFace/model/patch_experts/ccnf_patches_0.5_general.txt filter=lfs diff=lfs merge=lfs -text
50
+ OpenFace/model/patch_experts/ccnf_patches_0.25_general.txt filter=lfs diff=lfs merge=lfs -text
51
+ OpenFace/model/patch_experts/ccnf_patches_0.35_general.txt filter=lfs diff=lfs merge=lfs -text
52
+ OpenFace/model/mtcnn_detector/ONet.dat filter=lfs diff=lfs merge=lfs -text
53
+ samples/audios/trump.wav filter=lfs diff=lfs merge=lfs -text
54
+ samples/audios/abstract.wav filter=lfs diff=lfs merge=lfs -text
55
+ samples/audios/obama2.wav filter=lfs diff=lfs merge=lfs -text
56
+ OpenFace/model/patch_experts/cen_patches_0.35_of.dat filter=lfs diff=lfs merge=lfs -text
57
+ OpenFace/model/patch_experts/cen_patches_0.25_of.dat filter=lfs diff=lfs merge=lfs -text
58
+ OpenFace/model/patch_experts/cen_patches_1.00_of.dat filter=lfs diff=lfs merge=lfs -text
59
+ OpenFace/model/patch_experts/cen_patches_0.50_of.dat filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # One-shot Talking Face Generation from Single-speaker Audio-Visual Correlation Learning (AAAI 2022)
2
+
3
+ #### [Paper](https://arxiv.org/pdf/2112.02749.pdf) | [Demo](https://www.youtube.com/watch?v=HHj-XCXXePY)
4
+
5
+ #### Requirements
6
+
7
+ - Python >= 3.6 , Pytorch >= 1.8 and ffmpeg
8
+ - Set up [OpenFace](https://github.com/TadasBaltrusaitis/OpenFace)
9
+ - We use the OpenFace tools to extract the initial pose of the reference image
10
+ - Make sure you have installed this tool, and set the `OPENFACE_POSE_EXTRACTOR_PATH` in `config.py`. For example, it should be the absolute path of the "`FeatureExtraction.exe`" for Windows.
11
+ - Other requirements are listed in the 'requirements.txt'
12
+
13
+
14
+
15
+ #### Pretrained Checkpoint
16
+
17
+ Please download the pretrained checkpoint from [google-drive](https://drive.google.com/file/d/1mjFEozPR_2vMaVRMd9Agk_sU1VaiUYMl/view?usp=sharing) and unzip it to the directory (`/checkpoints`). Or manually modify the settings of `GENERATOR_CKPT` and `AUDIO2POSE_CKPT` in the `config.py`.
18
+
19
+
20
+
21
+ #### Extract phoneme
22
+
23
+ We employ the [CMU phoneset](https://github.com/cmusphinx/cmudict) to represent phonemes, the extra 'SIL' means silence. All the phonesets can be seen in '`phindex.json`'.
24
+
25
+ We have extracted the phonemes for the audios in the '`sample/audio`' directory. For other audios, you can extract the phonemes by other ASR tools and then map them to the CMU phoneset. Or email to [email protected] for help.
26
+
27
+
28
+
29
+ #### Generate Demo Results
30
+
31
+ ```
32
+ python test_script.py --img_path xxx.jpg --audio_path xxx.wav --phoneme_path xxx.json --save_dir "YOUR_DIR"
33
+ ```
34
+
35
+ Note that the input images must keep the same height and width and the face should be appropriately cropped as in `samples/imgs`. You can also preprocess your images with `image_preprocess.py`.
36
+
37
+
38
+
39
+ #### License and Citation
40
+
41
+ ```
42
+ @InProceedings{wang2021one,
43
+ author = Suzhen Wang, Lincheng Li, Yu Ding, Xin Yu
44
+ title = {One-shot Talking Face Generation from Single-speaker Audio-Visual Correlation Learning},
45
+ booktitle = {AAAI 2022},
46
+ year = {2022},
47
+ }
48
+ ```
49
+
50
+
51
+
52
+ #### Acknowledgement
53
+
54
+ This codebase is based on [First Order Motion Model](https://github.com/AliaksandrSiarohin/first-order-model) and [imaginaire](https://github.com/NVlabs/imaginaire), thanks for their contributions.
55
+
56
+
57
+
58
+
59
+
config.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ OPENFACE_POSE_EXTRACTOR_PATH = "/content/one-shot-talking-face/OpenFace/FeatureExtraction"
2
+ GENERATOR_CKPT = "/content/one-shot-talking-face/checkpoints/generator.ckpt"
3
+ AUDIO2POSE_CKPT = "/content/one-shot-talking-face/checkpoints/audio2pose.ckpt"
image_preprocess.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dlib
2
+ import cv2
3
+ def compute_aspect_preserved_bbox(bbox, increase_area, h, w):
4
+ left, top, right, bot = bbox
5
+ width = right - left
6
+ height = bot - top
7
+
8
+ width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
9
+ height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
10
+
11
+ left_t = int(left - width_increase * width)
12
+ top_t = int(top - height_increase * height)
13
+ right_t = int(right + width_increase * width)
14
+ bot_t = int(bot + height_increase * height)
15
+
16
+ left_oob = -min(0, left_t)
17
+ right_oob = right - min(right_t, w)
18
+ top_oob = -min(0, top_t)
19
+ bot_oob = bot - min(bot_t, h)
20
+
21
+ if max(left_oob, right_oob, top_oob, bot_oob) > 0:
22
+ max_w = max(left_oob, right_oob)
23
+ max_h = max(top_oob, bot_oob)
24
+ if max_w > max_h:
25
+ return left_t + max_w, top_t + max_w, right_t - max_w, bot_t - max_w
26
+ else:
27
+ return left_t + max_h, top_t + max_h, right_t - max_h, bot_t - max_h
28
+
29
+ else:
30
+ return (left_t, top_t, right_t, bot_t)
31
+
32
+ def crop_src_image(src_img,save_img, detector=None):
33
+ if detector is None:
34
+ detector = dlib.get_frontal_face_detector()
35
+
36
+ img = cv2.imread(src_img)
37
+ faces = detector(img, 0)
38
+ h, width, _ = img.shape
39
+ if len(faces) > 0:
40
+ bbox = [faces[0].left(), faces[0].top(),faces[0].right(), faces[0].bottom()]
41
+ l = bbox[3]-bbox[1]
42
+ bbox[1]= bbox[1]-l*0.1
43
+ bbox[3]= bbox[3]-l*0.1
44
+ bbox[1] = max(0,bbox[1])
45
+ bbox[3] = min(h,bbox[3])
46
+ bbox = compute_aspect_preserved_bbox(tuple(bbox), 0.5, img.shape[0], img.shape[1])
47
+ img = img[bbox[1] :bbox[3] , bbox[0]:bbox[2]]
48
+ img = cv2.resize(img, (256, 256))
49
+ cv2.imwrite(save_img,img)
50
+ else:
51
+ img = cv2.resize(img,(256,256))
52
+ cv2.imwrite(save_img, img)
53
+
54
+ if __name__ == '__main__':
55
+ src_img = ""
56
+ out_img = ""
57
+ crop_src_image(src_img,out_img)
phindex.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"AA": 0, "AE": 1, "AH": 2, "AO": 3, "AW": 4, "AY": 5, "B": 6, "CH": 7, "D": 8, "DH": 9, "EH": 10, "ER": 11, "EY": 12, "F": 13, "G": 14, "HH": 15, "IH": 16, "IY": 17, "JH": 18, "K": 19, "L": 20, "M": 21, "N": 22, "NG": 23, "NSN": 24, "OW": 25, "OY": 26, "P": 27, "R": 28, "S": 29, "SH": 30, "SIL": 31, "T": 32, "TH": 33, "UH": 34, "UW": 35, "V": 36, "W": 37, "Y": 38, "Z": 39, "ZH": 40}
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ scikit-image
2
+ python_speech_features
3
+ pyworld
4
+ pyyaml
5
+ imageio
6
+ scipy
7
+ pyworld
8
+ opencv-python
test_script.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import yaml
5
+ from models.generator import OcclusionAwareGenerator
6
+ from models.keypoint_detector import KPDetector
7
+ import argparse
8
+ import imageio
9
+ from models.util import draw_annotation_box
10
+ from models.transformer import Audio2kpTransformer
11
+ from scipy.io import wavfile
12
+ from tools.interface import read_img,get_img_pose,get_pose_from_audio,get_audio_feature_from_audio,\
13
+ parse_phoneme_file,load_ckpt
14
+ import config
15
+
16
+ def normalize_kp(kp_source, kp_driving, kp_driving_initial,
17
+ use_relative_movement=True, use_relative_jacobian=True):
18
+
19
+ kp_new = {k: v for k, v in kp_driving.items()}
20
+ if use_relative_movement:
21
+ kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
22
+ # kp_value_diff *= adapt_movement_scale
23
+ kp_new['value'] = kp_value_diff + kp_source['value']
24
+
25
+ if use_relative_jacobian:
26
+ jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
27
+ kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
28
+
29
+ return kp_new
30
+
31
+
32
+ def test_with_input_audio_and_image(img_path, audio_path,phs, generator_ckpt, audio2pose_ckpt, save_dir="samples/results"):
33
+ with open("config_file/vox-256.yaml") as f:
34
+ config = yaml.full_load(f)
35
+ # temp_audio = audio_path
36
+ # print(audio_path)
37
+ cur_path = os.getcwd()
38
+
39
+ sr,_ = wavfile.read(audio_path)
40
+ if sr!=16000:
41
+ temp_audio = os.path.join(cur_path,"samples","temp.wav")
42
+ command = "ffmpeg -y -i %s -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (audio_path, temp_audio)
43
+ os.system(command)
44
+ else:
45
+ temp_audio = audio_path
46
+
47
+
48
+ opt = argparse.Namespace(**yaml.full_load(open("config_file/audio2kp.yaml")))
49
+
50
+ img = read_img(img_path).cuda()
51
+
52
+ first_pose = get_img_pose(img_path)#.cuda()
53
+
54
+ audio_feature = get_audio_feature_from_audio(temp_audio)
55
+ frames = len(audio_feature) // 4
56
+ frames = min(frames,len(phs["phone_list"]))
57
+
58
+ tp = np.zeros([256, 256], dtype=np.float32)
59
+ draw_annotation_box(tp, first_pose[:3], first_pose[3:])
60
+ tp = torch.from_numpy(tp).unsqueeze(0).unsqueeze(0).cuda()
61
+ ref_pose = get_pose_from_audio(tp, audio_feature, audio2pose_ckpt)
62
+ torch.cuda.empty_cache()
63
+ trans_seq = ref_pose[:, 3:]
64
+ rot_seq = ref_pose[:, :3]
65
+
66
+
67
+
68
+ audio_seq = audio_feature#[40:]
69
+ ph_seq = phs["phone_list"]
70
+
71
+
72
+ ph_frames = []
73
+ audio_frames = []
74
+ pose_frames = []
75
+ name_len = frames
76
+
77
+ pad = np.zeros((4, audio_seq.shape[1]), dtype=np.float32)
78
+
79
+ for rid in range(0, frames):
80
+ ph = []
81
+ audio = []
82
+ pose = []
83
+ for i in range(rid - opt.num_w, rid + opt.num_w + 1):
84
+ if i < 0:
85
+ rot = rot_seq[0]
86
+ trans = trans_seq[0]
87
+ ph.append(31)
88
+ audio.append(pad)
89
+ elif i >= name_len:
90
+ ph.append(31)
91
+ rot = rot_seq[name_len - 1]
92
+ trans = trans_seq[name_len - 1]
93
+ audio.append(pad)
94
+ else:
95
+ ph.append(ph_seq[i])
96
+ rot = rot_seq[i]
97
+ trans = trans_seq[i]
98
+ audio.append(audio_seq[i * 4:i * 4 + 4])
99
+ tmp_pose = np.zeros([256, 256])
100
+ draw_annotation_box(tmp_pose, np.array(rot), np.array(trans))
101
+ pose.append(tmp_pose)
102
+
103
+ ph_frames.append(ph)
104
+ audio_frames.append(audio)
105
+ pose_frames.append(pose)
106
+
107
+ audio_f = torch.from_numpy(np.array(audio_frames,dtype=np.float32)).unsqueeze(0)
108
+ poses = torch.from_numpy(np.array(pose_frames, dtype=np.float32)).unsqueeze(0)
109
+ ph_frames = torch.from_numpy(np.array(ph_frames)).unsqueeze(0)
110
+ bs = audio_f.shape[1]
111
+ predictions_gen = []
112
+
113
+ kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
114
+ **config['model_params']['common_params'])
115
+ generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
116
+ **config['model_params']['common_params'])
117
+ kp_detector = kp_detector.cuda()
118
+ generator = generator.cuda()
119
+
120
+ ph2kp = Audio2kpTransformer(opt).cuda()
121
+
122
+ load_ckpt(generator_ckpt, kp_detector=kp_detector, generator=generator,ph2kp=ph2kp)
123
+
124
+
125
+ ph2kp.eval()
126
+ generator.eval()
127
+ kp_detector.eval()
128
+
129
+ with torch.no_grad():
130
+ for frame_idx in range(bs):
131
+ t = {}
132
+
133
+ t["audio"] = audio_f[:, frame_idx].cuda()
134
+ t["pose"] = poses[:, frame_idx].cuda()
135
+ t["ph"] = ph_frames[:,frame_idx].cuda()
136
+ t["id_img"] = img
137
+
138
+ kp_gen_source = kp_detector(img, True)
139
+
140
+ gen_kp = ph2kp(t,kp_gen_source)
141
+ if frame_idx == 0:
142
+ drive_first = gen_kp
143
+
144
+ norm = normalize_kp(kp_source=kp_gen_source, kp_driving=gen_kp, kp_driving_initial=drive_first)
145
+ out_gen = generator(img, kp_source=kp_gen_source, kp_driving=norm)
146
+
147
+ predictions_gen.append(
148
+ (np.transpose(out_gen['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0] * 255).astype(np.uint8))
149
+
150
+
151
+ log_dir = save_dir
152
+ os.makedirs(os.path.join(log_dir, "temp"),exist_ok=True)
153
+
154
+ f_name = os.path.basename(img_path)[:-4] + "_" + os.path.basename(audio_path)[:-4] + ".mp4"
155
+ # kwargs = {'duration': 1. / 25.0}
156
+ video_path = os.path.join(log_dir, "temp", f_name)
157
+ print("save video to: ", video_path)
158
+ imageio.mimsave(video_path, predictions_gen, fps=25.0)
159
+
160
+ # audio_path = os.path.join(audio_dir, x['name'][0].replace(".mp4", ".wav"))
161
+ save_video = os.path.join(log_dir, f_name)
162
+ cmd = r'ffmpeg -y -i "%s" -i "%s" -vcodec copy "%s"' % (video_path, audio_path, save_video)
163
+ os.system(cmd)
164
+ os.remove(video_path)
165
+
166
+
167
+
168
+
169
+
170
+
171
+ if __name__ == '__main__':
172
+ argparser = argparse.ArgumentParser()
173
+ argparser.add_argument("--img_path", type=str, default=None, help="path of the input image ( .jpg ), preprocessed by image_preprocess.py")
174
+ argparser.add_argument("--audio_path", type=str, default=None, help="path of the input audio ( .wav )")
175
+ argparser.add_argument("--phoneme_path", type=str, default=None, help="path of the input phoneme. It should be note that the phoneme must be consistent with the input audio")
176
+ argparser.add_argument("--save_dir", type=str, default="samples/results", help="path of the output video")
177
+ args = argparser.parse_args()
178
+
179
+ phoneme = parse_phoneme_file(args.phoneme_path)
180
+ test_with_input_audio_and_image(args.img_path,args.audio_path,phoneme,config.GENERATOR_CKPT,config.AUDIO2POSE_CKPT,args.save_dir)