Manyue-DataScientist commited on
Commit
b1426fb
·
1 Parent(s): 965e524

Changed Folder Structure

Browse files

Changed code format and folder structure to make it more available for optimization in future.

README.md CHANGED
@@ -1,13 +1,16 @@
1
- ---
2
- title: Speaker Diarization App
3
- emoji: 📈
4
- colorFrom: red
5
- colorTo: gray
6
- sdk: streamlit
7
- sdk_version: 1.41.1
8
- app_file: app.py
9
- pinned: false
10
- short_description: 'College Final Project '
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
1
+ # Multi-Speaker Audio Analyzer
2
+
3
+ A Streamlit application that performs speaker diarization, transcription, and summarization on audio files.
4
+
5
+ ## Features
6
+ - Speaker Diarization using Pyannote
7
+ - Transcription using Whisper
8
+ - Summarization using BART
9
+
10
+ ## Setup
11
+ 1. Install requirements: `pip install -r requirements.txt`
12
+ 2. Add HuggingFace token to Streamlit secrets
13
+ 3. Run app: `streamlit run app.py`
14
+
15
+ ## Usage
16
+ Upload an audio file (MP3/WAV) and click "Analyze Audio" to process.
app.py CHANGED
@@ -1,129 +1,93 @@
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
- from pyannote.audio import Pipeline
3
- import whisper
4
- import tempfile
 
 
5
  import os
6
- import torch
7
- from transformers import pipeline as tf_pipeline
8
- from pydub import AudioSegment
9
- import io
10
 
 
11
  @st.cache_resource
12
  def load_models():
 
 
 
 
 
 
13
  try:
14
- diarization = Pipeline.from_pretrained(
15
- "pyannote/speaker-diarization",
16
- use_auth_token=st.secrets["hf_token"]
17
- )
18
 
19
- transcriber = whisper.load_model("base")
 
20
 
21
- summarizer = tf_pipeline(
22
- "summarization",
23
- model="facebook/bart-large-cnn",
24
- device=0 if torch.cuda.is_available() else -1
25
- )
26
 
27
- if not diarization or not transcriber or not summarizer:
28
  raise ValueError("One or more models failed to load")
29
 
30
- return diarization, transcriber, summarizer
31
  except Exception as e:
32
  st.error(f"Error loading models: {str(e)}")
33
  st.error("Debug info: Check if HF token is valid and has necessary permissions")
34
  return None, None, None
35
 
36
  def process_audio(audio_file, max_duration=600):
 
 
 
 
 
 
 
 
 
 
37
  try:
38
- audio_bytes = io.BytesIO(audio_file.getvalue())
 
 
39
 
40
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
41
- try:
42
- if audio_file.name.lower().endswith('.mp3'):
43
- audio = AudioSegment.from_mp3(audio_bytes)
44
- else:
45
- audio = AudioSegment.from_wav(audio_bytes)
46
-
47
- # Standardize format
48
- audio = audio.set_frame_rate(16000)
49
- audio = audio.set_channels(1)
50
- audio = audio.set_sample_width(2)
51
-
52
- audio.export(
53
- tmp.name,
54
- format="wav",
55
- parameters=["-ac", "1", "-ar", "16000"]
56
- )
57
- tmp_path = tmp.name
58
-
59
- except Exception as e:
60
- st.error(f"Error converting audio: {str(e)}")
61
- return None
62
-
63
- diarization, transcriber, summarizer = load_models()
64
- if not all([diarization, transcriber, summarizer]):
65
- return "Model loading failed"
66
 
67
- with st.spinner("Identifying speakers..."):
68
- diarization_result = diarization(tmp_path)
 
 
 
 
69
 
70
- with st.spinner("Transcribing audio..."):
71
- transcription = transcriber.transcribe(tmp_path)
72
-
73
- with st.spinner("Generating summary..."):
74
- summary = summarizer(transcription["text"], max_length=130, min_length=30)
75
 
76
- os.unlink(tmp_path)
77
-
78
- return {
79
- "diarization": diarization_result,
80
- "transcription": transcription,
81
- "summary": summary[0]["summary_text"]
82
- }
83
-
 
