Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import moviepy.editor as mp | |
from moviepy.video.tools.subtitles import SubtitlesClip | |
from datetime import timedelta | |
import os | |
import logging | |
from transformers import ( | |
AutoModelForSpeechSeq2Seq, | |
AutoProcessor, | |
MarianMTModel, | |
MarianTokenizer, | |
pipeline | |
) | |
import torch | |
import numpy as np | |
from pydub import AudioSegment | |
import spaces | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler('video_subtitler.log'), | |
logging.StreamHandler() | |
] | |
) | |
logger = logging.getLogger(__name__) | |
# Dictionary of supported languages and their codes for MarianMT | |
LANGUAGE_CODES = { | |
"English": "en", | |
"Spanish": "es", | |
"French": "fr", | |
"German": "de", | |
"Italian": "it", | |
"Portuguese": "pt", | |
"Russian": "ru", | |
"Chinese": "zh", | |
"Japanese": "ja", | |
"Korean": "ko" | |
} | |
def get_model_name(source_lang, target_lang): | |
"""Get MarianMT model name for language pair""" | |
logger.info(f"Getting model name for translation from {source_lang} to {target_lang}") | |
return f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}" | |
def format_timestamp(seconds): | |
"""Convert seconds to SRT timestamp format""" | |
td = timedelta(seconds=seconds) | |
hours = td.seconds//3600 | |
minutes = (td.seconds//60)%60 | |
seconds = td.seconds%60 | |
milliseconds = td.microseconds//1000 | |
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}" | |
def translate_text(text, source_lang, target_lang): | |
"""Translate text using MarianMT""" | |
if source_lang == target_lang: | |
logger.info("Source and target languages are the same, skipping translation") | |
return text | |
try: | |
logger.info(f"Translating text from {source_lang} to {target_lang}") | |
model_name = get_model_name(source_lang, target_lang) | |
logger.info(f"Loading translation model: {model_name}") | |
tokenizer = MarianTokenizer.from_pretrained(model_name) | |
model = MarianMTModel.from_pretrained(model_name) | |
logger.debug(f"Input text: {text}") | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
translated = model.generate(**inputs) | |
translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0] | |
logger.debug(f"Translated text: {translated_text}") | |
return translated_text | |
except Exception as e: | |
logger.error(f"Translation error: {str(e)}", exc_info=True) | |
return text | |
def load_audio(video_path): | |
"""Extract and load audio from video file""" | |
logger.info(f"Loading audio from video: {video_path}") | |
try: | |
video = mp.VideoFileClip(video_path) | |
logger.info(f"Video loaded. Duration: {video.duration} seconds") | |
temp_audio_path = "temp_audio.wav" | |
logger.info(f"Extracting audio to temporary file: {temp_audio_path}") | |
video.audio.write_audiofile(temp_audio_path) | |
logger.info("Loading audio file with pydub") | |
audio = AudioSegment.from_wav(temp_audio_path) | |
audio_array = np.array(audio.get_array_of_samples()) | |
logger.info("Converting audio to float32 and normalizing") | |
audio_array = audio_array.astype(np.float32) / np.iinfo(np.int16).max | |
if len(audio_array.shape) > 1: | |
logger.info("Converting stereo to mono") | |
audio_array = audio_array.mean(axis=1) | |
logger.info(f"Audio loaded successfully. Shape: {audio_array.shape}, Sample rate: {audio.frame_rate}") | |
return audio_array, audio.frame_rate, video, temp_audio_path | |
except Exception as e: | |
logger.error(f"Error loading audio: {str(e)}", exc_info=True) | |
raise | |
def create_srt(segments, target_lang="en"): | |
"""Convert transcribed segments to SRT format with optional translation""" | |
logger.info(f"Creating SRT content for {len(segments)} segments") | |
srt_content = "" | |
for i, segment in enumerate(segments, start=1): | |
start_time = format_timestamp(segment['start']) | |
end_time = format_timestamp(segment['end']) | |
text = segment['text'].strip() | |
logger.debug(f"Processing segment {i}: {start_time} --> {end_time}") | |
if segment.get('language') and segment['language'] != target_lang: | |
logger.info(f"Translating segment {i}") | |
text = translate_text(text, segment['language'], target_lang) | |
srt_content += f"{i}\n{start_time} --> {end_time}\n{text}\n\n" | |
return srt_content | |
def create_subtitle_clips(segments, videosize, target_lang="en"): | |
"""Create subtitle clips for moviepy with translation support""" | |
logger.info(f"Creating subtitle clips for {len(segments)} segments") | |
subtitle_clips = [] | |
for i, segment in enumerate(segments): | |
logger.debug(f"Processing subtitle clip {i}") | |
start_time = segment['start'] | |
end_time = segment['end'] | |
duration = end_time - start_time | |
text = segment['text'].strip() | |
if segment.get('language') and segment['language'] != target_lang: | |
logger.info(f"Translating subtitle {i}") | |
text = translate_text(text, segment['language'], target_lang) | |
try: | |
text_clip = mp.TextClip( | |
text, | |
font='Arial', | |
fontsize=24, | |
color='white', | |
stroke_color='black', | |
stroke_width=1, | |
size=videosize, | |
method='caption' | |
).set_position(('center', 'bottom')) | |
text_clip = text_clip.set_start(start_time).set_duration(duration) | |
subtitle_clips.append(text_clip) | |
except Exception as e: | |
logger.error(f"Error creating subtitle clip {i}: {str(e)}", exc_info=True) | |
return subtitle_clips | |
def process_video(video_path, target_lang="en"): | |
"""Main function to process video and add subtitles with translation""" | |
logger.info(f"Starting video processing: {video_path}") | |
try: | |
# Set up device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {device}") | |
# Load CrisperWhisper model | |
model_id = "nyrahealth/CrisperWhisper" | |
logger.info(f"Loading CrisperWhisper model: {model_id}") | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
low_cpu_mem_usage=True, | |
use_safetensors=True | |
).to(device) | |
logger.info("Loading processor") | |
processor = AutoProcessor.from_pretrained(model_id) | |
# Load audio and video | |
logger.info("Loading audio from video") | |
audio_array, sampling_rate, video, temp_audio_path = load_audio(video_path) | |
# Create pipeline | |
logger.info("Creating ASR pipeline") | |
pipe = pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
max_new_tokens=128, | |
chunk_length_s=30, | |
batch_size=16, | |
return_timestamps=True, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
device=device, | |
) | |
# Transcribe audio | |
logger.info("Starting transcription") | |
result = pipe(audio_array, return_timestamps="word") | |
logger.info("Transcription completed") | |
logger.debug(f"Transcription result: {result}") | |
# Convert word-level timestamps to segments | |
logger.info("Converting word-level timestamps to segments") | |
segments = [] | |
current_segment = {"text": "", "start": result["chunks"][0]["timestamp"][0]} | |
for chunk in result["chunks"]: | |
current_segment["text"] += " " + chunk["text"] | |
current_segment["end"] = chunk["timestamp"][1] | |
if len(current_segment["text"].split()) > 10 or \ | |
(current_segment["end"] - current_segment["start"]) > 5.0: | |
segments.append(current_segment) | |
if chunk != result["chunks"][-1]: | |
current_segment = {"text": "", "start": chunk["timestamp"][1]} | |
if current_segment["text"]: | |
segments.append(current_segment) | |
logger.info(f"Created {len(segments)} segments") | |
# Add detected language | |
detected_language = "en" | |
for segment in segments: | |
segment['language'] = detected_language | |
# Create SRT content | |
logger.info("Creating SRT content") | |
srt_content = create_srt(segments, target_lang) | |
# Save SRT file | |
video_name = os.path.splitext(os.path.basename(video_path))[0] | |
srt_path = f"{video_name}_subtitles_{target_lang}.srt" | |
logger.info(f"Saving SRT file: {srt_path}") | |
with open(srt_path, "w", encoding="utf-8") as f: | |
f.write(srt_content) | |
# Create subtitle clips | |
logger.info("Creating subtitle clips") | |
subtitle_clips = create_subtitle_clips(segments, video.size, target_lang) | |
# Combine video with subtitles | |
logger.info("Combining video with subtitles") | |
final_video = mp.CompositeVideoClip([video] + subtitle_clips) | |
# Save final video | |
output_video_path = f"{video_name}_with_subtitles_{target_lang}.mp4" | |
logger.info(f"Saving final video: {output_video_path}") | |
final_video.write_videofile(output_video_path) | |
# Clean up | |
logger.info("Cleaning up temporary files") | |
os.remove(temp_audio_path) | |
video.close() | |
final_video.close() | |
logger.info("Video processing completed successfully") | |
return output_video_path, srt_path | |
except Exception as e: | |
logger.error(f"Error in video processing: {str(e)}", exc_info=True) | |
raise | |
def gradio_interface(video_file, target_language): | |
"""Gradio interface function with language selection""" | |
try: | |
logger.info(f"Processing new video request: {video_file.name}") | |
logger.info(f"Target language: {target_language}") | |
video_path = video_file.name | |
target_lang = LANGUAGE_CODES[target_language] | |
output_video, srt_file = process_video(video_path, target_lang) | |
logger.info("Processing completed successfully") | |
return output_video, srt_file | |
except Exception as e: | |
logger.error(f"Error in Gradio interface: {str(e)}", exc_info=True) | |
return str(e), None | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Video(label="Upload Video"), | |
gr.Dropdown( | |
choices=list(LANGUAGE_CODES.keys()), | |
value="English", | |
label="Target Language" | |
) | |
], | |
outputs=[ | |
gr.Video(label="Video with Subtitles"), | |
gr.File(label="SRT Subtitle File") | |
], | |
title="Video Subtitler with CrisperWhisper", | |
description="Upload a video to generate subtitles using CrisperWhisper, translate them to your chosen language, and embed them directly in the video." | |
) | |
if __name__ == "__main__": | |
logger.info("Starting Video Subtitler application") | |
iface.launch() |