Manyue-DataScientist's picture
Update app.py
965e524 verified
import streamlit as st
from pyannote.audio import Pipeline
import whisper
import tempfile
import os
import torch
from transformers import pipeline as tf_pipeline
from pydub import AudioSegment
import io
@st.cache_resource
def load_models():
try:
diarization = Pipeline.from_pretrained(
"pyannote/speaker-diarization",
use_auth_token=st.secrets["hf_token"]
)
transcriber = whisper.load_model("base")
summarizer = tf_pipeline(
"summarization",
model="facebook/bart-large-cnn",
device=0 if torch.cuda.is_available() else -1
)
if not diarization or not transcriber or not summarizer:
raise ValueError("One or more models failed to load")
return diarization, transcriber, summarizer
except Exception as e:
st.error(f"Error loading models: {str(e)}")
st.error("Debug info: Check if HF token is valid and has necessary permissions")
return None, None, None
def process_audio(audio_file, max_duration=600):
try:
audio_bytes = io.BytesIO(audio_file.getvalue())
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
try:
if audio_file.name.lower().endswith('.mp3'):
audio = AudioSegment.from_mp3(audio_bytes)
else:
audio = AudioSegment.from_wav(audio_bytes)
# Standardize format
audio = audio.set_frame_rate(16000)
audio = audio.set_channels(1)
audio = audio.set_sample_width(2)
audio.export(
tmp.name,
format="wav",
parameters=["-ac", "1", "-ar", "16000"]
)
tmp_path = tmp.name
except Exception as e:
st.error(f"Error converting audio: {str(e)}")
return None
diarization, transcriber, summarizer = load_models()
if not all([diarization, transcriber, summarizer]):
return "Model loading failed"
with st.spinner("Identifying speakers..."):
diarization_result = diarization(tmp_path)
with st.spinner("Transcribing audio..."):
transcription = transcriber.transcribe(tmp_path)
with st.spinner("Generating summary..."):
summary = summarizer(transcription["text"], max_length=130, min_length=30)
os.unlink(tmp_path)
return {
"diarization": diarization_result,
"transcription": transcription,
"summary": summary[0]["summary_text"]
}
except Exception as e:
st.error(f"Error processing audio: {str(e)}")
return None
def format_speaker_segments(diarization_result, transcription):
if diarization_result is None:
return []
formatted_segments = []
whisper_segments = transcription.get('segments', [])
try:
for turn, _, speaker in diarization_result.itertracks(yield_label=True):
current_text = ""
# Find matching whisper segments for this speaker's time window
for w_segment in whisper_segments:
w_start = float(w_segment['start'])
w_end = float(w_segment['end'])
# If whisper segment overlaps with speaker segment
if (w_start >= turn.start and w_start < turn.end) or \
(w_end > turn.start and w_end <= turn.end):
current_text += w_segment['text'].strip() + " "
formatted_segments.append({
'speaker': str(speaker),
'start': float(turn.start),
'end': float(turn.end),
'text': current_text.strip()
})
except Exception as e:
st.error(f"Error formatting segments: {str(e)}")
return []
return formatted_segments
def format_timestamp(seconds):
minutes = int(seconds // 60)
seconds = seconds % 60
return f"{minutes:02d}:{seconds:05.2f}"
def main():
st.title("Multi-Speaker Audio Analyzer")
st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")
uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])
if uploaded_file:
file_size = len(uploaded_file.getvalue()) / (1024 * 1024)
st.write(f"File size: {file_size:.2f} MB")
st.audio(uploaded_file, format='audio/wav')
if st.button("Analyze Audio"):
if file_size > 200:
st.error("File size exceeds 200MB limit")
else:
results = process_audio(uploaded_file)
if results:
tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
with tab1:
st.write("Speaker Timeline:")
segments = format_speaker_segments(
results["diarization"],
results["transcription"]
)
if segments:
for segment in segments:
col1, col2, col3 = st.columns([2,3,5])
with col1:
speaker_num = int(segment['speaker'].split('_')[1])
colors = ['πŸ”΅', 'πŸ”΄']
speaker_color = colors[speaker_num % len(colors)]
st.write(f"{speaker_color} {segment['speaker']}")
with col2:
start_time = format_timestamp(segment['start'])
end_time = format_timestamp(segment['end'])
st.write(f"{start_time} β†’ {end_time}")
with col3:
if segment['text']:
st.write(f"\"{segment['text']}\"")
else:
st.write("(no speech detected)")
st.markdown("---")
else:
st.warning("No speaker segments detected")
with tab2:
st.write("Transcription:")
if "text" in results["transcription"]:
st.write(results["transcription"]["text"])
else:
st.warning("No transcription available")
with tab3:
st.write("Summary:")
if results["summary"]:
st.write(results["summary"])
else:
st.warning("No summary available")
if __name__ == "__main__":
main()