84
  except Exception as e:
85
  st.error(f"Error processing audio: {str(e)}")
86
  return None
87
 
88
- def format_speaker_segments(diarization_result, transcription):
89
- if diarization_result is None:
90
- return []
91
-
92
- formatted_segments = []
93
- whisper_segments = transcription.get('segments', [])
94
-
95
- try:
96
- for turn, _, speaker in diarization_result.itertracks(yield_label=True):
97
- current_text = ""
98
- # Find matching whisper segments for this speaker's time window
99
- for w_segment in whisper_segments:
100
- w_start = float(w_segment['start'])
101
- w_end = float(w_segment['end'])
102
-
103
- # If whisper segment overlaps with speaker segment
104
- if (w_start >= turn.start and w_start < turn.end) or \
105
- (w_end > turn.start and w_end <= turn.end):
106
- current_text += w_segment['text'].strip() + " "
107
-
108
- formatted_segments.append({
109
- 'speaker': str(speaker),
110
- 'start': float(turn.start),
111
- 'end': float(turn.end),
112
- 'text': current_text.strip()
113
- })
114
-
115
- except Exception as e:
116
- st.error(f"Error formatting segments: {str(e)}")
117
- return []
118
-
119
- return formatted_segments
120
-
121
- def format_timestamp(seconds):
122
- minutes = int(seconds // 60)
123
- seconds = seconds % 60
124
- return f"{minutes:02d}:{seconds:05.2f}"
125
-
126
  def main():
 
127
  st.title("Multi-Speaker Audio Analyzer")
128
  st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")
129
 
@@ -144,51 +108,62 @@ def main():
144
  if results:
145
  tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
146
 
 
147
  with tab1:
148
- st.write("Speaker Timeline:")
149
- segments = format_speaker_segments(
150
- results["diarization"],
151
- results["transcription"]
152
- )
153
-
154
- if segments:
155
- for segment in segments:
156
- col1, col2, col3 = st.columns([2,3,5])
157
-
158
- with col1:
159
- speaker_num = int(segment['speaker'].split('_')[1])
160
- colors = ['🔵', '🔴']
161
- speaker_color = colors[speaker_num % len(colors)]
162
- st.write(f"{speaker_color} {segment['speaker']}")
163
-
164
- with col2:
165
- start_time = format_timestamp(segment['start'])
166
- end_time = format_timestamp(segment['end'])
167
- st.write(f"{start_time} → {end_time}")
168
-
169
- with col3:
170
- if segment['text']:
171
- st.write(f"\"{segment['text']}\"")
172
- else:
173
- st.write("(no speech detected)")
174
-
175
- st.markdown("---")
176
- else:
177
- st.warning("No speaker segments detected")
178
 
 
179
  with tab2:
180
- st.write("Transcription:")
181
- if "text" in results["transcription"]:
182
- st.write(results["transcription"]["text"])
183
- else:
184
- st.warning("No transcription available")
185
 
 
186
  with tab3:
187
- st.write("Summary:")
188
- if results["summary"]:
189
- st.write(results["summary"])
190
- else:
191
- st.warning("No summary available")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  if __name__ == "__main__":
194
  main()
 
1
+ """
2
+ Multi-Speaker Audio Analyzer
3
+ A Streamlit application that performs speaker diarization, transcription, and summarization on audio files.
4
+
5
+ Author: [Your Name]
6
+ Date: January 2025
7
+ """
8
+
9
  import streamlit as st
10
+ from src.models.diarization import SpeakerDiarizer
11
+ from src.models.transcription import Transcriber
12
+ from src.models.summarization import Summarizer
13
+ from src.utils.audio_processor import AudioProcessor
14
+ from src.utils.formatter import TimeFormatter
15
  import os
 
 
 
 
16
 
17
+ # Cache for model loading
18
  @st.cache_resource
19
  def load_models():
20
+ """
21
+ Load and cache all required models.
22
+
23
+ Returns:
24
+ tuple: (diarizer, transcriber, summarizer) or (None, None, None) if loading fails
25
+ """
26
  try:
27
+ diarizer = SpeakerDiarizer(st.secrets["hf_token"])
28
+ diarizer_model = diarizer.load_model()
 
 
29
 
30
+ transcriber = Transcriber()
31
+ transcriber_model = transcriber.load_model()
32
 
33
+ summarizer = Summarizer()
34
+ summarizer_model = summarizer.load_model()
 
 
 
35
 
36
+ if not all([diarizer_model, transcriber_model, summarizer_model]):
37
  raise ValueError("One or more models failed to load")
38
 
39
+ return diarizer, transcriber, summarizer
40
  except Exception as e:
41
  st.error(f"Error loading models: {str(e)}")
42
  st.error("Debug info: Check if HF token is valid and has necessary permissions")
43
  return None, None, None
44
 
45
  def process_audio(audio_file, max_duration=600):
46
+ """
47
+ Process the uploaded audio file through all models.
48
+
49
+ Args:
50
+ audio_file: Uploaded audio file
51
+ max_duration (int): Maximum duration in seconds
52
+
53
+ Returns:
54
+ dict: Processing results containing diarization, transcription, and summary
55
+ """
56
  try:
57
+ # Process audio file
58
+ audio_processor = AudioProcessor()
59
+ tmp_path = audio_processor.standardize_audio(audio_file)
60
 
61
+ # Load models
62
+ diarizer, transcriber, summarizer = load_models()
63
+ if not all([diarizer, transcriber, summarizer]):
64
+ return "Model loading failed"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ # Process with each model
67
+ with st.spinner("Identifying speakers..."):
68
+ diarization_result = diarizer.process(tmp_path)
69
+
70
+ with st.spinner("Transcribing audio..."):
71
+ transcription = transcriber.process(tmp_path)
72
 
73
+ with st.spinner("Generating summary..."):
74
+ summary = summarizer.process(transcription["text"])
 
 
 
75
 
76
+ # Cleanup
77
+ os.unlink(tmp_path)
78
+
79
+ return {
80
+ "diarization": diarization_result,
81
+ "transcription": transcription,
82
+ "summary": summary[0]["summary_text"]
83
+ }
84
+
85
  except Exception as e:
86
  st.error(f"Error processing audio: {str(e)}")
87
  return None
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def main():
90
+ """Main application function."""
91
  st.title("Multi-Speaker Audio Analyzer")
92
  st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")
93
 
 
108
  if results:
109
  tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
110
 
111
+ # Display speaker timeline
112
  with tab1:
113
+ display_speaker_timeline(results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
+ # Display transcription
116
  with tab2:
117
+ display_transcription(results)
 
 
 
 
118
 
119
+ # Display summary
120
  with tab3:
121
+ display_summary(results)
122
+
123
+ def display_speaker_timeline(results):
124
+ """Display speaker diarization results in a timeline format."""
125
+ st.write("Speaker Timeline:")
126
+ segments = TimeFormatter.format_speaker_segments(
127
+ results["diarization"],
128
+ results["transcription"]
129
+ )
130
+
131
+ if segments:
132
+ for segment in segments:
133
+ col1, col2, col3 = st.columns([2,3,5])
134
+
135
+ with col1:
136
+ display_speaker_info(segment)
137
+
138
+ with col2:
139
+ display_timestamp(segment)
140
+
141
+ with col3:
142
+ display_text(segment)
143
+
144
+ st.markdown("---")
145
+ else:
146
+ st.warning("No speaker segments detected")
147
+
148
+ def display_speaker_info(segment):
149
+ """Display speaker information with color coding."""
150
+ speaker_num = int(segment['speaker'].split('_')[1])
151
+ colors = ['🔵', '🔴']
152
+ speaker_color = colors[speaker_num % len(colors)]
153
+ st.write(f"{speaker_color} {segment['speaker']}")
154
+
155
+ def display_timestamp(segment):
156
+ """Display formatted timestamps."""
157
+ start_time = TimeFormatter.format_timestamp(segment['start'])
158
+ end_time = TimeFormatter.format_timestamp(segment['end'])
159
+ st.write(f"{start_time} → {end_time}")
160
+
161
+ def display_text(segment):
162
+ """Display speaker's text."""
163
+ if segment['text']:
164
+ st.write(f"\"{segment['text']}\"")
165
+ else:
166
+ st.write("(no speech detected)")
167
 
168
  if __name__ == "__main__":
169
  main()
src/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Initialize the src package.
3
+ """
4
+ from src.models import SpeakerDiarizer, Transcriber, Summarizer
5
+ from src.utils import AudioProcessor, TimeFormatter
6
+
7
+ __all__ = [
8
+ 'SpeakerDiarizer',
9
+ 'Transcriber',
10
+ 'Summarizer',
11
+ 'AudioProcessor',
12
+ 'TimeFormatter'
13
+ ]
src/models/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Initialize the models package.
3
+ """
4
+ from .diarization import SpeakerDiarizer
5
+ from .transcription import Transcriber
6
+ from .summarization import Summarizer
7
+
8
+ __all__ = ['SpeakerDiarizer', 'Transcriber', 'Summarizer']
src/models/diarization.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speaker Diarization Model Handler
3
+ Manages the pyannote-audio model for speaker diarization tasks.
4
+ """
5
+
6
+ from pyannote.audio import Pipeline
7
+ import streamlit as st
8
+
9
+ class SpeakerDiarizer:
10
+ def __init__(self, token: str):
11
+ """Initialize the diarization model.
12
+
13
+ Args:
14
+ token (str): HuggingFace authentication token
15
+ """
16
+ self.token = token
17
+ self.model = None
18
+
19
+ def load_model(self):
20
+ """Load the pyannote speaker diarization model."""
21
+ try:
22
+ self.model = Pipeline.from_pretrained(
23
+ "pyannote/speaker-diarization",
24
+ use_auth_token=self.token
25
+ )
26
+ return self.model
27
+ except Exception as e:
28
+ st.error(f"Error loading diarization model: {str(e)}")
29
+ return None
30
+
31
+ def process(self, audio_path: str):
32
+ """Process audio file for speaker diarization.
33
+
34
+ Args:
35
+ audio_path (str): Path to the audio file
36
+
37
+ Returns:
38
+ dict: Diarization results
39
+ """
40
+ try:
41
+ return self.model(audio_path)
42
+ except Exception as e:
43
+ st.error(f"Error in diarization: {str(e)}")
44
+ return None
src/models/summarization.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Summarization Model Handler
3
+ Manages the BART model for text summarization.
4
+ """
5
+
6
+ from transformers import pipeline
7
+ import torch
8
+ import streamlit as st
9
+
10
+ class Summarizer:
11
+ def __init__(self):
12
+ """Initialize the summarization model."""
13
+ self.model = None
14
+
15
+ def load_model(self):
16
+ """Load the BART summarization model."""
17
+ try:
18
+ self.model = pipeline(
19
+ "summarization",
20
+ model="facebook/bart-large-cnn",
21
+ device=0 if torch.cuda.is_available() else -1
22
+ )
23
+ return self.model
24
+ except Exception as e:
25
+ st.error(f"Error loading summarization model: {str(e)}")
26
+ return None
27
+
28
+ def process(self, text: str, max_length: int = 130, min_length: int = 30):
29
+ """Process text for summarization.
30
+
31
+ Args:
32
+ text (str): Text to summarize
33
+ max_length (int): Maximum length of summary
34
+ min_length (int): Minimum length of summary
35
+
36
+ Returns:
37
+ str: Summarized text
38
+ """
39
+ try:
40
+ summary = self.model(text, max_length=max_length, min_length=min_length)
41
+ return summary
42
+ except Exception as e:
43
+ st.error(f"Error in summarization: {str(e)}")
44
+ return None
src/models/transcription.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transcription Model Handler
3
+ Manages the Whisper model for speech-to-text transcription.
4
+ """
5
+
6
+ import whisper
7
+ import streamlit as st
8
+
9
+ class Transcriber:
10
+ def __init__(self):
11
+ """Initialize the transcription model."""
12
+ self.model = None
13
+
14
+ def load_model(self):
15
+ """Load the Whisper transcription model."""
16
+ try:
17
+ self.model = whisper.load_model("base")
18
+ return self.model
19
+ except Exception as e:
20
+ st.error(f"Error loading transcription model: {str(e)}")
21
+ return None
22
+
23
+ def process(self, audio_path: str):
24
+ """Process audio file for transcription.
25
+
26
+ Args:
27
+ audio_path (str): Path to the audio file
28
+
29
+ Returns:
30
+ dict: Transcription results
31
+ """
32
+ try:
33
+ return self.model.transcribe(audio_path)
34
+ except Exception as e:
35
+ st.error(f"Error in transcription: {str(e)}")
36
+ return None
src/utils/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ Initialize the utils package.
3
+ """
4
+ from .audio_processor import AudioProcessor
5
+ from .formatter import TimeFormatter
6
+
7
+ __all__ = ['AudioProcessor', 'TimeFormatter']
src/utils/audio_processor.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio Processing Utilities
3
+ Handles audio file preprocessing and standardization.
4
+ """
5
+
6
+ from pydub import AudioSegment
7
+ import io
8
+ import tempfile
9
+ import os
10
+
11
+ class AudioProcessor:
12
+ @staticmethod
13
+ def standardize_audio(audio_file):
14
+ """Standardize audio file to required format.
15
+
16
+ Args:
17
+ audio_file: Uploaded audio file
18
+
19
+ Returns:
20
+ str: Path to processed audio file
21
+ """
22
+ try:
23
+ audio_bytes = io.BytesIO(audio_file.getvalue())
24
+
25
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
26
+ if audio_file.name.lower().endswith('.mp3'):
27
+ audio = AudioSegment.from_mp3(audio_bytes)
28
+ else:
29
+ audio = AudioSegment.from_wav(audio_bytes)
30
+
31
+ audio = audio.set_frame_rate(16000)
32
+ audio = audio.set_channels(1)
33
+ audio = audio.set_sample_width(2)
34
+
35
+ audio.export(
36
+ tmp.name,
37
+ format="wav",
38
+ parameters=["-ac", "1", "-ar", "16000"]
39
+ )
40
+ return tmp.name
41
+
42
+ except Exception as e:
43
+ raise Exception(f"Error processing audio: {str(e)}")
src/utils/formatter.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Formatting utilities for timestamps and speaker segments.
3
+ """
4
+
5
+ class TimeFormatter:
6
+ @staticmethod
7
+ def format_timestamp(seconds: float) -> str:
8
+ """Format seconds into MM:SS.ss format.
9
+
10
+ Args:
11
+ seconds (float): Time in seconds
12
+
13
+ Returns:
14
+ str: Formatted time string
15
+ """
16
+ minutes = int(seconds // 60)
17
+ seconds = seconds % 60
18
+ return f"{minutes:02d}:{seconds:05.2f}"
19
+
20
+ @staticmethod
21
+ def format_speaker_segments(diarization_result, transcription):
22
+ """Format speaker segments with transcribed text.
23
+
24
+ Args:
25
+ diarization_result: Diarization model output
26
+ transcription: Whisper transcription output
27
+
28
+ Returns:
29
+ list: Formatted speaker segments
30
+ """
31
+ if diarization_result is None:
32
+ return []
33
+
34
+ formatted_segments = []
35
+ whisper_segments = transcription.get('segments', [])
36
+
37
+ try:
38
+ for turn, _, speaker in diarization_result.itertracks(yield_label=True):
39
+ current_text = ""
40
+ for w_segment in whisper_segments:
41
+ w_start = float(w_segment['start'])
42
+ w_end = float(w_segment['end'])
43
+
44
+ if (w_start >= turn.start and w_start < turn.end) or \
45
+ (w_end > turn.start and w_end <= turn.end):
46
+ current_text += w_segment['text'].strip() + " "
47
+
48
+ formatted_segments.append({
49
+ 'speaker': str(speaker),
50
+ 'start': float(turn.start),
51
+ 'end': float(turn.end),
52
+ 'text': current_text.strip()
53
+ })
54
+
55
+ except Exception as e:
56
+ print(f"Error formatting segments: {str(e)}")
57
+ return []
58
+
59
+ return formatted_segments