Manyue-DataScientist commited on
Commit
965e524
·
verified ·
1 Parent(s): 67e41d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -168
app.py CHANGED
@@ -10,186 +10,185 @@ import io
10
 
11
  @st.cache_resource
12
  def load_models():
13
- try:
14
- diarization = Pipeline.from_pretrained(
15
- "pyannote/speaker-diarization",
16
- use_auth_token=st.secrets["hf_token"]
17
- )
18
-
19
- transcriber = whisper.load_model("base")
20
-
21
- summarizer = tf_pipeline(
22
- "summarization",
23
- model="facebook/bart-large-cnn",
24
- device=0 if torch.cuda.is_available() else -1
25
- )
26
-
27
- if not diarization or not transcriber or not summarizer:
28
- raise ValueError("One or more models failed to load")
29
-
30
- return diarization, transcriber, summarizer
31
- except Exception as e:
32
- st.error(f"Error loading models: {str(e)}")
33
- st.error("Debug info: Check if HF token is valid and has necessary permissions")
34
- return None, None, None
35
 
36
  def process_audio(audio_file, max_duration=600):
37
- try:
38
- audio_bytes = io.BytesIO(audio_file.getvalue())
39
-
40
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
41
- try:
42
- if audio_file.name.lower().endswith('.mp3'):
43
- audio = AudioSegment.from_mp3(audio_bytes)
44
- else:
45
- audio = AudioSegment.from_wav(audio_bytes)
46
-
47
- audio = audio.set_frame_rate(16000)
48
- audio = audio.set_channels(1)
49
- audio = audio.set_sample_width(2)
50
-
51
- audio.export(
52
- tmp.name,
53
- format="wav",
54
- parameters=["-ac", "1", "-ar", "16000"]
55
- )
56
- tmp_path = tmp.name
57
-
58
- except Exception as e:
59
- st.error(f"Error converting audio: {str(e)}")
60
- return None
 
61
 
62
- diarization, transcriber, summarizer = load_models()
63
- if not all([diarization, transcriber, summarizer]):
64
- return "Model loading failed"
65
 
66
- with st.spinner("Identifying speakers..."):
67
- diarization_result = diarization(tmp_path)
68
-
69
- with st.spinner("Transcribing audio..."):
70
- transcription = transcriber.transcribe(tmp_path)
71
-
72
- with st.spinner("Generating summary..."):
73
- summary = summarizer(transcription["text"], max_length=130, min_length=30)
74
 
75
- os.unlink(tmp_path)
76
-
77
- return {
78
- "diarization": diarization_result,
79
- "transcription": transcription,
80
- "summary": summary[0]["summary_text"]
81
- }
82
-
83
- except Exception as e:
84
- st.error(f"Error processing audio: {str(e)}")
85
- return None
86
 
87
  def format_speaker_segments(diarization_result, transcription):
88
- if diarization_result is None or transcription is None:
89
- return []
90
-
91
- formatted_segments = []
92
- # Get whisper segments that include timestamps and text
93
- whisper_segments = transcription.get('segments', [])
94
-
95
- try:
96
- for turn, _, speaker in diarization_result.itertracks(yield_label=True):
97
- # Find matching text from whisper segments
98
- segment_text = ""
99
- for ws in whisper_segments:
100
- # If whisper segment overlaps with diarization segment
101
- if (float(ws['start']) >= float(turn.start) and
102
- float(ws['start']) <= float(turn.end)):
103
- segment_text += ws['text'] + " "
104
-
105
- # Only add segments that have text
106
- if segment_text.strip():
107
- formatted_segments.append({
108
- 'speaker': str(speaker),
109
- 'start': float(turn.start),
110
- 'end': float(turn.end),
111
- 'text': segment_text.strip()
112
- })
113
- except Exception as e:
114
- st.error(f"Error formatting segments: {str(e)}")
115
- return []
116
-
117
- # Sort by start time and handle overlaps
118
- formatted_segments.sort(key=lambda x: x['start'])
119
- cleaned_segments = []
120
- for i, segment in enumerate(formatted_segments):
121
- # Skip if this segment overlaps with previous one
122
- if i > 0 and segment['start'] < cleaned_segments[-1]['end']:
123
- continue
124
- cleaned_segments.append(segment)
125
-
126
- return cleaned_segments
127
 
128
  def format_timestamp(seconds):
