import os import gc import sys import time import torch import spaces import torchaudio import numpy as np from scipy.signal import resample from pyannote.audio import Pipeline from dotenv import load_dotenv load_dotenv() from difflib import SequenceMatcher from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor, Wav2Vec2ForCTC, AutoProcessor, AutoTokenizer, AutoModelForSeq2SeqLM from difflib import SequenceMatcher import logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger(__name__) class ChunkedTranscriber: def __init__(self, chunk_size=5, overlap=1, sample_rate=16000): self.chunk_size = chunk_size self.overlap = overlap self.sample_rate = sample_rate self.previous_text = "" self.previous_lang = None self.speaker_diarization_pipeline = self.load_speaker_diarization_pipeline() def load_speaker_diarization_pipeline(self): """ Load the pre-trained speaker diarization pipeline from pyannote-audio. """ pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=os.getenv("HF_TOKEN")) return pipeline @spaces.GPU(duration=60) def diarize_audio(self, audio_path): """ Perform speaker diarization on the input audio. """ diarization_result = self.speaker_diarization_pipeline({"uri": "audio", "audio": audio_path}) return diarization_result def load_lid_mms(self): model_id = "facebook/mms-lid-256" processor = AutoFeatureExtractor.from_pretrained(model_id) model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id) return processor, model @spaces.GPU(duration=60) def language_identification(self, model, processor, chunk, device="cuda"): inputs = processor(chunk, sampling_rate=16_000, return_tensors="pt") model.to(device) inputs.to(device) with torch.no_grad(): outputs = model(**inputs).logits lang_id = torch.argmax(outputs, dim=-1)[0].item() detected_lang = model.config.id2label[lang_id] del model del inputs torch.cuda.empty_cache() gc.collect() return detected_lang def load_mms(self) : model_id = "facebook/mms-1b-all" processor = AutoProcessor.from_pretrained(model_id) model = Wav2Vec2ForCTC.from_pretrained(model_id) return model, processor @spaces.GPU(duration=60) def mms_transcription(self, model, processor, chunk, device="cuda"): inputs = processor(chunk, sampling_rate=16_000, return_tensors="pt") model.to(device) inputs.to(device) with torch.no_grad(): outputs = model(**inputs).logits ids = torch.argmax(outputs, dim=-1)[0] transcription = processor.decode(ids) del model del inputs torch.cuda.empty_cache() gc.collect() return transcription def load_T2T_translation_model(self) : model_id = "facebook/nllb-200-distilled-600M" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForSeq2SeqLM.from_pretrained(model_id) return model, tokenizer @spaces.GPU(duration=60) def text2text_translation(self, translation_model, translation_tokenizer, transcript, device="cuda"): # model, tokenizer = load_translation_model() tokenized_inputs = translation_tokenizer(transcript, return_tensors='pt') translation_model.to(device) tokenized_inputs.to(device) translated_tokens = translation_model.generate(**tokenized_inputs, forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"), max_length=100) del translation_model del tokenized_inputs torch.cuda.empty_cache() gc.collect() return translation_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] def preprocess_audio(self, audio): """ Create overlapping chunks with improved timing logic """ chunk_samples = int(self.chunk_size * self.sample_rate) overlap_samples = int(self.overlap * self.sample_rate) chunks_with_times = [] start_idx = 0 while start_idx < len(audio): end_idx = min(start_idx + chunk_samples, len(audio)) # Add padding for first chunk if start_idx == 0: chunk = audio[start_idx:end_idx] padding = torch.zeros(int(1 * self.sample_rate)) chunk = torch.cat([padding, chunk]) else: # Include overlap from previous chunk actual_start = max(0, start_idx - overlap_samples) chunk = audio[actual_start:end_idx] # Pad if necessary if len(chunk) < chunk_samples: chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk))) # Adjust time ranges to account for overlaps chunk_start_time = max(0, (start_idx / self.sample_rate) - self.overlap) chunk_end_time = min((end_idx / self.sample_rate) + self.overlap, len(audio) / self.sample_rate) chunks_with_times.append({ 'chunk': chunk, 'start_time': start_idx / self.sample_rate, 'end_time': end_idx / self.sample_rate, 'transcribe_start': chunk_start_time, 'transcribe_end': chunk_end_time }) # Move to next chunk with smaller step size for better continuity start_idx += (chunk_samples - overlap_samples) return chunks_with_times def merge_close_segments(self, results): """ Merge segments that are close in time and have the same language """ if not results: return results merged = [] current = results[0] for next_segment in results[1:]: # Skip empty segments if not next_segment['text'].strip(): continue # If segments are in the same language and close in time if (current['detected_language'] == next_segment['detected_language'] and abs(next_segment['start_time'] - current['end_time']) <= self.overlap): # Merge the segments current['text'] = current['text'] + ' ' + next_segment['text'] current['end_time'] = next_segment['end_time'] if 'translated' in current and 'translated' in next_segment: current['translated'] = current['translated'] + ' ' + next_segment['translated'] else: if current['text'].strip(): # Only add non-empty segments merged.append(current) current = next_segment if current['text'].strip(): # Add the last segment if non-empty merged.append(current) return merged def clean_overlapping_text(self, current_text, prev_text, current_lang, prev_lang, min_overlap=3): """ Improved text cleaning with language awareness and better sentence boundary handling """ if not prev_text or not current_text: return current_text # If languages are different, don't try to merge if prev_lang and current_lang and prev_lang != current_lang: return current_text # Split into words prev_words = prev_text.split() curr_words = current_text.split() if len(prev_words) < 2 or len(curr_words) < 2: return current_text # Find matching sequences at the end of prev_text and start of current_text matcher = SequenceMatcher(None, prev_words, curr_words) matches = list(matcher.get_matching_blocks()) # Look for significant overlaps best_overlap = 0 overlap_size = 0 for match in matches: # Check if the match is at the start of current text if match.b == 0 and match.size >= min_overlap: if match.size > overlap_size: best_overlap = match.size overlap_size = match.size if best_overlap > 0: # Remove overlapping content while preserving sentence integrity cleaned_words = curr_words[best_overlap:] if not cleaned_words: # If everything was overlapping return "" return ' '.join(cleaned_words).strip() return current_text def process_chunk(self, chunk_data, mms_model, mms_processor, translation_model=None, translation_tokenizer=None): """ Process chunk with improved language handling """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") try: # Language detection lid_processor, lid_model = self.load_lid_mms() lid_lang = self.language_identification(lid_model, lid_processor, chunk_data['chunk']) # Configure processor mms_processor.tokenizer.set_target_lang(lid_lang) mms_model.load_adapter(lid_lang) # Transcribe inputs = mms_processor(chunk_data['chunk'], sampling_rate=self.sample_rate, return_tensors="pt") inputs = inputs.to(device) mms_model = mms_model.to(device) with torch.no_grad(): outputs = mms_model(**inputs).logits ids = torch.argmax(outputs, dim=-1)[0] transcription = mms_processor.decode(ids) # Clean overlapping text with language awareness cleaned_transcription = self.clean_overlapping_text( transcription, self.previous_text, lid_lang, self.previous_lang, min_overlap=3 ) # Update previous state self.previous_text = transcription self.previous_lang = lid_lang if not cleaned_transcription.strip(): return None result = { 'start_time': chunk_data['start_time'], 'end_time': chunk_data['end_time'], 'text': cleaned_transcription, 'detected_language': lid_lang } # Handle translation if translation_model and translation_tokenizer and cleaned_transcription.strip(): translation = self.text2text_translation( translation_model, translation_tokenizer, cleaned_transcription ) result['translated'] = translation return result except Exception as e: print(f"Error processing chunk: {str(e)}") return None finally: torch.cuda.empty_cache() gc.collect() def translate_text(self, text, translation_model, translation_tokenizer, device): """ Translate cleaned text using the provided translation model. """ tokenized_inputs = translation_tokenizer(text, return_tensors='pt') tokenized_inputs = tokenized_inputs.to(device) translation_model = translation_model.to(device) translated_tokens = translation_model.generate( **tokenized_inputs, forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"), max_length=100 ) translation = translation_tokenizer.batch_decode( translated_tokens, skip_special_tokens=True )[0] del translation_model del tokenized_inputs torch.cuda.empty_cache() gc.collect() return translation def transcribe_audio(self, audio_path, translate=False): """ Main transcription function with improved segment merging """ # Perform speaker diarization diarization_result = self.diarize_audio(audio_path) # Extract speaker segments speaker_segments = [] for turn, _, speaker in diarization_result.itertracks(yield_label=True): speaker_segments.append({ 'start_time': turn.start, 'end_time': turn.end, 'speaker': speaker }) audio = self.load_audio(audio_path) chunks = self.preprocess_audio(audio) mms_model, mms_processor = self.load_mms() translation_model, translation_tokenizer = None, None if translate: translation_model, translation_tokenizer = self.load_T2T_translation_model() # Process chunks results = [] for chunk_data in chunks: result = self.process_chunk( chunk_data, mms_model, mms_processor, translation_model, translation_tokenizer ) if result: for segment in speaker_segments: if int(segment['start_time']) <= int(chunk_data['start_time']) < int(segment['end_time']): result['speaker'] = segment['speaker'] break results.append(result) # results.append(result) # Merge close segments and clean up merged_results = self.merge_close_segments(results) _translation = "" _output = "" for res in merged_results: _translation+=res['translated'] _output+=f"{res['start_time']}-{res['end_time']} - Speaker: {res['speaker'].split('_')[1]} - Language: {res['detected_language']}\n Text: {res['text']}\n Translation: {res['translated']}\n\n" logger.info(f"\n\n TRANSLATION: {_translation}") return _output def load_audio(self, audio_path): """ Load and preprocess audio file. """ waveform, sample_rate = torchaudio.load(audio_path) # Convert to mono if stereo if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0) else: waveform = waveform.squeeze(0) # Resample if necessary if sample_rate != self.sample_rate: resampler = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=self.sample_rate ) waveform = resampler(waveform) return waveform.float()