|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import math |
|
import numpy as np |
|
|
|
|
|
class SkeletonConv(nn.Module): |
|
def __init__(self, neighbour_list, in_channels, out_channels, kernel_size, joint_num, stride=1, padding=0, |
|
bias=True, padding_mode='zeros', add_offset=False, in_offset_channel=0): |
|
super(SkeletonConv, self).__init__() |
|
|
|
if in_channels % joint_num != 0 or out_channels % joint_num != 0: |
|
raise Exception('in/out channels should be divided by joint_num') |
|
self.in_channels_per_joint = in_channels // joint_num |
|
self.out_channels_per_joint = out_channels // joint_num |
|
|
|
if padding_mode == 'zeros': padding_mode = 'constant' |
|
|
|
self.expanded_neighbour_list = [] |
|
self.expanded_neighbour_list_offset = [] |
|
self.neighbour_list = neighbour_list |
|
self.add_offset = add_offset |
|
self.joint_num = joint_num |
|
|
|
self.stride = stride |
|
self.dilation = 1 |
|
self.groups = 1 |
|
self.padding = padding |
|
self.padding_mode = padding_mode |
|
self._padding_repeated_twice = (padding, padding) |
|
|
|
for neighbour in neighbour_list: |
|
expanded = [] |
|
for k in neighbour: |
|
for i in range(self.in_channels_per_joint): |
|
expanded.append(k * self.in_channels_per_joint + i) |
|
self.expanded_neighbour_list.append(expanded) |
|
|
|
if self.add_offset: |
|
self.offset_enc = SkeletonLinear(neighbour_list, in_offset_channel * len(neighbour_list), out_channels) |
|
|
|
for neighbour in neighbour_list: |
|
expanded = [] |
|
for k in neighbour: |
|
for i in range(add_offset): |
|
expanded.append(k * in_offset_channel + i) |
|
self.expanded_neighbour_list_offset.append(expanded) |
|
|
|
self.weight = torch.zeros(out_channels, in_channels, kernel_size) |
|
if bias: |
|
self.bias = torch.zeros(out_channels) |
|
else: |
|
self.register_parameter('bias', None) |
|
|
|
self.mask = torch.zeros_like(self.weight) |
|
for i, neighbour in enumerate(self.expanded_neighbour_list): |
|
self.mask[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...] = 1 |
|
self.mask = nn.Parameter(self.mask, requires_grad=False) |
|
|
|
self.description = 'SkeletonConv(in_channels_per_armature={}, out_channels_per_armature={}, kernel_size={}, ' \ |
|
'joint_num={}, stride={}, padding={}, bias={})'.format( |
|
in_channels // joint_num, out_channels // joint_num, kernel_size, joint_num, stride, padding, bias |
|
) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
for i, neighbour in enumerate(self.expanded_neighbour_list): |
|
""" Use temporary variable to avoid assign to copy of slice, which might lead to un expected result """ |
|
tmp = torch.zeros_like(self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), |
|
neighbour, ...]) |
|
nn.init.kaiming_uniform_(tmp, a=math.sqrt(5)) |
|
self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), |
|
neighbour, ...] = tmp |
|
if self.bias is not None: |
|
fan_in, _ = nn.init._calculate_fan_in_and_fan_out( |
|
self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...]) |
|
bound = 1 / math.sqrt(fan_in) |
|
tmp = torch.zeros_like( |
|
self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)]) |
|
nn.init.uniform_(tmp, -bound, bound) |
|
self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)] = tmp |
|
|
|
self.weight = nn.Parameter(self.weight) |
|
if self.bias is not None: |
|
self.bias = nn.Parameter(self.bias) |
|
|
|
def set_offset(self, offset): |
|
if not self.add_offset: raise Exception('Wrong Combination of Parameters') |
|
self.offset = offset.reshape(offset.shape[0], -1) |
|
|
|
def forward(self, input): |
|
weight_masked = self.weight * self.mask |
|
res = F.conv1d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), |
|
weight_masked, self.bias, self.stride, |
|
0, self.dilation, self.groups) |
|
|
|
if self.add_offset: |
|
offset_res = self.offset_enc(self.offset) |
|
offset_res = offset_res.reshape(offset_res.shape + (1, )) |
|
res += offset_res / 100 |
|
return res |
|
|
|
def __repr__(self): |
|
return self.description |
|
|
|
|
|
class SkeletonLinear(nn.Module): |
|
def __init__(self, neighbour_list, in_channels, out_channels, extra_dim1=False): |
|
super(SkeletonLinear, self).__init__() |
|
self.neighbour_list = neighbour_list |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.in_channels_per_joint = in_channels // len(neighbour_list) |
|
self.out_channels_per_joint = out_channels // len(neighbour_list) |
|
self.extra_dim1 = extra_dim1 |
|
self.expanded_neighbour_list = [] |
|
|
|
for neighbour in neighbour_list: |
|
expanded = [] |
|
for k in neighbour: |
|
for i in range(self.in_channels_per_joint): |
|
expanded.append(k * self.in_channels_per_joint + i) |
|
self.expanded_neighbour_list.append(expanded) |
|
|
|
self.weight = torch.zeros(out_channels, in_channels) |
|
self.mask = torch.zeros(out_channels, in_channels) |
|
self.bias = nn.Parameter(torch.Tensor(out_channels)) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
for i, neighbour in enumerate(self.expanded_neighbour_list): |
|
tmp = torch.zeros_like( |
|
self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] |
|
) |
|
self.mask[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = 1 |
|
nn.init.kaiming_uniform_(tmp, a=math.sqrt(5)) |
|
self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = tmp |
|
|
|
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) |
|
bound = 1 / math.sqrt(fan_in) |
|
nn.init.uniform_(self.bias, -bound, bound) |
|
|
|
self.weight = nn.Parameter(self.weight) |
|
self.mask = nn.Parameter(self.mask, requires_grad=False) |
|
|
|
def forward(self, input): |
|
input = input.reshape(input.shape[0], -1) |
|
weight_masked = self.weight * self.mask |
|
res = F.linear(input, weight_masked, self.bias) |
|
if self.extra_dim1: res = res.reshape(res.shape + (1,)) |
|
return res |
|
|
|
|
|
class SkeletonPoolJoint(nn.Module): |
|
def __init__(self, topology, pooling_mode, channels_per_joint, last_pool=False): |
|
super(SkeletonPoolJoint, self).__init__() |
|
|
|
if pooling_mode != 'mean': |
|
raise Exception('Unimplemented pooling mode in matrix_implementation') |
|
|
|
self.joint_num = len(topology) |
|
self.parent = topology |
|
self.pooling_list = [] |
|
self.pooling_mode = pooling_mode |
|
|
|
self.pooling_map = [-1 for _ in range(len(self.parent))] |
|
self.child = [-1 for _ in range(len(self.parent))] |
|
children_cnt = [0 for _ in range(len(self.parent))] |
|
for x, pa in enumerate(self.parent): |
|
if pa < 0: continue |
|
children_cnt[pa] += 1 |
|
self.child[pa] = x |
|
self.pooling_map[0] = 0 |
|
for x in range(len(self.parent)): |
|
if children_cnt[x] == 0 or (children_cnt[x] == 1 and children_cnt[self.child[x]] > 1): |
|
while children_cnt[x] <= 1: |
|
pa = self.parent[x] |
|
if last_pool: |
|
seq = [x] |
|
while pa != -1 and children_cnt[pa] == 1: |
|
seq = [pa] + seq |
|
x = pa |
|
pa = self.parent[x] |
|
self.pooling_list.append(seq) |
|
break |
|
else: |
|
if pa != -1 and children_cnt[pa] == 1: |
|
self.pooling_list.append([pa, x]) |
|
x = self.parent[pa] |
|
else: |
|
self.pooling_list.append([x, ]) |
|
break |
|
elif children_cnt[x] > 1: |
|
self.pooling_list.append([x, ]) |
|
|
|
self.description = 'SkeletonPool(in_joint_num={}, out_joint_num={})'.format( |
|
len(topology), len(self.pooling_list), |
|
) |
|
|
|
self.pooling_list.sort(key=lambda x:x[0]) |
|
for i, a in enumerate(self.pooling_list): |
|
for j in a: |
|
self.pooling_map[j] = i |
|
|
|
self.output_joint_num = len(self.pooling_list) |
|
self.new_topology = [-1 for _ in range(len(self.pooling_list))] |
|
for i, x in enumerate(self.pooling_list): |
|
if i < 1: continue |
|
self.new_topology[i] = self.pooling_map[self.parent[x[0]]] |
|
|
|
self.weight = torch.zeros(len(self.pooling_list) * channels_per_joint, self.joint_num * channels_per_joint) |
|
|
|
for i, pair in enumerate(self.pooling_list): |
|
for j in pair: |
|
for c in range(channels_per_joint): |
|
self.weight[i * channels_per_joint + c, j * channels_per_joint + c] = 1.0 / len(pair) |
|
|
|
self.weight = nn.Parameter(self.weight, requires_grad=False) |
|
|
|
def forward(self, input: torch.Tensor): |
|
return torch.matmul(self.weight, input.unsqueeze(-1)).squeeze(-1) |
|
|
|
|
|
class SkeletonPool(nn.Module): |
|
def __init__(self, edges, pooling_mode, channels_per_edge, last_pool=False): |
|
super(SkeletonPool, self).__init__() |
|
|
|
if pooling_mode != 'mean': |
|
raise Exception('Unimplemented pooling mode in matrix_implementation') |
|
|
|
self.channels_per_edge = channels_per_edge |
|
self.pooling_mode = pooling_mode |
|
self.edge_num = len(edges) + 1 |
|
self.seq_list = [] |
|
self.pooling_list = [] |
|
self.new_edges = [] |
|
degree = [0] * 100 |
|
|
|
for edge in edges: |
|
degree[edge[0]] += 1 |
|
degree[edge[1]] += 1 |
|
|
|
def find_seq(j, seq): |
|
nonlocal self, degree, edges |
|
|
|
if degree[j] > 2 and j != 0: |
|
self.seq_list.append(seq) |
|
seq = [] |
|
|
|
if degree[j] == 1: |
|
self.seq_list.append(seq) |
|
return |
|
|
|
for idx, edge in enumerate(edges): |
|
if edge[0] == j: |
|
find_seq(edge[1], seq + [idx]) |
|
|
|
find_seq(0, []) |
|
for seq in self.seq_list: |
|
if last_pool: |
|
self.pooling_list.append(seq) |
|
continue |
|
if len(seq) % 2 == 1: |
|
self.pooling_list.append([seq[0]]) |
|
self.new_edges.append(edges[seq[0]]) |
|
seq = seq[1:] |
|
for i in range(0, len(seq), 2): |
|
self.pooling_list.append([seq[i], seq[i + 1]]) |
|
self.new_edges.append([edges[seq[i]][0], edges[seq[i + 1]][1]]) |
|
|
|
|
|
self.pooling_list.append([self.edge_num - 1]) |
|
|
|
self.description = 'SkeletonPool(in_edge_num={}, out_edge_num={})'.format( |
|
len(edges), len(self.pooling_list) |
|
) |
|
|
|
self.weight = torch.zeros(len(self.pooling_list) * channels_per_edge, self.edge_num * channels_per_edge) |
|
|
|
for i, pair in enumerate(self.pooling_list): |
|
for j in pair: |
|
for c in range(channels_per_edge): |
|
self.weight[i * channels_per_edge + c, j * channels_per_edge + c] = 1.0 / len(pair) |
|
|
|
self.weight = nn.Parameter(self.weight, requires_grad=False) |
|
|
|
def forward(self, input: torch.Tensor): |
|
return torch.matmul(self.weight, input) |
|
|
|
|
|
class SkeletonUnpool(nn.Module): |
|
def __init__(self, pooling_list, channels_per_edge): |
|
super(SkeletonUnpool, self).__init__() |
|
self.pooling_list = pooling_list |
|
self.input_joint_num = len(pooling_list) |
|
self.output_joint_num = 0 |
|
self.channels_per_edge = channels_per_edge |
|
for t in self.pooling_list: |
|
self.output_joint_num += len(t) |
|
|
|
self.description = 'SkeletonUnpool(in_joint_num={}, out_joint_num={})'.format( |
|
self.input_joint_num, self.output_joint_num, |
|
) |
|
|
|
self.weight = torch.zeros(self.output_joint_num * channels_per_edge, self.input_joint_num * channels_per_edge) |
|
|
|
for i, pair in enumerate(self.pooling_list): |
|
for j in pair: |
|
for c in range(channels_per_edge): |
|
self.weight[j * channels_per_edge + c, i * channels_per_edge + c] = 1 |
|
|
|
self.weight = nn.Parameter(self.weight) |
|
self.weight.requires_grad_(False) |
|
|
|
def forward(self, input: torch.Tensor): |
|
return torch.matmul(self.weight, input.unsqueeze(-1)).squeeze(-1) |
|
|
|
|
|
def find_neighbor_joint(parents, threshold): |
|
n_joint = len(parents) |
|
dist_mat = np.empty((n_joint, n_joint), dtype=np.int) |
|
dist_mat[:, :] = 100000 |
|
for i, p in enumerate(parents): |
|
dist_mat[i, i] = 0 |
|
if i != 0: |
|
dist_mat[i, p] = dist_mat[p, i] = 1 |
|
|
|
""" |
|
Floyd's algorithm |
|
""" |
|
for k in range(n_joint): |
|
for i in range(n_joint): |
|
for j in range(n_joint): |
|
dist_mat[i, j] = min(dist_mat[i, j], dist_mat[i, k] + dist_mat[k, j]) |
|
|
|
neighbor_list = [] |
|
for i in range(n_joint): |
|
neighbor = [] |
|
for j in range(n_joint): |
|
if dist_mat[i, j] <= threshold: |
|
neighbor.append(j) |
|
neighbor_list.append(neighbor) |
|
|
|
return neighbor_list |
|
|