129
- minutes = int(seconds // 60)
130
- seconds = seconds % 60
131
- return f"{minutes:02d}:{seconds:05.2f}"
132
 
133
  def main():
134
- st.title("Multi-Speaker Audio Analyzer")
135
- st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")
136
 
137
- uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])
138
 
139
- if uploaded_file:
140
- file_size = len(uploaded_file.getvalue()) / (1024 * 1024)
141
- st.write(f"File size: {file_size:.2f} MB")
142
-
143
- st.audio(uploaded_file, format='audio/wav')
144
-
145
- if st.button("Analyze Audio"):
146
- if file_size > 200:
147
- st.error("File size exceeds 200MB limit")
148
- else:
149
- results = process_audio(uploaded_file)
150
-
151
- if results:
152
- tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
153
-
154
- with tab1:
155
- st.write("Speaker Timeline:")
156
- segments = format_speaker_segments(results["diarization"], results["transcription"])
157
-
158
- if segments:
159
- for segment in segments:
160
- col1, col2, col3 = st.columns([2,3,5])
161
-
162
- with col1:
163
- speaker_num = int(segment['speaker'].split('_')[1])
164
- colors = ['🔵', '🔴']
165
- speaker_color = colors[speaker_num % len(colors)]
166
- st.write(f"{speaker_color} {segment['speaker']}")
167
-
168
- with col2:
169
- start_time = format_timestamp(segment['start'])
170
- end_time = format_timestamp(segment['end'])
171
- st.write(f"{start_time} → {end_time}")
172
-
173
- with col3:
174
- st.write(f"\"{segment['text']}\"")
175
-
176
- st.markdown("---")
177
- else:
178
- st.warning("No speaker segments detected")
179
-
180
- with tab2:
181
- st.write("Transcription:")
182
- if "text" in results["transcription"]:
183
- st.write(results["transcription"]["text"])
184
- else:
185
- st.warning("No transcription available")
186
-
187
- with tab3:
188
- st.write("Summary:")
189
- if results["summary"]:
190
- st.write(results["summary"])
191
- else:
192
- st.warning("No summary available")
 
 
 
 
 
 
193
 
194
  if __name__ == "__main__":
195
- main()
 
10
 
11
  @st.cache_resource
12
  def load_models():
13
+ try:
14
+ diarization = Pipeline.from_pretrained(
15
+ "pyannote/speaker-diarization",
16
+ use_auth_token=st.secrets["hf_token"]
17
+ )
18
+
19
+ transcriber = whisper.load_model("base")
20
+
21
+ summarizer = tf_pipeline(
22
+ "summarization",
23
+ model="facebook/bart-large-cnn",
24
+ device=0 if torch.cuda.is_available() else -1
25
+ )
26
+
27
+ if not diarization or not transcriber or not summarizer:
28
+ raise ValueError("One or more models failed to load")
29
+
30
+ return diarization, transcriber, summarizer
31
+ except Exception as e:
32
+ st.error(f"Error loading models: {str(e)}")
33
+ st.error("Debug info: Check if HF token is valid and has necessary permissions")
34
+ return None, None, None
35
 
36
  def process_audio(audio_file, max_duration=600):
37
+ try:
38
+ audio_bytes = io.BytesIO(audio_file.getvalue())
39
+
40
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
41
+ try:
42
+ if audio_file.name.lower().endswith('.mp3'):
43
+ audio = AudioSegment.from_mp3(audio_bytes)
44
+ else:
45
+ audio = AudioSegment.from_wav(audio_bytes)
46
+
47
+ # Standardize format
48
+ audio = audio.set_frame_rate(16000)
49
+ audio = audio.set_channels(1)
50
+ audio = audio.set_sample_width(2)
51
+
52
+ audio.export(
53
+ tmp.name,
54
+ format="wav",
55
+ parameters=["-ac", "1", "-ar", "16000"]
56
+ )
57
+ tmp_path = tmp.name
58
+
59
+ except Exception as e:
60
+ st.error(f"Error converting audio: {str(e)}")
61
+ return None
62
 
63
+ diarization, transcriber, summarizer = load_models()
64
+ if not all([diarization, transcriber, summarizer]):
65
+ return "Model loading failed"
66
 
