File size: 6,092 Bytes
03a856a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from src.utils.mp_utils import LMKExtractor
from src.utils.draw_utils import FaceMeshVisualizer
from src.utils.img_utils import pil_to_cv2, cv2_to_pil, center_crop_cv2, pils_from_video, save_videos_from_pils, save_video_from_cv2_list
from PIL import Image
import cv2
from IPython import embed
import numpy as np
import copy
from src.utils.motion_utils import motion_sync
import pathlib
import torch
import pickle
from glob import glob
import os
vis = FaceMeshVisualizer(draw_iris=False, draw_mouse=True, draw_eye=True, draw_nose=True, draw_eyebrow=True, draw_pupil=True)
imsize = (512, 512)
visualization = True
driver_video = "./assets/driven_videos/a.mp4"
# driver_videos = glob("/nas2/luoque.lym/evaluation/test_datasets/gt_data/OurDataset/*.mp4")
ref_image = './assets/test_imgs/d.png'
# ref_image = 'panda.png'
lmk_extractor = LMKExtractor()
input_frames_cv2 = [cv2.resize(center_crop_cv2(pil_to_cv2(i)), imsize) for i in pils_from_video(driver_video)]
ref_frame =cv2.resize(cv2.imread(ref_image), (512, 512))
ref_det = lmk_extractor(ref_frame)
# print(ref_det)
sequence_driver_det = []
try:
for frame in input_frames_cv2:
result = lmk_extractor(frame)
assert result is not None, "{}, bad video, face not detected".format(driver_video)
sequence_driver_det.append(result)
except:
print("face detection failed")
exit()
print(len(sequence_driver_det))
if visualization:
pose_frames_driver = [vis.draw_landmarks((512, 512), i["lmks"], normed=True) for i in sequence_driver_det]
poses_add_driver = [(i * 0.5 + j * 0.5).clip(0,255).astype(np.uint8) for i, j in zip(input_frames_cv2, pose_frames_driver)]
save_dir = './{}'.format(ref_image.split('/')[-1].replace('.png', ''))
os.makedirs(save_dir, exist_ok=True)
sequence_det_ms = motion_sync(sequence_driver_det, ref_det)
for i in range(len(sequence_det_ms)):
with open('{}/{}.pkl'.format(save_dir, i), 'wb') as file:
pickle.dump(sequence_det_ms[i], file)
if visualization:
pose_frames = [vis.draw_landmarks((512, 512), i, normed=False) for i in sequence_det_ms]
poses_add = [(i * 0.5 + ref_frame * 0.5).clip(0,255).astype(np.uint8) for i in pose_frames]
# sequence_det_ms = motion_sync(sequence_driver_det, ref_det, per_landmark_align=False)
# for i in range(len(sequence_det_ms)):
# tmp = {}
# tmp["lmks"] = sequence_det_ms[i]
# with open('{}_v2/{}.pkl'.format(save_dir, i), 'wb') as file:
# pickle.dump(tmp, file)
# pose_frames_wo_lmkalign = [vis.draw_landmarks((512, 512), i, normed=False) for i in sequence_det_ms]
# poses_add_wo_lmkalign = [(i * 0.5 + ref_frame * 0.5).clip(0,255).astype(np.uint8) for i in pose_frames_wo_lmkalign]
poses_cat = [np.concatenate([i, j], axis=1) for i, j in zip(poses_add_driver, poses_add)]
save_video_from_cv2_list(poses_cat, "./vis_example.mp4", fps=24.0)
# for ref_image in ref_images[:1]:
# # for driver_video in driver_videos:
# # ref_image = "./samples/007.png"
# # save_dir = '/nas2/jiajiong.caojiajio/data/test_pose/OurDataset/{}'.format(driver_video.split('/')[-1].replace('.mp4', ''))
# save_dir = './{}'.format(ref_image.split('/')[-1].replace('.png', ''))
# os.makedirs(save_dir+'_v1', exist_ok=True)
# os.makedirs(save_dir+'_v2', exist_ok=True)
# #"./samples/hedra_003.png"
# #"./samples/video_temp_fix.mov"
# input_frames_cv2 = [cv2.resize(center_crop_cv2(pil_to_cv2(i)), imsize) for i in pils_from_video(driver_video)]
# # input_frames_cv2 = [cv2.resize(pil_to_cv2(i), imsize) for i in pils_from_video(driver_video)]
# lmk_extractor = LMKExtractor()
# ref_frame =cv2.resize(cv2.imread(ref_image), (512, 512))
# ref_det = lmk_extractor(ref_frame)
# sequence_driver_det = []
# try:
# for frame in input_frames_cv2:
# result = lmk_extractor(frame)
# assert result is not None, "{}, bad video, face not detected".format(driver_video)
# sequence_driver_det.append(result)
# except:
# continue
# print(len(sequence_driver_det))
# # os.makedirs(save_dir, exist_ok=True)
# # for i in range(len(sequence_driver_det)):
# # with open('{}/{}.pkl'.format(save_dir, i), 'wb') as file:
# # pickle.dump(sequence_driver_det[i]["lmks"] * imsize[0], file)
# #[vis.draw_landmarks(imsize, i["lmks"], normed=True, white=True) for i in det_results]
# pose_frames_driver = [vis.draw_landmarks((512, 512), i["lmks"], normed=True) for i in sequence_driver_det]
# poses_add_driver = [(i * 0.5 + j * 0.5).clip(0,255).astype(np.uint8) for i, j in zip(input_frames_cv2, pose_frames_driver)]
# sequence_det_ms = motion_sync(sequence_driver_det, ref_det)
# for i in range(len(sequence_det_ms)):
# tmp = {}
# tmp["lmks"] = sequence_det_ms[i]
# with open('{}_v1/{}.pkl'.format(save_dir, i), 'wb') as file:
# pickle.dump(tmp, file)
# pose_frames = [vis.draw_landmarks((512, 512), i, normed=False) for i in sequence_det_ms]
# poses_add = [(i * 0.5 + ref_frame * 0.5).clip(0,255).astype(np.uint8) for i in pose_frames]
# sequence_det_ms = motion_sync(sequence_driver_det, ref_det, per_landmark_align=False)
# for i in range(len(sequence_det_ms)):
# tmp = {}
# tmp["lmks"] = sequence_det_ms[i]
# with open('{}_v2/{}.pkl'.format(save_dir, i), 'wb') as file:
# pickle.dump(tmp, file)
# pose_frames_wo_lmkalign = [vis.draw_landmarks((512, 512), i, normed=False) for i in sequence_det_ms]
# poses_add_wo_lmkalign = [(i * 0.5 + ref_frame * 0.5).clip(0,255).astype(np.uint8) for i in pose_frames_wo_lmkalign]
# poses_cat = [np.concatenate([i, j, k], axis=1) for i, j, k in zip(poses_add_driver, poses_add_wo_lmkalign, poses_add)]
# save_video_from_cv2_list(poses_cat, "./output/example2.mp4", fps=24.0)
# # exit()
# #embed()
# #poses_cat = [(i * 0.5 + j * 0.5).clip(0,255).astype(np.uint8) for i, j in zip(input_frames_cv2, pose_frames)]
# #save_videos_from_pils([cv2_to_pil(i) for i in poses_cat], "./output/pose_cat.mp4", fps=24) |