aadnk commited on
Commit
a1da02d
1 Parent(s): 14761aa

Adding support for diarization

Browse files
app.py CHANGED
@@ -14,6 +14,7 @@ import numpy as np
14
  import torch
15
 
16
  from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
 
17
  from src.hooks.progressListener import ProgressListener
18
  from src.hooks.subTaskProgressListener import SubTaskProgressListener
19
  from src.hooks.whisperProgressHook import create_progress_listener_handle
@@ -73,6 +74,7 @@ class WhisperTranscriber:
73
  self.deleteUploadedFiles = delete_uploaded_files
74
  self.output_dir = output_dir
75
 
 
76
  self.app_config = app_config
77
 
78
  def set_parallel_devices(self, vad_parallel_devices: str):
@@ -89,19 +91,27 @@ class WhisperTranscriber:
89
  # Entry function for the simple tab
90
  def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
91
  vad, vadMergeWindow, vadMaxMergeSize,
92
- word_timestamps: bool = False, highlight_words: bool = False):
 
93
  return self.transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
94
  vad, vadMergeWindow, vadMaxMergeSize,
95
- word_timestamps, highlight_words)
 
96
 
97
  # Entry function for the simple tab progress
98
  def transcribe_webui_simple_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
99
  vad, vadMergeWindow, vadMaxMergeSize,
100
  word_timestamps: bool = False, highlight_words: bool = False,
 
101
  progress=gr.Progress()):
102
 
103
  vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, self.app_config.vad_padding, self.app_config.vad_prompt_window, self.app_config.vad_initial_prompt_mode)
104
 
 
 
 
 
 
105
  return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
106
  word_timestamps=word_timestamps, highlight_words=highlight_words, progress=progress)
107
 
@@ -112,14 +122,18 @@ class WhisperTranscriber:
112
  word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
113
  initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
114
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
115
- compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float):
 
 
116
 
117
  return self.transcribe_webui_full_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
118
  vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
119
  word_timestamps, highlight_words, prepend_punctuations, append_punctuations,
120
  initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
121
  condition_on_previous_text, fp16, temperature_increment_on_fallback,
122
- compression_ratio_threshold, logprob_threshold, no_speech_threshold)
 
 
123
 
124
  # Entry function for the full tab with progress
125
  def transcribe_webui_full_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
@@ -129,6 +143,8 @@ class WhisperTranscriber:
129
  initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
130
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
131
  compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
 
 
132
  progress=gr.Progress()):
133
 
134
  # Handle temperature_increment_on_fallback
@@ -139,6 +155,13 @@ class WhisperTranscriber:
139
 
140
  vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
141
 
 
 
 
 
 
 
 
142
  return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
143
  initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
144
  condition_on_previous_text=condition_on_previous_text, fp16=fp16,
@@ -202,6 +225,19 @@ class WhisperTranscriber:
202
  # Update progress
203
  current_progress += source_audio_duration
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory, highlight_words)
206
 
207
  if len(sources) > 1:
@@ -515,6 +551,17 @@ def create_ui(app_config: ApplicationConfig):
515
  gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words),
516
  ]
517
 
 
 
 
 
 
 
 
 
 
 
 
518
  is_queue_mode = app_config.queue_concurrency_count is not None and app_config.queue_concurrency_count > 0
519
 
