|
""" |
|
Multi-Speaker Audio Analyzer |
|
A Streamlit application that performs speaker diarization, transcription, and summarization on audio files. |
|
|
|
Author: [Your Name] |
|
Date: January 2025 |
|
""" |
|
|
|
import streamlit as st |
|
from src.models.diarization import SpeakerDiarizer |
|
from src.models.transcription import Transcriber |
|
from src.models.summarization import Summarizer |
|
from src.utils.audio_processor import AudioProcessor |
|
from src.utils.formatter import TimeFormatter |
|
import os |
|
|
|
|
|
@st.cache_resource |
|
def load_models(): |
|
""" |
|
Load and cache all required models. |
|
|
|
Returns: |
|
tuple: (diarizer, transcriber, summarizer) or (None, None, None) if loading fails |
|
""" |
|
try: |
|
diarizer = SpeakerDiarizer(st.secrets["hf_token"]) |
|
diarizer_model = diarizer.load_model() |
|
|
|
transcriber = Transcriber() |
|
transcriber_model = transcriber.load_model() |
|
|
|
summarizer = Summarizer() |
|
summarizer_model = summarizer.load_model() |
|
|
|
if not all([diarizer_model, transcriber_model, summarizer_model]): |
|
raise ValueError("One or more models failed to load") |
|
|
|
return diarizer, transcriber, summarizer |
|
except Exception as e: |
|
st.error(f"Error loading models: {str(e)}") |
|
st.error("Debug info: Check if HF token is valid and has necessary permissions") |
|
return None, None, None |
|
|
|
def process_audio(audio_file, max_duration=600): |
|
""" |
|
Process the uploaded audio file through all models. |
|
|
|
Args: |
|
audio_file: Uploaded audio file |
|
max_duration (int): Maximum duration in seconds |
|
|
|
Returns: |
|
dict: Processing results containing diarization, transcription, and summary |
|
""" |
|
try: |
|
|
|
audio_processor = AudioProcessor() |
|
tmp_path = audio_processor.standardize_audio(audio_file) |
|
|
|
|
|
diarizer, transcriber, summarizer = load_models() |
|
if not all([diarizer, transcriber, summarizer]): |
|
return "Model loading failed" |
|
|
|
|
|
with st.spinner("Identifying speakers..."): |
|
diarization_result = diarizer.process(tmp_path) |
|
|
|
with st.spinner("Transcribing audio..."): |
|
transcription = transcriber.process(tmp_path) |
|
|
|
with st.spinner("Generating summary..."): |
|
summary = summarizer.process(transcription["text"]) |
|
|
|
|
|
os.unlink(tmp_path) |
|
|
|
return { |
|
"diarization": diarization_result, |
|
"transcription": transcription, |
|
"summary": summary[0]["summary_text"] |
|
} |
|
|
|
except Exception as e: |
|
st.error(f"Error processing audio: {str(e)}") |
|
return None |
|
|
|
def main(): |
|
"""Main application function.""" |
|
st.title("Multi-Speaker Audio Analyzer") |
|
st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance") |
|
|
|
uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"]) |
|
|
|
if uploaded_file: |
|
file_size = len(uploaded_file.getvalue()) / (1024 * 1024) |
|
st.write(f"File size: {file_size:.2f} MB") |
|
|
|
st.audio(uploaded_file, format='audio/wav') |
|
|
|
if st.button("Analyze Audio"): |
|
if file_size > 200: |
|
st.error("File size exceeds 200MB limit") |
|
else: |
|
results = process_audio(uploaded_file) |
|
|
|
if results: |
|
tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"]) |
|
|
|
|
|
with tab1: |
|
display_speaker_timeline(results) |
|
|
|
|
|
with tab2: |
|
display_transcription(results) |
|
|
|
|
|
with tab3: |
|
display_summary(results) |
|
|
|
def display_speaker_timeline(results): |
|
"""Display speaker diarization results in a timeline format.""" |
|
st.write("Speaker Timeline:") |
|
segments = TimeFormatter.format_speaker_segments( |
|
results["diarization"], |
|
results["transcription"] |
|
) |
|
|
|
if segments: |
|
for segment in segments: |
|
col1, col2, col3 = st.columns([2,3,5]) |
|
|
|
with col1: |
|
display_speaker_info(segment) |
|
|
|
with col2: |
|
display_timestamp(segment) |
|
|
|
with col3: |
|
display_text(segment) |
|
|
|
st.markdown("---") |
|
else: |
|
st.warning("No speaker segments detected") |
|
|
|
def display_speaker_info(segment): |
|
"""Display speaker information with color coding.""" |
|
speaker_num = int(segment['speaker'].split('_')[1]) |
|
colors = ['🔵', '🔴'] |
|
speaker_color = colors[speaker_num % len(colors)] |
|
st.write(f"{speaker_color} {segment['speaker']}") |
|
|
|
def display_timestamp(segment): |
|
"""Display formatted timestamps.""" |
|
start_time = TimeFormatter.format_timestamp(segment['start']) |
|
end_time = TimeFormatter.format_timestamp(segment['end']) |
|
st.write(f"{start_time} → {end_time}") |
|
|
|
def display_text(segment): |
|
"""Display speaker's text.""" |
|
if segment['text']: |
|
st.write(f"\"{segment['text']}\"") |
|
else: |
|
st.write("(no speech detected)") |
|
|
|
if __name__ == "__main__": |
|
main() |