GenMM / utils /skeleton.py
wyysf's picture
Duplicate from radames/GenMM-demo
27763e5
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
class SkeletonConv(nn.Module):
def __init__(self, neighbour_list, in_channels, out_channels, kernel_size, joint_num, stride=1, padding=0,
bias=True, padding_mode='zeros', add_offset=False, in_offset_channel=0):
super(SkeletonConv, self).__init__()
if in_channels % joint_num != 0 or out_channels % joint_num != 0:
raise Exception('in/out channels should be divided by joint_num')
self.in_channels_per_joint = in_channels // joint_num
self.out_channels_per_joint = out_channels // joint_num
if padding_mode == 'zeros': padding_mode = 'constant'
self.expanded_neighbour_list = []
self.expanded_neighbour_list_offset = []
self.neighbour_list = neighbour_list
self.add_offset = add_offset
self.joint_num = joint_num
self.stride = stride
self.dilation = 1
self.groups = 1
self.padding = padding
self.padding_mode = padding_mode
self._padding_repeated_twice = (padding, padding)
for neighbour in neighbour_list:
expanded = []
for k in neighbour:
for i in range(self.in_channels_per_joint):
expanded.append(k * self.in_channels_per_joint + i)
self.expanded_neighbour_list.append(expanded)
if self.add_offset:
self.offset_enc = SkeletonLinear(neighbour_list, in_offset_channel * len(neighbour_list), out_channels)
for neighbour in neighbour_list:
expanded = []
for k in neighbour:
for i in range(add_offset):
expanded.append(k * in_offset_channel + i)
self.expanded_neighbour_list_offset.append(expanded)
self.weight = torch.zeros(out_channels, in_channels, kernel_size)
if bias:
self.bias = torch.zeros(out_channels)
else:
self.register_parameter('bias', None)
self.mask = torch.zeros_like(self.weight)
for i, neighbour in enumerate(self.expanded_neighbour_list):
self.mask[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...] = 1
self.mask = nn.Parameter(self.mask, requires_grad=False)
self.description = 'SkeletonConv(in_channels_per_armature={}, out_channels_per_armature={}, kernel_size={}, ' \
'joint_num={}, stride={}, padding={}, bias={})'.format(
in_channels // joint_num, out_channels // joint_num, kernel_size, joint_num, stride, padding, bias
)
self.reset_parameters()
def reset_parameters(self):
for i, neighbour in enumerate(self.expanded_neighbour_list):
""" Use temporary variable to avoid assign to copy of slice, which might lead to un expected result """
tmp = torch.zeros_like(self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1),
neighbour, ...])
nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1),
neighbour, ...] = tmp
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...])
bound = 1 / math.sqrt(fan_in)
tmp = torch.zeros_like(
self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)])
nn.init.uniform_(tmp, -bound, bound)
self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)] = tmp
self.weight = nn.Parameter(self.weight)
if self.bias is not None:
self.bias = nn.Parameter(self.bias)
def set_offset(self, offset):
if not self.add_offset: raise Exception('Wrong Combination of Parameters')
self.offset = offset.reshape(offset.shape[0], -1)
def forward(self, input):
weight_masked = self.weight * self.mask
res = F.conv1d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
weight_masked, self.bias, self.stride,
0, self.dilation, self.groups)
if self.add_offset:
offset_res = self.offset_enc(self.offset)
offset_res = offset_res.reshape(offset_res.shape + (1, ))
res += offset_res / 100
return res
def __repr__(self):
return self.description
class SkeletonLinear(nn.Module):
def __init__(self, neighbour_list, in_channels, out_channels, extra_dim1=False):
super(SkeletonLinear, self).__init__()
self.neighbour_list = neighbour_list
self.in_channels = in_channels
self.out_channels = out_channels
self.in_channels_per_joint = in_channels // len(neighbour_list)
self.out_channels_per_joint = out_channels // len(neighbour_list)
self.extra_dim1 = extra_dim1
self.expanded_neighbour_list = []
for neighbour in neighbour_list:
expanded = []
for k in neighbour:
for i in range(self.in_channels_per_joint):
expanded.append(k * self.in_channels_per_joint + i)
self.expanded_neighbour_list.append(expanded)
self.weight = torch.zeros(out_channels, in_channels)
self.mask = torch.zeros(out_channels, in_channels)
self.bias = nn.Parameter(torch.Tensor(out_channels))
self.reset_parameters()
def reset_parameters(self):
for i, neighbour in enumerate(self.expanded_neighbour_list):
tmp = torch.zeros_like(
self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour]
)
self.mask[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = 1
nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = tmp
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
self.weight = nn.Parameter(self.weight)
self.mask = nn.Parameter(self.mask, requires_grad=False)
def forward(self, input):
input = input.reshape(input.shape[0], -1)
weight_masked = self.weight * self.mask
res = F.linear(input, weight_masked, self.bias)
if self.extra_dim1: res = res.reshape(res.shape + (1,))
return res
class SkeletonPoolJoint(nn.Module):
def __init__(self, topology, pooling_mode, channels_per_joint, last_pool=False):
super(SkeletonPoolJoint, self).__init__()
if pooling_mode != 'mean':
raise Exception('Unimplemented pooling mode in matrix_implementation')
self.joint_num = len(topology)
self.parent = topology
self.pooling_list = []
self.pooling_mode = pooling_mode
self.pooling_map = [-1 for _ in range(len(self.parent))]
self.child = [-1 for _ in range(len(self.parent))]
children_cnt = [0 for _ in range(len(self.parent))]
for x, pa in enumerate(self.parent):
if pa < 0: continue
children_cnt[pa] += 1
self.child[pa] = x
self.pooling_map[0] = 0
for x in range(len(self.parent)):
if children_cnt[x] == 0 or (children_cnt[x] == 1 and children_cnt[self.child[x]] > 1):
while children_cnt[x] <= 1:
pa = self.parent[x]
if last_pool:
seq = [x]
while pa != -1 and children_cnt[pa] == 1:
seq = [pa] + seq
x = pa
pa = self.parent[x]
self.pooling_list.append(seq)
break
else:
if pa != -1 and children_cnt[pa] == 1:
self.pooling_list.append([pa, x])
x = self.parent[pa]
else:
self.pooling_list.append([x, ])
break
elif children_cnt[x] > 1:
self.pooling_list.append([x, ])
self.description = 'SkeletonPool(in_joint_num={}, out_joint_num={})'.format(
len(topology), len(self.pooling_list),
)
self.pooling_list.sort(key=lambda x:x[0])
for i, a in enumerate(self.pooling_list):
for j in a:
self.pooling_map[j] = i
self.output_joint_num = len(self.pooling_list)
self.new_topology = [-1 for _ in range(len(self.pooling_list))]
for i, x in enumerate(self.pooling_list):
if i < 1: continue
self.new_topology[i] = self.pooling_map[self.parent[x[0]]]
self.weight = torch.zeros(len(self.pooling_list) * channels_per_joint, self.joint_num * channels_per_joint)
for i, pair in enumerate(self.pooling_list):
for j in pair:
for c in range(channels_per_joint):
self.weight[i * channels_per_joint + c, j * channels_per_joint + c] = 1.0 / len(pair)
self.weight = nn.Parameter(self.weight, requires_grad=False)
def forward(self, input: torch.Tensor):
return torch.matmul(self.weight, input.unsqueeze(-1)).squeeze(-1)
class SkeletonPool(nn.Module):
def __init__(self, edges, pooling_mode, channels_per_edge, last_pool=False):
super(SkeletonPool, self).__init__()
if pooling_mode != 'mean':
raise Exception('Unimplemented pooling mode in matrix_implementation')
self.channels_per_edge = channels_per_edge
self.pooling_mode = pooling_mode
self.edge_num = len(edges) + 1
self.seq_list = []
self.pooling_list = []
self.new_edges = []
degree = [0] * 100
for edge in edges:
degree[edge[0]] += 1
degree[edge[1]] += 1
def find_seq(j, seq):
nonlocal self, degree, edges
if degree[j] > 2 and j != 0:
self.seq_list.append(seq)
seq = []
if degree[j] == 1:
self.seq_list.append(seq)
return
for idx, edge in enumerate(edges):
if edge[0] == j:
find_seq(edge[1], seq + [idx])
find_seq(0, [])
for seq in self.seq_list:
if last_pool:
self.pooling_list.append(seq)
continue
if len(seq) % 2 == 1:
self.pooling_list.append([seq[0]])
self.new_edges.append(edges[seq[0]])
seq = seq[1:]
for i in range(0, len(seq), 2):
self.pooling_list.append([seq[i], seq[i + 1]])
self.new_edges.append([edges[seq[i]][0], edges[seq[i + 1]][1]])
# add global position
self.pooling_list.append([self.edge_num - 1])
self.description = 'SkeletonPool(in_edge_num={}, out_edge_num={})'.format(
len(edges), len(self.pooling_list)
)
self.weight = torch.zeros(len(self.pooling_list) * channels_per_edge, self.edge_num * channels_per_edge)
for i, pair in enumerate(self.pooling_list):
for j in pair:
for c in range(channels_per_edge):
self.weight[i * channels_per_edge + c, j * channels_per_edge + c] = 1.0 / len(pair)
self.weight = nn.Parameter(self.weight, requires_grad=False)
def forward(self, input: torch.Tensor):
return torch.matmul(self.weight, input)
class SkeletonUnpool(nn.Module):
def __init__(self, pooling_list, channels_per_edge):
super(SkeletonUnpool, self).__init__()
self.pooling_list = pooling_list
self.input_joint_num = len(pooling_list)
self.output_joint_num = 0
self.channels_per_edge = channels_per_edge
for t in self.pooling_list:
self.output_joint_num += len(t)
self.description = 'SkeletonUnpool(in_joint_num={}, out_joint_num={})'.format(
self.input_joint_num, self.output_joint_num,
)
self.weight = torch.zeros(self.output_joint_num * channels_per_edge, self.input_joint_num * channels_per_edge)
for i, pair in enumerate(self.pooling_list):
for j in pair:
for c in range(channels_per_edge):
self.weight[j * channels_per_edge + c, i * channels_per_edge + c] = 1
self.weight = nn.Parameter(self.weight)
self.weight.requires_grad_(False)
def forward(self, input: torch.Tensor):
return torch.matmul(self.weight, input.unsqueeze(-1)).squeeze(-1)
def find_neighbor_joint(parents, threshold):
n_joint = len(parents)
dist_mat = np.empty((n_joint, n_joint), dtype=np.int)
dist_mat[:, :] = 100000
for i, p in enumerate(parents):
dist_mat[i, i] = 0
if i != 0:
dist_mat[i, p] = dist_mat[p, i] = 1
"""
Floyd's algorithm
"""
for k in range(n_joint):
for i in range(n_joint):
for j in range(n_joint):
dist_mat[i, j] = min(dist_mat[i, j], dist_mat[i, k] + dist_mat[k, j])
neighbor_list = []
for i in range(n_joint):
neighbor = []
for j in range(n_joint):
if dist_mat[i, j] <= threshold:
neighbor.append(j)
neighbor_list.append(neighbor)
return neighbor_list