520
  simple_transcribe = gr.Interface(fn=ui.transcribe_webui_simple_progress if is_queue_mode else ui.transcribe_webui_simple,
@@ -522,6 +569,7 @@ def create_ui(app_config: ApplicationConfig):
522
  *common_inputs(),
523
  *common_vad_inputs(),
524
  *common_word_timestamps_inputs(),
 
525
  ], outputs=[
526
  gr.File(label="Download"),
527
  gr.Text(label="Transcription"),
@@ -556,6 +604,11 @@ def create_ui(app_config: ApplicationConfig):
556
  gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
557
  gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
558
  gr.Number(label="No speech threshold", value=app_config.no_speech_threshold),
 
 
 
 
 
559
  ], outputs=[
560
  gr.File(label="Download"),
561
  gr.Text(label="Transcription"),
 
14
  import torch
15
 
16
  from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
17
+ from src.diarization.diarization import Diarization
18
  from src.hooks.progressListener import ProgressListener
19
  from src.hooks.subTaskProgressListener import SubTaskProgressListener
20
  from src.hooks.whisperProgressHook import create_progress_listener_handle
 
74
  self.deleteUploadedFiles = delete_uploaded_files
75
  self.output_dir = output_dir
76
 
77
+ self.diarization: Diarization = None
78
  self.app_config = app_config
79
 
80
  def set_parallel_devices(self, vad_parallel_devices: str):
 
91
  # Entry function for the simple tab
92
  def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
93
  vad, vadMergeWindow, vadMaxMergeSize,
94
+ word_timestamps: bool = False, highlight_words: bool = False,
95
+ diarization: bool = False, diarization_speakers: int = 2):
96
  return self.transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
97
  vad, vadMergeWindow, vadMaxMergeSize,
98
+ word_timestamps, highlight_words,
99
+ diarization, diarization_speakers)
100
 
101
  # Entry function for the simple tab progress
102
  def transcribe_webui_simple_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
103
  vad, vadMergeWindow, vadMaxMergeSize,
104
  word_timestamps: bool = False, highlight_words: bool = False,
105
+ diarization: bool = False, diarization_speakers: int = 2,
106
  progress=gr.Progress()):
107
 
108
  vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, self.app_config.vad_padding, self.app_config.vad_prompt_window, self.app_config.vad_initial_prompt_mode)
109
 
110
+ if diarization:
111
+ self.diarization = Diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers)
112
+ else:
113
+ self.diarization = None
114
+
115
  return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
116
  word_timestamps=word_timestamps, highlight_words=highlight_words, progress=progress)
117
 
 
122
  word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
123
  initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
124
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
125
+ compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
126
+ diarization: bool = False, diarization_speakers: int = 2,
127
+ diarization_min_speakers = 1, diarization_max_speakers = 5):
128
 
129
  return self.transcribe_webui_full_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
130
  vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
131
  word_timestamps, highlight_words, prepend_punctuations, append_punctuations,
132
  initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
133
  condition_on_previous_text, fp16, temperature_increment_on_fallback,
134
+ compression_ratio_threshold, logprob_threshold, no_speech_threshold,
135
+ diarization, diarization_speakers,
136
+ diarization_min_speakers, diarization_max_speakers)
137
 
138
  # Entry function for the full tab with progress
139
  def transcribe_webui_full_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
 
143
  initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
144
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
145
  compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
146
+ diarization: bool = False, diarization_speakers: int = 2,
147
+ diarization_min_speakers = 1, diarization_max_speakers = 5,
148
  progress=gr.Progress()):
149
 
150
  # Handle temperature_increment_on_fallback
 
155
 
156
  vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
157
 
158
+ # Set diarization
159
+ if diarization:
160
+ self.diarization = Diarization(auth_token=self.app_config.auth_token, num_speakers=diarization_speakers,
161
+ min_speakers=diarization_min_speakers, max_speakers=diarization_max_speakers)
162
+ else:
163
+ self.diarization = None
164
+
165
  return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
166
  initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
167
  condition_on_previous_text=condition_on_previous_text, fp16=fp16,
 
225
  # Update progress
226
  current_progress += source_audio_duration
227
 
228
+ # Diarization
229
+ if self.diarization:
230
+ print("Diarizing ", source.source_path)
231
+ diarization_result = list(self.diarization.run(source.source_path))
232
+
233
+ # Print result
234
+ print("Diarization result: ")
235
+ for entry in diarization_result:
236
+ print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
237
+
238
+ # Add speakers to result
239
+ result = self.diarization.mark_speakers(diarization_result, result)
240
+
241
  source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory, highlight_words)
242
 
243
  if len(sources) > 1:
 
551
  gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words),
552
  ]
553
 
