TANGO / create_graph.py
H-Liu1997's picture
Update create_graph.py
28ffb97 verified
raw
history blame
20.3 kB
"""
input: json file with video, audio, motion paths
output: igraph object with nodes containing video, audio, motion, position, velocity, axis_angle, previous, next, frame, fps
preprocess:
1. assume you have a video for one speaker in folder, listed in
-- video_a.mp4
-- video_b.mp4
run process_video.py to extract frames and audio
"""
import os
import smplx
import torch
import numpy as np
import cv2
import librosa
import igraph
import json
import utils.rotation_conversions as rc
from moviepy.editor import VideoClip, AudioFileClip, VideoFileClip
from tqdm import tqdm
import imageio
import tempfile
import argparse
def get_motion_reps_tensor(motion_tensor, smplx_model, pose_fps=30, device='cuda'):
bs, n, _ = motion_tensor.shape
motion_tensor = motion_tensor.float().to(device)
motion_tensor_reshaped = motion_tensor.reshape(bs * n, 165)
output = smplx_model(
betas=torch.zeros(bs * n, 300, device=device),
transl=torch.zeros(bs * n, 3, device=device),
expression=torch.zeros(bs * n, 100, device=device),
jaw_pose=torch.zeros(bs * n, 3, device=device),
global_orient=torch.zeros(bs * n, 3, device=device),
body_pose=motion_tensor_reshaped[:, 3:21 * 3 + 3],
left_hand_pose=motion_tensor_reshaped[:, 25 * 3:40 * 3],
right_hand_pose=motion_tensor_reshaped[:, 40 * 3:55 * 3],
return_joints=True,
leye_pose=torch.zeros(bs * n, 3, device=device),
reye_pose=torch.zeros(bs * n, 3, device=device),
)
joints = output['joints'].reshape(bs, n, 127, 3)[:, :, :55, :]
dt = 1 / pose_fps
init_vel = (joints[:, 1:2] - joints[:, 0:1]) / dt
middle_vel = (joints[:, 2:] - joints[:, :-2]) / (2 * dt)
final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt
vel = torch.cat([init_vel, middle_vel, final_vel], dim=1)
position = joints
rot_matrices = rc.axis_angle_to_matrix(motion_tensor.reshape(bs, n, 55, 3))
rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(bs, n, 55, 6)
init_vel_ang = (motion_tensor[:, 1:2] - motion_tensor[:, 0:1]) / dt
middle_vel_ang = (motion_tensor[:, 2:] - motion_tensor[:, :-2]) / (2 * dt)
final_vel_ang = (motion_tensor[:, -1:] - motion_tensor[:, -2:-1]) / dt
angular_velocity = torch.cat([init_vel_ang, middle_vel_ang, final_vel_ang], dim=1).reshape(bs, n, 55, 3)
rep15d = torch.cat([position, vel, rot6d, angular_velocity], dim=3).reshape(bs, n, 55 * 15)
return {
"position": position,
"velocity": vel,
"rotation": rot6d,
"axis_angle": motion_tensor,
"angular_velocity": angular_velocity,
"rep15d": rep15d,
}
def get_motion_reps(motion, smplx_model, pose_fps=30):
gt_motion_tensor = motion["poses"]
n = gt_motion_tensor.shape[0]
bs = 1
gt_motion_tensor = torch.from_numpy(gt_motion_tensor).float().to(device).unsqueeze(0)
gt_motion_tensor_reshaped = gt_motion_tensor.reshape(bs * n, 165)
output = smplx_model(
betas=torch.zeros(bs * n, 300).to(device),
transl=torch.zeros(bs * n, 3).to(device),
expression=torch.zeros(bs * n, 100).to(device),
jaw_pose=torch.zeros(bs * n, 3).to(device),
global_orient=torch.zeros(bs * n, 3).to(device),
body_pose=gt_motion_tensor_reshaped[:, 3:21 * 3 + 3],
left_hand_pose=gt_motion_tensor_reshaped[:, 25 * 3:40 * 3],
right_hand_pose=gt_motion_tensor_reshaped[:, 40 * 3:55 * 3],
return_joints=True,
leye_pose=torch.zeros(bs * n, 3).to(device),
reye_pose=torch.zeros(bs * n, 3).to(device),
)
joints = output["joints"].detach().cpu().numpy().reshape(n, 127, 3)[:, :55, :]
dt = 1 / pose_fps
init_vel = (joints[1:2] - joints[0:1]) / dt
middle_vel = (joints[2:] - joints[:-2]) / (2 * dt)
final_vel = (joints[-1:] - joints[-2:-1]) / dt
vel = np.concatenate([init_vel, middle_vel, final_vel], axis=0)
position = joints
rot_matrices = rc.axis_angle_to_matrix(gt_motion_tensor.reshape(1, n, 55, 3))[0]
rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(n, 55, 6).cpu().numpy()
init_vel = (motion["poses"][1:2] - motion["poses"][0:1]) / dt
middle_vel = (motion["poses"][2:] - motion["poses"][:-2]) / (2 * dt)
final_vel = (motion["poses"][-1:] - motion["poses"][-2:-1]) / dt
angular_velocity = np.concatenate([init_vel, middle_vel, final_vel], axis=0).reshape(n, 55, 3)
rep15d = np.concatenate([
position,
vel,
rot6d,
angular_velocity],
axis=2
).reshape(n, 55*15)
return {
"position": position,
"velocity": vel,
"rotation": rot6d,
"axis_angle": motion["poses"],
"angular_velocity": angular_velocity,
"rep15d": rep15d,
"trans": motion["trans"]
}
def create_graph(json_path, smplx_model):
fps = 30
data_meta = json.load(open(json_path, "r"))
graph = igraph.Graph(directed=True)
global_i = 0
for data_item in data_meta:
video_path = os.path.join(data_item['video_path'], data_item['video_id'] + ".mp4")
# audio_path = os.path.join(data_item['audio_path'], data_item['video_id'] + ".wav")
motion_path = os.path.join(data_item['motion_path'], data_item['video_id'] + ".npz")
video_id = data_item.get("video_id", "")
motion = np.load(motion_path, allow_pickle=True)
motion_reps = get_motion_reps(motion, smplx_model)
position = motion_reps['position']
velocity = motion_reps['velocity']
trans = motion_reps['trans']
axis_angle = motion_reps['axis_angle']
# audio, sr = librosa.load(audio_path, sr=None)
# audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
all_frames = []
reader = imageio.get_reader(video_path)
all_frames = []
for frame in reader:
all_frames.append(frame)
video_frames = np.array(all_frames)
min_frames = min(len(video_frames), position.shape[0])
position = position[:min_frames]
velocity = velocity[:min_frames]
video_frames = video_frames[:min_frames]
# print(min_frames)
for i in tqdm(range(min_frames)):
if i == 0:
previous = -1
next_node = global_i + 1
elif i == min_frames - 1:
previous = global_i - 1
next_node = -1
else:
previous = global_i - 1
next_node = global_i + 1
graph.add_vertex(
idx=global_i,
name=video_id,
motion=motion_reps,
position=position[i],
velocity=velocity[i],
axis_angle=axis_angle[i],
trans=trans[i],
# audio=audio[],
video=video_frames[i],
previous=previous,
next=next_node,
frame=i,
fps=fps,
)
global_i += 1
return graph
def create_edges(graph, threshold_edges):
adaptive_length = [-4, -3, -2, -1, 1, 2, 3, 4]
# print()
for i, node in enumerate(graph.vs):
current_position = node['position']
current_velocity = node['velocity']
current_trans = node['trans']
# print(current_position.shape, current_velocity.shape)
avg_position = np.zeros(current_position.shape[0])
avg_velocity = np.zeros(current_position.shape[0])
avg_trans = 0
count = 0
for node_offset in adaptive_length:
idx = i + node_offset
if idx < 0 or idx >= len(graph.vs):
continue
if node_offset < 0:
if graph.vs[idx]['next'] == -1:continue
else:
if graph.vs[idx]['previous'] == -1:continue
# add check
other_node = graph.vs[idx]
other_position = other_node['position']
other_velocity = other_node['velocity']
other_trans = other_node['trans']
# print(other_position.shape, other_velocity.shape)
avg_position += np.linalg.norm(current_position - other_position, axis=1)
avg_velocity += np.linalg.norm(current_velocity - other_velocity, axis=1)
avg_trans += np.linalg.norm(current_trans - other_trans, axis=0)
count += 1
if count == 0:
continue
threshold_position = avg_position / count
threshold_velocity = avg_velocity / count
threshold_trans = avg_trans / count
# print(threshold_position, threshold_velocity, threshold_trans)
for j, other_node in enumerate(graph.vs):
if i == j:
continue
if j == node['previous'] or j == node['next']:
graph.add_edge(i, j, is_continue=1)
continue
other_position = other_node['position']
other_velocity = other_node['velocity']
other_trans = other_node['trans']
position_similarity = np.linalg.norm(current_position - other_position, axis=1)
velocity_similarity = np.linalg.norm(current_velocity - other_velocity, axis=1)
trans_similarity = np.linalg.norm(current_trans - other_trans, axis=0)
if trans_similarity < threshold_trans:
if np.sum(position_similarity < threshold_edges*threshold_position) >= 45 and np.sum(velocity_similarity < threshold_edges*threshold_velocity) >= 45:
graph.add_edge(i, j, is_continue=0)
print(f"nodes: {len(graph.vs)}, edges: {len(graph.es)}")
in_degrees = graph.indegree()
out_degrees = graph.outdegree()
avg_in_degree = sum(in_degrees) / len(in_degrees)
avg_out_degree = sum(out_degrees) / len(out_degrees)
print(f"Average In-degree: {avg_in_degree}")
print(f"Average Out-degree: {avg_out_degree}")
print(f"max in degree: {max(in_degrees)}, max out degree: {max(out_degrees)}")
print(f"min in degree: {min(in_degrees)}, min out degree: {min(out_degrees)}")
# igraph.plot(graph, target="/content/test.png", bbox=(1000, 1000), vertex_size=10)
return graph
def random_walk(graph, walk_length, start_node=None):
if start_node is None:
start_node = np.random.choice(graph.vs)
walk = [start_node]
is_continue = [1]
for _ in range(walk_length):
current_node = walk[-1]
neighbor_indices = graph.neighbors(current_node.index, mode='OUT')
if not neighbor_indices:
break
next_idx = np.random.choice(neighbor_indices)
edge_id = graph.get_eid(current_node.index, next_idx)
is_cont = graph.es[edge_id]['is_continue']
walk.append(graph.vs[next_idx])
is_continue.append(is_cont)
return walk, is_continue
import subprocess
def path_visualization(graph, path, is_continue, save_path, verbose_continue=False, audio_path=None, return_motion=False):
all_frames = [node['video'] for node in path]
average_dis_continue = 1 - sum(is_continue) / len(is_continue)
if verbose_continue:
print("average_dis_continue:", average_dis_continue)
fps = graph.vs[0]['fps']
duration = len(all_frames) / fps
def make_frame(t):
idx = min(int(t * fps), len(all_frames) - 1)
return all_frames[idx]
video_only_path = 'video_only.mp4' # Temporary file
video_clip = VideoClip(make_frame, duration=duration)
video_clip.write_videofile(
video_only_path,
codec='libx264',
fps=fps,
audio=False
)
# Optionally, ensure audio and video durations match
if audio_path is not None:
audio_clip = AudioFileClip(audio_path)
video_duration = video_clip.duration
audio_duration = audio_clip.duration
if audio_duration > video_duration:
# Trim the audio
trimmed_audio_path = 'trimmed_audio.aac'
audio_clip = audio_clip.subclip(0, video_duration)
audio_clip.write_audiofile(trimmed_audio_path)
audio_input = trimmed_audio_path
else:
audio_input = audio_path
# Use FFmpeg to combine video and audio
ffmpeg_command = [
'ffmpeg', '-y',
'-i', video_only_path,
'-i', audio_input,
'-c:v', 'copy',
'-c:a', 'aac',
'-strict', 'experimental',
save_path
]
subprocess.check_call(ffmpeg_command)
# Clean up temporary files if necessary
os.remove(video_only_path)
if audio_input != audio_path:
os.remove(audio_input)
if return_motion:
all_motion = [node['axis_angle'] for node in path]
all_motion = np.stack(all_motion, 0)
return all_motion
def generate_transition_video(frame_start_path, frame_end_path, output_video_path):
import subprocess
import os
# Define the path to your model and inference script
model_path = "./frame-interpolation-pytorch/film_net_fp32.pt"
inference_script = "./frame-interpolation-pytorch/inference.py"
# Build the command to run the inference script
command = [
"python",
inference_script,
model_path,
frame_start_path,
frame_end_path,
"--save_path", output_video_path,
"--gpu",
"--frames", "3",
"--fps", "30"
]
# Run the command
try:
subprocess.run(command, check=True)
print(f"Generated transition video saved at {output_video_path}")
except subprocess.CalledProcessError as e:
print(f"Error occurred while generating transition video: {e}")
def path_visualization_v2(graph, path, is_continue, save_path, verbose_continue=False, audio_path=None, return_motion=False):
'''
this is for hugging face demo for fast interpolation. our paper use a diffusion based interpolation method
'''
all_frames = [node['video'] for node in path]
average_dis_continue = 1 - sum(is_continue) / len(is_continue)
if verbose_continue:
print("average_dis_continue:", average_dis_continue)
duration = len(all_frames) / graph.vs[0]['fps']
# First loop: Confirm where blending is needed
discontinuity_indices = []
for i, cont in enumerate(is_continue):
if cont == 0:
discontinuity_indices.append(i)
# Identify blending positions without overlapping
blend_positions = []
processed_frames = set()
for i in discontinuity_indices:
# Define the frames for blending: i-2 to i+2
start_idx = i - 2
end_idx = i + 2
# Check index boundaries
if start_idx < 0 or end_idx >= len(all_frames):
continue # Skip if indices are out of bounds
# Check for overlapping frames
overlap = any(idx in processed_frames for idx in range(i - 1, i + 2))
if overlap:
continue # Skip if frames have been processed
# Mark frames as processed
processed_frames.update(range(i - 1, i + 2))
blend_positions.append(i)
# Second loop: Perform blending
temp_dir = tempfile.mkdtemp(prefix='blending_frames_')
for i in tqdm(blend_positions):
start_frame_idx = i - 2
end_frame_idx = i + 2
frame_start = all_frames[start_frame_idx]
frame_end = all_frames[end_frame_idx]
frame_start_path = os.path.join(temp_dir, f'frame_{start_frame_idx}.png')
frame_end_path = os.path.join(temp_dir, f'frame_{end_frame_idx}.png')
# Save the start and end frames as images
imageio.imwrite(frame_start_path, frame_start)
imageio.imwrite(frame_end_path, frame_end)
# Call FiLM API to generate video
generated_video_path = os.path.join(temp_dir, f'generated_{start_frame_idx}_{end_frame_idx}.mp4')
generate_transition_video(frame_start_path, frame_end_path, generated_video_path)
# Read the generated video frames
reader = imageio.get_reader(generated_video_path)
generated_frames = [frame for frame in reader]
reader.close()
# Replace the middle three frames (i-1, i, i+1) in all_frames
total_generated_frames = len(generated_frames)
if total_generated_frames < 5:
print(f"Generated video has insufficient frames ({total_generated_frames}). Skipping blending at position {i}.")
continue
middle_start = 1 # Start index for middle 3 frames
middle_frames = generated_frames[middle_start:middle_start+3]
for idx, frame_idx in enumerate(range(i - 1, i + 2)):
all_frames[frame_idx] = middle_frames[idx]
# Create the video clip
def make_frame(t):
idx = min(int(t * graph.vs[0]['fps']), len(all_frames) - 1)
return all_frames[idx]
video_clip = VideoClip(make_frame, duration=duration)
if audio_path is not None:
audio_clip = AudioFileClip(audio_path)
video_clip = video_clip.set_audio(audio_clip)
video_clip.write_videofile(save_path, codec='libx264', fps=graph.vs[0]['fps'], audio_codec='aac')
if return_motion:
all_motion = [node['axis_angle'] for node in path]
all_motion = np.stack(all_motion, 0)
return all_motion
def graph_pruning(graph):
ascc = graph.clusters(mode="STRONG")
lascc = ascc.giant()
print(f"before nodes: {len(graph.vs)}, edges: {len(graph.es)}")
print(f"after nodes: {len(lascc.vs)}, edges: {len(lascc.es)}")
in_degrees = lascc.indegree()
out_degrees = lascc.outdegree()
avg_in_degree = sum(in_degrees) / len(in_degrees)
avg_out_degree = sum(out_degrees) / len(out_degrees)
print(f"Average In-degree: {avg_in_degree}")
print(f"Average Out-degree: {avg_out_degree}")
print(f"max in degree: {max(in_degrees)}, max out degree: {max(out_degrees)}")
print(f"min in degree: {min(in_degrees)}, min out degree: {min(out_degrees)}")
return lascc
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--json_save_path", type=str, default="")
parser.add_argument("--graph_save_path", type=str, default="")
parser.add_argument("--threshold", type=float, default=1.0)
args = parser.parse_args()
json_path = args.json_save_path
graph_path = args.graph_save_path
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
smplx_model = smplx.create(
"./emage/smplx_models/",
model_type='smplx',
gender='NEUTRAL_2020',
use_face_contour=False,
num_betas=300,
num_expression_coeffs=100,
ext='npz',
use_pca=False,
).to(device).eval()
# single_test
# graph = create_graph('/content/drive/MyDrive/003_Codes/TANGO/datasets/data_json/show_oliver_test/Abortion_Laws_-_Last_Week_Tonight_with_John_Oliver_HBO-DRauXXz6t0Y.webm.json')
graph = create_graph(json_path, smplx_model)
graph = create_edges(graph, args.threshold)
# pool_path = "/content/drive/MyDrive/003_Codes/TANGO-JointEmbedding/datasets/oliver_test/show-oliver-test.pkl"
# graph = igraph.Graph.Read_Pickle(fname=pool_path)
# graph = igraph.Graph.Read_Pickle(fname="/content/drive/MyDrive/003_Codes/TANGO-JointEmbedding/datasets/oliver_test/test.pkl")
# walk, is_continue = random_walk(graph, 100)
# motion = path_visualization(graph, walk, is_continue, "./test.mp4", audio_path=None, verbose_continue=True, return_motion=True)
# print(motion.shape)
save_graph = graph.write_pickle(fname=graph_path)
# graph = graph_pruning(graph)
# show-oliver
# json_path = "/content/drive/MyDrive/003_Codes/TANGO/datasets/data_json/show_oliver_test/"
# pre_node_path = "/content/drive/MyDrive/003_Codes/TANGO/datasets/cached_graph/show_oliver_test/"
# for json_file in tqdm(os.listdir(json_path)):
# graph = create_graph(os.path.join(json_path, json_file))
# graph = create_edges(graph)
# if not len(graph.vs) >= 1500:
# print(f"skip: {len(graph.vs)}", json_file)
# graph.write_pickle(fname=os.path.join(pre_node_path, json_file.split(".")[0] + ".pkl"))
# print(f"Graph saved at {json_file.split('.')[0]}.pkl")