Spaces:
Sleeping
Sleeping
import gradio as gr | |
import librosa | |
import soundfile | |
import tempfile | |
import os | |
import uuid | |
import json | |
from nemo.collections.asr.models import ASRModel | |
from nemo.utils import logging | |
from align import main, AlignmentConfig, ASSFileConfig | |
SAMPLE_RATE = 16000 | |
logging.setLevel(logging.INFO) | |
def get_audio_data_and_duration(file): | |
data, sr = librosa.load(file) | |
if sr != SAMPLE_RATE: | |
data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE) | |
# monochannel | |
data = librosa.to_mono(data) | |
duration = librosa.get_duration(y=data, sr=SAMPLE_RATE) | |
return data, duration | |
def get_char_tokens(text, model): | |
tokens = [] | |
for character in text: | |
if character in model.decoder.vocabulary: | |
tokens.append(model.decoder.vocabulary.index(character)) | |
else: | |
tokens.append(len(model.decoder.vocabulary)) # return unk token (same as blank token) | |
return tokens | |
def get_S_prime_and_T(text, model_name, model, audio_duration): | |
# estimate T | |
if "citrinet" in model_name or "_fastconformer_" in model_name: | |
output_timestep_duration = 0.08 | |
elif "_conformer_" in model_name: | |
output_timestep_duration = 0.04 | |
elif "quartznet" in model_name: | |
output_timestep_duration = 0.02 | |
else: | |
raise RuntimeError("unexpected model name") | |
T = int(audio_duration / output_timestep_duration) + 1 | |
# calculate S_prime = num tokens + num repetitions | |
if hasattr(model, 'tokenizer'): | |
all_tokens = model.tokenizer.text_to_ids(text) | |
elif hasattr(model.decoder, "vocabulary"): # i.e. tokenization is simply character-based | |
all_tokens = get_char_tokens(text, model) | |
else: | |
raise RuntimeError("cannot obtain tokens from this model") | |
n_token_repetitions = 0 | |
for i_tok in range(1, len(all_tokens)): | |
if all_tokens[i_tok] == all_tokens[i_tok - 1]: | |
n_token_repetitions += 1 | |
S_prime = len(all_tokens) + n_token_repetitions | |
return S_prime, T | |
def hex_to_rgb_list(hex_string): | |
hex_string = hex_string.lstrip("#") | |
r = int(hex_string[:2], 16) | |
g = int(hex_string[2:4], 16) | |
b = int(hex_string[4:], 16) | |
return [r, g, b] | |
def delete_mp4s_except_given_filepath(filepath): | |
files_in_dir = os.listdir() | |
mp4_files_in_dir = [x for x in files_in_dir if x.endswith(".mp4")] | |
for mp4_file in mp4_files_in_dir: | |
if mp4_file != filepath: | |
os.remove(mp4_file) | |
def align(Microphone, File_Upload, text, col1, col2, col3, split_on_newline, progress=gr.Progress()): | |
# Create utt_id, specify output_video_filepath and delete any MP4s | |
# that are not that filepath. These stray MP4s can be created | |
# if a user refreshes or exits the page while this 'align' function is executing. | |
# This deletion will not delete any other users' video as long as this 'align' function | |
# is run one at a time. | |
utt_id = uuid.uuid4() | |
output_video_filepath = f"{utt_id}.mp4" | |
delete_mp4s_except_given_filepath(output_video_filepath) | |
output_info = "" | |
ass_text="" | |
progress(0, desc="Validating input") | |
# decide which of Mic / File_Upload is used as input & do error handling | |
if (Microphone is not None) and (File_Upload is not None): | |
raise gr.Error("Please use either the microphone or file upload input - not both") | |
elif (Microphone is None) and (File_Upload is None): | |
raise gr.Error("You have to either use the microphone or upload an audio file") | |
elif Microphone is not None: | |
file = Microphone | |
else: | |
file = File_Upload | |
# check audio is not too long | |
audio_data, duration = get_audio_data_and_duration(file) | |
if duration > 4 * 60: | |
raise gr.Error( | |
f"Detected that uploaded audio has duration {duration/60:.1f} mins - please only upload audio of less than 4 mins duration" | |
) | |
# loading model | |
progress(0.1, desc="Loading speech recognition model") | |
model_name = "ayymen/stt_zgh_fastconformer_ctc_small" | |
model = ASRModel.from_pretrained(model_name) | |
if text: # check input text is not too long compared to audio | |
S_prime, T = get_S_prime_and_T(text, model_name, model, duration) | |
if S_prime > T: | |
raise gr.Error( | |
f"The number of tokens in the input text is too long compared to the duration of the audio." | |
f" This model can handle {T} tokens + token repetitions at most. You have provided {S_prime} tokens + token repetitions. " | |
f" (Adjacent tokens that are not in the model's vocabulary are also counted as a token repetition.)" | |
) | |
with tempfile.TemporaryDirectory() as tmpdir: | |
audio_path = os.path.join(tmpdir, f'{utt_id}.wav') | |
soundfile.write(audio_path, audio_data, SAMPLE_RATE) | |
# getting the text if it hasn't been provided | |
if not text: | |
progress(0.2, desc="Transcribing audio") | |
text = model.transcribe([audio_path])[0] | |
if 'hybrid' in model_name: | |
text = text[0] | |
if text == "": | |
raise gr.Error( | |
"ERROR: the ASR model did not detect any speech in the input audio. Please upload audio with speech." | |
) | |
output_info += ( | |
"You did not enter any input text, so the ASR model's transcription will be used:\n" | |
"--------------------------\n" | |
f"{text}\n" | |
"--------------------------\n" | |
f"You could try pasting the transcription into the text input box, correcting any" | |
" transcription errors, and clicking 'Submit' again." | |
) | |
# split text on new lines if requested | |
if split_on_newline: | |
text = "|".join(list(filter(None, text.split("\n")))) | |
data = { | |
"audio_filepath": audio_path, | |
"text": text, | |
} | |
manifest_path = os.path.join(tmpdir, f"{utt_id}_manifest.json") | |
with open(manifest_path, 'w') as fout: | |
fout.write(f"{json.dumps(data)}\n") | |
# run alignment | |
if "|" in text: | |
resegment_text_to_fill_space = False | |
else: | |
resegment_text_to_fill_space = True | |
alignment_config = AlignmentConfig( | |
pretrained_name=model_name, | |
manifest_filepath=manifest_path, | |
output_dir=f"{tmpdir}/nfa_output/", | |
audio_filepath_parts_in_utt_id=1, | |
batch_size=1, | |
use_local_attention=True, | |
additional_segment_grouping_separator="|", | |
# transcribe_device='cpu', | |
# viterbi_device='cpu', | |
save_output_file_formats=["ass", "ctm"], | |
ass_file_config=ASSFileConfig( | |
fontsize=45, | |
resegment_text_to_fill_space=resegment_text_to_fill_space, | |
max_lines_per_segment=4, | |
text_already_spoken_rgb=hex_to_rgb_list(col1), | |
text_being_spoken_rgb=hex_to_rgb_list(col2), | |
text_not_yet_spoken_rgb=hex_to_rgb_list(col3), | |
), | |
) | |
progress(0.5, desc="Aligning audio") | |
main(alignment_config) | |
progress(0.95, desc="Saving generated alignments") | |
# make video file from the word-level ASS file | |
ass_file_for_video = f"{tmpdir}/nfa_output/ass/words/{utt_id}.ass" | |
with open(ass_file_for_video, "r") as ass_file: | |
ass_text = ass_file.read() | |
ffmpeg_command = ( | |
f"ffmpeg -y -i {audio_path} " | |
"-f lavfi -i color=c=white:s=1280x720:r=50 " | |
"-crf 1 -shortest -vcodec libx264 -pix_fmt yuv420p " | |
f"-vf 'ass={ass_file_for_video}' " | |
f"{output_video_filepath}" | |
) | |
os.system(ffmpeg_command) | |
# save ASS file | |
ass_path = "word_level.ass" | |
with open(ass_path, "w", encoding="utf-8") as f: | |
f.write(ass_text) | |
# save word-level CTM file | |
with open(f"{tmpdir}/nfa_output/ctm/words/{utt_id}.ctm", "r") as word_ctm_file: | |
word_ctm_text = word_ctm_file.read() | |
word_ctm_path = "word_level.ctm" | |
with open(word_ctm_path, "w", encoding="utf-8") as f: | |
f.write(word_ctm_text) | |
# save segment-level CTM file | |
with open(f"{tmpdir}/nfa_output/ctm/segments/{utt_id}.ctm", "r") as segment_ctm_file: | |
segment_ctm_text = segment_ctm_file.read() | |
segment_ctm_path = "segment_level.ctm" | |
with open(segment_ctm_path, "w", encoding="utf-8") as f: | |
f.write(segment_ctm_text) | |
return output_video_filepath, gr.update(value=output_info, visible=True if output_info else False), output_video_filepath, gr.update(value=ass_path, visible=True), gr.update(value=word_ctm_path, visible=True), gr.update(value=segment_ctm_path, visible=True) | |
def delete_non_tmp_video(video_path): | |
if video_path: | |
if os.path.exists(video_path): | |
os.remove(video_path) | |
return None | |
with gr.Blocks(title="NeMo Forced Aligner", theme="huggingface") as demo: | |
non_tmp_output_video_filepath = gr.State([]) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("# NeMo Forced Aligner") | |
gr.Markdown( | |
"Demo for [NeMo Forced Aligner](https://github.com/NVIDIA/NeMo/tree/main/tools/nemo_forced_aligner) (NFA). " | |
"Upload audio and (optionally) the text spoken in the audio to generate a video where each part of the text will be highlighted as it is spoken. ", | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("## Input") | |
mic_in = gr.Audio(sources=["microphone"], type='filepath', label="Microphone input (max 4 mins)") | |
audio_file_in = gr.Audio(sources=["upload"], type='filepath', label="File upload (max 4 mins)") | |
ref_text = gr.Textbox( | |
label="[Optional] The reference text. Use '|' separators to specify which text will appear together. " | |
"Leave this field blank to use an ASR model's transcription as the reference text instead." | |
) | |
split_on_newline = gr.Checkbox( | |
True, | |
label="Separate text on new lines", | |
) | |
gr.Markdown("[Optional] For fun - adjust the colors of the text in the output video") | |
with gr.Row(): | |
col1 = gr.ColorPicker(label="text already spoken", value="#fcba03") | |
col2 = gr.ColorPicker(label="text being spoken", value="#bf45bf") | |
col3 = gr.ColorPicker(label="text to be spoken", value="#3e1af0") | |
submit_button = gr.Button("Submit") | |
with gr.Column(scale=1): | |
gr.Markdown("## Output") | |
video_out = gr.Video(label="Output Video") | |
text_out = gr.Textbox(label="Output Info", visible=False) | |
ass_file = gr.File(label="ASS File", visible=False) | |
word_ctm_file = gr.File(label="Word-level CTM File", visible=False) | |
segment_ctm_file = gr.File(label="Segment-level CTM File", visible=False) | |
with gr.Row(): | |
gr.HTML( | |
"<p style='text-align: center'>" | |
"Tutorial: <a href='https://colab.research.google.com/github/NVIDIA/NeMo/blob/main/tutorials/tools/NeMo_Forced_Aligner_Tutorial.ipynb' target='_blank'>\"How to use NFA?\"</a> 🚀 | " | |
"Blog post: <a href='https://nvidia.github.io/NeMo/blogs/2023/2023-08-forced-alignment/' target='_blank'>\"How does forced alignment work?\"</a> 📚 | " | |
"NFA <a href='https://github.com/NVIDIA/NeMo/tree/main/tools/nemo_forced_aligner/' target='_blank'>Github page</a> 👩💻" | |
"</p>" | |
) | |
submit_button.click( | |
fn=align, | |
inputs=[mic_in, audio_file_in, ref_text, col1, col2, col3, split_on_newline], | |
outputs=[video_out, text_out, non_tmp_output_video_filepath, ass_file, word_ctm_file, segment_ctm_file], | |
).then( | |
fn=delete_non_tmp_video, inputs=[non_tmp_output_video_filepath], outputs=None, | |
) | |
example_2 = """ⵜⴰⴽⵟⵟⵓⵎⵜ ⵏ ⵜⵙⴰⴷⵓⴼⵜ. | |
ⵙ ⵉⵙⵎ ⵏ ⵕⴱⴱⵉ ⴰⵎⴰⵍⵍⴰⵢ ⴰⵎⵙⵎⵓⵍⵍⵓ. | |
ⴰⵎⵓⵢ ⵉ ⵕⴱⴱⵉ ⵍⵍⵉ ⵎⵓ ⵜⴳⴰ ⵜⵓⵍⵖⵉⵜ ⵜⵉⵏⵏⵙ, ⵕⴱⴱⵉ ⵏ ⵉⵖⵥⵡⴰⵕⵏ, ⴽⵔⴰ ⴳⴰⵏ. | |
ⴰⵎⴰⵍⵍⴰⵢ ⴰⵎⵙⵎⵓⵍⵍⵓ, ⵖ ⵜⵎⵣⵡⴰⵔⵓⵜ ⵓⵍⴰ ⵖ ⵜⵎⴳⴳⴰⵔⵓⵜ. | |
ⴰⴳⵍⵍⵉⴷ ⵏ ⵡⴰⵙⵙ ⵏ ⵓⴼⵔⴰ, ⴰⵙⵙ ⵏ ⵓⵙⵙⵃⵙⵓ, ⴽⵔⴰⵉⴳⴰⵜ ⵢⴰⵏ ⴷ ⵎⴰⴷ ⵉⵙⴽⵔ. | |
ⵀⴰ ⵏⵏ ⴽⵢⵢⵉ ⴽⴰ ⵙ ⵏⵙⵙⵓⵎⴷ, ⴷ ⴽⵢⵢⵉ ⴽⴰ ⴰⴷ ⵏⵎⵎⵜⵔ. | |
ⵙⵎⵓⵏ ⴰⵖ, ⵜⵎⵍⵜ ⴰⵖ, ⴰⵖⴰⵔⴰⵙ ⵢⵓⵖⴷⵏ. | |
ⴰⵖⴰⵔⴰⵙ ⵏ ⵖⵡⵉⵍⵍⵉ ⵜⵙⵏⵏⵓⴼⴰⵜ, ⵓⵔ ⴷ ⴰⵢⵜ ⵜⵉⵢⵓⵔⵉ, ⵓⵍⴰ ⵉⵎⵓⴹⴹⴰⵕ.""" | |
examples = gr.Examples( | |
examples=[ | |
["common_voice_zgh_37837257.mp3", "ⵎⵍ ⵉⵢⵉ ⵎⴰⴷ ⴷ ⵜⴻⵜⵜⵎⵓⵏⴷ ⴰⴷ ⴰⴽ ⵎⵍⵖ ⵎⴰⴷ ⵜⴳⵉⴷ"], | |
["Voice1410.wav", example_2] | |
], | |
inputs=[audio_file_in, ref_text] | |
) | |
demo.queue() | |
demo.launch() | |