Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
import numpy as np | |
class SkeletonConv(nn.Module): | |
def __init__(self, neighbour_list, in_channels, out_channels, kernel_size, joint_num, stride=1, padding=0, | |
bias=True, padding_mode='zeros', add_offset=False, in_offset_channel=0): | |
super(SkeletonConv, self).__init__() | |
if in_channels % joint_num != 0 or out_channels % joint_num != 0: | |
raise Exception('in/out channels should be divided by joint_num') | |
self.in_channels_per_joint = in_channels // joint_num | |
self.out_channels_per_joint = out_channels // joint_num | |
if padding_mode == 'zeros': padding_mode = 'constant' | |
self.expanded_neighbour_list = [] | |
self.expanded_neighbour_list_offset = [] | |
self.neighbour_list = neighbour_list | |
self.add_offset = add_offset | |
self.joint_num = joint_num | |
self.stride = stride | |
self.dilation = 1 | |
self.groups = 1 | |
self.padding = padding | |
self.padding_mode = padding_mode | |
self._padding_repeated_twice = (padding, padding) | |
for neighbour in neighbour_list: | |
expanded = [] | |
for k in neighbour: | |
for i in range(self.in_channels_per_joint): | |
expanded.append(k * self.in_channels_per_joint + i) | |
self.expanded_neighbour_list.append(expanded) | |
if self.add_offset: | |
self.offset_enc = SkeletonLinear(neighbour_list, in_offset_channel * len(neighbour_list), out_channels) | |
for neighbour in neighbour_list: | |
expanded = [] | |
for k in neighbour: | |
for i in range(add_offset): | |
expanded.append(k * in_offset_channel + i) | |
self.expanded_neighbour_list_offset.append(expanded) | |
self.weight = torch.zeros(out_channels, in_channels, kernel_size) | |
if bias: | |
self.bias = torch.zeros(out_channels) | |
else: | |
self.register_parameter('bias', None) | |
self.mask = torch.zeros_like(self.weight) | |
for i, neighbour in enumerate(self.expanded_neighbour_list): | |
self.mask[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...] = 1 | |
self.mask = nn.Parameter(self.mask, requires_grad=False) | |
self.description = 'SkeletonConv(in_channels_per_armature={}, out_channels_per_armature={}, kernel_size={}, ' \ | |
'joint_num={}, stride={}, padding={}, bias={})'.format( | |
in_channels // joint_num, out_channels // joint_num, kernel_size, joint_num, stride, padding, bias | |
) | |
self.reset_parameters() | |
def reset_parameters(self): | |
for i, neighbour in enumerate(self.expanded_neighbour_list): | |
""" Use temporary variable to avoid assign to copy of slice, which might lead to un expected result """ | |
tmp = torch.zeros_like(self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), | |
neighbour, ...]) | |
nn.init.kaiming_uniform_(tmp, a=math.sqrt(5)) | |
self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), | |
neighbour, ...] = tmp | |
if self.bias is not None: | |
fan_in, _ = nn.init._calculate_fan_in_and_fan_out( | |
self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...]) | |
bound = 1 / math.sqrt(fan_in) | |
tmp = torch.zeros_like( | |
self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)]) | |
nn.init.uniform_(tmp, -bound, bound) | |
self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)] = tmp | |
self.weight = nn.Parameter(self.weight) | |
if self.bias is not None: | |
self.bias = nn.Parameter(self.bias) | |
def set_offset(self, offset): | |
if not self.add_offset: raise Exception('Wrong Combination of Parameters') | |
self.offset = offset.reshape(offset.shape[0], -1) | |
def forward(self, input): | |
weight_masked = self.weight * self.mask | |
res = F.conv1d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), | |
weight_masked, self.bias, self.stride, | |
0, self.dilation, self.groups) | |
if self.add_offset: | |
offset_res = self.offset_enc(self.offset) | |
offset_res = offset_res.reshape(offset_res.shape + (1, )) | |
res += offset_res / 100 | |
return res | |
def __repr__(self): | |
return self.description | |
class SkeletonLinear(nn.Module): | |
def __init__(self, neighbour_list, in_channels, out_channels, extra_dim1=False): | |
super(SkeletonLinear, self).__init__() | |
self.neighbour_list = neighbour_list | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.in_channels_per_joint = in_channels // len(neighbour_list) | |
self.out_channels_per_joint = out_channels // len(neighbour_list) | |
self.extra_dim1 = extra_dim1 | |
self.expanded_neighbour_list = [] | |
for neighbour in neighbour_list: | |
expanded = [] | |
for k in neighbour: | |
for i in range(self.in_channels_per_joint): | |
expanded.append(k * self.in_channels_per_joint + i) | |
self.expanded_neighbour_list.append(expanded) | |
self.weight = torch.zeros(out_channels, in_channels) | |
self.mask = torch.zeros(out_channels, in_channels) | |
self.bias = nn.Parameter(torch.Tensor(out_channels)) | |
self.reset_parameters() | |
def reset_parameters(self): | |
for i, neighbour in enumerate(self.expanded_neighbour_list): | |
tmp = torch.zeros_like( | |
self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] | |
) | |
self.mask[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = 1 | |
nn.init.kaiming_uniform_(tmp, a=math.sqrt(5)) | |
self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = tmp | |
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) | |
bound = 1 / math.sqrt(fan_in) | |
nn.init.uniform_(self.bias, -bound, bound) | |
self.weight = nn.Parameter(self.weight) | |
self.mask = nn.Parameter(self.mask, requires_grad=False) | |
def forward(self, input): | |
input = input.reshape(input.shape[0], -1) | |
weight_masked = self.weight * self.mask | |
res = F.linear(input, weight_masked, self.bias) | |
if self.extra_dim1: res = res.reshape(res.shape + (1,)) | |
return res | |
class SkeletonPoolJoint(nn.Module): | |
def __init__(self, topology, pooling_mode, channels_per_joint, last_pool=False): | |
super(SkeletonPoolJoint, self).__init__() | |
if pooling_mode != 'mean': | |
raise Exception('Unimplemented pooling mode in matrix_implementation') | |
self.joint_num = len(topology) | |
self.parent = topology | |
self.pooling_list = [] | |
self.pooling_mode = pooling_mode | |
self.pooling_map = [-1 for _ in range(len(self.parent))] | |
self.child = [-1 for _ in range(len(self.parent))] | |
children_cnt = [0 for _ in range(len(self.parent))] | |
for x, pa in enumerate(self.parent): | |
if pa < 0: continue | |
children_cnt[pa] += 1 | |
self.child[pa] = x | |
self.pooling_map[0] = 0 | |
for x in range(len(self.parent)): | |
if children_cnt[x] == 0 or (children_cnt[x] == 1 and children_cnt[self.child[x]] > 1): | |
while children_cnt[x] <= 1: | |
pa = self.parent[x] | |
if last_pool: | |
seq = [x] | |
while pa != -1 and children_cnt[pa] == 1: | |
seq = [pa] + seq | |
x = pa | |
pa = self.parent[x] | |
self.pooling_list.append(seq) | |
break | |
else: | |
if pa != -1 and children_cnt[pa] == 1: | |
self.pooling_list.append([pa, x]) | |
x = self.parent[pa] | |
else: | |
self.pooling_list.append([x, ]) | |
break | |
elif children_cnt[x] > 1: | |
self.pooling_list.append([x, ]) | |
self.description = 'SkeletonPool(in_joint_num={}, out_joint_num={})'.format( | |
len(topology), len(self.pooling_list), | |
) | |
self.pooling_list.sort(key=lambda x:x[0]) | |
for i, a in enumerate(self.pooling_list): | |
for j in a: | |
self.pooling_map[j] = i | |
self.output_joint_num = len(self.pooling_list) | |
self.new_topology = [-1 for _ in range(len(self.pooling_list))] | |
for i, x in enumerate(self.pooling_list): | |
if i < 1: continue | |
self.new_topology[i] = self.pooling_map[self.parent[x[0]]] | |
self.weight = torch.zeros(len(self.pooling_list) * channels_per_joint, self.joint_num * channels_per_joint) | |
for i, pair in enumerate(self.pooling_list): | |
for j in pair: | |
for c in range(channels_per_joint): | |
self.weight[i * channels_per_joint + c, j * channels_per_joint + c] = 1.0 / len(pair) | |
self.weight = nn.Parameter(self.weight, requires_grad=False) | |
def forward(self, input: torch.Tensor): | |
return torch.matmul(self.weight, input.unsqueeze(-1)).squeeze(-1) | |
class SkeletonPool(nn.Module): | |
def __init__(self, edges, pooling_mode, channels_per_edge, last_pool=False): | |
super(SkeletonPool, self).__init__() | |
if pooling_mode != 'mean': | |
raise Exception('Unimplemented pooling mode in matrix_implementation') | |
self.channels_per_edge = channels_per_edge | |
self.pooling_mode = pooling_mode | |
self.edge_num = len(edges) + 1 | |
self.seq_list = [] | |
self.pooling_list = [] | |
self.new_edges = [] | |
degree = [0] * 100 | |
for edge in edges: | |
degree[edge[0]] += 1 | |
degree[edge[1]] += 1 | |
def find_seq(j, seq): | |
nonlocal self, degree, edges | |
if degree[j] > 2 and j != 0: | |
self.seq_list.append(seq) | |
seq = [] | |
if degree[j] == 1: | |
self.seq_list.append(seq) | |
return | |
for idx, edge in enumerate(edges): | |
if edge[0] == j: | |
find_seq(edge[1], seq + [idx]) | |
find_seq(0, []) | |
for seq in self.seq_list: | |
if last_pool: | |
self.pooling_list.append(seq) | |
continue | |
if len(seq) % 2 == 1: | |
self.pooling_list.append([seq[0]]) | |
self.new_edges.append(edges[seq[0]]) | |
seq = seq[1:] | |
for i in range(0, len(seq), 2): | |
self.pooling_list.append([seq[i], seq[i + 1]]) | |
self.new_edges.append([edges[seq[i]][0], edges[seq[i + 1]][1]]) | |
# add global position | |
self.pooling_list.append([self.edge_num - 1]) | |
self.description = 'SkeletonPool(in_edge_num={}, out_edge_num={})'.format( | |
len(edges), len(self.pooling_list) | |
) | |
self.weight = torch.zeros(len(self.pooling_list) * channels_per_edge, self.edge_num * channels_per_edge) | |
for i, pair in enumerate(self.pooling_list): | |
for j in pair: | |
for c in range(channels_per_edge): | |
self.weight[i * channels_per_edge + c, j * channels_per_edge + c] = 1.0 / len(pair) | |
self.weight = nn.Parameter(self.weight, requires_grad=False) | |
def forward(self, input: torch.Tensor): | |
return torch.matmul(self.weight, input) | |
class SkeletonUnpool(nn.Module): | |
def __init__(self, pooling_list, channels_per_edge): | |
super(SkeletonUnpool, self).__init__() | |
self.pooling_list = pooling_list | |
self.input_joint_num = len(pooling_list) | |
self.output_joint_num = 0 | |
self.channels_per_edge = channels_per_edge | |
for t in self.pooling_list: | |
self.output_joint_num += len(t) | |
self.description = 'SkeletonUnpool(in_joint_num={}, out_joint_num={})'.format( | |
self.input_joint_num, self.output_joint_num, | |
) | |
self.weight = torch.zeros(self.output_joint_num * channels_per_edge, self.input_joint_num * channels_per_edge) | |
for i, pair in enumerate(self.pooling_list): | |
for j in pair: | |
for c in range(channels_per_edge): | |
self.weight[j * channels_per_edge + c, i * channels_per_edge + c] = 1 | |
self.weight = nn.Parameter(self.weight) | |
self.weight.requires_grad_(False) | |
def forward(self, input: torch.Tensor): | |
return torch.matmul(self.weight, input.unsqueeze(-1)).squeeze(-1) | |
def find_neighbor_joint(parents, threshold): | |
n_joint = len(parents) | |
dist_mat = np.empty((n_joint, n_joint), dtype=np.int) | |
dist_mat[:, :] = 100000 | |
for i, p in enumerate(parents): | |
dist_mat[i, i] = 0 | |
if i != 0: | |
dist_mat[i, p] = dist_mat[p, i] = 1 | |
""" | |
Floyd's algorithm | |
""" | |
for k in range(n_joint): | |
for i in range(n_joint): | |
for j in range(n_joint): | |
dist_mat[i, j] = min(dist_mat[i, j], dist_mat[i, k] + dist_mat[k, j]) | |
neighbor_list = [] | |
for i in range(n_joint): | |
neighbor = [] | |
for j in range(n_joint): | |
if dist_mat[i, j] <= threshold: | |
neighbor.append(j) | |
neighbor_list.append(neighbor) | |
return neighbor_list | |