Manyue-DataScientist's picture
Changed Folder Structure
b1426fb
raw
history blame
5.53 kB
"""
Multi-Speaker Audio Analyzer
A Streamlit application that performs speaker diarization, transcription, and summarization on audio files.
Author: [Your Name]
Date: January 2025
"""
import streamlit as st
from src.models.diarization import SpeakerDiarizer
from src.models.transcription import Transcriber
from src.models.summarization import Summarizer
from src.utils.audio_processor import AudioProcessor
from src.utils.formatter import TimeFormatter
import os
# Cache for model loading
@st.cache_resource
def load_models():
"""
Load and cache all required models.
Returns:
tuple: (diarizer, transcriber, summarizer) or (None, None, None) if loading fails
"""
try:
diarizer = SpeakerDiarizer(st.secrets["hf_token"])
diarizer_model = diarizer.load_model()
transcriber = Transcriber()
transcriber_model = transcriber.load_model()
summarizer = Summarizer()
summarizer_model = summarizer.load_model()
if not all([diarizer_model, transcriber_model, summarizer_model]):
raise ValueError("One or more models failed to load")
return diarizer, transcriber, summarizer
except Exception as e:
st.error(f"Error loading models: {str(e)}")
st.error("Debug info: Check if HF token is valid and has necessary permissions")
return None, None, None
def process_audio(audio_file, max_duration=600):
"""
Process the uploaded audio file through all models.
Args:
audio_file: Uploaded audio file
max_duration (int): Maximum duration in seconds
Returns:
dict: Processing results containing diarization, transcription, and summary
"""
try:
# Process audio file
audio_processor = AudioProcessor()
tmp_path = audio_processor.standardize_audio(audio_file)
# Load models
diarizer, transcriber, summarizer = load_models()
if not all([diarizer, transcriber, summarizer]):
return "Model loading failed"
# Process with each model
with st.spinner("Identifying speakers..."):
diarization_result = diarizer.process(tmp_path)
with st.spinner("Transcribing audio..."):
transcription = transcriber.process(tmp_path)
with st.spinner("Generating summary..."):
summary = summarizer.process(transcription["text"])
# Cleanup
os.unlink(tmp_path)
return {
"diarization": diarization_result,
"transcription": transcription,
"summary": summary[0]["summary_text"]
}
except Exception as e:
st.error(f"Error processing audio: {str(e)}")
return None
def main():
"""Main application function."""
st.title("Multi-Speaker Audio Analyzer")
st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")
uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])
if uploaded_file:
file_size = len(uploaded_file.getvalue()) / (1024 * 1024)
st.write(f"File size: {file_size:.2f} MB")
st.audio(uploaded_file, format='audio/wav')
if st.button("Analyze Audio"):
if file_size > 200:
st.error("File size exceeds 200MB limit")
else:
results = process_audio(uploaded_file)
if results:
tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
# Display speaker timeline
with tab1:
display_speaker_timeline(results)
# Display transcription
with tab2:
display_transcription(results)
# Display summary
with tab3:
display_summary(results)
def display_speaker_timeline(results):
"""Display speaker diarization results in a timeline format."""
st.write("Speaker Timeline:")
segments = TimeFormatter.format_speaker_segments(
results["diarization"],
results["transcription"]
)
if segments:
for segment in segments:
col1, col2, col3 = st.columns([2,3,5])
with col1:
display_speaker_info(segment)
with col2:
display_timestamp(segment)
with col3:
display_text(segment)
st.markdown("---")
else:
st.warning("No speaker segments detected")
def display_speaker_info(segment):
"""Display speaker information with color coding."""
speaker_num = int(segment['speaker'].split('_')[1])
colors = ['🔵', '🔴']
speaker_color = colors[speaker_num % len(colors)]
st.write(f"{speaker_color} {segment['speaker']}")
def display_timestamp(segment):
"""Display formatted timestamps."""
start_time = TimeFormatter.format_timestamp(segment['start'])
end_time = TimeFormatter.format_timestamp(segment['end'])
st.write(f"{start_time}{end_time}")
def display_text(segment):
"""Display speaker's text."""
if segment['text']:
st.write(f"\"{segment['text']}\"")
else:
st.write("(no speech detected)")
if __name__ == "__main__":
main()