import streamlit as st from pyannote.audio import Pipeline import whisper import tempfile import os import torch from transformers import pipeline as tf_pipeline from pydub import AudioSegment import io @st.cache_resource def load_models(): try: diarization = Pipeline.from_pretrained( "pyannote/speaker-diarization", use_auth_token=st.secrets["hf_token"] ) transcriber = whisper.load_model("base") summarizer = tf_pipeline( "summarization", model="facebook/bart-large-cnn", device=0 if torch.cuda.is_available() else -1 ) if not diarization or not transcriber or not summarizer: raise ValueError("One or more models failed to load") return diarization, transcriber, summarizer except Exception as e: st.error(f"Error loading models: {str(e)}") st.error("Debug info: Check if HF token is valid and has necessary permissions") return None, None, None def process_audio(audio_file, max_duration=600): try: audio_bytes = io.BytesIO(audio_file.getvalue()) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: try: if audio_file.name.lower().endswith('.mp3'): audio = AudioSegment.from_mp3(audio_bytes) else: audio = AudioSegment.from_wav(audio_bytes) audio = audio.set_frame_rate(16000) audio = audio.set_channels(1) audio = audio.set_sample_width(2) audio.export( tmp.name, format="wav", parameters=["-ac", "1", "-ar", "16000"] ) tmp_path = tmp.name except Exception as e: st.error(f"Error converting audio: {str(e)}") return None diarization, transcriber, summarizer = load_models() if not all([diarization, transcriber, summarizer]): return "Model loading failed" with st.spinner("Identifying speakers..."): diarization_result = diarization(tmp_path) with st.spinner("Transcribing audio..."): transcription = transcriber.transcribe(tmp_path) with st.spinner("Generating summary..."): summary = summarizer(transcription["text"], max_length=130, min_length=30) os.unlink(tmp_path) return { "diarization": diarization_result, "transcription": transcription, "summary": summary[0]["summary_text"] } except Exception as e: st.error(f"Error processing audio: {str(e)}") return None def format_speaker_segments(diarization_result, transcription): if diarization_result is None or transcription is None: return [] formatted_segments = [] # Get whisper segments that include timestamps and text whisper_segments = transcription.get('segments', []) try: for turn, _, speaker in diarization_result.itertracks(yield_label=True): # Find matching text from whisper segments segment_text = "" for ws in whisper_segments: # If whisper segment overlaps with diarization segment if (float(ws['start']) >= float(turn.start) and float(ws['start']) <= float(turn.end)): segment_text += ws['text'] + " " # Only add segments that have text if segment_text.strip(): formatted_segments.append({ 'speaker': str(speaker), 'start': float(turn.start), 'end': float(turn.end), 'text': segment_text.strip() }) except Exception as e: st.error(f"Error formatting segments: {str(e)}") return [] # Sort by start time and handle overlaps formatted_segments.sort(key=lambda x: x['start']) cleaned_segments = [] for i, segment in enumerate(formatted_segments): # Skip if this segment overlaps with previous one if i > 0 and segment['start'] < cleaned_segments[-1]['end']: continue cleaned_segments.append(segment) return cleaned_segments def format_timestamp(seconds): minutes = int(seconds // 60) seconds = seconds % 60 return f"{minutes:02d}:{seconds:05.2f}" def main(): st.title("Multi-Speaker Audio Analyzer") st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance") uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"]) if uploaded_file: file_size = len(uploaded_file.getvalue()) / (1024 * 1024) st.write(f"File size: {file_size:.2f} MB") st.audio(uploaded_file, format='audio/wav') if st.button("Analyze Audio"): if file_size > 200: st.error("File size exceeds 200MB limit") else: results = process_audio(uploaded_file) if results: tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"]) with tab1: st.write("Speaker Timeline:") segments = format_speaker_segments(results["diarization"], results["transcription"]) if segments: for segment in segments: col1, col2, col3 = st.columns([2,3,5]) with col1: speaker_num = int(segment['speaker'].split('_')[1]) colors = ['🔵', '🔴'] speaker_color = colors[speaker_num % len(colors)] st.write(f"{speaker_color} {segment['speaker']}") with col2: start_time = format_timestamp(segment['start']) end_time = format_timestamp(segment['end']) st.write(f"{start_time} → {end_time}") with col3: st.write(f"\"{segment['text']}\"") st.markdown("---") else: st.warning("No speaker segments detected") with tab2: st.write("Transcription:") if "text" in results["transcription"]: st.write(results["transcription"]["text"]) else: st.warning("No transcription available") with tab3: st.write("Summary:") if results["summary"]: st.write(results["summary"]) else: st.warning("No summary available") if __name__ == "__main__": main()