import numpy as np import torch import torch.nn as nn from torch.nn import functional as F def dct1(x): """ Discrete Cosine Transform, Type I :param x: the input signal :return: the DCT-I of the signal over the last dimension """ x_shape = x.shape x = x.view(-1, x_shape[-1]) #return torch.rfft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1), 1)[:, :, 0].view(*x_shape) return torch.fft.fft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1), 1)[:, :, 0].view(*x_shape) def idct1(X): """ The inverse of DCT-I, which is just a scaled DCT-I Our definition if idct1 is such that idct1(dct1(x)) == x :param X: the input signal :return: the inverse DCT-I of the signal over the last dimension """ n = X.shape[-1] return dct1(X) / (2 * (n - 1)) def dct(x, norm=None): """ Discrete Cosine Transform, Type II (a.k.a. the DCT) For the meaning of the parameter `norm`, see: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html :param x: the input signal :param norm: the normalization, None or 'ortho' :return: the DCT-II of the signal over the last dimension """ x_shape = x.shape N = x_shape[-1] x = x.contiguous().view(-1, N) v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) #Vc = torch.fft.rfft(v, 1, onesided=False) Vc = torch.view_as_real(torch.fft.fft(v, dim=1)) k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) W_r = torch.cos(k) W_i = torch.sin(k) V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i if norm == 'ortho': V[:, 0] /= np.sqrt(N) * 2 V[:, 1:] /= np.sqrt(N / 2) * 2 V = 2 * V.view(*x_shape) return V def idct(X, norm=None): """ The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III Our definition of idct is that idct(dct(x)) == x For the meaning of the parameter `norm`, see: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html :param X: the input signal :param norm: the normalization, None or 'ortho' :return: the inverse DCT-II of the signal over the last dimension """ x_shape = X.shape N = x_shape[-1] X_v = X.contiguous().view(-1, x_shape[-1]) / 2 if norm == 'ortho': X_v[:, 0] *= np.sqrt(N) * 2 X_v[:, 1:] *= np.sqrt(N / 2) * 2 k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N) W_r = torch.cos(k) W_i = torch.sin(k) V_t_r = X_v V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1) V_r = V_t_r * W_r - V_t_i * W_i V_i = V_t_r * W_i + V_t_i * W_r V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) #v = torch.irfft(V, 1, onesided=False) v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1) x = v.new_zeros(v.shape) x[:, ::2] += v[:, :N - (N // 2)] x[:, 1::2] += v.flip([1])[:, :N // 2] return x.view(*x_shape) def dct_2d(x, norm=None): """ 2-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT) For the meaning of the parameter `norm`, see: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html :param x: the input signal :param norm: the normalization, None or 'ortho' :return: the DCT-II of the signal over the last 2 dimensions """ X1 = dct(x, norm=norm) X2 = dct(X1.transpose(-1, -2), norm=norm) return X2.transpose(-1, -2) def idct_2d(X, norm=None): """ The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III Our definition of idct is that idct_2d(dct_2d(x)) == x For the meaning of the parameter `norm`, see: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html :param X: the input signal :param norm: the normalization, None or 'ortho' :return: the DCT-II of the signal over the last 2 dimensions """ x1 = idct(X, norm=norm) x2 = idct(x1.transpose(-1, -2), norm=norm) return x2.transpose(-1, -2) def dct_3d(x, norm=None): """ 3-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT) For the meaning of the parameter `norm`, see: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html :param x: the input signal :param norm: the normalization, None or 'ortho' :return: the DCT-II of the signal over the last 3 dimensions """ X1 = dct(x, norm=norm) X2 = dct(X1.transpose(-1, -2), norm=norm) X3 = dct(X2.transpose(-1, -3), norm=norm) return X3.transpose(-1, -3).transpose(-1, -2) def idct_3d(X, norm=None): """ The inverse to 3D DCT-II, which is a scaled Discrete Cosine Transform, Type III Our definition of idct is that idct_3d(dct_3d(x)) == x For the meaning of the parameter `norm`, see: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html :param X: the input signal :param norm: the normalization, None or 'ortho' :return: the DCT-II of the signal over the last 3 dimensions """ x1 = idct(X, norm=norm) x2 = idct(x1.transpose(-1, -2), norm=norm) x3 = idct(x2.transpose(-1, -3), norm=norm) return x3.transpose(-1, -3).transpose(-1, -2) # class LinearDCT(nn.Linear): # """Implement any DCT as a linear layer; in practice this executes around # 50x faster on GPU. Unfortunately, the DCT matrix is stored, which will # increase memory usage. # :param in_features: size of expected input # :param type: which dct function in this file to use""" # # def __init__(self, in_features, type, norm=None, bias=False): # self.type = type # self.N = in_features # self.norm = norm # super(LinearDCT, self).__init__(in_features, in_features, bias=bias) # # def reset_parameters(self): # # initialise using dct function # I = torch.eye(self.N) # if self.type == 'dct1': # self.weight.data = dct1(I).data.t() # elif self.type == 'idct1': # self.weight.data = idct1(I).data.t() # elif self.type == 'dct': # self.weight.data = dct(I, norm=self.norm).data.t() # elif self.type == 'idct': # self.weight.data = idct(I, norm=self.norm).data.t() # self.weight.require_grad = False # don't learn this! class LinearDCT(nn.Module): """Implement any DCT as a linear layer; in practice this executes around 50x faster on GPU. Unfortunately, the DCT matrix is stored, which will increase memory usage. :param in_features: size of expected input :param type: which dct function in this file to use""" def __init__(self, in_features, type, norm=None): super(LinearDCT, self).__init__() self.type = type self.N = in_features self.norm = norm I = torch.eye(self.N) if self.type == 'dct1': self.weight = dct1(I).data.t() elif self.type == 'idct1': self.weight = idct1(I).data.t() elif self.type == 'dct': self.weight = dct(I, norm=self.norm).data.t() elif self.type == 'idct': self.weight = idct(I, norm=self.norm).data.t() # self.register_buffer('weight', kernel) # self.weight = kernel def forward(self, x): return F.linear(x, weight=self.weight.cuda(x.get_device())) def apply_linear_2d(x, linear_layer): """Can be used with a LinearDCT layer to do a 2D DCT. :param x: the input signal :param linear_layer: any PyTorch Linear layer :return: result of linear layer applied to last 2 dimensions """ X1 = linear_layer(x) X2 = linear_layer(X1.transpose(-1, -2)) return X2.transpose(-1, -2) def apply_linear_3d(x, linear_layer): """Can be used with a LinearDCT layer to do a 3D DCT. :param x: the input signal :param linear_layer: any PyTorch Linear layer :return: result of linear layer applied to last 3 dimensions """ X1 = linear_layer(x) X2 = linear_layer(X1.transpose(-1, -2)) X3 = linear_layer(X2.transpose(-1, -3)) return X3.transpose(-1, -3).transpose(-1, -2) if __name__ == '__main__': x = torch.Tensor(1000, 4096) x.normal_(0, 1) linear_dct = LinearDCT(4096, 'dct') error = torch.abs(dct(x) - linear_dct(x)) assert error.max() < 1e-3, (error, error.max()) linear_idct = LinearDCT(4096, 'idct') error = torch.abs(idct(x) - linear_idct(x)) assert error.max() < 1e-3, (error, error.max())