|
import os |
|
import gc |
|
import sys |
|
import time |
|
import torch |
|
import spaces |
|
import torchaudio |
|
import numpy as np |
|
from scipy.signal import resample |
|
from pyannote.audio import Pipeline |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
from difflib import SequenceMatcher |
|
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor, Wav2Vec2ForCTC, AutoProcessor, AutoTokenizer, AutoModelForSeq2SeqLM |
|
from difflib import SequenceMatcher |
|
|
|
class ChunkedTranscriber: |
|
def __init__(self, chunk_size=5, overlap=1, sample_rate=16000): |
|
self.chunk_size = chunk_size |
|
self.overlap = overlap |
|
self.sample_rate = sample_rate |
|
self.previous_text = "" |
|
self.previous_lang = None |
|
self.speaker_diarization_pipeline = self.load_speaker_diarization_pipeline() |
|
|
|
def load_speaker_diarization_pipeline(self): |
|
""" |
|
Load the pre-trained speaker diarization pipeline from pyannote-audio. |
|
""" |
|
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=os.getenv("HF_TOKEN")) |
|
return pipeline |
|
|
|
@spaces.GPU(duration=60) |
|
def diarize_audio(self, audio_path): |
|
""" |
|
Perform speaker diarization on the input audio. |
|
""" |
|
diarization_result = self.speaker_diarization_pipeline({"uri": "audio", "audio": audio_path}) |
|
return diarization_result |
|
|
|
def load_lid_mms(self): |
|
model_id = "facebook/mms-lid-256" |
|
processor = AutoFeatureExtractor.from_pretrained(model_id) |
|
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id) |
|
return processor, model |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
def language_identification(self, model, processor, chunk, device="cuda"): |
|
inputs = processor(chunk, sampling_rate=16_000, return_tensors="pt") |
|
model.to(device) |
|
inputs.to(device) |
|
with torch.no_grad(): |
|
outputs = model(**inputs).logits |
|
|
|
lang_id = torch.argmax(outputs, dim=-1)[0].item() |
|
detected_lang = model.config.id2label[lang_id] |
|
del model |
|
del inputs |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
return detected_lang |
|
|
|
|
|
def load_mms(self) : |
|
model_id = "facebook/mms-1b-all" |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
model = Wav2Vec2ForCTC.from_pretrained(model_id) |
|
return model, processor |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
def mms_transcription(self, model, processor, chunk, device="cuda"): |
|
|
|
inputs = processor(chunk, sampling_rate=16_000, return_tensors="pt") |
|
model.to(device) |
|
inputs.to(device) |
|
with torch.no_grad(): |
|
outputs = model(**inputs).logits |
|
|
|
ids = torch.argmax(outputs, dim=-1)[0] |
|
transcription = processor.decode(ids) |
|
del model |
|
del inputs |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
return transcription |
|
|
|
|
|
def load_T2T_translation_model(self) : |
|
model_id = "facebook/nllb-200-distilled-600M" |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_id) |
|
return model, tokenizer |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
def text2text_translation(self, translation_model, translation_tokenizer, transcript, device="cuda"): |
|
|
|
|
|
tokenized_inputs = translation_tokenizer(transcript, return_tensors='pt') |
|
translation_model.to(device) |
|
tokenized_inputs.to(device) |
|
translated_tokens = translation_model.generate(**tokenized_inputs, |
|
forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"), |
|
max_length=100) |
|
del translation_model |
|
del tokenized_inputs |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
return translation_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] |
|
|
|
|
|
def preprocess_audio(self, audio): |
|
""" |
|
Create overlapping chunks with improved timing logic |
|
""" |
|
chunk_samples = int(self.chunk_size * self.sample_rate) |
|
overlap_samples = int(self.overlap * self.sample_rate) |
|
|
|
chunks_with_times = [] |
|
start_idx = 0 |
|
|
|
while start_idx < len(audio): |
|
end_idx = min(start_idx + chunk_samples, len(audio)) |
|
|
|
|
|
if start_idx == 0: |
|
chunk = audio[start_idx:end_idx] |
|
padding = torch.zeros(int(1 * self.sample_rate)) |
|
chunk = torch.cat([padding, chunk]) |
|
else: |
|
|
|
actual_start = max(0, start_idx - overlap_samples) |
|
chunk = audio[actual_start:end_idx] |
|
|
|
|
|
if len(chunk) < chunk_samples: |
|
chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk))) |
|
|
|
|
|
chunk_start_time = max(0, (start_idx / self.sample_rate) - self.overlap) |
|
chunk_end_time = min((end_idx / self.sample_rate) + self.overlap, len(audio) / self.sample_rate) |
|
|
|
chunks_with_times.append({ |
|
'chunk': chunk, |
|
'start_time': start_idx / self.sample_rate, |
|
'end_time': end_idx / self.sample_rate, |
|
'transcribe_start': chunk_start_time, |
|
'transcribe_end': chunk_end_time |
|
}) |
|
|
|
|
|
start_idx += (chunk_samples - overlap_samples) |
|
|
|
return chunks_with_times |
|
|
|
|
|
def merge_close_segments(self, results): |
|
""" |
|
Merge segments that are close in time and have the same language |
|
""" |
|
if not results: |
|
return results |
|
|
|
merged = [] |
|
current = results[0] |
|
|
|
for next_segment in results[1:]: |
|
|
|
if not next_segment['text'].strip(): |
|
continue |
|
|
|
|
|
if (current['detected_language'] == next_segment['detected_language'] and |
|
abs(next_segment['start_time'] - current['end_time']) <= self.overlap): |
|
|
|
|
|
current['text'] = current['text'] + ' ' + next_segment['text'] |
|
current['end_time'] = next_segment['end_time'] |
|
if 'translated' in current and 'translated' in next_segment: |
|
current['translated'] = current['translated'] + ' ' + next_segment['translated'] |
|
else: |
|
if current['text'].strip(): |
|
merged.append(current) |
|
current = next_segment |
|
|
|
if current['text'].strip(): |
|
merged.append(current) |
|
|
|
return merged |
|
|
|
|
|
def clean_overlapping_text(self, current_text, prev_text, current_lang, prev_lang, min_overlap=3): |
|
""" |
|
Improved text cleaning with language awareness and better sentence boundary handling |
|
""" |
|
if not prev_text or not current_text: |
|
return current_text |
|
|
|
|
|
if prev_lang and current_lang and prev_lang != current_lang: |
|
return current_text |
|
|
|
|
|
prev_words = prev_text.split() |
|
curr_words = current_text.split() |
|
|
|
if len(prev_words) < 2 or len(curr_words) < 2: |
|
return current_text |
|
|
|
|
|
matcher = SequenceMatcher(None, prev_words, curr_words) |
|
matches = list(matcher.get_matching_blocks()) |
|
|
|
|
|
best_overlap = 0 |
|
overlap_size = 0 |
|
|
|
for match in matches: |
|
|
|
if match.b == 0 and match.size >= min_overlap: |
|
if match.size > overlap_size: |
|
best_overlap = match.size |
|
overlap_size = match.size |
|
|
|
if best_overlap > 0: |
|
|
|
cleaned_words = curr_words[best_overlap:] |
|
if not cleaned_words: |
|
return "" |
|
return ' '.join(cleaned_words).strip() |
|
|
|
return current_text |
|
|
|
|
|
def process_chunk(self, chunk_data, mms_model, mms_processor, translation_model=None, translation_tokenizer=None): |
|
""" |
|
Process chunk with improved language handling |
|
""" |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
try: |
|
print(f"\n\n Chunk shape: {chunk_data['chunk'].shape}") |
|
|
|
lid_processor, lid_model = self.load_lid_mms() |
|
lid_lang = self.language_identification(lid_model, lid_processor, chunk_data['chunk']) |
|
|
|
|
|
mms_processor.tokenizer.set_target_lang(lid_lang) |
|
mms_model.load_adapter(lid_lang) |
|
|
|
|
|
inputs = mms_processor(chunk_data['chunk'], sampling_rate=self.sample_rate, return_tensors="pt") |
|
inputs = inputs.to(device) |
|
mms_model = mms_model.to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = mms_model(**inputs).logits |
|
|
|
ids = torch.argmax(outputs, dim=-1)[0] |
|
transcription = mms_processor.decode(ids) |
|
|
|
|
|
cleaned_transcription = self.clean_overlapping_text( |
|
transcription, |
|
self.previous_text, |
|
lid_lang, |
|
self.previous_lang, |
|
min_overlap=3 |
|
) |
|
|
|
|
|
self.previous_text = transcription |
|
self.previous_lang = lid_lang |
|
|
|
if not cleaned_transcription.strip(): |
|
return None |
|
|
|
result = { |
|
'start_time': chunk_data['start_time'], |
|
'end_time': chunk_data['end_time'], |
|
'text': cleaned_transcription, |
|
'detected_language': lid_lang |
|
} |
|
|
|
|
|
if translation_model and translation_tokenizer and cleaned_transcription.strip(): |
|
translation = self.text2text_translation( |
|
translation_model, |
|
translation_tokenizer, |
|
cleaned_transcription |
|
) |
|
result['translated'] = translation |
|
|
|
return result |
|
|
|
except Exception as e: |
|
print(f"Error processing chunk: {str(e)}") |
|
return None |
|
finally: |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
def translate_text(self, text, translation_model, translation_tokenizer, device): |
|
""" |
|
Translate cleaned text using the provided translation model. |
|
""" |
|
tokenized_inputs = translation_tokenizer(text, return_tensors='pt') |
|
tokenized_inputs = tokenized_inputs.to(device) |
|
translation_model = translation_model.to(device) |
|
|
|
translated_tokens = translation_model.generate( |
|
**tokenized_inputs, |
|
forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"), |
|
max_length=100 |
|
) |
|
|
|
translation = translation_tokenizer.batch_decode( |
|
translated_tokens, |
|
skip_special_tokens=True |
|
)[0] |
|
|
|
del translation_model |
|
del tokenized_inputs |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
return translation |
|
|
|
|
|
|
|
def transcribe_audio(self, audio_path, translate=False): |
|
""" |
|
Main transcription function with improved segment merging |
|
""" |
|
|
|
diarization_result = self.diarize_audio(audio_path) |
|
|
|
|
|
speaker_segments = [] |
|
|
|
for turn, _, speaker in diarization_result.itertracks(yield_label=True): |
|
speaker_segments.append({ |
|
'start_time': turn.start, |
|
'end_time': turn.end, |
|
'speaker': speaker |
|
}) |
|
|
|
|
|
audio = self.load_audio(audio_path) |
|
chunks = self.preprocess_audio(audio) |
|
|
|
mms_model, mms_processor = self.load_mms() |
|
translation_model, translation_tokenizer = None, None |
|
if translate: |
|
translation_model, translation_tokenizer = self.load_T2T_translation_model() |
|
|
|
|
|
results = [] |
|
for chunk_data in chunks: |
|
result = self.process_chunk( |
|
chunk_data, |
|
mms_model, |
|
mms_processor, |
|
translation_model, |
|
translation_tokenizer |
|
) |
|
print(f"\n\nResult:\n{result}") |
|
if result: |
|
for segment in speaker_segments: |
|
if int(segment['start_time']) <= int(chunk_data['start_time']) < int(segment['end_time']): |
|
result['speaker'] = segment['speaker'] |
|
break |
|
results.append(result) |
|
|
|
|
|
|
|
merged_results = self.merge_close_segments(results) |
|
|
|
return merged_results |
|
|
|
|
|
def load_audio(self, audio_path): |
|
""" |
|
Load and preprocess audio file. |
|
""" |
|
waveform, sample_rate = torchaudio.load(audio_path) |
|
|
|
|
|
if waveform.shape[0] > 1: |
|
waveform = torch.mean(waveform, dim=0) |
|
else: |
|
waveform = waveform.squeeze(0) |
|
|
|
|
|
if sample_rate != self.sample_rate: |
|
resampler = torchaudio.transforms.Resample( |
|
orig_freq=sample_rate, |
|
new_freq=self.sample_rate |
|
) |
|
waveform = resampler(waveform) |
|
|
|
return waveform.float() |