import os, sys import yaml from argparse import ArgumentParser from tqdm import tqdm import modules.generator as GEN import imageio import numpy as np from skimage.transform import resize from skimage import img_as_ubyte import torch from sync_batchnorm import DataParallelWithCallback import depth from modules.keypoint_detector import KPDetector from scipy.spatial import ConvexHull from collections import OrderedDict import warnings warnings.filterwarnings("ignore") def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False, use_relative_movement=False, use_relative_jacobian=False): if adapt_movement_scale: source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) else: adapt_movement_scale = 1 kp_new = {k: v for k, v in kp_driving.items()} if use_relative_movement: kp_value_diff = (kp_driving['value'] - kp_driving_initial['value']) kp_value_diff *= adapt_movement_scale kp_new['value'] = kp_value_diff + kp_source['value'] if use_relative_jacobian: jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian'])) kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian']) return kp_new def find_best_frame(source, driving, cpu=False): import face_alignment def normalize_kp(kp): kp = kp - kp.mean(axis=0, keepdims=True) area = ConvexHull(kp[:, :2]).volume area = np.sqrt(area) kp[:, :2] = kp[:, :2] / area return kp fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True, device='cpu' if cpu else 'cuda') kp_source = fa.get_landmarks(255 * source)[0] kp_source = normalize_kp(kp_source) norm = float('inf') frame_num = 0 for i, image in tqdm(enumerate(driving)): kp_driving = fa.get_landmarks(255 * image)[0] kp_driving = normalize_kp(kp_driving) new_norm = (np.abs(kp_source - kp_driving) ** 2).sum() if new_norm < norm: norm = new_norm frame_num = i return frame_num def animation(source_image, driving_video, generator, kp_detector, depth_encoder, depth_decoder, relative=True, adapt_scale=True, cpu=False): with torch.no_grad(): predictions = [] source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2) driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3) if not cpu: source = source.cuda() driving = driving.cuda() outputs = depth_decoder(depth_encoder(source)) depth_source = outputs[("disp", 0)] outputs = depth_decoder(depth_encoder(driving[:, :, 0])) depth_driving = outputs[("disp", 0)] source_kp = torch.cat((source,depth_source),1) driving_kp = torch.cat((driving[:, :, 0],depth_driving),1) kp_source = kp_detector(source_kp) kp_driving_initial = kp_detector(driving_kp) for frame_idx in tqdm(range(driving.shape[2])): driving_frame = driving[:, :, frame_idx] if not cpu: driving_frame = driving_frame.cuda() outputs = depth_decoder(depth_encoder(driving_frame)) depth_map = outputs[("disp", 0)] frame = torch.cat((driving_frame,depth_map),1) kp_driving = kp_detector(frame) kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving, kp_driving_initial=kp_driving_initial, use_relative_movement=relative, use_relative_jacobian=relative, adapt_movement_scale=adapt_scale) out = generator(source, kp_source=kp_source, kp_driving=kp_norm,source_depth = depth_source, driving_depth = depth_map) predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) return predictions def reconstruction(source_image, driving_video, generator, kp_detector, depth_encoder, depth_decoder, cpu=False): with torch.no_grad(): predictions = [] source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2) driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3) if not cpu: source = source.cuda() driving = driving.cuda() outputs = depth_decoder(depth_encoder(source)) depth_source = outputs[("disp", 0)] outputs = depth_decoder(depth_encoder(driving[:, :, 0])) depth_driving = outputs[("disp", 0)] source_kp = torch.cat((source,depth_source),1) driving_kp = torch.cat((driving[:, :, 0],depth_driving),1) kp_source = kp_detector(source_kp) kp_driving_initial = kp_detector(driving_kp) for frame_idx in tqdm(range(driving.shape[2])): driving_frame = driving[:, :, frame_idx] if not cpu: driving_frame = driving_frame.cuda() outputs = depth_decoder(depth_encoder(driving_frame)) depth_map = outputs[("disp", 0)] frame = torch.cat((driving_frame,depth_map),1) kp_driving = kp_detector(frame) out = generator(source, kp_source=kp_source, kp_driving=kp_driving,source_depth = depth_source, driving_depth = depth_map) predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) return predictions parser = ArgumentParser() parser.add_argument("--source_image", help="path to source image") parser.add_argument("--driving_video", help="path to driving video") parser.add_argument("--result_video", help="path to output") parser.add_argument("--mode", type=str, choices=["reconstruction", "animation"], help="mode to run") opt = parser.parse_args() with open("config/vox-adv-256.yaml") as f: config = yaml.load(f) generator = GEN.SPADEDepthAwareGenerator(**config['model_params']['generator_params'],**config['model_params']['common_params']) config['model_params']['common_params']['num_channels'] = 4 kp_detector = KPDetector(**config['model_params']['kp_detector_params'], **config['model_params']['common_params']) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') cpu = False if torch.cuda.is_available() else True g_checkpoint = torch.load("checkpoints/generator.pth", map_location=device) kp_checkpoint = torch.load("checkpoints/kp_detector.pth", map_location=device) ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in g_checkpoint.items()) generator.load_state_dict(ckp_generator) ckp_kp_detector = OrderedDict((k.replace('module.',''),v) for k,v in kp_checkpoint.items()) kp_detector.load_state_dict(ckp_kp_detector) depth_encoder = depth.ResnetEncoder(18, False) depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4)) loaded_dict_enc = torch.load('checkpoints/encoder.pth', map_location=device) loaded_dict_dec = torch.load('checkpoints/depth.pth', map_location=device) filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()} depth_encoder.load_state_dict(filtered_dict_enc) depth_decoder.load_state_dict(loaded_dict_dec) depth_encoder.eval() depth_decoder.eval() generator.eval() kp_detector.eval() generator.to(device) kp_detector.to(device) depth_encoder.to(device) depth_decoder.to(device) with torch.inference_mode(): if torch.cuda.is_available(): torch.cuda.ipc_collect() torch.cuda.empty_cache() reader = imageio.get_reader(opt.driving_video) fps = reader.get_meta_data()['fps'] driving_video = [] try: for im in reader: driving_video.append(im) except RuntimeError: pass reader.close() if opt.mode == 'animation': source_image = imageio.imread(opt.source_image) elif opt.mode == 'reconstruction': source_image = driving_video[0] source_image = resize(source_image, (256, 256))[..., :3] driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video] if opt.mode == 'animation': i = find_best_frame(source_image, driving_video, cpu=cpu) driving_forward = driving_video[i:] driving_backward = driving_video[:(i+1)][::-1] predictions_forward = animation(source_image, driving_forward, generator, kp_detector, depth_encoder, depth_decoder, cpu=cpu) predictions_backward = animation(source_image, driving_backward, generator, kp_detector, depth_encoder, depth_decoder, cpu=cpu) predictions = predictions_backward[::-1] + predictions_forward[1:] elif opt.mode == 'reconstruction': predictions = reconstruction(source_image, driving_video, generator, kp_detector, depth_encoder, depth_decoder, cpu) imageio.mimsave(opt.result_video, [img_as_ubyte(p) for p in predictions], fps=fps)