|
import torch
|
|
import torch.nn as nn
|
|
|
|
from .skeleton_DME import SkeletonConv, SkeletonPool, SkeletonUnpool
|
|
|
|
|
|
def calc_node_depth(topology):
|
|
def dfs(node, topology):
|
|
if topology[node] < 0:
|
|
return 0
|
|
return 1 + dfs(topology[node], topology)
|
|
|
|
depth = []
|
|
for i in range(len(topology)):
|
|
depth.append(dfs(i, topology))
|
|
|
|
return depth
|
|
|
|
|
|
def residual_ratio(k):
|
|
return 1 / (k + 1)
|
|
|
|
|
|
class Affine(nn.Module):
|
|
def __init__(self, num_parameters, scale=True, bias=True, scale_init=1.0):
|
|
super(Affine, self).__init__()
|
|
if scale:
|
|
self.scale = nn.Parameter(torch.ones(num_parameters) * scale_init)
|
|
else:
|
|
self.register_parameter("scale", None)
|
|
|
|
if bias:
|
|
self.bias = nn.Parameter(torch.zeros(num_parameters))
|
|
else:
|
|
self.register_parameter("bias", None)
|
|
|
|
def forward(self, input):
|
|
output = input
|
|
if self.scale is not None:
|
|
scale = self.scale.unsqueeze(0)
|
|
while scale.dim() < input.dim():
|
|
scale = scale.unsqueeze(2)
|
|
output = output.mul(scale)
|
|
|
|
if self.bias is not None:
|
|
bias = self.bias.unsqueeze(0)
|
|
while bias.dim() < input.dim():
|
|
bias = bias.unsqueeze(2)
|
|
output += bias
|
|
|
|
return output
|
|
|
|
|
|
class BatchStatistics(nn.Module):
|
|
def __init__(self, affine=-1):
|
|
super(BatchStatistics, self).__init__()
|
|
self.affine = nn.Sequential() if affine == -1 else Affine(affine)
|
|
self.loss = 0
|
|
|
|
def clear_loss(self):
|
|
self.loss = 0
|
|
|
|
def compute_loss(self, input):
|
|
input_flat = input.view(input.size(1), input.numel() // input.size(1))
|
|
mu = input_flat.mean(1)
|
|
logvar = (input_flat.pow(2).mean(1) - mu.pow(2)).sqrt().log()
|
|
|
|
self.loss = mu.pow(2).mean() + logvar.pow(2).mean()
|
|
|
|
def forward(self, input):
|
|
self.compute_loss(input)
|
|
return self.affine(input)
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
def __init__(
|
|
self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation, batch_statistics=False, last_layer=False
|
|
):
|
|
super(ResidualBlock, self).__init__()
|
|
|
|
self.residual_ratio = residual_ratio
|
|
self.shortcut_ratio = 1 - residual_ratio
|
|
|
|
residual = []
|
|
residual.append(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding))
|
|
if batch_statistics:
|
|
residual.append(BatchStatistics(out_channels))
|
|
if not last_layer:
|
|
residual.append(nn.PReLU() if activation == "relu" else nn.Tanh())
|
|
self.residual = nn.Sequential(*residual)
|
|
|
|
self.shortcut = nn.Sequential(
|
|
nn.AvgPool1d(kernel_size=2) if stride == 2 else nn.Sequential(),
|
|
nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
|
|
BatchStatistics(out_channels) if (in_channels != out_channels and batch_statistics is True) else nn.Sequential(),
|
|
)
|
|
|
|
def forward(self, input):
|
|
return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio)
|
|
|
|
|
|
class ResidualBlockTranspose(nn.Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation):
|
|
super(ResidualBlockTranspose, self).__init__()
|
|
|
|
self.residual_ratio = residual_ratio
|
|
self.shortcut_ratio = 1 - residual_ratio
|
|
|
|
self.residual = nn.Sequential(
|
|
nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding), nn.PReLU() if activation == "relu" else nn.Tanh()
|
|
)
|
|
|
|
self.shortcut = nn.Sequential(
|
|
nn.Upsample(scale_factor=2, mode="linear", align_corners=False) if stride == 2 else nn.Sequential(),
|
|
nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
|
|
)
|
|
|
|
def forward(self, input):
|
|
return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio)
|
|
|
|
|
|
class SkeletonResidual(nn.Module):
|
|
def __init__(
|
|
self,
|
|
topology,
|
|
neighbour_list,
|
|
joint_num,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
padding_mode,
|
|
bias,
|
|
extra_conv,
|
|
pooling_mode,
|
|
activation,
|
|
last_pool,
|
|
):
|
|
super(SkeletonResidual, self).__init__()
|
|
|
|
kernel_even = False if kernel_size % 2 else True
|
|
|
|
seq = []
|
|
for _ in range(extra_conv):
|
|
|
|
seq.append(
|
|
SkeletonConv(
|
|
neighbour_list,
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
joint_num=joint_num,
|
|
kernel_size=kernel_size - 1 if kernel_even else kernel_size,
|
|
stride=1,
|
|
padding=padding,
|
|
padding_mode=padding_mode,
|
|
bias=bias,
|
|
)
|
|
)
|
|
seq.append(nn.PReLU() if activation == "relu" else nn.Tanh())
|
|
|
|
seq.append(
|
|
SkeletonConv(
|
|
neighbour_list,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
joint_num=joint_num,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
padding_mode=padding_mode,
|
|
bias=bias,
|
|
add_offset=False,
|
|
)
|
|
)
|
|
seq.append(nn.GroupNorm(10, out_channels))
|
|
self.residual = nn.Sequential(*seq)
|
|
|
|
|
|
self.shortcut = SkeletonConv(
|
|
neighbour_list,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
joint_num=joint_num,
|
|
kernel_size=1,
|
|
stride=stride,
|
|
padding=0,
|
|
bias=True,
|
|
add_offset=False,
|
|
)
|
|
|
|
seq = []
|
|
|
|
pool = SkeletonPool(
|
|
edges=topology, pooling_mode=pooling_mode, channels_per_edge=out_channels // len(neighbour_list), last_pool=last_pool
|
|
)
|
|
if len(pool.pooling_list) != pool.edge_num:
|
|
seq.append(pool)
|
|
seq.append(nn.PReLU() if activation == "relu" else nn.Tanh())
|
|
self.common = nn.Sequential(*seq)
|
|
|
|
def forward(self, input):
|
|
output = self.residual(input) + self.shortcut(input)
|
|
|
|
return self.common(output)
|
|
|
|
|
|
class SkeletonResidualTranspose(nn.Module):
|
|
def __init__(
|
|
self,
|
|
neighbour_list,
|
|
joint_num,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
padding,
|
|
padding_mode,
|
|
bias,
|
|
extra_conv,
|
|
pooling_list,
|
|
upsampling,
|
|
activation,
|
|
last_layer,
|
|
):
|
|
super(SkeletonResidualTranspose, self).__init__()
|
|
|
|
kernel_even = False if kernel_size % 2 else True
|
|
|
|
seq = []
|
|
|
|
if upsampling is not None:
|
|
seq.append(nn.Upsample(scale_factor=2, mode=upsampling, align_corners=False))
|
|
|
|
unpool = SkeletonUnpool(pooling_list, in_channels // len(neighbour_list))
|
|
if unpool.input_edge_num != unpool.output_edge_num:
|
|
seq.append(unpool)
|
|
self.common = nn.Sequential(*seq)
|
|
|
|
seq = []
|
|
for _ in range(extra_conv):
|
|
|
|
seq.append(
|
|
SkeletonConv(
|
|
neighbour_list,
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
joint_num=joint_num,
|
|
kernel_size=kernel_size - 1 if kernel_even else kernel_size,
|
|
stride=1,
|
|
padding=padding,
|
|
padding_mode=padding_mode,
|
|
bias=bias,
|
|
)
|
|
)
|
|
seq.append(nn.PReLU() if activation == "relu" else nn.Tanh())
|
|
|
|
seq.append(
|
|
SkeletonConv(
|
|
neighbour_list,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
joint_num=joint_num,
|
|
kernel_size=kernel_size - 1 if kernel_even else kernel_size,
|
|
stride=1,
|
|
padding=padding,
|
|
padding_mode=padding_mode,
|
|
bias=bias,
|
|
add_offset=False,
|
|
)
|
|
)
|
|
self.residual = nn.Sequential(*seq)
|
|
|
|
|
|
self.shortcut = SkeletonConv(
|
|
neighbour_list,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
joint_num=joint_num,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=True,
|
|
add_offset=False,
|
|
)
|
|
|
|
if activation == "relu":
|
|
self.activation = nn.PReLU() if not last_layer else None
|
|
else:
|
|
self.activation = nn.Tanh() if not last_layer else None
|
|
|
|
def forward(self, input):
|
|
output = self.common(input)
|
|
output = self.residual(output) + self.shortcut(output)
|
|
|
|
if self.activation is not None:
|
|
return self.activation(output)
|
|
else:
|
|
return output
|
|
|