cogniveon's picture
feat: :man_dancing: YOLO
764c9f8
raw
history blame
No virus
2.51 kB
import torchaudio
import torch
import gradio as gr
import keras
import pandas as pd
from transformers import pipeline
def get_murmur_from_recordings(audio):
pipe = pipeline("audio-classification",
model="cogniveon/eeem069_heart_murmur_classification")
sampling_rate, data = audio
waveform = torch.tensor(data).float()
# Resample the audio to 16 kHz (if necessary)
if sampling_rate != 16000:
resampler = torchaudio.transforms.Resample(sampling_rate, 16000)
waveform = resampler(waveform)
results = pipe(waveform.numpy())
sorted_results = sorted(results, key=lambda x: x['score'], reverse=True)
label_scores = {item['label']: item['score'] for item in sorted_results}
return label_scores
def get_patient_outcome(age, sex, height, weight, is_pregnant, murmur):
model = keras.models.load_model('patient_outcome_classifier.keras')
is_pregnant = 1 if is_pregnant else 0
sex2int = {'Male': 0, 'Female': 1}
sex = sex2int[sex]
age2int = {'Neonate': 0, 'Infant': 1, 'Child': 2, 'Adolescent': 3}
age = age2int[age]
murmur = 0 if murmur == 'Absent' else (1 if murmur == 'Present' else 2)
data = pd.DataFrame({
'Age': float(age),
'Sex': float(sex),
'Height': float(height),
'Weight': float(weight),
'Pregnancy status': float(is_pregnant),
'Murmur': float(murmur),
}, index=[0])
output = model.predict(data)[0]
# 0 - Normal, 1 - Abnormal
results = {
'Normal': output[0],
'Abnormal': output[1]
}
return results
def predict(audio, age, sex, height, weight, is_pregnant):
assert audio is not None, 'Audio cannot be None'
murmur_scores = get_murmur_from_recordings(audio)
murmur = "Unknown"
if murmur_scores['Present'] > 0.70:
murmur = "Present"
if murmur_scores['Absent'] > 0.80:
murmur = "Absent"
outcome = get_patient_outcome(
age, sex, height, weight, is_pregnant, murmur)
return outcome
demo = gr.Interface(
fn=predict,
inputs=[
"audio",
gr.Radio(label="Age", choices=[
"Neonate", "Infant", "Child", "Adolescent"],
value="Adolescent"),
gr.Radio(label="Sex", choices=["Male", "Female"], value="Male"),
gr.Number(label="Height", value="98.0"),
gr.Number(label="Weight", value="38.1"),
gr.Checkbox(label="Pregnant", value=False)
],
outputs="label"
)
demo.launch()