554
+ has_diarization_libs = Diarization.has_libraries()
555
+
556
+ if not has_diarization_libs:
557
+ print("Diarization libraries not found - disabling diarization")
558
+ app_config.diarization = False
559
+
560
+ common_diarization_inputs = lambda : [
561
+ gr.Checkbox(label="Diarization", value=app_config.diarization, interactive=has_diarization_libs),
562
+ gr.Number(label="Diarization - Speakers", precision=0, value=app_config.diarization_speakers, interactive=has_diarization_libs)
563
+ ]
564
+
565
  is_queue_mode = app_config.queue_concurrency_count is not None and app_config.queue_concurrency_count > 0
566
 
567
  simple_transcribe = gr.Interface(fn=ui.transcribe_webui_simple_progress if is_queue_mode else ui.transcribe_webui_simple,
 
569
  *common_inputs(),
570
  *common_vad_inputs(),
571
  *common_word_timestamps_inputs(),
572
+ *common_diarization_inputs(),
573
  ], outputs=[
574
  gr.File(label="Download"),
575
  gr.Text(label="Transcription"),
 
604
  gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
605
  gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
606
  gr.Number(label="No speech threshold", value=app_config.no_speech_threshold),
607
+
608
+ *common_diarization_inputs(),
609
+ gr.Number(label="Diarization - Min Speakers", precision=0, value=app_config.diarization_min_speakers, interactive=has_diarization_libs),
610
+ gr.Number(label="Diarization - Max Speakers", precision=0, value=app_config.diarization_max_speakers, interactive=has_diarization_libs),
611
+
612
  ], outputs=[
613
  gr.File(label="Download"),
614
  gr.Text(label="Transcription"),
cli.py CHANGED
@@ -8,6 +8,7 @@ import numpy as np
8
  import torch
9
  from app import VadOptions, WhisperTranscriber
10
  from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
 
11
  from src.download import download_url
12
  from src.languages import get_language_names
13
 
@@ -106,6 +107,14 @@ def cli():
106
  parser.add_argument("--threads", type=optional_int, default=0,
107
  help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
108
 
 
 
 
 
 
 
 
 
109
  args = parser.parse_args().__dict__
110
  model_name: str = args.pop("model")
111
  model_dir: str = args.pop("model_dir")
@@ -142,10 +151,19 @@ def cli():
142
  compute_type = args.pop("compute_type")
143
  highlight_words = args.pop("highlight_words")
144
 
 
 
 
 
 
 
145
  transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
146
  transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
147
  transcriber.set_auto_parallel(auto_parallel)
148
 
 
 
 
149
  model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
150
  device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
151
 
 
8
  import torch
9
  from app import VadOptions, WhisperTranscriber
10
  from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
11
+ from src.diarization.diarization import Diarization
12
  from src.download import download_url
13
  from src.languages import get_language_names
14
 
 
107
  parser.add_argument("--threads", type=optional_int, default=0,
108
  help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
109
 
110
+ # Diarization
111
+ parser.add_argument('--auth_token', type=str, default=None, help='HuggingFace API Token (optional)')
112
+ parser.add_argument("--diarization", type=str2bool, default=app_config.diarization, \
113
+ help="whether to perform speaker diarization")
114
+ parser.add_argument("--num_speakers", type=int, default=None, help="Number of speakers")
115
+ parser.add_argument("--min_speakers", type=int, default=None, help="Minimum number of speakers")
116
+ parser.add_argument("--max_speakers", type=int, default=None, help="Maximum number of speakers")
117
+
118
  args = parser.parse_args().__dict__
119
  model_name: str = args.pop("model")
120
  model_dir: str = args.pop("model_dir")
 
151
  compute_type = args.pop("compute_type")
152
  highlight_words = args.pop("highlight_words")
153
 
154
+ diarization = args.pop("diarization")
155
+ auth_token = args.pop("auth_token")
156
+ num_speakers = args.pop("num_speakers")
157
+ min_speakers = args.pop("min_speakers")
158
+ max_speakers = args.pop("max_speakers")
159
+
160
  transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
161
  transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
162
  transcriber.set_auto_parallel(auto_parallel)
163
 
164
+ if diarization:
165
+ transcriber.set_diarization(Diarization(auth_token=auth_token, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers))
166
+
167
  model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
168
  device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
169
 
config.json5 CHANGED
@@ -138,4 +138,15 @@
138
  "append_punctuations": "\"\'.。,,!!??::”)]}、",
139
  // (requires --word_timestamps True) underline each word as it is spoken in srt and vtt
140
  "highlight_words": false,
 
 
 
 
 
 
 
 
 
 
 
141
  }
 
138
  "append_punctuations": "\"\'.。,,!!??::”)]}、",
139
  // (requires --word_timestamps True) underline each word as it is spoken in srt and vtt
140
  "highlight_words": false,
141
+
142
+ // Diarization settings
143
+ "auth_token": null,
144
+ // Whether to perform speaker diarization
145
+ "diarization": false,
146
+ // The number of speakers to detect
147
+ "diarization_speakers": 2,
148
+ // The minimum number of speakers to detect
149
+ "diarization_min_speakers": 1,
150
+ // The maximum number of speakers to detect
151
+ "diarization_max_speakers": 5,
152
  }
src/config.py CHANGED
@@ -69,7 +69,10 @@ class ApplicationConfig:
69
  # Word timestamp settings
70
  word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
71
  append_punctuations: str = "\"\'.。,,!!??::”)]}、",
72
- highlight_words: bool = False):
 
 
 
