|
import os
|
|
import re
|
|
import librosa
|
|
import torch
|
|
import numpy as np
|
|
from pathlib import Path
|
|
import locale
|
|
|
|
|
|
try:
|
|
locale.setlocale(locale.LC_ALL, 'en_US.UTF-8')
|
|
except:
|
|
pass
|
|
|
|
from transformers import Wav2Vec2ForCTC, AutoProcessor
|
|
from huggingface_hub import hf_hub_download
|
|
from torchaudio.models.decoder import ctc_decoder
|
|
from utils.text_norm import text_normalize
|
|
|
|
|
|
TEMP_DIR = Path("D:/Ngen/bot/temp_lexicon")
|
|
TEMP_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
uroman_dir = "uroman"
|
|
assert os.path.exists(uroman_dir)
|
|
UROMAN_PL = os.path.join(uroman_dir, "bin", "uroman.pl")
|
|
|
|
ASR_SAMPLING_RATE = 16_000
|
|
|
|
WORD_SCORE_DEFAULT_IF_NOLM = -3.5
|
|
|
|
MODEL_ID = "mms-meta/mms-zeroshot-300m"
|
|
|
|
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
|
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
|
|
|
|
token_file = hf_hub_download(
|
|
repo_id=MODEL_ID,
|
|
filename="tokens.txt",
|
|
)
|
|
|
|
class MY_LOG:
|
|
def __init__(self):
|
|
self.text = "[START]"
|
|
|
|
def add(self, new_log, new_line=True):
|
|
self.text = self.text + ("\n" if new_line else " ") + new_log
|
|
self.text = self.text.strip()
|
|
return self.text
|
|
|
|
def create_temp_file(suffix=None):
|
|
"""Create a temporary file with UTF-8 encoding"""
|
|
temp_path = TEMP_DIR / f"temp_{os.urandom(8).hex()}{suffix if suffix else ''}"
|
|
|
|
with open(temp_path, 'w', encoding='utf-8') as f:
|
|
pass
|
|
return temp_path
|
|
|
|
def error_check_file(filepath):
|
|
if not isinstance(filepath, str):
|
|
return "Expected file to be of type 'str'. Instead got {}".format(
|
|
type(filepath)
|
|
)
|
|
if not os.path.exists(filepath):
|
|
return "Input file '{}' doesn't exists".format(type(filepath))
|
|
|
|
def norm_uroman(text):
|
|
"""Normalize text with unicode support"""
|
|
text = text.lower()
|
|
text = text.replace("'", "'")
|
|
text = re.sub("([^a-z' ])", " ", text)
|
|
text = re.sub(" +", " ", text)
|
|
|
|
|
|
|
|
return text.strip()
|
|
|
|
def uromanize(words):
|
|
"""Romanize words with proper unicode handling"""
|
|
iso = "xxx"
|
|
input_file = create_temp_file(suffix=".txt")
|
|
output_file = create_temp_file(suffix=".txt")
|
|
|
|
try:
|
|
|
|
with open(input_file, "w", encoding='utf-8') as f:
|
|
f.write("\n".join(words))
|
|
|
|
cmd = f"perl {UROMAN_PL} -l {iso} < {input_file} > {output_file}"
|
|
os.system(cmd)
|
|
|
|
lexicon = {}
|
|
|
|
with open(output_file, encoding='utf-8') as f:
|
|
for idx, line in enumerate(f):
|
|
if not line.strip():
|
|
continue
|
|
try:
|
|
line = re.sub(r"\s+", "", norm_uroman(line)).strip()
|
|
lexicon[words[idx]] = " ".join(line) + " |"
|
|
except Exception as e:
|
|
print(f"Warning: Could not process line {idx}: {str(e)}")
|
|
continue
|
|
finally:
|
|
|
|
try:
|
|
input_file.unlink(missing_ok=True)
|
|
output_file.unlink(missing_ok=True)
|
|
except Exception as e:
|
|
print(f"Warning: Could not delete temporary files: {str(e)}")
|
|
|
|
return lexicon
|
|
|
|
def filter_lexicon(lexicon, word_counts):
|
|
spelling_to_words = {}
|
|
for w, s in lexicon.items():
|
|
spelling_to_words.setdefault(s, [])
|
|
spelling_to_words[s].append(w)
|
|
|
|
filtered_lexicon = {}
|
|
for s, ws in spelling_to_words.items():
|
|
if len(ws) > 1:
|
|
|
|
ws.sort(key=lambda w: (-word_counts[w], len(w)))
|
|
filtered_lexicon[ws[0]] = s
|
|
return filtered_lexicon
|
|
|
|
def load_words(filepath):
|
|
"""Load words from file with proper encoding handling"""
|
|
words = {}
|
|
|
|
encodings = ['utf-8', 'cp1251', 'latin-1', 'utf-16']
|
|
|
|
for encoding in encodings:
|
|
try:
|
|
with open(filepath, encoding=encoding) as f:
|
|
lines = f.readlines()
|
|
break
|
|
except UnicodeDecodeError:
|
|
continue
|
|
except Exception as e:
|
|
raise Exception(f"Error reading file: {str(e)}")
|
|
else:
|
|
|
|
raise Exception("Could not decode file with any of the attempted encodings")
|
|
|
|
num_sentences = len(lines)
|
|
all_sentences = " ".join([l.strip() for l in lines])
|
|
norm_all_sentences = text_normalize(all_sentences)
|
|
for w in norm_all_sentences.split():
|
|
words.setdefault(w, 0)
|
|
words[w] += 1
|
|
return words, num_sentences
|
|
|
|
def process(
|
|
audio_data,
|
|
words_file,
|
|
lm_path=None,
|
|
wscore=None,
|
|
lmscore=None,
|
|
wscore_usedefault=True,
|
|
lmscore_usedefault=True,
|
|
autolm=False,
|
|
reference=None,
|
|
):
|
|
transcription, logs = "", MY_LOG()
|
|
if not audio_data or not words_file:
|
|
yield "ERROR: Empty audio data or words file", logs.text
|
|
return
|
|
|
|
if isinstance(audio_data, tuple):
|
|
|
|
sr, audio_samples = audio_data
|
|
audio_samples = (audio_samples / 32768.0).astype(float)
|
|
|
|
if sr != ASR_SAMPLING_RATE:
|
|
audio_samples = librosa.resample(
|
|
audio_samples, orig_sr=sr, target_sr=ASR_SAMPLING_RATE
|
|
)
|
|
else:
|
|
|
|
assert isinstance(audio_data, str)
|
|
audio_samples = librosa.load(audio_data, sr=ASR_SAMPLING_RATE, mono=True)[0]
|
|
yield transcription, logs.add(f"Number of audio samples: {len(audio_samples)}")
|
|
|
|
inputs = processor(
|
|
audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
|
|
)
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda")
|
|
elif (
|
|
hasattr(torch.backends, "mps")
|
|
and torch.backends.mps.is_available()
|
|
and torch.backends.mps.is_built()
|
|
):
|
|
device = torch.device("mps")
|
|
else:
|
|
device = torch.device("cpu")
|
|
|
|
model.to(device)
|
|
inputs = inputs.to(device)
|
|
yield transcription, logs.add(f"Using device: {device}")
|
|
|
|
with torch.no_grad():
|
|
outputs = model(**inputs).logits
|
|
|
|
|
|
yield transcription, logs.add(f"Loading words....")
|
|
try:
|
|
word_counts, num_sentences = load_words(words_file)
|
|
except Exception as e:
|
|
yield f"ERROR: Loading words failed '{str(e)}'", logs.text
|
|
return
|
|
|
|
yield transcription, logs.add(
|
|
f"Loaded {len(word_counts)} words from {num_sentences} lines.\nPreparing lexicon...."
|
|
)
|
|
|
|
try:
|
|
lexicon = uromanize(list(word_counts.keys()))
|
|
except Exception as e:
|
|
yield f"ERROR: Creating lexicon failed '{str(e)}'", logs.text
|
|
return
|
|
|
|
yield transcription, logs.add(f"Lexicon size: {len(lexicon)}")
|
|
|
|
yield transcription, logs.add(f"Filtering lexicon....")
|
|
lexicon = filter_lexicon(lexicon, word_counts)
|
|
yield transcription, logs.add(
|
|
f"Ok. Lexicon size after filtering: {len(lexicon)}"
|
|
)
|
|
|
|
lexicon_file = create_temp_file(suffix=".txt")
|
|
try:
|
|
with open(lexicon_file, "w", encoding='utf-8') as f:
|
|
idx = 10
|
|
for word, spelling in lexicon.items():
|
|
f.write(word + " " + spelling + "\n")
|
|
idx += 1
|
|
|
|
if wscore_usedefault:
|
|
wscore = WORD_SCORE_DEFAULT_IF_NOLM
|
|
|
|
yield transcription, logs.add(
|
|
f"Using word score: {wscore}"
|
|
)
|
|
|
|
beam_search_decoder = ctc_decoder(
|
|
lexicon=str(lexicon_file),
|
|
tokens=token_file,
|
|
nbest=1,
|
|
beam_size=500,
|
|
beam_size_token=50,
|
|
word_score=wscore,
|
|
sil_score=0,
|
|
blank_token="<s>",
|
|
)
|
|
|
|
beam_search_result = beam_search_decoder(outputs.to("cpu"))
|
|
transcription = " ".join(beam_search_result[0][0].words).strip()
|
|
|
|
finally:
|
|
|
|
try:
|
|
lexicon_file.unlink(missing_ok=True)
|
|
except Exception as e:
|
|
print(f"Warning: Could not delete temporary file: {str(e)}")
|
|
|
|
yield transcription, logs.add(f"[DONE]") |