demo-dubbing / pipeline.py
azunre's picture
Update pipeline.py
4b39b8f verified
from transformers import pipeline
import re
from num2words import num2words
import aiohttp
from aiohttp import ClientSession
from aiohttp_retry import RetryClient, ExponentialRetry
from tqdm import tqdm
import asyncio
import os
from dotenv import load_dotenv
import requests
import ffmpeg
import torch
import aiofiles
import tempfile
import subprocess
# load khaya token from environment
load_dotenv()
# load khaya token
KHAYA_TOKEN = os.getenv("KHAYA_TOKEN")
translation_url = "https://translation-api.ghananlp.org/v1/translate"
tts_url = "https://tts-backend-nlpghana-staging.azurewebsites.net/v0/tts"
translation_hdr = {
# Request headers
"Content-Type": "application/json",
"Cache-Control": "no-cache",
"Ocp-Apim-Subscription-Key": KHAYA_TOKEN,
}
tts_header = {
# Request headers
"Content-Type": "application/json",
"Cache-Control": "no-cache",
"Ocp-Apim-Subscription-Key": f"{KHAYA_TOKEN}",
}
LANG_DICT = {"Twi": "tw", "Ewe": "ee"}
# Check if GPU is available
pipe_device = 0 if torch.cuda.is_available() else -1
def replace_numbers_with_words(text):
def replace(match):
return num2words(match.group().replace(",", ""), lang="en")
return re.sub(r"[\d]+[.,\d]+", replace, text)
async def fetch(session, url, headers, data, semaphore, index):
async with semaphore:
try:
async with session.post(
url, headers=headers, json=data, timeout=10
) as response:
response.raise_for_status()
return index, await response.json()
except aiohttp.ClientError as e:
print(f"Request error: {e}")
return index, str(e)
except Exception as e:
print(f"Unexpected error: {e}")
return index, str(e)
async def translation_main(sentences, url, headers, lang):
khaya_translations = [None] * len(sentences)
semaphore = asyncio.Semaphore(2) # limit the number of concurrent requests
retry_options = ExponentialRetry(
attempts=3,
)
async with RetryClient(ClientSession(), retry_options=retry_options) as session:
tasks = []
for index, sent in enumerate(sentences):
data = {"in": sent, "lang": f"en-{lang}"}
tasks.append(fetch(session, url, headers, data, semaphore, index))
for f in tqdm(
asyncio.as_completed(tasks), total=len(tasks), desc="Translating Sentences"
):
index, result = await f
# TODO: handle error response
khaya_translations[index] = result
return khaya_translations
async def convert_text_to_speech(
session,
tts_url,
tts_header,
text,
text_index,
language,
speaker,
semaphore,
output_dir,
):
speaker_dict = {
"tw": {"male": "twi_speaker_5", "female": "twi_speaker_7"},
"ee": {"male": "ewe_speaker_3", "female": None},
}
speaker_id = speaker_dict[language][speaker]
data = {"text": text, "language": language, "speaker_id": speaker_id}
try:
async with semaphore:
async with session.post(tts_url, headers=tts_header, json=data) as response:
response.raise_for_status()
output_path = os.path.join(output_dir, f"{text_index}_tts.wav")
async with aiofiles.open(output_path, "wb") as file:
while True:
chunk = await response.content.read(16384)
if not chunk:
break
await file.write(chunk)
return output_path
except aiohttp.ClientError as e:
print(f"Request error: {e}")
except Exception as e:
print(f"Unexpected error: {e}")
async def tts_main(khaya_translations, speaker, language):
with tempfile.TemporaryDirectory() as temp_dir:
async with aiohttp.ClientSession() as session:
semaphore = asyncio.Semaphore(3)
tasks = [
convert_text_to_speech(
session,
tts_url,
tts_header,
sent,
text_index,
language,
speaker,
semaphore,
temp_dir,
)
for text_index, sent in enumerate(khaya_translations)
]
output_files = []
for task in tqdm(
asyncio.as_completed(tasks),
total=len(tasks),
desc="Converting to Speech",
):
result = await task
if result:
output_files.append(result)
output_audio = combine_audio_streams(output_files, "combined_audio.wav")
return output_audio
def extract_audio_from_video(input_video):
if input_video:
output_audio_path = f"separated_audio.aac"
try:
(
ffmpeg.input(f"{input_video}")
.output(f"{output_audio_path}", acodec="copy", vn=None)
.run(overwrite_output=True)
)
print("Audio extracted successfully")
return output_audio_path
except ffmpeg.Error as e:
print(e.stderr.decode())
raise e
def transcribe_and_preprocess_audio(input_audio):
asr = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3",
device=pipe_device,
)
pipeline_whisper_output = asr(
f"{input_audio}",
return_timestamps=True,
)
# preprocess the output before machine translation
sentences = pipeline_whisper_output["text"].split(". ")
sentences = [el.strip() for el in sentences if el]
# replace numbers with words
sentences = [replace_numbers_with_words(sent) for sent in sentences]
return sentences
def combine_audio_streams(list_of_output_chunks, output_audio):
list_of_output_chunks = sorted(
list_of_output_chunks, key=lambda x: int(os.path.basename(x).split("_")[0])
)
input_streams = [ffmpeg.input(chunk) for chunk in list_of_output_chunks]
concatenated = ffmpeg.concat(*input_streams, v=0, a=1).output(f"{output_audio}")
try:
concatenated.run(overwrite_output=True)
return output_audio
except ffmpeg.Error as e:
print(e.stderr.decode())
def create_combined_output(input_video, output_audio, output_video):
try:
video = ffmpeg.input(f"{input_video}")
audio = ffmpeg.input(f"{output_audio}")
(
ffmpeg.output(
video["v"],
audio["a"],
filename=f"{output_video}",
vcodec="copy",
).run(overwrite_output=True)
)
print("Video and audio combined successfully")
return output_video
except ffmpeg.Error as e:
print(e.stderr.decode())
raise e
def create_combined_output_subprocess(input_video, output_audio, output_video):
video_duration = get_media_duration(input_video)
audio_duration = get_media_duration(output_audio)
speed_factor = calculate_speed_factor(video_duration, audio_duration)
if speed_factor < 0.5:
speed_factor = 0.5
if speed_factor > 100:
speed_factor = 100
print(f"Speed factor: {speed_factor}")
try:
command = [
"ffmpeg",
"-i",
f"{input_video}",
"-i",
f"{output_audio}",
"-filter:a",
f"atempo={speed_factor}",
"-c:v",
"copy",
"-map",
"0:v:0",
"-map",
"1:a:0",
f"{output_video}",
]
subprocess.run(command, check=True)
print("Video and audio combined successfully")
return output_video
except subprocess.CalledProcessError as e:
print(e.stderr.decode())
raise e
def get_media_duration(media_file):
"""
Get the duration of a media file in seconds.
"""
probe = ffmpeg.probe(media_file)
duration = float(probe["format"]["duration"])
return duration
def calculate_speed_factor(video_duration, audio_duration):
"""
Calculate the speed factor to align audio with video.
"""
return audio_duration / video_duration