|
import os |
|
import numpy as np |
|
import torch |
|
import yaml |
|
from models.generator import OcclusionAwareGenerator |
|
from models.keypoint_detector import KPDetector |
|
import argparse |
|
import imageio |
|
from models.util import draw_annotation_box |
|
from models.transformer import Audio2kpTransformer |
|
from scipy.io import wavfile |
|
from tools.interface import read_img,get_img_pose,get_pose_from_audio,get_audio_feature_from_audio,\ |
|
parse_phoneme_file,load_ckpt |
|
import config |
|
|
|
def normalize_kp(kp_source, kp_driving, kp_driving_initial, |
|
use_relative_movement=True, use_relative_jacobian=True): |
|
|
|
kp_new = {k: v for k, v in kp_driving.items()} |
|
if use_relative_movement: |
|
kp_value_diff = (kp_driving['value'] - kp_driving_initial['value']) |
|
|
|
kp_new['value'] = kp_value_diff + kp_source['value'] |
|
|
|
if use_relative_jacobian: |
|
jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian'])) |
|
kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian']) |
|
|
|
return kp_new |
|
|
|
|
|
def test_with_input_audio_and_image(img_path, audio_path,phs, generator_ckpt, audio2pose_ckpt, save_dir="samples/results"): |
|
with open("config_file/vox-256.yaml") as f: |
|
config = yaml.full_load(f) |
|
|
|
|
|
cur_path = os.getcwd() |
|
|
|
sr,_ = wavfile.read(audio_path) |
|
if sr!=16000: |
|
temp_audio = os.path.join(cur_path,"samples","temp.wav") |
|
command = "ffmpeg -y -i %s -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (audio_path, temp_audio) |
|
os.system(command) |
|
else: |
|
temp_audio = audio_path |
|
|
|
|
|
opt = argparse.Namespace(**yaml.full_load(open("config_file/audio2kp.yaml"))) |
|
|
|
img = read_img(img_path).cuda() |
|
|
|
first_pose = get_img_pose(img_path) |
|
|
|
audio_feature = get_audio_feature_from_audio(temp_audio) |
|
frames = len(audio_feature) // 4 |
|
frames = min(frames,len(phs["phone_list"])) |
|
|
|
tp = np.zeros([256, 256], dtype=np.float32) |
|
draw_annotation_box(tp, first_pose[:3], first_pose[3:]) |
|
tp = torch.from_numpy(tp).unsqueeze(0).unsqueeze(0).cuda() |
|
ref_pose = get_pose_from_audio(tp, audio_feature, audio2pose_ckpt) |
|
torch.cuda.empty_cache() |
|
trans_seq = ref_pose[:, 3:] |
|
rot_seq = ref_pose[:, :3] |
|
|
|
|
|
|
|
audio_seq = audio_feature |
|
ph_seq = phs["phone_list"] |
|
|
|
|
|
ph_frames = [] |
|
audio_frames = [] |
|
pose_frames = [] |
|
name_len = frames |
|
|
|
pad = np.zeros((4, audio_seq.shape[1]), dtype=np.float32) |
|
|
|
for rid in range(0, frames): |
|
ph = [] |
|
audio = [] |
|
pose = [] |
|
for i in range(rid - opt.num_w, rid + opt.num_w + 1): |
|
if i < 0: |
|
rot = rot_seq[0] |
|
trans = trans_seq[0] |
|
ph.append(31) |
|
audio.append(pad) |
|
elif i >= name_len: |
|
ph.append(31) |
|
rot = rot_seq[name_len - 1] |
|
trans = trans_seq[name_len - 1] |
|
audio.append(pad) |
|
else: |
|
ph.append(ph_seq[i]) |
|
rot = rot_seq[i] |
|
trans = trans_seq[i] |
|
audio.append(audio_seq[i * 4:i * 4 + 4]) |
|
tmp_pose = np.zeros([256, 256]) |
|
draw_annotation_box(tmp_pose, np.array(rot), np.array(trans)) |
|
pose.append(tmp_pose) |
|
|
|
ph_frames.append(ph) |
|
audio_frames.append(audio) |
|
pose_frames.append(pose) |
|
|
|
audio_f = torch.from_numpy(np.array(audio_frames,dtype=np.float32)).unsqueeze(0) |
|
poses = torch.from_numpy(np.array(pose_frames, dtype=np.float32)).unsqueeze(0) |
|
ph_frames = torch.from_numpy(np.array(ph_frames)).unsqueeze(0) |
|
bs = audio_f.shape[1] |
|
predictions_gen = [] |
|
|
|
kp_detector = KPDetector(**config['model_params']['kp_detector_params'], |
|
**config['model_params']['common_params']) |
|
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'], |
|
**config['model_params']['common_params']) |
|
kp_detector = kp_detector.cuda() |
|
generator = generator.cuda() |
|
|
|
ph2kp = Audio2kpTransformer(opt).cuda() |
|
|
|
load_ckpt(generator_ckpt, kp_detector=kp_detector, generator=generator,ph2kp=ph2kp) |
|
|
|
|
|
ph2kp.eval() |
|
generator.eval() |
|
kp_detector.eval() |
|
|
|
with torch.no_grad(): |
|
for frame_idx in range(bs): |
|
t = {} |
|
|
|
t["audio"] = audio_f[:, frame_idx].cuda() |
|
t["pose"] = poses[:, frame_idx].cuda() |
|
t["ph"] = ph_frames[:,frame_idx].cuda() |
|
t["id_img"] = img |
|
|
|
kp_gen_source = kp_detector(img, True) |
|
|
|
gen_kp = ph2kp(t,kp_gen_source) |
|
if frame_idx == 0: |
|
drive_first = gen_kp |
|
|
|
norm = normalize_kp(kp_source=kp_gen_source, kp_driving=gen_kp, kp_driving_initial=drive_first) |
|
out_gen = generator(img, kp_source=kp_gen_source, kp_driving=norm) |
|
|
|
predictions_gen.append( |
|
(np.transpose(out_gen['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0] * 255).astype(np.uint8)) |
|
|
|
|
|
log_dir = save_dir |
|
os.makedirs(os.path.join(log_dir, "temp"),exist_ok=True) |
|
|
|
f_name = os.path.basename(img_path)[:-4] + "_" + os.path.basename(audio_path)[:-4] + ".mp4" |
|
|
|
video_path = os.path.join(log_dir, "temp", f_name) |
|
print("save video to: ", video_path) |
|
imageio.mimsave(video_path, predictions_gen, fps=25.0) |
|
|
|
|
|
save_video = os.path.join(log_dir, f_name) |
|
cmd = r'ffmpeg -y -i "%s" -i "%s" -vcodec copy "%s"' % (video_path, audio_path, save_video) |
|
os.system(cmd) |
|
os.remove(video_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
argparser = argparse.ArgumentParser() |
|
argparser.add_argument("--img_path", type=str, default=None, help="path of the input image ( .jpg ), preprocessed by image_preprocess.py") |
|
argparser.add_argument("--audio_path", type=str, default=None, help="path of the input audio ( .wav )") |
|
argparser.add_argument("--phoneme_path", type=str, default=None, help="path of the input phoneme. It should be note that the phoneme must be consistent with the input audio") |
|
argparser.add_argument("--save_dir", type=str, default="samples/results", help="path of the output video") |
|
args = argparser.parse_args() |
|
|
|
phoneme = parse_phoneme_file(args.phoneme_path) |
|
test_with_input_audio_and_image(args.img_path,args.audio_path,phoneme,config.GENERATOR_CKPT,config.AUDIO2POSE_CKPT,args.save_dir) |
|
|