73
 
74
  self.models = models
75
 
@@ -121,6 +124,13 @@ class ApplicationConfig:
121
  self.append_punctuations = append_punctuations
122
  self.highlight_words = highlight_words
123
 
 
 
 
 
 
 
 
124
  def get_model_names(self):
125
  return [ x.name for x in self.models ]
126
 
 
69
  # Word timestamp settings
70
  word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
71
  append_punctuations: str = "\"\'.。,,!!??::”)]}、",
72
+ highlight_words: bool = False,
73
+ # Diarization
74
+ auth_token: str = None, diarization: bool = False, diarization_speakers: int = 2,
75
+ diarization_min_speakers: int = 1, diarization_max_speakers: int = 5):
76
 
77
  self.models = models
78
 
 
124
  self.append_punctuations = append_punctuations
125
  self.highlight_words = highlight_words
126
 
127
+ # Diarization settings
128
+ self.auth_token = auth_token
129
+ self.diarization = diarization
130
+ self.diarization_speakers = diarization_speakers
131
+ self.diarization_min_speakers = diarization_min_speakers
132
+ self.diarization_max_speakers = diarization_max_speakers
133
+
134
  def get_model_names(self):
135
  return [ x.name for x in self.models ]
136
 
src/diarization/diarization.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ from pathlib import Path
5
+ import tempfile
6
+ from typing import TYPE_CHECKING, List
7
+ import torch
8
+
9
+ import ffmpeg
10
+
11
+ from src.diarization.transcriptLoader import load_transcript
12
+ from src.utils import write_srt
13
+
14
+ class DiarizationEntry:
15
+ def __init__(self, start, end, speaker):
16
+ self.start = start
17
+ self.end = end
18
+ self.speaker = speaker
19
+
20
+ def __repr__(self):
21
+ return f"<DiarizationEntry start={self.start} end={self.end} speaker={self.speaker}>"
22
+
23
+ def toJson(self):
24
+ return {
25
+ "start": self.start,
26
+ "end": self.end,
27
+ "speaker": self.speaker
28
+ }
29
+
30
+ class Diarization:
31
+ def __init__(self, auth_token=None, **kwargs):
32
+ if auth_token is None:
33
+ auth_token = os.environ.get("HK_ACCESS_TOKEN")
34
+ if auth_token is None:
35
+ raise ValueError("No HuggingFace API Token provided - please use the --auth_token argument or set the HK_ACCESS_TOKEN environment variable")
36
+
37
+ self.auth_token = auth_token
38
+ self.initialized = False
39
+ self.pipeline = None
40
+ self.pipeline_kwargs = kwargs
41
+
42
+ @staticmethod
43
+ def has_libraries():
44
+ try:
45
+ import pyannote.audio
46
+ import intervaltree
47
+ return True
48
+ except ImportError:
49
+ return False
50
+
51
+ def initialize(self):
52
+ if self.initialized:
53
+ return
54
+ from pyannote.audio import Pipeline
55
+
56
+ self.pipeline = Pipeline.from_pretrained("pyannote/[email protected]", use_auth_token=self.auth_token)
57
+
58
+ # Load GPU mode if available
59
+ device = "cuda" if torch.cuda.is_available() else "cpu"
60
+ if device == "cuda":
61
+ print("Diarization - using GPU")
62
+ self.pipeline = self.pipeline.to(torch.device(0))
63
+ else:
64
+ print("Diarization - using CPU")
65
+
66
+ def run(self, audio_file):
67
+ self.initialize()
68
+ audio_file_obj = Path(audio_file)
69
+
70
+ # Supported file types in soundfile is WAV, FLAC, OGG and MAT
71
+ if audio_file_obj.suffix in [".wav", ".flac", ".ogg", ".mat"]:
72
+ target_file = audio_file
73
+ else:
74
+ # Create temp WAV file
75
+ target_file = tempfile.mktemp(prefix="diarization_", suffix=".wav")
76
+ try:
77
+ ffmpeg.input(audio_file).output(target_file, ac=1).run()
78
+ except ffmpeg.Error as e:
79
+ print(f"Error occurred during audio conversion: {e.stderr}")
80
+
81
+ diarization = self.pipeline(target_file, **self.pipeline_kwargs)
82
+
83
+ if target_file != audio_file:
84
+ # Delete temp file
85
+ os.remove(target_file)
86
+
87
+ # Yield result
88
+ for turn, _, speaker in diarization.itertracks(yield_label=True):
89
+ yield DiarizationEntry(turn.start, turn.end, speaker)
90
+
91
+ def mark_speakers(self, diarization_result: List[DiarizationEntry], whisper_result: dict):
92
+ from intervaltree import IntervalTree
93
+ result = whisper_result.copy()
94
+
95
+ # Create an interval tree from the diarization results
96
+ tree = IntervalTree()
97
+ for entry in diarization_result:
98
+ tree[entry.start:entry.end] = entry
99
+
100
+ # Iterate through each segment in the Whisper JSON
101
+ for segment in result["segments"]:
102
+ segment_start = segment["start"]
103
+ segment_end = segment["end"]
104
+
105
+ # Find overlapping speakers using the interval tree
106
+ overlapping_speakers = tree[segment_start:segment_end]
107
+
108
+ # If no speakers overlap with this segment, skip it
109
+ if not overlapping_speakers:
110
+ continue
111
+
112
+ # If multiple speakers overlap with this segment, choose the one with the longest duration
113
+ longest_speaker = None
114
+ longest_duration = 0
115
+
116
+ for speaker_interval in overlapping_speakers:
117
+ overlap_start = max(speaker_interval.begin, segment_start)
118
+ overlap_end = min(speaker_interval.end, segment_end)
119
+ overlap_duration = overlap_end - overlap_start
120
+
121
+ if overlap_duration > longest_duration:
122
+ longest_speaker = speaker_interval.data.speaker
123
+ longest_duration = overlap_duration
124
+
125
+ # Add speakers
126
+ segment["longest_speaker"] = longest_speaker
127
+ segment["speakers"] = list([speaker_interval.data.toJson() for speaker_interval in overlapping_speakers])
128
+
129
+ # The write_srt will use the longest_speaker if it exist, and add it to the text field
130
+
131
+ return result
132
+
133
+ def _write_file(input_file: str, output_path: str, output_extension: str, file_writer: lambda f: None):
134
+ if input_file is None:
135
+ raise ValueError("input_file is required")
136
+ if file_writer is None:
137
+ raise ValueError("file_writer is required")
138
+
139
+ # Write file
140
+ if output_path is None:
141
+ effective_path = os.path.splitext(input_file)[0] + "_output" + output_extension
142
+ else:
143
+ effective_path = output_path
144
+
145
+ with open(effective_path, 'w+', encoding="utf-8") as f:
146
+ file_writer(f)
147
+
148
+ print(f"Output saved to {effective_path}")
149
+
150
+ def main():
151
+ parser = argparse.ArgumentParser(description='Add speakers to a SRT file or Whisper JSON file using pyannote/speaker-diarization.')
152
+ parser.add_argument('audio_file', type=str, help='Input audio file')
153
+ parser.add_argument('whisper_file', type=str, help='Input Whisper JSON/SRT file')
154
+ parser.add_argument('--output_json_file', type=str, default=None, help='Output JSON file (optional)')
155
+ parser.add_argument('--output_srt_file', type=str, default=None, help='Output SRT file (optional)')
156
+ parser.add_argument('--auth_token', type=str, default=None, help='HuggingFace API Token (optional)')
157
+ parser.add_argument("--max_line_width", type=int, default=40, help="Maximum line width for SRT file (default: 40)")
158
+ parser.add_argument("--num_speakers", type=int, default=None, help="Number of speakers")
159
+ parser.add_argument("--min_speakers", type=int, default=None, help="Minimum number of speakers")
160
+ parser.add_argument("--max_speakers", type=int, default=None, help="Maximum number of speakers")
161
+
162
+ args = parser.parse_args()
163
+
164
+ print("\nReading whisper JSON from " + args.whisper_file)
165
+
166
+ # Read whisper JSON or SRT file
167
+ whisper_result = load_transcript(args.whisper_file)
168
+
169
+ diarization = Diarization(auth_token=args.auth_token, num_speakers=args.num_speakers, min_speakers=args.min_speakers, max_speakers=args.max_speakers)
170
+ diarization_result = list(diarization.run(args.audio_file))
171
+
172
+ # Print result
173
+ print("Diarization result:")
174
+ for entry in diarization_result:
175
+ print(f" start={entry.start:.1f}s stop={entry.end:.1f}s speaker_{entry.speaker}")
176
+
177
+ marked_whisper_result = diarization.mark_speakers(diarization_result, whisper_result)
178
+
179
+ # Write output JSON to file
180
+ _write_file(args.whisper_file, args.output_json_file, ".json",
181
+ lambda f: json.dump(marked_whisper_result, f, indent=4, ensure_ascii=False))
182
+
183
+ # Write SRT
184
+ _write_file(args.whisper_file, args.output_srt_file, ".srt",
185
+ lambda f: write_srt(marked_whisper_result["segments"], f, maxLineWidth=args.max_line_width))
186
+
187
+ if __name__ == "__main__":
188
+ main()
src/diarization/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ intervaltree
2
+ srt
3
+ torch
4
+ ffmpeg-python==0.2.0
5
+ https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip
src/diarization/transcriptLoader.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ from pathlib import Path
4
+
5
+ def load_transcript_json(transcript_file: str):
6
+ """
7
+ Parse a Whisper JSON file into a Whisper JSON object
8
+
9
+ # Parameters:
10
+ transcript_file (str): Path to the Whisper JSON file
11
+ """
12
+ with open(transcript_file, "r", encoding="utf-8") as f:
13
+ whisper_result = json.load(f)
14
+
15
+ # Format of Whisper JSON file:
16
+ # {
17
+ # "text": " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.",
18
+ # "segments": [
19
+ # {
20
+ # "text": " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.",
21
+ # "start": 0.0,
22
+ # "end": 10.36,
23
+ # "words": [
24
+ # {
25
+ # "start": 0.0,
26
+ # "end": 0.56,
27
+ # "word": " And",
28
+ # "probability": 0.61767578125
29
+ # },
30
+ # {
31
+ # "start": 0.56,
32
+ # "end": 0.88,
33
+ # "word": " so",
34
+ # "probability": 0.9033203125
35
+ # },
36
+ # etc.
37
+
38
+ return whisper_result
39
+
40
+
41
+ def load_transcript_srt(subtitle_file: str):
42
+ import srt
43
+
44
+ """
45
+ Parse a SRT file into a Whisper JSON object
46
+
47
+ # Parameters:
48
+ subtitle_file (str): Path to the SRT file
49
+ """
50
+ with open(subtitle_file, "r", encoding="utf-8") as f:
51
+ subs = srt.parse(f)
52
+
53
+ whisper_result = {
54
+ "text": "",
55
+ "segments": []
56
+ }
57
+
58
+ for sub in subs:
59
+ # Subtitle(index=1, start=datetime.timedelta(seconds=33, microseconds=843000), end=datetime.timedelta(seconds=38, microseconds=97000), content='地球上只有3%的水是淡水', proprietary='')
60
+ segment = {
61
+ "text": sub.content,
62
+ "start": sub.start.total_seconds(),
63
+ "end": sub.end.total_seconds(),
64
+ "words": []
65
+ }
66
+ whisper_result["segments"].append(segment)
67
+ whisper_result["text"] += sub.content
68
+
69
+ return whisper_result
70
+
71
+ def load_transcript(file: str):
72
+ # Determine file type
73
+ file_extension = Path(file).suffix.lower()
74
+
75
+ if file_extension == ".json":
76
+ return load_transcript_json(file)
77
+ elif file_extension == ".srt":
78
+ return load_transcript_srt(file)
79
+ else:
80
+ raise ValueError(f"Unsupported file type: {file_extension}")
src/utils.py CHANGED
@@ -102,17 +102,26 @@ def write_srt(transcript: Iterator[dict], file: TextIO,
102
 
103
  def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: int = None, highlight_words: bool = False):
