Spaces:
Runtime error
Runtime error
import numpy as np | |
import skorch | |
import torch | |
import torch.nn as nn | |
import gradio as gr | |
import librosa | |
from joblib import dump, load | |
from sklearn.pipeline import Pipeline | |
from sklearn.preprocessing import LabelEncoder | |
from resnet import ResNet | |
from gradio_utils import load_as_librosa, predict_gradio | |
from dataloading import uniformize, to_numpy | |
from preprocessing import MfccTransformer, TorchTransform | |
SEED : int = 42 | |
np.random.seed(SEED) | |
torch.manual_seed(SEED) | |
model = load('./model/model.joblib') | |
only_mffc_transform = load('./model/only_mffc_transform.joblib') | |
label_encoder = load('./model/label_encoder.joblib') | |
SAMPLE_RATE = load("./model/SAMPLE_RATE.joblib") | |
METHOD = load("./model/METHOD.joblib") | |
MAX_TIME = load("./model/MAX_TIME.joblib") | |
N_MFCC = load("./model/N_MFCC.joblib") | |
HOP_LENGHT = load("./model/HOP_LENGHT.joblib") | |
sklearn_model = Pipeline( | |
steps=[ | |
("mfcc", only_mffc_transform), | |
("model", model) | |
] | |
) | |
uniform_lambda = lambda y, sr: uniformize(y, sr, METHOD, MAX_TIME) | |
title = r"ResNet 9" | |
description = r""" | |
<center> | |
The resnet9 model was trained to classify drone speech command. | |
<img src="http://zeus.blanchon.cc/dropshare/modia.png" width=200px> | |
</center> | |
""" | |
article = r""" | |
- [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385) | |
""" | |
demo_men = gr.Interface( | |
title = title, | |
description = description, | |
article = article, | |
fn=lambda data: predict_gradio( | |
data=data, | |
uniform_lambda=uniform_lambda, | |
sklearn_model=sklearn_model, | |
label_transform=label_encoder, | |
target_sr=SAMPLE_RATE), | |
inputs = gr.Audio(source="microphone", type="numpy"), | |
outputs = gr.Label(), | |
# allow_flagging = "manual", | |
# flagging_options = ['recule', 'tournedroite', 'arretetoi', 'tournegauche', 'gauche', 'avance', 'droite'], | |
# flagging_dir = "./flag/men" | |
) | |
demo_men.launch() | |