Manyue-DataScientist commited on
Commit
67e41d5
·
verified ·
1 Parent(s): d6e0f11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -22
app.py CHANGED
@@ -11,9 +11,8 @@ import io
11
  @st.cache_resource
12
  def load_models():
13
  try:
14
- # Back to original model name
15
  diarization = Pipeline.from_pretrained(
16
- "pyannote/speaker-diarization", # Original model name
17
  use_auth_token=st.secrets["hf_token"]
18
  )
19
 
@@ -25,7 +24,6 @@ def load_models():
25
  device=0 if torch.cuda.is_available() else -1
26
  )
27
 
28
- # Validate models loaded correctly
29
  if not diarization or not transcriber or not summarizer:
30
  raise ValueError("One or more models failed to load")
31
 
@@ -46,7 +44,6 @@ def process_audio(audio_file, max_duration=600):
46
  else:
47
  audio = AudioSegment.from_wav(audio_bytes)
48
 
49
- # Standardize format
50
  audio = audio.set_frame_rate(16000)
51
  audio = audio.set_channels(1)
52
  audio = audio.set_sample_width(2)
@@ -87,23 +84,46 @@ def process_audio(audio_file, max_duration=600):
87
  st.error(f"Error processing audio: {str(e)}")
88
  return None
89
 
90
- def format_speaker_segments(diarization_result):
91
- if diarization_result is None:
92
  return []
93
 
94
  formatted_segments = []
 
 
 
95
  try:
96
  for turn, _, speaker in diarization_result.itertracks(yield_label=True):
97
- formatted_segments.append({
98
- 'speaker': str(speaker), # Ensure string
99
- 'start': float(turn.start) if turn.start is not None else 0.0,
100
- 'end': float(turn.end) if turn.end is not None else 0.0
101
- })
 
 
 
 
 
 
 
 
 
 
 
102
  except Exception as e:
103
  st.error(f"Error formatting segments: {str(e)}")
104
  return []
105
 
106
- return formatted_segments
 
 
 
 
 
 
 
 
 
107
 
108
  def format_timestamp(seconds):
109
  minutes = int(seconds // 60)
@@ -133,25 +153,25 @@ def main():
133
 
134
  with tab1:
135
  st.write("Speaker Timeline:")
136
- segments = format_speaker_segments(results["diarization"])
137
 
138
- if segments: # Only proceed if we have segments
139
  for segment in segments:
140
- col1, col2 = st.columns([2,8])
141
 
142
  with col1:
143
- try:
144
- speaker_num = int(segment['speaker'].split('_')[1])
145
- colors = ['🔵', '🔴'] # Two colors for alternating speakers
146
- speaker_color = colors[speaker_num % len(colors)]
147
- st.write(f"{speaker_color} {segment['speaker']}")
148
- except (IndexError, ValueError) as e:
149
- st.write(f"⚪ {segment['speaker']}")
150
 
151
  with col2:
152
  start_time = format_timestamp(segment['start'])
153
  end_time = format_timestamp(segment['end'])
154
  st.write(f"{start_time} → {end_time}")
 
 
 
155
 
156
  st.markdown("---")
157
  else:
 
11
  @st.cache_resource
12
  def load_models():
13
  try:
 
14
  diarization = Pipeline.from_pretrained(
15
+ "pyannote/speaker-diarization",
16
  use_auth_token=st.secrets["hf_token"]
17
  )
18
 
 
24
  device=0 if torch.cuda.is_available() else -1
25
  )
26
 
 
27
  if not diarization or not transcriber or not summarizer:
28
  raise ValueError("One or more models failed to load")
29
 
 
44
  else:
45
  audio = AudioSegment.from_wav(audio_bytes)
46
 
 
47
  audio = audio.set_frame_rate(16000)
48
  audio = audio.set_channels(1)
49
  audio = audio.set_sample_width(2)
 
84
  st.error(f"Error processing audio: {str(e)}")
85
  return None
86
 
87
+ def format_speaker_segments(diarization_result, transcription):
88
+ if diarization_result is None or transcription is None:
89
  return []
90
 
91
  formatted_segments = []
92
+ # Get whisper segments that include timestamps and text
93
+ whisper_segments = transcription.get('segments', [])
94
+
95
  try:
96
  for turn, _, speaker in diarization_result.itertracks(yield_label=True):
97
+ # Find matching text from whisper segments
98
+ segment_text = ""
99
+ for ws in whisper_segments:
100
+ # If whisper segment overlaps with diarization segment
101
+ if (float(ws['start']) >= float(turn.start) and
102
+ float(ws['start']) <= float(turn.end)):
103
+ segment_text += ws['text'] + " "
104
+
105
+ # Only add segments that have text
106
+ if segment_text.strip():
107
+ formatted_segments.append({
108
+ 'speaker': str(speaker),
109
+ 'start': float(turn.start),
110
+ 'end': float(turn.end),
111
+ 'text': segment_text.strip()
112
+ })
113
  except Exception as e:
114
  st.error(f"Error formatting segments: {str(e)}")
115
  return []
116
 
117
+ # Sort by start time and handle overlaps
118
+ formatted_segments.sort(key=lambda x: x['start'])
119
+ cleaned_segments = []
120
+ for i, segment in enumerate(formatted_segments):
121
+ # Skip if this segment overlaps with previous one
122
+ if i > 0 and segment['start'] < cleaned_segments[-1]['end']:
123
+ continue
124
+ cleaned_segments.append(segment)
125
+
126
+ return cleaned_segments
127
 
128
  def format_timestamp(seconds):
129
  minutes = int(seconds // 60)
 
153
 
154
  with tab1:
155
  st.write("Speaker Timeline:")
156
+ segments = format_speaker_segments(results["diarization"], results["transcription"])
157
 
158
+ if segments:
159
  for segment in segments:
160
+ col1, col2, col3 = st.columns([2,3,5])
161
 
162
  with col1:
163
+ speaker_num = int(segment['speaker'].split('_')[1])
164
+ colors = ['🔵', '🔴']
165
+ speaker_color = colors[speaker_num % len(colors)]
166
+ st.write(f"{speaker_color} {segment['speaker']}")
 
 
 
167
 
168
  with col2:
169
  start_time = format_timestamp(segment['start'])
170
  end_time = format_timestamp(segment['end'])
171
  st.write(f"{start_time} → {end_time}")
172
+
173
+ with col3:
174
+ st.write(f"\"{segment['text']}\"")
175
 
176
  st.markdown("---")
177
  else: