Taejin commited on
Commit
5917f0a
·
verified ·
1 Parent(s): a1e6361

Uploading ngram base model

Browse files
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,142 @@
1
- ---
2
- license: cc-by-4.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # llm_speaker_tagging
2
+
3
+ SLT 2024 Challenge: Post-ASR-Speaker-Tagging Baseline
4
+
5
+ # Project Name
6
+
7
+ SLT 2024 Challenge GenSEC Track 2: Post-ASR-Speaker-Tagging Baseline
8
+
9
+ ## Features
10
+
11
+ - Data download and cleaning
12
+ - n-gram + beam search decoder based baselinee system
13
+
14
+ ## Installation
15
+
16
+ Run the following commands at the main level of this repository.
17
+
18
+ ### Conda Environment
19
+
20
+ ```
21
+ conda create --name llmspk python=3.10
22
+ ```
23
+ ### Install requirements
24
+
25
+ You need to install the following packages
26
+
27
+ ```
28
+ kenlm
29
+ arpa
30
+ numpy
31
+ hydra-core
32
+ meeteval
33
+ tqdm
34
+ requests
35
+ simplejson
36
+ pydiardecode @ git+https://github.com/tango4j/pydiardecode@main
37
+ ```
38
+
39
+ Simply install all the requirments.
40
+
41
+ ```
42
+ pip install -r requirements.txt
43
+ ```
44
+
45
+ ### Download ARPA language model
46
+
47
+ ```
48
+ mkdir -p arpa_model
49
+ cd arpa_model
50
+ wget https://kaldi-asr.org/models/5/4gram_small.arpa.gz
51
+ gunzip 4gram_small.arpa.gz
52
+ ```
53
+
54
+ ### Download track-2 challenge dev set and eval set
55
+
56
+ Clone the dataset from Hugging Face server.
57
+ ```
58
+ git clone https://huggingface.co/datasets/GenSEC-LLM/SLT-Task2-Post-ASR-Speaker-Tagging
59
+ ```
60
+
61
+ ```
62
+ find . $PWD/SLT-Task2-Post-ASR-Speaker-Tagging/err_source_text/dev -name *.seglst.json > err_dev.src.list
63
+ find . $PWD/SLT-Task2-Post-ASR-Speaker-Tagging/ref_annotated_text/dev -name *.seglst.json > err_dev.ref.list
64
+ ```
65
+
66
+ ### Launch the baseline script
67
+
68
+ Now you are ready to launch the script.
69
+ Launch the baseline script `run_speaker_tagging_beam_search.sh`
70
+
71
+ ```
72
+ BASEPATH=${PWD}
73
+ DIAR_LM_PATH=$BASEPATH/arpa_model/4gram_small.arpa
74
+ ASRDIAR_FILE_NAME=err_dev
75
+ WORKSPACE=$BASEPATH/SLT-Task2-Post-ASR-Speaker-Tagging
76
+ INPUT_ERROR_SRC_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.src.list
77
+ GROUNDTRUTH_REF_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.ref.list
78
+ DIAR_OUT_DOWNLOAD=$WORKSPACE/short2_all_seglst_infer
79
+ mkdir -p $DIAR_OUT_DOWNLOAD
80
+
81
+ ### SLT 2024 Speaker Tagging Setting v1.0.2
82
+ ALPHA=0.4
83
+ BETA=0.04
84
+ PARALLEL_CHUNK_WORD_LEN=100
85
+ BEAM_WIDTH=16
86
+ WORD_WINDOW=32
87
+ PEAK_PROB=0.95
88
+ USE_NGRAM=True
89
+ LM_METHOD=ngram
90
+
91
+ # Get the base name of the test_manifest and remove extension
92
+ UNIQ_MEMO=$(basename "${INPUT_ERROR_SRC_LIST_PATH}" .json | sed 's/\./_/g')
93
+ echo "UNIQ MEMO:" $UNIQ_MEMO
94
+ TRIAL=telephonic
95
+ BATCH_SIZE=11
96
+
97
+ rm $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst.json
98
+ rm $WORKSPACE/$ASRDIAR_FILE_NAME.ref.seglst.json
99
+ rm $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst.json
100
+
101
+ python $BASEPATH/speaker_tagging_beamsearch.py \
102
+ port=[5501,5502,5511,5512,5521,5522,5531,5532] \
103
+ arpa_language_model=$DIAR_LM_PATH \
104
+ batch_size=$BATCH_SIZE \
105
+ groundtruth_ref_list_path=$GROUNDTRUTH_REF_LIST_PATH \
106
+ input_error_src_list_path=$INPUT_ERROR_SRC_LIST_PATH \
107
+ parallel_chunk_word_len=$PARALLEL_CHUNK_WORD_LEN \
108
+ use_ngram=$USE_NGRAM \
109
+ alpha=$ALPHA \
110
+ beta=$BETA \
111
+ beam_width=$BEAM_WIDTH \
112
+ word_window=$WORD_WINDOW \
113
+ peak_prob=$PEAK_PROB \
114
+ out_dir=$DIAR_OUT_DOWNLOAD
115
+ ```
116
+
117
+ ### Evaluate
118
+
119
+ We use [MeetEval](https://github.com/fgnt/meeteval) software to evaluate `cpWER`.
120
+ cpWER measures both speaker tagging and word error rate (WER) by testing all the permutation of trancripts and choosing the permutation that
121
+ gives the lowest error.
122
+
123
+ ```
124
+ echo "Evaluating the original source transcript."
125
+ meeteval-wer cpwer -h $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst.json -r $WORKSPACE/$ASRDIAR_FILE_NAME.ref.seglst.json
126
+ echo "Source cpWER: " $(jq '.error_rate' "[ $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst_cpwer.json) ]"
127
+
128
+ echo "Evaluating the original hypothesis transcript."
129
+ meeteval-wer cpwer -h $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst.json -r $WORKSPACE/$ASRDIAR_FILE_NAME.ref.seglst.json
130
+ echo "Hypothesis cpWER: " $(jq '.error_rate' $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst_cpwer.json)
131
+ ```
132
+
133
+ ### Reference
134
+
135
+ @inproceedings{park2024enhancing,
136
+ title={Enhancing speaker diarization with large language models: A contextual beam search approach},
137
+ author={Park, Tae Jin and Dhawan, Kunal and Koluguri, Nithin and Balam, Jagadeesh},
138
+ booktitle={ICASSP 2024-2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
139
+ pages={10861--10865},
140
+ year={2024},
141
+ organization={IEEE}
142
+ }
beam_search_utils.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from typing import Dict, List
3
+ from pydiardecode import build_diardecoder
4
+ import numpy as np
5
+ import copy
6
+ import os
7
+ import json
8
+ import concurrent.futures
9
+ import kenlm
10
+
11
+ __INFO_TAG__ = "[INFO]"
12
+
13
+ class SpeakerTaggingBeamSearchDecoder:
14
+ def __init__(self, loaded_kenlm_model: kenlm, cfg: dict):
15
+ self.realigning_lm_params = cfg
16
+ self.realigning_lm = self._load_realigning_LM(loaded_kenlm_model=loaded_kenlm_model)
17
+ self._SPLITSYM = "@"
18
+
19
+ def _load_realigning_LM(self, loaded_kenlm_model: kenlm):
20
+ """
21
+ Load ARPA language model for realigning speaker labels for words.
22
+ """
23
+ diar_decoder = build_diardecoder(
24
+ loaded_kenlm_model=loaded_kenlm_model,
25
+ kenlm_model_path=self.realigning_lm_params['arpa_language_model'],
26
+ alpha=self.realigning_lm_params['alpha'],
27
+ beta=self.realigning_lm_params['beta'],
28
+ word_window=self.realigning_lm_params['word_window'],
29
+ use_ngram=self.realigning_lm_params['use_ngram'],
30
+ )
31
+ return diar_decoder
32
+
33
+ def realign_words_with_lm(self, word_dict_seq_list: List[Dict[str, float]], speaker_count: int = None, port_num=None) -> List[Dict[str, float]]:
34
+ if speaker_count is None:
35
+ spk_list = []
36
+ for k, line_dict in enumerate(word_dict_seq_list):
37
+ _, spk_label = line_dict['word'], line_dict['speaker']
38
+ spk_list.append(spk_label)
39
+ else:
40
+ spk_list = [ f"speaker_{k}" for k in range(speaker_count)]
41
+
42
+ realigned_list = self.realigning_lm.decode_beams(beam_width=self.realigning_lm_params['beam_width'],
43
+ speaker_list=sorted(list(set(spk_list))),
44
+ word_dict_seq_list=word_dict_seq_list,
45
+ port_num=port_num)
46
+ return realigned_list
47
+
48
+ def beam_search_diarization(
49
+ self,
50
+ trans_info_dict: Dict[str, Dict[str, list]],
51
+ port_num: List[int] = None,
52
+ ) -> Dict[str, Dict[str, float]]:
53
+ """
54
+ Match the diarization result with the ASR output.
55
+ The words and the timestamps for the corresponding words are matched in a for loop.
56
+
57
+ Args:
58
+
59
+ Returns:
60
+ trans_info_dict (dict):
61
+ Dictionary containing word timestamps, speaker labels and words from all sessions.
62
+ Each session is indexed by a unique ID.
63
+ """
64
+ for uniq_id, session_dict in tqdm(trans_info_dict.items(), total=len(trans_info_dict), disable=True):
65
+ word_dict_seq_list = session_dict['words']
66
+ output_beams = self.realign_words_with_lm(word_dict_seq_list=word_dict_seq_list, speaker_count=session_dict['speaker_count'], port_num=port_num)
67
+ word_dict_seq_list = output_beams[0][2]
68
+ trans_info_dict[uniq_id]['words'] = word_dict_seq_list
69
+ return trans_info_dict
70
+
71
+ def merge_div_inputs(self, div_trans_info_dict, org_trans_info_dict, win_len=250, word_window=16):
72
+ """
73
+ Merge the outputs of parallel processing.
74
+ """
75
+ uniq_id_list = list(org_trans_info_dict.keys())
76
+ sub_div_dict = {}
77
+ for seq_id in div_trans_info_dict.keys():
78
+ div_info = seq_id.split(self._SPLITSYM)
79
+ uniq_id, sub_idx, total_count = div_info[0], int(div_info[1]), int(div_info[2])
80
+ if uniq_id not in sub_div_dict:
81
+ sub_div_dict[uniq_id] = [None] * total_count
82
+ sub_div_dict[uniq_id][sub_idx] = div_trans_info_dict[seq_id]['words']
83
+
84
+ for uniq_id in uniq_id_list:
85
+ org_trans_info_dict[uniq_id]['words'] = []
86
+ for k, div_words in enumerate(sub_div_dict[uniq_id]):
87
+ if k == 0:
88
+ div_words = div_words[:win_len]
89
+ else:
90
+ div_words = div_words[word_window:]
91
+ org_trans_info_dict[uniq_id]['words'].extend(div_words)
92
+ return org_trans_info_dict
93
+
94
+ def divide_chunks(self, trans_info_dict, win_len, word_window, port):
95
+ """
96
+ Divide word sequence into chunks of length `win_len` for parallel processing.
97
+
98
+ Args:
99
+ trans_info_dict (_type_): _description_
100
+ diar_logits (_type_): _description_
101
+ win_len (int, optional): _description_. Defaults to 250.
102
+ """
103
+ if len(port) > 1:
104
+ num_workers = len(port)
105
+ else:
106
+ num_workers = 1
107
+ div_trans_info_dict = {}
108
+ for uniq_id in trans_info_dict.keys():
109
+ uniq_trans = trans_info_dict[uniq_id]
110
+ del uniq_trans['status']
111
+ del uniq_trans['transcription']
112
+ del uniq_trans['sentences']
113
+ word_seq = uniq_trans['words']
114
+
115
+ div_word_seq = []
116
+ if win_len is None:
117
+ win_len = int(np.ceil(len(word_seq)/num_workers))
118
+ n_chunks = int(np.ceil(len(word_seq)/win_len))
119
+
120
+ for k in range(n_chunks):
121
+ div_word_seq.append(word_seq[max(k*win_len - word_window, 0):(k+1)*win_len])
122
+
123
+ total_count = len(div_word_seq)
124
+ for k, w_seq in enumerate(div_word_seq):
125
+ seq_id = uniq_id + f"{self._SPLITSYM}{k}{self._SPLITSYM}{total_count}"
126
+ div_trans_info_dict[seq_id] = dict(uniq_trans)
127
+ div_trans_info_dict[seq_id]['words'] = w_seq
128
+ return div_trans_info_dict
129
+
130
+
131
+ def run_mp_beam_search_decoding(
132
+ speaker_beam_search_decoder,
133
+ loaded_kenlm_model,
134
+ trans_info_dict,
135
+ org_trans_info_dict,
136
+ div_mp,
137
+ win_len,
138
+ word_window,
139
+ port=None,
140
+ use_ngram=False
141
+ ):
142
+ if len(port) > 1:
143
+ port = [int(p) for p in port]
144
+ if use_ngram:
145
+ port = [None]
146
+ num_workers = 36
147
+ else:
148
+ num_workers = len(port)
149
+
150
+ uniq_id_list = sorted(list(trans_info_dict.keys() ))
151
+ tp = concurrent.futures.ProcessPoolExecutor(max_workers=num_workers)
152
+ futures = []
153
+
154
+ count = 0
155
+ for uniq_id in uniq_id_list:
156
+ print(f"{__INFO_TAG__} Running beam search decoding for {uniq_id}...")
157
+ if port is not None:
158
+ port_num = port[count % len(port)]
159
+ else:
160
+ port_num = None
161
+ count += 1
162
+ uniq_trans_info_dict = {uniq_id: trans_info_dict[uniq_id]}
163
+ futures.append(tp.submit(speaker_beam_search_decoder.beam_search_diarization, uniq_trans_info_dict, port_num=port_num))
164
+
165
+ pbar = tqdm(total=len(uniq_id_list), desc="Running beam search decoding", unit="files")
166
+ count = 0
167
+ output_trans_info_dict = {}
168
+ for done_future in concurrent.futures.as_completed(futures):
169
+ count += 1
170
+ pbar.update()
171
+ output_trans_info_dict.update(done_future.result())
172
+ pbar.close()
173
+ tp.shutdown()
174
+ if div_mp:
175
+ output_trans_info_dict = speaker_beam_search_decoder.merge_div_inputs(div_trans_info_dict=output_trans_info_dict,
176
+ org_trans_info_dict=org_trans_info_dict,
177
+ win_len=win_len,
178
+ word_window=word_window)
179
+ return output_trans_info_dict
180
+
181
+ def count_num_of_spks(json_trans_list):
182
+ spk_set = set()
183
+ for sentence_dict in json_trans_list:
184
+ spk_set.add(sentence_dict['speaker'])
185
+ speaker_map = { spk_str: idx for idx, spk_str in enumerate(spk_set)}
186
+ return speaker_map
187
+
188
+ def add_placeholder_speaker_softmax(json_trans_list, peak_prob=0.94 ,max_spks=4):
189
+ nemo_json_dict = {}
190
+ word_dict_seq_list = []
191
+ if peak_prob > 1 or peak_prob < 0:
192
+ raise ValueError(f"peak_prob must be between 0 and 1 but got {peak_prob}")
193
+ speaker_map = count_num_of_spks(json_trans_list)
194
+ base_array = np.ones(max_spks) * (1 - peak_prob)/(max_spks-1)
195
+ stt_sec, end_sec = None, None
196
+ for sentence_dict in json_trans_list:
197
+ word_list = sentence_dict['words'].split()
198
+ speaker = sentence_dict['speaker']
199
+ for word in word_list:
200
+ speaker_softmax = copy.deepcopy(base_array)
201
+ speaker_softmax[speaker_map[speaker]] = peak_prob
202
+ word_dict_seq_list.append({'word': word,
203
+ 'start_time': stt_sec,
204
+ 'end_time': end_sec,
205
+ 'speaker': speaker_map[speaker],
206
+ 'speaker_softmax': speaker_softmax}
207
+ )
208
+ nemo_json_dict.update({'words': word_dict_seq_list,
209
+ 'status': "success",
210
+ 'sentences': json_trans_list,
211
+ 'speaker_count': len(speaker_map),
212
+ 'transcription': None}
213
+ )
214
+ return nemo_json_dict
215
+
216
+ def convert_nemo_json_to_seglst(trans_info_dict):
217
+ seglst_seq_list = []
218
+ seg_lst_dict, spk_wise_trans_sessions = {}, {}
219
+ for uniq_id in trans_info_dict.keys():
220
+ spk_wise_trans_sessions[uniq_id] = {}
221
+ seglst_seq_list = []
222
+ word_seq_list = trans_info_dict[uniq_id]['words']
223
+ prev_speaker, sentence = None, ''
224
+ for widx, word_dict in enumerate(word_seq_list):
225
+ curr_speaker = word_dict['speaker']
226
+
227
+ # For making speaker wise transcriptions
228
+ word = word_dict['word']
229
+ if curr_speaker not in spk_wise_trans_sessions[uniq_id]:
230
+ spk_wise_trans_sessions[uniq_id][curr_speaker] = word
231
+ elif curr_speaker in spk_wise_trans_sessions[uniq_id]:
232
+ spk_wise_trans_sessions[uniq_id][curr_speaker] = f"{spk_wise_trans_sessions[uniq_id][curr_speaker]} {word_dict['word']}"
233
+
234
+ # For making segment wise transcriptions
235
+ if curr_speaker!= prev_speaker and prev_speaker is not None:
236
+ seglst_seq_list.append({'session_id': uniq_id,
237
+ 'words': sentence.strip(),
238
+ 'start_time': 0.0,
239
+ 'end_time': 0.0,
240
+ 'speaker': prev_speaker,
241
+ })
242
+ sentence = word_dict['word']
243
+ else:
244
+ sentence = f"{sentence} {word_dict['word']}"
245
+ prev_speaker = curr_speaker
246
+
247
+ # For the last word:
248
+ # (1) If there is no speaker change, add the existing sentence and exit the loop
249
+ # (2) If there is a speaker change, add the last word and exit the loop
250
+ if widx == len(word_seq_list) - 1:
251
+ seglst_seq_list.append({'session_id': uniq_id,
252
+ 'words': sentence.strip(),
253
+ 'start_time': 0.0,
254
+ 'end_time': 0.0,
255
+ 'speaker': curr_speaker,
256
+ })
257
+ seg_lst_dict[uniq_id] = seglst_seq_list
258
+ return seg_lst_dict
259
+
260
+ def load_input_jsons(input_error_src_list_path, ext_str=".seglst.json", peak_prob=0.94, max_spks=4):
261
+ trans_info_dict = {}
262
+ json_filepath_list = open(input_error_src_list_path).readlines()
263
+ for json_path in json_filepath_list:
264
+ json_path = json_path.strip()
265
+ uniq_id = os.path.split(json_path)[-1].split(ext_str)[0]
266
+ if os.path.exists(json_path):
267
+ with open(json_path, "r") as file:
268
+ json_trans = json.load(file)
269
+ else:
270
+ raise FileNotFoundError(f"{json_path} does not exist. Aborting.")
271
+ nemo_json_dict = add_placeholder_speaker_softmax(json_trans, peak_prob=peak_prob, max_spks=max_spks)
272
+ trans_info_dict[uniq_id] = nemo_json_dict
273
+ return trans_info_dict
274
+
275
+ def load_reference_jsons(reference_seglst_list_path, ext_str=".seglst.json"):
276
+ reference_info_dict = {}
277
+ json_filepath_list = open(reference_seglst_list_path).readlines()
278
+ for json_path in json_filepath_list:
279
+ json_path = json_path.strip()
280
+ uniq_id = os.path.split(json_path)[-1].split(ext_str)[0]
281
+ if os.path.exists(json_path):
282
+ with open(json_path, "r") as file:
283
+ json_trans = json.load(file)
284
+ else:
285
+ raise FileNotFoundError(f"{json_path} does not exist. Aborting.")
286
+ json_trans_uniq_id = []
287
+ for sentence_dict in json_trans:
288
+ sentence_dict['session_id'] = uniq_id
289
+ json_trans_uniq_id.append(sentence_dict)
290
+ reference_info_dict[uniq_id] = json_trans_uniq_id
291
+ return reference_info_dict
292
+
293
+ def write_seglst_jsons(
294
+ seg_lst_sessions_dict: dict,
295
+ input_error_src_list_path: str,
296
+ diar_out_path: str,
297
+ ext_str: str,
298
+ write_individual_seglst_jsons=True
299
+ ):
300
+ """
301
+ Writes the segment list (seglst) JSON files to the output directory.
302
+
303
+ Parameters:
304
+ seg_lst_sessions_dict (dict): A dictionary containing session IDs as keys and their corresponding segment lists as values.
305
+ input_error_src_list_path (str): The path to the input error source list file.
306
+ diar_out_path (str): The path to the output directory where the seglst JSON files will be written.
307
+ type_string (str): A string representing the type of the seglst JSON files (e.g., 'hyp' for hypothesis or 'ef' for reference).
308
+ write_individual_seglst_jsons (bool, optional): A flag indicating whether to write individual seglst JSON files for each session. Defaults to True.
309
+
310
+ Returns:
311
+ None
312
+ """
313
+ total_infer_list = []
314
+ total_output_filename = os.path.split(input_error_src_list_path)[-1].replace(".list", "")
315
+ for session_id, seg_lst_list in seg_lst_sessions_dict.items():
316
+ total_infer_list.extend(seg_lst_list)
317
+ if write_individual_seglst_jsons:
318
+ print(f"{__INFO_TAG__} Writing {diar_out_path}/{session_id}.seglst.json")
319
+ with open(f'{diar_out_path}/{session_id}.seglst.json', 'w') as file:
320
+ json.dump(seg_lst_list, file, indent=4) # indent=4 for pretty printing
321
+
322
+ print(f"{__INFO_TAG__} Writing {diar_out_path}/{session_id}.seglst.json")
323
+ total_output_filename = total_output_filename.replace("src", ext_str).replace("ref", ext_str)
324
+ with open(f'{diar_out_path}/../{total_output_filename}.seglst.json', 'w') as file:
325
+ json.dump(total_infer_list, file, indent=4) # indent=4 for pretty printing
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ kenlm
2
+ arpa
3
+ numpy
4
+ hydra-core
5
+ meeteval
6
+ tqdm
7
+ requests
8
+ simplejson
9
+ pydiardecode @ git+https://github.com/tango4j/pydiardecode@main
run_speaker_tagging_beam_search.sh ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ### Speaker Tagging Task-2 Parameters
3
+
4
+
5
+ BASEPATH=${PWD}
6
+ DIAR_LM_PATH=$BASEPATH/arpa_model/4gram_small.arpa
7
+ ASRDIAR_FILE_NAME=err_dev
8
+ WORKSPACE=$BASEPATH/SLT-Task2-Post-ASR-Speaker-Tagging
9
+ INPUT_ERROR_SRC_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.src.list
10
+ GROUNDTRUTH_REF_LIST_PATH=$BASEPATH/$ASRDIAR_FILE_NAME.ref.list
11
+ DIAR_OUT_DOWNLOAD=$WORKSPACE/short2_all_seglst_infer
12
+ mkdir -p $DIAR_OUT_DOWNLOAD
13
+
14
+
15
+ ### SLT 2024 Speaker Tagging Setting v1.0.2
16
+ ALPHA=0.4
17
+ BETA=0.04
18
+ PARALLEL_CHUNK_WORD_LEN=100
19
+ BEAM_WIDTH=16
20
+ WORD_WINDOW=32
21
+ PEAK_PROB=0.95
22
+ USE_NGRAM=True
23
+ LM_METHOD=ngram
24
+
25
+ # Get the base name of the test_manifest and remove extension
26
+ UNIQ_MEMO=$(basename "${INPUT_ERROR_SRC_LIST_PATH}" .json | sed 's/\./_/g')
27
+ echo "UNIQ MEMO:" $UNIQ_MEMO
28
+ TRIAL=telephonic
29
+ BATCH_SIZE=11
30
+
31
+
32
+ rm $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst.json
33
+ rm $WORKSPACE/$ASRDIAR_FILE_NAME.ref.seglst.json
34
+ rm $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst.json
35
+
36
+
37
+ python $BASEPATH/speaker_tagging_beamsearch.py \
38
+ port=[5501,5502,5511,5512,5521,5522,5531,5532] \
39
+ arpa_language_model=$DIAR_LM_PATH \
40
+ batch_size=$BATCH_SIZE \
41
+ groundtruth_ref_list_path=$GROUNDTRUTH_REF_LIST_PATH \
42
+ input_error_src_list_path=$INPUT_ERROR_SRC_LIST_PATH \
43
+ parallel_chunk_word_len=$PARALLEL_CHUNK_WORD_LEN \
44
+ use_ngram=$USE_NGRAM \
45
+ alpha=$ALPHA \
46
+ beta=$BETA \
47
+ beam_width=$BEAM_WIDTH \
48
+ word_window=$WORD_WINDOW \
49
+ peak_prob=$PEAK_PROB \
50
+ out_dir=$DIAR_OUT_DOWNLOAD
51
+
52
+
53
+ echo "Evaluating the original source transcript."
54
+ meeteval-wer cpwer -h $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst.json -r $WORKSPACE/$ASRDIAR_FILE_NAME.ref.seglst.json
55
+ echo "Source cpWER: " $(jq '.error_rate' "[ $WORKSPACE/$ASRDIAR_FILE_NAME.src.seglst_cpwer.json) ]"
56
+
57
+ echo "Evaluating the original hypothesis transcript."
58
+ meeteval-wer cpwer -h $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst.json -r $WORKSPACE/$ASRDIAR_FILE_NAME.ref.seglst.json
59
+ echo "Hypothesis cpWER: " $(jq '.error_rate' $WORKSPACE/$ASRDIAR_FILE_NAME.hyp.seglst_cpwer.json)
speaker_tagging_beamsearch.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hydra
2
+ from typing import List, Optional
3
+ from dataclasses import dataclass, field
4
+ import kenlm
5
+ from beam_search_utils import (
6
+ SpeakerTaggingBeamSearchDecoder,
7
+ load_input_jsons,
8
+ load_reference_jsons,
9
+ write_seglst_jsons,
10
+ run_mp_beam_search_decoding,
11
+ convert_nemo_json_to_seglst,
12
+ )
13
+ from hydra.core.config_store import ConfigStore
14
+
15
+ __INFO_TAG__ = "[INFO]"
16
+
17
+ @dataclass
18
+ class RealigningLanguageModelParameters:
19
+ batch_size: int = 32
20
+ use_mp: bool = True
21
+ input_error_src_list_path: Optional[str] = None
22
+ groundtruth_ref_list_path: Optional[str] = None
23
+ arpa_language_model: Optional[str] = None
24
+ word_window: int = 32
25
+ port: List[int] = field(default_factory=list)
26
+ parallel_chunk_word_len: int = 250
27
+ use_ngram: bool = True
28
+ peak_prob: float = 0.95
29
+ alpha: float = 0.5
30
+ beta: float = 0.05
31
+ beam_width: int = 16
32
+ out_dir: Optional[str] = None
33
+
34
+ cs = ConfigStore.instance()
35
+ cs.store(name="config", node=RealigningLanguageModelParameters)
36
+
37
+ @hydra.main(config_name="config", version_base="1.1")
38
+ def main(cfg: RealigningLanguageModelParameters) -> None:
39
+ trans_info_dict = load_input_jsons(input_error_src_list_path=cfg.input_error_src_list_path, peak_prob=float(cfg.peak_prob))
40
+ reference_info_dict = load_reference_jsons(reference_seglst_list_path=cfg.groundtruth_ref_list_path)
41
+ source_info_dict = load_reference_jsons(reference_seglst_list_path=cfg.input_error_src_list_path)
42
+ loaded_kenlm_model = kenlm.Model(cfg.arpa_language_model)
43
+
44
+ speaker_beam_search_decoder = SpeakerTaggingBeamSearchDecoder(loaded_kenlm_model=loaded_kenlm_model, cfg=cfg)
45
+
46
+ div_trans_info_dict = speaker_beam_search_decoder.divide_chunks(trans_info_dict=trans_info_dict,
47
+ win_len=cfg.parallel_chunk_word_len,
48
+ word_window=cfg.word_window,
49
+ port=cfg.port,)
50
+
51
+ trans_info_dict = run_mp_beam_search_decoding(speaker_beam_search_decoder,
52
+ loaded_kenlm_model=loaded_kenlm_model,
53
+ trans_info_dict=div_trans_info_dict,
54
+ org_trans_info_dict=trans_info_dict,
55
+ div_mp=True,
56
+ win_len=cfg.parallel_chunk_word_len,
57
+ word_window=cfg.word_window,
58
+ port=cfg.port,
59
+ use_ngram=cfg.use_ngram,
60
+ )
61
+ hypothesis_sessions_dict = convert_nemo_json_to_seglst(trans_info_dict)
62
+
63
+ write_seglst_jsons(hypothesis_sessions_dict, input_error_src_list_path=cfg.input_error_src_list_path, diar_out_path=cfg.out_dir, ext_str='hyp')
64
+ write_seglst_jsons(reference_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=cfg.out_dir, ext_str='ref')
65
+ write_seglst_jsons(source_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=cfg.out_dir, ext_str='src')
66
+ print(f"{__INFO_TAG__} Parameters used: \
67
+ \n ALPHA: {cfg.alpha} \
68
+ \n BETA: {cfg.beta} \
69
+ \n BEAM WIDTH: {cfg.beam_width} \
70
+ \n Word Window: {cfg.word_window} \
71
+ \n Use Ngram: {cfg.use_ngram} \
72
+ \n Chunk Word Len: {cfg.parallel_chunk_word_len} \
73
+ \n SpeakerLM Model: {cfg.arpa_language_model}") \
74
+
75
+ if __name__ == '__main__':
76
+ main()