104
  for segment in transcript:
105
- words = segment.get('words', [])
 
 
 
106
 
107
  if len(words) == 0:
108
  # Yield the segment as-is or processed
109
- if maxLineWidth is None or maxLineWidth < 0:
110
  yield segment
111
  else:
 
 
 
 
 
 
112
  yield {
113
  'start': segment['start'],
114
  'end': segment['end'],
115
- 'text': process_text(segment['text'].strip(), maxLineWidth)
116
  }
117
  # We are done
118
  continue
@@ -120,9 +129,17 @@ def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: i
120
  subtitle_start = segment['start']
121
  subtitle_end = segment['end']
122
 
 
 
 
 
 
 
 
 
123
  text_words = [ this_word["word"] for this_word in words ]
124
  subtitle_text = __join_words(text_words, maxLineWidth)
125
-
126
  # Iterate over the words in the segment
127
  if highlight_words:
128
  last = subtitle_start
 
102
 
103
  def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: int = None, highlight_words: bool = False):
104
  for segment in transcript:
105
+ words: list = segment.get('words', [])
106
+
107
+ # Append longest speaker ID if available
108
+ segment_longest_speaker = segment.get('longest_speaker', None)
109
 
110
  if len(words) == 0:
111
  # Yield the segment as-is or processed
112
+ if (maxLineWidth is None or maxLineWidth < 0) and segment_longest_speaker is None:
113
  yield segment
114
  else:
115
+ text = segment['text'].strip()
116
+
117
+ # Prepend the longest speaker ID if available
118
+ if segment_longest_speaker is not None:
119
+ text = f"({segment_longest_speaker}) {text}"
120
+
121
  yield {
122
  'start': segment['start'],
123
  'end': segment['end'],
124
+ 'text': process_text(text, maxLineWidth)
125
  }
126
  # We are done
127
  continue
 
129
  subtitle_start = segment['start']
130
  subtitle_end = segment['end']
131
 
132
+ if segment_longest_speaker is not None:
133
+ # Add the beginning
134
+ words.insert(0, {
135
+ 'start': subtitle_start,
136
+ 'end': subtitle_start,
137
+ 'word': f"({segment_longest_speaker})"
138
+ })
139
+
140
  text_words = [ this_word["word"] for this_word in words ]
141
  subtitle_text = __join_words(text_words, maxLineWidth)
142
+
143
  # Iterate over the words in the segment
144
  if highlight_words:
145
  last = subtitle_start