cogniveon's picture
Update app.py
95cd1be verified
raw
history blame contribute delete
No virus
2.97 kB
import torchaudio
import torch
import gradio as gr
import keras
import pandas as pd
import joblib
from transformers import pipeline
from examples import normal, abnormal
def get_murmur_from_recordings(audio):
pipe = pipeline("audio-classification",
model="cogniveon/exp_1715080677")
sampling_rate, data = audio
waveform = torch.tensor(data).float()
# Resample the audio to 16 kHz (if necessary)
if sampling_rate != 16000:
resampler = torchaudio.transforms.Resample(sampling_rate, 16000)
waveform = resampler(waveform)
results = pipe(waveform.numpy())
sorted_results = sorted(results, key=lambda x: x['score'], reverse=True)
label_scores = {item['label']: item['score'] for item in sorted_results}
return label_scores
def get_patient_outcome(age, sex, height, weight, is_pregnant, murmur):
keras_model = keras.models.load_model('patient_outcome_classifier.keras')
model = joblib.load("patient_outcome_classifier_v3.joblib")
# Data prep
is_pregnant = 1 if is_pregnant else 0
sex2int = {'Male': 0, 'Female': 1}
sex = sex2int[sex]
age2int = {'Neonate': 0, 'Infant': 1, 'Child': 2, 'Adolescent': 3}
age = age2int[age]
murmur = 0 if murmur == 'Absent' else (1 if murmur == 'Present' else 2)
data = pd.DataFrame({
'Age': float(age),
'Sex': float(sex),
'Height': float(height),
'Weight': float(weight),
'Pregnancy status': float(is_pregnant),
'Murmur': float(murmur),
}, index=[0])
# # Predict Keras
# output = keras_model.predict(data)[0]
# # 0 - Normal, 1 - Abnormal -> %
# results_keras = {
# 'Normal': output[0],
# 'Abnormal': output[1]
# }
# Predict SVC
output = model.predict(data)[0]
# 0 - Normal or 1 - Abnormal
results_svc = 'Normal' if output == 1 else 'Abnormal'
return results_svc
def predict(audio, age, sex, height, weight, is_pregnant):
assert audio is not None, 'Audio cannot be None'
murmur_scores = get_murmur_from_recordings(audio)
murmur = "Unknown"
if murmur_scores['Present'] > 0.70:
murmur = "Present"
if murmur_scores['Absent'] > 0.80:
murmur = "Absent"
outcome = get_patient_outcome(
age, sex, height, weight, is_pregnant, murmur)
return outcome
demo = gr.Interface(
fn=predict,
inputs=[
gr.Audio(label="Recording"),
gr.Radio(label="Age", choices=[
"Neonate", "Infant", "Child", "Adolescent"],
value="Adolescent"),
gr.Radio(label="Sex", choices=["Male", "Female"], value="Male"),
gr.Number(label="Height", value="98.0"),
gr.Number(label="Weight", value="38.1"),
gr.Checkbox(label="Pregnant", value=False)
],
outputs=[
gr.Label(label="svc_pred"),
],
cache_examples=True,
examples=[
normal,
abnormal,
],
)
demo.launch()