File size: 6,036 Bytes
b1426fb
 
 
 
c37649d
b1426fb
 
 
2a6784d
0d5109a
 
 
 
 
b3635dd
2a6784d
0d5109a
b3635dd
 
0d5109a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a6784d
b3635dd
0d5109a
 
c37649d
0d5109a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45b780c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a6784d
b3635dd
0d5109a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
"""
Multi-Speaker Audio Analyzer
A Streamlit application that performs speaker diarization, transcription, and summarization on audio files.

Author: Manyue
Date: January 2025
"""

import streamlit as st
from src.models.diarization import SpeakerDiarizer
from src.models.transcription import Transcriber
from src.models.summarization import Summarizer
from src.utils.audio_processor import AudioProcessor
from src.utils.formatter import TimeFormatter
import os

# Cache for model loading
@st.cache_resource
def load_models():
    """
    Load and cache all required models.
    
    Returns:
        tuple: (diarizer, transcriber, summarizer) or (None, None, None) if loading fails
    """
    try:
        diarizer = SpeakerDiarizer(st.secrets["hf_token"])
        diarizer_model = diarizer.load_model()
        
        transcriber = Transcriber()
        transcriber_model = transcriber.load_model()
        
        summarizer = Summarizer()
        summarizer_model = summarizer.load_model()
        
        if not all([diarizer_model, transcriber_model, summarizer_model]):
            raise ValueError("One or more models failed to load")
            
        return diarizer, transcriber, summarizer
    except Exception as e:
        st.error(f"Error loading models: {str(e)}")
        st.error("Debug info: Check if HF token is valid and has necessary permissions")
        return None, None, None

def process_audio(audio_file, max_duration=600):
    """
    Process the uploaded audio file through all models.
    
    Args:
        audio_file: Uploaded audio file
        max_duration (int): Maximum duration in seconds
        
    Returns:
        dict: Processing results containing diarization, transcription, and summary
    """
    try:
        # Process audio file
        audio_processor = AudioProcessor()
        tmp_path = audio_processor.standardize_audio(audio_file)
        
        # Load models
        diarizer, transcriber, summarizer = load_models()
        if not all([diarizer, transcriber, summarizer]):
            return "Model loading failed"

        # Process with each model
        with st.spinner("Identifying speakers..."):
            diarization_result = diarizer.process(tmp_path)
        
        with st.spinner("Transcribing audio..."):
            transcription = transcriber.process(tmp_path)
            
        with st.spinner("Generating summary..."):
            summary = summarizer.process(transcription["text"])

        # Cleanup
        os.unlink(tmp_path)
        
        return {
            "diarization": diarization_result,
            "transcription": transcription,
            "summary": summary[0]["summary_text"]
        }
        
    except Exception as e:
        st.error(f"Error processing audio: {str(e)}")
        return None

def main():
    """Main application function."""
    st.title("Multi-Speaker Audio Analyzer")
    st.write("Upload an audio file (MP3) up to 5 minutes long for best performance")

    uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])

    if uploaded_file:
        file_size = len(uploaded_file.getvalue()) / (1024 * 1024)
        st.write(f"File size: {file_size:.2f} MB")
        
        st.audio(uploaded_file, format='audio/wav')
        
        if st.button("Analyze Audio"):
            if file_size > 200:
                st.error("File size exceeds 200MB limit")
            else:
                results = process_audio(uploaded_file)
                
                if results:
                    tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
                    
                    # Display speaker timeline
                    with tab1:
                        display_speaker_timeline(results)
                    
                    # Display transcription
                    with tab2:
                        display_transcription(results)
                    
                    # Display summary
                    with tab3:
                        display_summary(results)

def display_speaker_timeline(results):
    """Display speaker diarization results in a timeline format."""
    st.write("Speaker Timeline:")
    segments = TimeFormatter.format_speaker_segments(
        results["diarization"], 
        results["transcription"]
    )
    
    if segments:
        for segment in segments:
            col1, col2, col3 = st.columns([2,3,5])
            
            with col1:
                display_speaker_info(segment)
            
            with col2:
                display_timestamp(segment)
            
            with col3:
                display_text(segment)
            
            st.markdown("---")
    else:
        st.warning("No speaker segments detected")

def display_speaker_info(segment):
    """Display speaker information with color coding."""
    speaker_num = int(segment['speaker'].split('_')[1])
    colors = ['🔵', '🔴']
    speaker_color = colors[speaker_num % len(colors)]
    st.write(f"{speaker_color} {segment['speaker']}")

def display_timestamp(segment):
    """Display formatted timestamps."""
    start_time = TimeFormatter.format_timestamp(segment['start'])
    end_time = TimeFormatter.format_timestamp(segment['end'])
    st.write(f"{start_time}{end_time}")

def display_text(segment):
    """Display speaker's text."""
    if segment['text']:
        st.write(f"\"{segment['text']}\"")
    else:
        st.write("(no speech detected)")
def display_transcription(results):
    """Display transcription text."""
    if "transcription" in results and "text" in results["transcription"]:
        st.write("Transcription:")
        st.write(results["transcription"]["text"])
    else:
        st.warning("No transcription available.")
def display_summary(results):
    """Display summarization results."""
    if "summary" in results:
        st.write("Summary:")
        st.write(results["summary"])
    else:
        st.warning("No summary available.")


if __name__ == "__main__":
    main()