67
+ with st.spinner("Identifying speakers..."):
68
+ diarization_result = diarization(tmp_path)
69
+
70
+ with st.spinner("Transcribing audio..."):
71
+ transcription = transcriber.transcribe(tmp_path)
72
+
73
+ with st.spinner("Generating summary..."):
74
+ summary = summarizer(transcription["text"], max_length=130, min_length=30)
75
 
76
+ os.unlink(tmp_path)
77
+
78
+ return {
79
+ "diarization": diarization_result,
80
+ "transcription": transcription,
81
+ "summary": summary[0]["summary_text"]
82
+ }
83
+
84
+ except Exception as e:
85
+ st.error(f"Error processing audio: {str(e)}")
86
+ return None
87
 
88
  def format_speaker_segments(diarization_result, transcription):
89
+ if diarization_result is None:
90
+ return []
91
+
92
+ formatted_segments = []
93
+ whisper_segments = transcription.get('segments', [])
94
+
95
+ try:
96
+ for turn, _, speaker in diarization_result.itertracks(yield_label=True):
97
+ current_text = ""
98
+ # Find matching whisper segments for this speaker's time window
99
+ for w_segment in whisper_segments:
100
+ w_start = float(w_segment['start'])
101
+ w_end = float(w_segment['end'])
102
+
103
+ # If whisper segment overlaps with speaker segment
104
+ if (w_start >= turn.start and w_start < turn.end) or \
105
+ (w_end > turn.start and w_end <= turn.end):
106
+ current_text += w_segment['text'].strip() + " "
107
+
108
+ formatted_segments.append({
109
+ 'speaker': str(speaker),
110
+ 'start': float(turn.start),
111
+ 'end': float(turn.end),
112
+ 'text': current_text.strip()
113
+ })
114
+
115
+ except Exception as e:
116
+ st.error(f"Error formatting segments: {str(e)}")
117
+ return []
118
+
119
+ return formatted_segments
 
 
 
 
 
 
 
 
120
 
121
  def format_timestamp(seconds):
122
+ minutes = int(seconds // 60)
123
+ seconds = seconds % 60
124
+ return f"{minutes:02d}:{seconds:05.2f}"
125
 
126
  def main():
127
+ st.title("Multi-Speaker Audio Analyzer")
128
+ st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")
129
 
130
+ uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])
131
 
132
+ if uploaded_file:
133
+ file_size = len(uploaded_file.getvalue()) / (1024 * 1024)
134
+ st.write(f"File size: {file_size:.2f} MB")
135
+
136
+ st.audio(uploaded_file, format='audio/wav')
137
+
138
+ if st.button("Analyze Audio"):
139
+ if file_size > 200:
140
+ st.error("File size exceeds 200MB limit")
141
+ else:
142
+ results = process_audio(uploaded_file)
143
+
144
+ if results:
145
+ tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
146
+
147
+ with tab1:
148
+ st.write("Speaker Timeline:")
149
+ segments = format_speaker_segments(
150
+ results["diarization"],
151
+ results["transcription"]
152
+ )
153
+
154
+ if segments:
155
+ for segment in segments:
156
+ col1, col2, col3 = st.columns([2,3,5])
157
+
158
+ with col1:
159
+ speaker_num = int(segment['speaker'].split('_')[1])
160
+ colors = ['🔵', '🔴']
161
+ speaker_color = colors[speaker_num % len(colors)]
162
+ st.write(f"{speaker_color} {segment['speaker']}")
163
+
164
+ with col2:
165
+ start_time = format_timestamp(segment['start'])
166
+ end_time = format_timestamp(segment['end'])
167
+ st.write(f"{start_time} → {end_time}")
168
+
169
+ with col3:
170
+ if segment['text']:
171
+ st.write(f"\"{segment['text']}\"")
172
+ else:
173
+ st.write("(no speech detected)")
174
+
175
+ st.markdown("---")
176
+ else:
177
+ st.warning("No speaker segments detected")
178
+
179
+ with tab2:
180
+ st.write("Transcription:")
181
+ if "text" in results["transcription"]:
182
+ st.write(results["transcription"]["text"])
183
+ else:
184
+ st.warning("No transcription available")
185
+
186
+ with tab3:
187
+ st.write("Summary:")
188
+ if results["summary"]:
189
+ st.write(results["summary"])
190
+ else:
191
+ st.warning("No summary available")
192
 
193
  if __name__ == "__main__":
194
+ main()