cogniveon's picture
Update app.py
ec88959 verified
raw
history blame
No virus
2.72 kB
import torchaudio
import torch
import gradio as gr
import keras
import pandas as pd
import joblib
from transformers import pipeline
from examples import normal, abnormal
def get_murmur_from_recordings(audio):
pipe = pipeline("audio-classification",
model="cogniveon/wav2vec2-base-960h")
sampling_rate, data = audio
waveform = torch.tensor(data).float()
# Resample the audio to 16 kHz (if necessary)
if sampling_rate != 16000:
resampler = torchaudio.transforms.Resample(sampling_rate, 16000)
waveform = resampler(waveform)
results = pipe(waveform.numpy())
sorted_results = sorted(results, key=lambda x: x['score'], reverse=True)
label_scores = {item['label']: item['score'] for item in sorted_results}
return label_scores
def get_patient_outcome(age, sex, height, weight, is_pregnant, murmur):
# model = keras.models.load_model('patient_outcome_classifier.keras')
model = joblib.load("patient_outcome_classifier_v3.joblib")
is_pregnant = 1 if is_pregnant else 0
sex2int = {'Male': 0, 'Female': 1}
sex = sex2int[sex]
age2int = {'Neonate': 0, 'Infant': 1, 'Child': 2, 'Adolescent': 3}
age = age2int[age]
murmur = 0 if murmur == 'Absent' else (1 if murmur == 'Present' else 2)
data = pd.DataFrame({
'Age': float(age),
'Sex': float(sex),
'Height': float(height),
'Weight': float(weight),
'Pregnancy status': float(is_pregnant),
'Murmur': float(murmur),
}, index=[0])
output = model.predict(data)[0]
# 0 - Normal, 1 - Abnormal
results = {
'Normal': output[0],
'Abnormal': output[1]
}
return results
def predict(audio, age, sex, height, weight, is_pregnant):
assert audio is not None, 'Audio cannot be None'
murmur_scores = get_murmur_from_recordings(audio)
murmur = "Unknown"
if murmur_scores['Present'] > 0.70:
murmur = "Present"
if murmur_scores['Absent'] > 0.80:
murmur = "Absent"
outcome = get_patient_outcome(
age, sex, height, weight, is_pregnant, murmur)
return outcome
demo = gr.Interface(
fn=predict,
inputs=[
gr.Audio(label="Recording"),
gr.Radio(label="Age", choices=[
"Neonate", "Infant", "Child", "Adolescent"],
value="Adolescent"),
gr.Radio(label="Sex", choices=["Male", "Female"], value="Male"),
gr.Number(label="Height", value="98.0"),
gr.Number(label="Weight", value="38.1"),
gr.Checkbox(label="Pregnant", value=False)
],
outputs="label",
cache_examples=True,
examples=[
normal,
abnormal,
],
)
demo.launch()