add streamlit app
Browse files- app.py +117 -0
- dataset.py +1 -1
- inference_onnx.py +4 -4
- main.py +3 -4
- models/frn.py +2 -2
- sample.wav +0 -0
- utils/utils.py +5 -5
app.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import librosa
|
3 |
+
import librosa.display
|
4 |
+
from config import CONFIG
|
5 |
+
import torch
|
6 |
+
from dataset import MaskGenerator
|
7 |
+
import onnxruntime, onnx
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import numpy as np
|
10 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
11 |
+
|
12 |
+
@st.cache_resource
|
13 |
+
def load_model():
|
14 |
+
path = 'lightning_logs/version_0/checkpoints/frn.onnx'
|
15 |
+
onnx_model = onnx.load(path)
|
16 |
+
options = onnxruntime.SessionOptions()
|
17 |
+
options.intra_op_num_threads = 2
|
18 |
+
options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
19 |
+
session = onnxruntime.InferenceSession(path, options)
|
20 |
+
input_names = [x.name for x in session.get_inputs()]
|
21 |
+
output_names = [x.name for x in session.get_outputs()]
|
22 |
+
return session, onnx_model, input_names, output_names
|
23 |
+
|
24 |
+
def inference(re_im, session, onnx_model, input_names, output_names):
|
25 |
+
inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim],
|
26 |
+
dtype=np.float32)
|
27 |
+
for i, _input in enumerate(onnx_model.graph.input)
|
28 |
+
}
|
29 |
+
|
30 |
+
output_audio = []
|
31 |
+
for t in range(re_im.shape[0]):
|
32 |
+
inputs[input_names[0]] = re_im[t]
|
33 |
+
out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs)
|
34 |
+
inputs[input_names[1]] = prev_mag
|
35 |
+
inputs[input_names[2]] = predictor_state
|
36 |
+
inputs[input_names[3]] = mlp_state
|
37 |
+
output_audio.append(out)
|
38 |
+
|
39 |
+
output_audio = torch.tensor(np.concatenate(output_audio, 0))
|
40 |
+
output_audio = output_audio.permute(1, 0, 2).contiguous()
|
41 |
+
output_audio = torch.view_as_complex(output_audio)
|
42 |
+
output_audio = torch.istft(output_audio, window, stride, window=hann)
|
43 |
+
return output_audio.numpy()
|
44 |
+
|
45 |
+
def visualize(hr, lr, recon):
|
46 |
+
sr = CONFIG.DATA.sr
|
47 |
+
window_size = 1024
|
48 |
+
window = np.hanning(window_size)
|
49 |
+
|
50 |
+
stft_hr = librosa.core.spectrum.stft(hr, n_fft=window_size, hop_length=512, window=window)
|
51 |
+
stft_hr = 2 * np.abs(stft_hr) / np.sum(window)
|
52 |
+
|
53 |
+
stft_lr = librosa.core.spectrum.stft(lr, n_fft=window_size, hop_length=512, window=window)
|
54 |
+
stft_lr = 2 * np.abs(stft_lr) / np.sum(window)
|
55 |
+
|
56 |
+
stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window)
|
57 |
+
stft_recon = 2 * np.abs(stft_recon) / np.sum(window)
|
58 |
+
|
59 |
+
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 10))
|
60 |
+
ax1.title.set_text('Target signal')
|
61 |
+
ax2.title.set_text('Lossy signal')
|
62 |
+
ax3.title.set_text('Enhanced signal')
|
63 |
+
|
64 |
+
canvas = FigureCanvas(fig)
|
65 |
+
p = librosa.display.specshow(librosa.amplitude_to_db(stft_hr), ax=ax1, y_axis='linear', x_axis='time', sr=sr)
|
66 |
+
p = librosa.display.specshow(librosa.amplitude_to_db(stft_lr), ax=ax2, y_axis='linear', x_axis='time', sr=sr)
|
67 |
+
p = librosa.display.specshow(librosa.amplitude_to_db(stft_recon), ax=ax3, y_axis='linear', x_axis='time', sr=sr)
|
68 |
+
return fig
|
69 |
+
|
70 |
+
packet_size = CONFIG.DATA.EVAL.packet_size
|
71 |
+
window = CONFIG.DATA.window_size
|
72 |
+
stride = CONFIG.DATA.stride
|
73 |
+
|
74 |
+
title = 'Packet Loss Concealment'
|
75 |
+
st.set_page_config(page_title=title, page_icon=":sound:")
|
76 |
+
st.title(title)
|
77 |
+
|
78 |
+
uploaded_file = st.file_uploader("Upload your audio file (.wav)")
|
79 |
+
|
80 |
+
is_file_uploaded = uploaded_file is not None
|
81 |
+
if not is_file_uploaded:
|
82 |
+
uploaded_file = 'sample.wav'
|
83 |
+
|
84 |
+
target, sr = librosa.load(uploaded_file, sr=48000)
|
85 |
+
target = target[:packet_size * (len(target) // packet_size)]
|
86 |
+
|
87 |
+
st.subheader('Original audio')
|
88 |
+
st.audio(uploaded_file)
|
89 |
+
|
90 |
+
st.subheader('Choose loss packet percentage')
|
91 |
+
loss_percent = st.radio('Loss percentage', ['10%', '20%', '30%', '40%'])
|
92 |
+
loss_percent = float(loss_percent[:-1])/100
|
93 |
+
mask_gen = MaskGenerator(is_train=False, probs=[(1 - loss_percent, loss_percent)])
|
94 |
+
lossy_input = target.copy().reshape(-1, packet_size)
|
95 |
+
mask = mask_gen.gen_mask(len(lossy_input), seed=0)[:, np.newaxis]
|
96 |
+
lossy_input *= mask
|
97 |
+
lossy_input = lossy_input.reshape(-1)
|
98 |
+
hann = torch.sqrt(torch.hann_window(window))
|
99 |
+
lossy_input_tensor = torch.tensor(lossy_input)
|
100 |
+
re_im = torch.stft(lossy_input_tensor, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze(
|
101 |
+
1).numpy().astype(np.float32)
|
102 |
+
session, onnx_model, input_names, output_names = load_model()
|
103 |
+
|
104 |
+
if st.button('Conceal lossy audio!'):
|
105 |
+
with st.spinner('Please wait for completion'):
|
106 |
+
output = inference(re_im, session, onnx_model, input_names, output_names)
|
107 |
+
|
108 |
+
st.subheader('Visualization')
|
109 |
+
fig = visualize(target, lossy_input, output)
|
110 |
+
st.pyplot(fig)
|
111 |
+
st.success('Done!')
|
112 |
+
st.text('Original audio')
|
113 |
+
st.audio(target, sample_rate=sr)
|
114 |
+
st.text('Lossy audio')
|
115 |
+
st.audio(lossy_input, sample_rate=sr)
|
116 |
+
st.text('Enhanced audio')
|
117 |
+
st.audio(output, sample_rate=sr)
|
dataset.py
CHANGED
@@ -67,7 +67,7 @@ class MaskGenerator:
|
|
67 |
else:
|
68 |
assert len(probs) == 1
|
69 |
prob = self.probs[0]
|
70 |
-
self.mcs.append(MarkovChain([[
|
71 |
|
72 |
def gen_mask(self, length, seed=0):
|
73 |
if self.is_train:
|
|
|
67 |
else:
|
68 |
assert len(probs) == 1
|
69 |
prob = self.probs[0]
|
70 |
+
self.mcs.append(MarkovChain([[prob[0], 1 - prob[0]], [1 - prob[1], prob[1]]], ['1', '0']))
|
71 |
|
72 |
def gen_mask(self, length, seed=0):
|
73 |
if self.is_train:
|
inference_onnx.py
CHANGED
@@ -38,8 +38,8 @@ if __name__ == '__main__':
|
|
38 |
for file in tqdm.tqdm(audio_files, total=len(audio_files)):
|
39 |
sig, _ = librosa.load(file, sr=48000)
|
40 |
sig = torch.tensor(sig)
|
41 |
-
re_im = torch.stft(sig, window, stride, window=hann, return_complex=False).permute(
|
42 |
-
|
43 |
|
44 |
inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim],
|
45 |
dtype=np.float32)
|
@@ -47,8 +47,8 @@ if __name__ == '__main__':
|
|
47 |
}
|
48 |
|
49 |
output_audio = []
|
50 |
-
for t in range(re_im.shape[
|
51 |
-
|
52 |
out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs)
|
53 |
inputs[input_names[1]] = prev_mag
|
54 |
inputs[input_names[2]] = predictor_state
|
|
|
38 |
for file in tqdm.tqdm(audio_files, total=len(audio_files)):
|
39 |
sig, _ = librosa.load(file, sr=48000)
|
40 |
sig = torch.tensor(sig)
|
41 |
+
re_im = torch.stft(sig, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze(
|
42 |
+
1).numpy().astype(np.float32)
|
43 |
|
44 |
inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim],
|
45 |
dtype=np.float32)
|
|
|
47 |
}
|
48 |
|
49 |
output_audio = []
|
50 |
+
for t in range(re_im.shape[0]):
|
51 |
+
inputs[input_names[0]] = re_im[t]
|
52 |
out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs)
|
53 |
inputs[input_names[1]] = prev_mag
|
54 |
inputs[input_names[2]] = predictor_state
|
main.py
CHANGED
@@ -4,7 +4,7 @@ import os
|
|
4 |
import pytorch_lightning as pl
|
5 |
import soundfile as sf
|
6 |
import torch
|
7 |
-
from pytorch_lightning.callbacks import ModelCheckpoint
|
8 |
from pytorch_lightning.utilities.model_summary import summarize
|
9 |
from torch.utils.data import DataLoader
|
10 |
|
@@ -65,9 +65,8 @@ def train():
|
|
65 |
gradient_clip_val=CONFIG.TRAIN.clipping_val,
|
66 |
gpus=len(gpus),
|
67 |
max_epochs=CONFIG.TRAIN.epochs,
|
68 |
-
accelerator="
|
69 |
-
|
70 |
-
callbacks=[checkpoint_callback]
|
71 |
)
|
72 |
|
73 |
print(model.hparams)
|
|
|
4 |
import pytorch_lightning as pl
|
5 |
import soundfile as sf
|
6 |
import torch
|
7 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging
|
8 |
from pytorch_lightning.utilities.model_summary import summarize
|
9 |
from torch.utils.data import DataLoader
|
10 |
|
|
|
65 |
gradient_clip_val=CONFIG.TRAIN.clipping_val,
|
66 |
gpus=len(gpus),
|
67 |
max_epochs=CONFIG.TRAIN.epochs,
|
68 |
+
accelerator="gpu" if len(gpus) > 1 else None,
|
69 |
+
callbacks=[checkpoint_callback, StochasticWeightAveraging(swa_lrs=1e-2)]
|
|
|
70 |
)
|
71 |
|
72 |
print(model.hparams)
|
models/frn.py
CHANGED
@@ -92,11 +92,11 @@ class PLCModel(pl.LightningModule):
|
|
92 |
|
93 |
def train_dataloader(self):
|
94 |
return DataLoader(self.train_dataset, shuffle=False, batch_size=self.hparams.batch_size,
|
95 |
-
num_workers=CONFIG.TRAIN.workers)
|
96 |
|
97 |
def val_dataloader(self):
|
98 |
return DataLoader(self.val_dataset, shuffle=False, batch_size=self.hparams.batch_size,
|
99 |
-
num_workers=CONFIG.TRAIN.workers)
|
100 |
|
101 |
def training_step(self, batch, batch_idx):
|
102 |
x_in, y = batch
|
|
|
92 |
|
93 |
def train_dataloader(self):
|
94 |
return DataLoader(self.train_dataset, shuffle=False, batch_size=self.hparams.batch_size,
|
95 |
+
num_workers=CONFIG.TRAIN.workers, persistent_workers=True)
|
96 |
|
97 |
def val_dataloader(self):
|
98 |
return DataLoader(self.val_dataset, shuffle=False, batch_size=self.hparams.batch_size,
|
99 |
+
num_workers=CONFIG.TRAIN.workers, persistent_workers=True)
|
100 |
|
101 |
def training_step(self, batch, batch_idx):
|
102 |
x_in, y = batch
|
sample.wav
ADDED
Binary file (797 kB). View file
|
|
utils/utils.py
CHANGED
@@ -24,23 +24,23 @@ def mkdir_p(mypath):
|
|
24 |
raise
|
25 |
|
26 |
|
27 |
-
def visualize(
|
28 |
sr = CONFIG.DATA.sr
|
29 |
window_size = 1024
|
30 |
window = np.hanning(window_size)
|
31 |
|
32 |
-
stft_hr = librosa.core.spectrum.stft(
|
33 |
stft_hr = 2 * np.abs(stft_hr) / np.sum(window)
|
34 |
|
35 |
-
stft_lr = librosa.core.spectrum.stft(
|
36 |
stft_lr = 2 * np.abs(stft_lr) / np.sum(window)
|
37 |
|
38 |
stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window)
|
39 |
stft_recon = 2 * np.abs(stft_recon) / np.sum(window)
|
40 |
|
41 |
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 10))
|
42 |
-
ax1.title.set_text('
|
43 |
-
ax2.title.set_text('
|
44 |
ax3.title.set_text('Reconstructed signal')
|
45 |
|
46 |
canvas = FigureCanvas(fig)
|
|
|
24 |
raise
|
25 |
|
26 |
|
27 |
+
def visualize(target, input, recon, path):
|
28 |
sr = CONFIG.DATA.sr
|
29 |
window_size = 1024
|
30 |
window = np.hanning(window_size)
|
31 |
|
32 |
+
stft_hr = librosa.core.spectrum.stft(target, n_fft=window_size, hop_length=512, window=window)
|
33 |
stft_hr = 2 * np.abs(stft_hr) / np.sum(window)
|
34 |
|
35 |
+
stft_lr = librosa.core.spectrum.stft(input, n_fft=window_size, hop_length=512, window=window)
|
36 |
stft_lr = 2 * np.abs(stft_lr) / np.sum(window)
|
37 |
|
38 |
stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window)
|
39 |
stft_recon = 2 * np.abs(stft_recon) / np.sum(window)
|
40 |
|
41 |
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 10))
|
42 |
+
ax1.title.set_text('Target signal')
|
43 |
+
ax2.title.set_text('Lossy signal')
|
44 |
ax3.title.set_text('Reconstructed signal')
|
45 |
|
46 |
canvas = FigureCanvas(fig)
|