File size: 5,530 Bytes
b1426fb
 
 
 
 
 
 
 
2a6784d
b1426fb
 
 
 
 
b3635dd
2a6784d
b1426fb
b3635dd
 
b1426fb
 
 
 
 
 
965e524
b1426fb
 
965e524
b1426fb
 
965e524
b1426fb
 
965e524
b1426fb
965e524
 
b1426fb
965e524
 
 
 
2a6784d
935113b
b1426fb
 
 
 
 
 
 
 
 
 
965e524
b1426fb
 
 
965e524
b1426fb
 
 
 
2a6784d
b1426fb
 
 
 
 
 
965e524
b1426fb
 
2a6784d
b1426fb
 
 
 
 
 
 
 
 
965e524
 
 
2a6784d
b3635dd
b1426fb
965e524
 
2a6784d
965e524
2a6784d
965e524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1426fb
965e524
b1426fb
965e524
b1426fb
965e524
b1426fb
965e524
b1426fb
965e524
b1426fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a6784d
b3635dd
965e524
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""
Multi-Speaker Audio Analyzer
A Streamlit application that performs speaker diarization, transcription, and summarization on audio files.

Author: [Your Name]
Date: January 2025
"""

import streamlit as st
from src.models.diarization import SpeakerDiarizer
from src.models.transcription import Transcriber
from src.models.summarization import Summarizer
from src.utils.audio_processor import AudioProcessor
from src.utils.formatter import TimeFormatter
import os

# Cache for model loading
@st.cache_resource
def load_models():
    """
    Load and cache all required models.
    
    Returns:
        tuple: (diarizer, transcriber, summarizer) or (None, None, None) if loading fails
    """
    try:
        diarizer = SpeakerDiarizer(st.secrets["hf_token"])
        diarizer_model = diarizer.load_model()
        
        transcriber = Transcriber()
        transcriber_model = transcriber.load_model()
        
        summarizer = Summarizer()
        summarizer_model = summarizer.load_model()
        
        if not all([diarizer_model, transcriber_model, summarizer_model]):
            raise ValueError("One or more models failed to load")
            
        return diarizer, transcriber, summarizer
    except Exception as e:
        st.error(f"Error loading models: {str(e)}")
        st.error("Debug info: Check if HF token is valid and has necessary permissions")
        return None, None, None

def process_audio(audio_file, max_duration=600):
    """
    Process the uploaded audio file through all models.
    
    Args:
        audio_file: Uploaded audio file
        max_duration (int): Maximum duration in seconds
        
    Returns:
        dict: Processing results containing diarization, transcription, and summary
    """
    try:
        # Process audio file
        audio_processor = AudioProcessor()
        tmp_path = audio_processor.standardize_audio(audio_file)
        
        # Load models
        diarizer, transcriber, summarizer = load_models()
        if not all([diarizer, transcriber, summarizer]):
            return "Model loading failed"

        # Process with each model
        with st.spinner("Identifying speakers..."):
            diarization_result = diarizer.process(tmp_path)
        
        with st.spinner("Transcribing audio..."):
            transcription = transcriber.process(tmp_path)
            
        with st.spinner("Generating summary..."):
            summary = summarizer.process(transcription["text"])

        # Cleanup
        os.unlink(tmp_path)
        
        return {
            "diarization": diarization_result,
            "transcription": transcription,
            "summary": summary[0]["summary_text"]
        }
        
    except Exception as e:
        st.error(f"Error processing audio: {str(e)}")
        return None

def main():
    """Main application function."""
    st.title("Multi-Speaker Audio Analyzer")
    st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")

    uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])

    if uploaded_file:
        file_size = len(uploaded_file.getvalue()) / (1024 * 1024)
        st.write(f"File size: {file_size:.2f} MB")
        
        st.audio(uploaded_file, format='audio/wav')
        
        if st.button("Analyze Audio"):
            if file_size > 200:
                st.error("File size exceeds 200MB limit")
            else:
                results = process_audio(uploaded_file)
                
                if results:
                    tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
                    
                    # Display speaker timeline
                    with tab1:
                        display_speaker_timeline(results)
                    
                    # Display transcription
                    with tab2:
                        display_transcription(results)
                    
                    # Display summary
                    with tab3:
                        display_summary(results)

def display_speaker_timeline(results):
    """Display speaker diarization results in a timeline format."""
    st.write("Speaker Timeline:")
    segments = TimeFormatter.format_speaker_segments(
        results["diarization"], 
        results["transcription"]
    )
    
    if segments:
        for segment in segments:
            col1, col2, col3 = st.columns([2,3,5])
            
            with col1:
                display_speaker_info(segment)
            
            with col2:
                display_timestamp(segment)
            
            with col3:
                display_text(segment)
            
            st.markdown("---")
    else:
        st.warning("No speaker segments detected")

def display_speaker_info(segment):
    """Display speaker information with color coding."""
    speaker_num = int(segment['speaker'].split('_')[1])
    colors = ['🔵', '🔴']
    speaker_color = colors[speaker_num % len(colors)]
    st.write(f"{speaker_color} {segment['speaker']}")

def display_timestamp(segment):
    """Display formatted timestamps."""
    start_time = TimeFormatter.format_timestamp(segment['start'])
    end_time = TimeFormatter.format_timestamp(segment['end'])
    st.write(f"{start_time}{end_time}")

def display_text(segment):
    """Display speaker's text."""
    if segment['text']:
        st.write(f"\"{segment['text']}\"")
    else:
        st.write("(no speech detected)")

if __name__ == "__main__":
    main()