Spaces:
Runtime error
Runtime error
# -*- coding:utf-8 -*- | |
# @FileName :postprocee.py | |
# @Time :2023/8/8 20:26 | |
# @Author :lovemefan | |
# @Email :[email protected] | |
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. | |
# MIT License (https://opensource.org/licenses/MIT) | |
from typing import Any, List, Union | |
def isChinese(ch: str): | |
if "\u4e00" <= ch <= "\u9fff" or "\u0030" <= ch <= "\u0039": | |
return True | |
return False | |
def isAllChinese(word: Union[List[Any], str]): | |
word_lists = [] | |
for i in word: | |
cur = i.replace(" ", "") | |
cur = cur.replace("</s>", "") | |
cur = cur.replace("<s>", "") | |
word_lists.append(cur) | |
if len(word_lists) == 0: | |
return False | |
for ch in word_lists: | |
if isChinese(ch) is False: | |
return False | |
return True | |
def isAllAlpha(word: Union[List[Any], str]): | |
word_lists = [] | |
for i in word: | |
cur = i.replace(" ", "") | |
cur = cur.replace("</s>", "") | |
cur = cur.replace("<s>", "") | |
word_lists.append(cur) | |
if len(word_lists) == 0: | |
return False | |
for ch in word_lists: | |
if ch.isalpha() is False and ch != "'": | |
return False | |
elif ch.isalpha() is True and isChinese(ch) is True: | |
return False | |
return True | |
# def abbr_dispose(words: List[Any]) -> List[Any]: | |
def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]: | |
words_size = len(words) | |
word_lists = [] | |
abbr_begin = [] | |
abbr_end = [] | |
last_num = -1 | |
ts_lists = [] | |
ts_nums = [] | |
ts_index = 0 | |
for num in range(words_size): | |
if num <= last_num: | |
continue | |
if len(words[num]) == 1 and words[num].encode("utf-8").isalpha(): | |
if ( | |
num + 1 < words_size | |
and words[num + 1] == " " | |
and num + 2 < words_size | |
and len(words[num + 2]) == 1 | |
and words[num + 2].encode("utf-8").isalpha() | |
): | |
# found the begin of abbr | |
abbr_begin.append(num) | |
num += 2 | |
abbr_end.append(num) | |
# to find the end of abbr | |
while True: | |
num += 1 | |
if num < words_size and words[num] == " ": | |
num += 1 | |
if ( | |
num < words_size | |
and len(words[num]) == 1 | |
and words[num].encode("utf-8").isalpha() | |
): | |
abbr_end.pop() | |
abbr_end.append(num) | |
last_num = num | |
else: | |
break | |
else: | |
break | |
for num in range(words_size): | |
if words[num] == " ": | |
ts_nums.append(ts_index) | |
else: | |
ts_nums.append(ts_index) | |
ts_index += 1 | |
last_num = -1 | |
for num in range(words_size): | |
if num <= last_num: | |
continue | |
if num in abbr_begin: | |
if time_stamp is not None: | |
begin = time_stamp[ts_nums[num]][0] | |
word_lists.append(words[num].upper()) | |
num += 1 | |
while num < words_size: | |
if num in abbr_end: | |
word_lists.append(words[num].upper()) | |
last_num = num | |
break | |
else: | |
if words[num].encode("utf-8").isalpha(): | |
word_lists.append(words[num].upper()) | |
num += 1 | |
if time_stamp is not None: | |
end = time_stamp[ts_nums[num]][1] | |
ts_lists.append([begin, end]) | |
else: | |
word_lists.append(words[num]) | |
if time_stamp is not None and words[num] != " ": | |
begin = time_stamp[ts_nums[num]][0] | |
end = time_stamp[ts_nums[num]][1] | |
ts_lists.append([begin, end]) | |
begin = end | |
if time_stamp is not None: | |
return word_lists, ts_lists | |
else: | |
return word_lists | |
def sentence_postprocess(words: List[Any], time_stamp: List[List] = None): | |
middle_lists = [] | |
word_lists = [] | |
word_item = "" | |
ts_lists = [] | |
# wash words lists | |
for i in words: | |
word = "" | |
if isinstance(i, str): | |
word = i | |
else: | |
word = i.decode("utf-8") | |
if word in ["<s>", "</s>", "<unk>"]: | |
continue | |
else: | |
middle_lists.append(word) | |
# all chinese characters | |
if isAllChinese(middle_lists): | |
for i, ch in enumerate(middle_lists): | |
word_lists.append(ch.replace(" ", "")) | |
if time_stamp is not None: | |
ts_lists = time_stamp | |
# all alpha characters | |
elif isAllAlpha(middle_lists): | |
ts_flag = True | |
for i, ch in enumerate(middle_lists): | |
if ts_flag and time_stamp is not None: | |
begin = time_stamp[i][0] | |
end = time_stamp[i][1] | |
word = "" | |
if "@@" in ch: | |
word = ch.replace("@@", "") | |
word_item += word | |
if time_stamp is not None: | |
ts_flag = False | |
end = time_stamp[i][1] | |
else: | |
word_item += ch | |
word_lists.append(word_item) | |
word_lists.append(" ") | |
word_item = "" | |
if time_stamp is not None: | |
ts_flag = True | |
end = time_stamp[i][1] | |
ts_lists.append([begin, end]) | |
begin = end | |
# mix characters | |
else: | |
alpha_blank = False | |
ts_flag = True | |
begin = -1 | |
end = -1 | |
for i, ch in enumerate(middle_lists): | |
if ts_flag and time_stamp is not None: | |
begin = time_stamp[i][0] | |
end = time_stamp[i][1] | |
word = "" | |
if isAllChinese(ch): | |
if alpha_blank is True: | |
word_lists.pop() | |
word_lists.append(ch) | |
alpha_blank = False | |
if time_stamp is not None: | |
ts_flag = True | |
ts_lists.append([begin, end]) | |
begin = end | |
elif "@@" in ch: | |
word = ch.replace("@@", "") | |
word_item += word | |
alpha_blank = False | |
if time_stamp is not None: | |
ts_flag = False | |
end = time_stamp[i][1] | |
elif isAllAlpha(ch): | |
word_item += ch | |
word_lists.append(word_item) | |
word_lists.append(" ") | |
word_item = "" | |
alpha_blank = True | |
if time_stamp is not None: | |
ts_flag = True | |
end = time_stamp[i][1] | |
ts_lists.append([begin, end]) | |
begin = end | |
else: | |
raise ValueError("invalid character: {}".format(ch)) | |
if time_stamp is not None: | |
word_lists, ts_lists = abbr_dispose(word_lists, ts_lists) | |
real_word_lists = [] | |
for ch in word_lists: | |
if ch != " ": | |
real_word_lists.append(ch) | |
sentence = " ".join(real_word_lists).strip() | |
return sentence, ts_lists, real_word_lists | |
else: | |
word_lists = abbr_dispose(word_lists) | |
real_word_lists = [] | |
for ch in word_lists: | |
if ch != " ": | |
real_word_lists.append(ch) | |
sentence = "".join(word_lists).strip() | |
return sentence, real_word_lists | |