File size: 2,994 Bytes
66a6dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import io
import torch
import PIL.Image
import numpy as np
import scipy.signal
import librosa.display
import matplotlib.pyplot as plt

from torch.functional import Tensor
from torchvision.transforms import ToTensor


def compute_comparison_spectrogram(
    x: np.ndarray,
    y: np.ndarray,
    sample_rate: float = 44100,
    n_fft: int = 2048,
    hop_length: int = 1024,
) -> Tensor:
    X = librosa.stft(x, n_fft=n_fft, hop_length=hop_length)
    X_db = librosa.amplitude_to_db(np.abs(X), ref=np.max)

    Y = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
    Y_db = librosa.amplitude_to_db(np.abs(Y), ref=np.max)

    fig, axs = plt.subplots(figsize=(9, 6), nrows=2)
    img = librosa.display.specshow(
        X_db,
        ax=axs[0],
        hop_length=hop_length,
        x_axis="time",
        y_axis="log",
        sr=sample_rate,
    )
    # fig.colorbar(img, ax=axs[0])
    img = librosa.display.specshow(
        Y_db,
        ax=axs[1],
        hop_length=hop_length,
        x_axis="time",
        y_axis="log",
        sr=sample_rate,
    )
    # fig.colorbar(img, ax=axs[1])

    plt.tight_layout()

    buf = io.BytesIO()
    plt.savefig(buf, format="jpeg")
    buf.seek(0)
    image = PIL.Image.open(buf)
    image = ToTensor()(image)
    plt.close("all")

    return image


def plot_multi_spectrum(
    ys=None,
    Hs=None,
    legend=[],
    title="Spectrum",
    filename=None,
    sample_rate=44100,
    n_fft=1024,
    zero_mean=False,
):

    if Hs is None:
        Hs = []
        for y in ys:
            X = get_average_spectrum(y, n_fft)
            X_sm = smooth_spectrum(X)
            Hs.append(X_sm)

    bin_width = (sample_rate / 2) / (n_fft // 2)
    freqs = np.arange(0, (sample_rate / 2) + bin_width, step=bin_width)

    fig, ax1 = plt.subplots()

    for idx, H in enumerate(Hs):
        H = np.nan_to_num(H)
        H = np.clip(H, 0, np.max(H))
        H_dB = 20 * np.log10(H + 1e-8)
        if zero_mean:
            H_dB -= np.mean(H_dB)
        if "Target" in legend[idx]:
            ax1.plot(freqs, H_dB, linestyle="--", color="k")
        else:
            ax1.plot(freqs, H_dB)

    plt.legend(legend)

    ax1.set_xscale("log")
    ax1.set_ylim([-80, 0])
    ax1.set_xlim([100, 11000])
    plt.title(title)
    plt.ylabel("Magnitude (dB)")
    plt.xlabel("Frequency (Hz)")
    plt.grid(c="lightgray", which="both")

    if filename is not None:
        plt.savefig(f"{filename}.png", dpi=300)

    plt.tight_layout()

    buf = io.BytesIO()
    plt.savefig(buf, format="jpeg")
    buf.seek(0)
    image = PIL.Image.open(buf)
    image = ToTensor()(image)
    plt.close("all")

    return image


def smooth_spectrum(H):
    # apply Savgol filter for smoothed target curve
    return scipy.signal.savgol_filter(H, 1025, 2)


def get_average_spectrum(x, n_fft):
    X = torch.stft(x, n_fft, return_complex=True, normalized=True)
    X = X.abs()  # convert to magnitude
    X = X.mean(dim=-1).view(-1)  # average across frames
    return X