Staticaliza commited on
Commit
b9b5a56
·
verified ·
1 Parent(s): 45fad82

Upload filter.py

Browse files
Files changed (1) hide show
  1. modules/bigvgan/filter.py +101 -0
modules/bigvgan/filter.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ if "sinc" in dir(torch):
10
+ sinc = torch.sinc
11
+ else:
12
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
13
+ # https://adefossez.github.io/julius/julius/core.html
14
+ # LICENSE is in incl_licenses directory.
15
+ def sinc(x: torch.Tensor):
16
+ """
17
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
+ """
20
+ return torch.where(
21
+ x == 0,
22
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
23
+ torch.sin(math.pi * x) / math.pi / x,
24
+ )
25
+
26
+
27
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
28
+ # https://adefossez.github.io/julius/julius/lowpass.html
29
+ # LICENSE is in incl_licenses directory.
30
+ def kaiser_sinc_filter1d(
31
+ cutoff, half_width, kernel_size
32
+ ): # return filter [1,1,kernel_size]
33
+ even = kernel_size % 2 == 0
34
+ half_size = kernel_size // 2
35
+
36
+ # For kaiser window
37
+ delta_f = 4 * half_width
38
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
39
+ if A > 50.0:
40
+ beta = 0.1102 * (A - 8.7)
41
+ elif A >= 21.0:
42
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
43
+ else:
44
+ beta = 0.0
45
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
46
+
47
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
48
+ if even:
49
+ time = torch.arange(-half_size, half_size) + 0.5
50
+ else:
51
+ time = torch.arange(kernel_size) - half_size
52
+ if cutoff == 0:
53
+ filter_ = torch.zeros_like(time)
54
+ else:
55
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
56
+ """
57
+ Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
58
+ """
59
+ filter_ /= filter_.sum()
60
+ filter = filter_.view(1, 1, kernel_size)
61
+
62
+ return filter
63
+
64
+
65
+ class LowPassFilter1d(nn.Module):
66
+ def __init__(
67
+ self,
68
+ cutoff=0.5,
69
+ half_width=0.6,
70
+ stride: int = 1,
71
+ padding: bool = True,
72
+ padding_mode: str = "replicate",
73
+ kernel_size: int = 12,
74
+ ):
75
+ """
76
+ kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
77
+ """
78
+ super().__init__()
79
+ if cutoff < -0.0:
80
+ raise ValueError("Minimum cutoff must be larger than zero.")
81
+ if cutoff > 0.5:
82
+ raise ValueError("A cutoff above 0.5 does not make sense.")
83
+ self.kernel_size = kernel_size
84
+ self.even = kernel_size % 2 == 0
85
+ self.pad_left = kernel_size // 2 - int(self.even)
86
+ self.pad_right = kernel_size // 2
87
+ self.stride = stride
88
+ self.padding = padding
89
+ self.padding_mode = padding_mode
90
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
91
+ self.register_buffer("filter", filter)
92
+
93
+ # Input [B, C, T]
94
+ def forward(self, x):
95
+ _, C, _ = x.shape
96
+
97
+ if self.padding:
98
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
99
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
100
+
101
+ return out