Spaces:
Runtime error
Runtime error
# Copyright (c) 2019, Adobe Inc. All rights reserved. | |
# | |
# This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike | |
# 4.0 International Public License. To view a copy of this license, visit | |
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. | |
# DWT code borrow from https://github.com/LiQiufu/WaveSNet/blob/12cb9d24208c3d26917bf953618c30f0c6b0f03d/DWT_IDWT/DWT_IDWT_layer.py | |
import pywt | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
__all__ = ['DWT_1D'] | |
Pad_Mode = ['constant', 'reflect', 'replicate', 'circular'] | |
class DWT_1D(nn.Module): | |
def __init__(self, pad_type='reflect', wavename='haar', | |
stride=2, in_channels=1, out_channels=None, groups=None, | |
kernel_size=None, trainable=False): | |
super(DWT_1D, self).__init__() | |
self.trainable = trainable | |
self.kernel_size = kernel_size | |
if not self.trainable: | |
assert self.kernel_size == None | |
self.in_channels = in_channels | |
self.out_channels = self.in_channels if out_channels == None else out_channels | |
self.groups = self.in_channels if groups == None else groups | |
assert isinstance(self.groups, int) and self.in_channels % self.groups == 0 | |
self.stride = stride | |
assert self.stride == 2 | |
self.wavename = wavename | |
self.pad_type = pad_type | |
assert self.pad_type in Pad_Mode | |
self.get_filters() | |
self.initialization() | |
def get_filters(self): | |
wavelet = pywt.Wavelet(self.wavename) | |
band_low = torch.tensor(wavelet.rec_lo) | |
band_high = torch.tensor(wavelet.rec_hi) | |
length_band = band_low.size()[0] | |
self.kernel_size = length_band if self.kernel_size == None else self.kernel_size | |
assert self.kernel_size >= length_band | |
a = (self.kernel_size - length_band) // 2 | |
b = - (self.kernel_size - length_band - a) | |
b = None if b == 0 else b | |
self.filt_low = torch.zeros(self.kernel_size) | |
self.filt_high = torch.zeros(self.kernel_size) | |
self.filt_low[a:b] = band_low | |
self.filt_high[a:b] = band_high | |
def initialization(self): | |
self.filter_low = self.filt_low[None, None, :].repeat((self.out_channels, self.in_channels // self.groups, 1)) | |
self.filter_high = self.filt_high[None, None, :].repeat((self.out_channels, self.in_channels // self.groups, 1)) | |
if torch.cuda.is_available(): | |
self.filter_low = self.filter_low.cuda() | |
self.filter_high = self.filter_high.cuda() | |
if self.trainable: | |
self.filter_low = nn.Parameter(self.filter_low) | |
self.filter_high = nn.Parameter(self.filter_high) | |
if self.kernel_size % 2 == 0: | |
self.pad_sizes = [self.kernel_size // 2 - 1, self.kernel_size // 2 - 1] | |
else: | |
self.pad_sizes = [self.kernel_size // 2, self.kernel_size // 2] | |
def forward(self, input): | |
assert isinstance(input, torch.Tensor) | |
assert len(input.size()) == 3 | |
assert input.size()[1] == self.in_channels | |
input = F.pad(input, pad=self.pad_sizes, mode=self.pad_type) | |
return F.conv1d(input, self.filter_low.to(input.device), stride=self.stride, groups=self.groups), \ | |
F.conv1d(input, self.filter_high.to(input.device), stride=self.stride, groups=self.groups) | |