|
import gc |
|
import torch |
|
import torchaudio |
|
import numpy as np |
|
from transformers import ( |
|
Wav2Vec2ForSequenceClassification, |
|
AutoFeatureExtractor, |
|
Wav2Vec2ForCTC, |
|
AutoProcessor, |
|
AutoTokenizer, |
|
AutoModelForSeq2SeqLM |
|
) |
|
import spaces |
|
import logging |
|
from difflib import SequenceMatcher |
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
class AudioProcessor: |
|
def __init__(self, chunk_size=5, overlap=1, sample_rate=16000): |
|
self.chunk_size = chunk_size |
|
self.overlap = overlap |
|
self.sample_rate = sample_rate |
|
self.previous_text = "" |
|
self.previous_lang = None |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
def load_models(self): |
|
"""Load all required models""" |
|
logger.info("Loading MMS models...") |
|
|
|
|
|
lid_processor = AutoFeatureExtractor.from_pretrained("facebook/mms-lid-256") |
|
lid_model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/mms-lid-256") |
|
|
|
|
|
mms_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all") |
|
mms_model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all") |
|
|
|
|
|
translation_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") |
|
translation_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") |
|
|
|
return { |
|
'lid': (lid_model, lid_processor), |
|
'mms': (mms_model, mms_processor), |
|
'translation': (translation_model, translation_tokenizer) |
|
} |
|
|
|
@spaces.GPU(duration=60) |
|
def identify_language(self, audio_chunk, models): |
|
"""Identify language of audio chunk""" |
|
lid_model, lid_processor = models['lid'] |
|
inputs = lid_processor(audio_chunk, sampling_rate=16000, return_tensors="pt") |
|
lid_model.to(self.device) |
|
with torch.no_grad(): |
|
outputs = lid_model(inputs.input_values.to(self.device)).logits |
|
lang_id = torch.argmax(outputs, dim=-1)[0].item() |
|
detected_lang = lid_model.config.id2label[lang_id] |
|
|
|
return detected_lang |
|
|
|
@spaces.GPU(duration=60) |
|
def transcribe_chunk(self, audio_chunk, language, models): |
|
"""Transcribe audio chunk""" |
|
mms_model, mms_processor = models['mms'] |
|
|
|
mms_processor.tokenizer.set_target_lang(language) |
|
mms_model.load_adapter(language) |
|
mms_model.to(self.device) |
|
inputs = mms_processor(audio_chunk, sampling_rate=16000, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
outputs = mms_model(inputs.input_values.to(self.device)).logits |
|
ids = torch.argmax(outputs, dim=-1)[0] |
|
transcription = mms_processor.decode(ids) |
|
|
|
return transcription |
|
|
|
@spaces.GPU(duration=60) |
|
def translate_text(self, text, models): |
|
"""Translate text to English""" |
|
translation_model, translation_tokenizer = models['translation'] |
|
|
|
inputs = translation_tokenizer(text, return_tensors="pt") |
|
inputs = inputs.to(self.device) |
|
translation_model.to(self.device) |
|
with torch.no_grad(): |
|
outputs = translation_model.generate( |
|
**inputs, |
|
forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"), |
|
max_length=100 |
|
) |
|
translation = translation_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] |
|
|
|
return translation |
|
|
|
def preprocess_audio(self, audio): |
|
""" |
|
Create overlapping chunks with improved timing logic |
|
""" |
|
chunk_samples = int(self.chunk_size * self.sample_rate) |
|
overlap_samples = int(self.overlap * self.sample_rate) |
|
|
|
chunks_with_times = [] |
|
start_idx = 0 |
|
|
|
while start_idx < len(audio): |
|
end_idx = min(start_idx + chunk_samples, len(audio)) |
|
|
|
|
|
if start_idx == 0: |
|
chunk = audio[start_idx:end_idx] |
|
padding = torch.zeros(int(1 * self.sample_rate)) |
|
chunk = torch.cat([padding, chunk]) |
|
else: |
|
|
|
actual_start = max(0, start_idx - overlap_samples) |
|
chunk = audio[actual_start:end_idx] |
|
|
|
|
|
if len(chunk) < chunk_samples: |
|
chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk))) |
|
|
|
|
|
chunk_start_time = max(0, (start_idx / self.sample_rate) - self.overlap) |
|
chunk_end_time = min((end_idx / self.sample_rate) + self.overlap, len(audio) / self.sample_rate) |
|
|
|
chunks_with_times.append({ |
|
'chunk': chunk, |
|
'start_time': start_idx / self.sample_rate, |
|
'end_time': end_idx / self.sample_rate, |
|
'transcribe_start': chunk_start_time, |
|
'transcribe_end': chunk_end_time |
|
}) |
|
|
|
|
|
start_idx += (chunk_samples - overlap_samples) |
|
|
|
return chunks_with_times |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
def process_audio(self, audio_path, translate=False): |
|
"""Main processing function""" |
|
try: |
|
|
|
waveform, sample_rate = torchaudio.load(audio_path) |
|
if waveform.shape[0] > 1: |
|
waveform = torch.mean(waveform, dim=0) |
|
else: |
|
waveform = waveform.squeeze(0) |
|
|
|
|
|
if sample_rate != self.sample_rate: |
|
resampler = torchaudio.transforms.Resample( |
|
orig_freq=sample_rate, |
|
new_freq=self.sample_rate |
|
) |
|
waveform = resampler(waveform) |
|
|
|
|
|
|
|
|
|
|
|
models = self.load_models() |
|
|
|
|
|
chunk_samples = int(self.chunk_size * self.sample_rate) |
|
overlap_samples = int(self.overlap * self.sample_rate) |
|
|
|
segments = [] |
|
language_segments = [] |
|
|
|
for i in range(0, len(waveform), chunk_samples - overlap_samples): |
|
chunk = waveform[i:i + chunk_samples] |
|
if len(chunk) < chunk_samples: |
|
chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk))) |
|
|
|
|
|
start_time = i / self.sample_rate |
|
end_time = (i + len(chunk)) / self.sample_rate |
|
|
|
|
|
language = self.identify_language(chunk, models) |
|
|
|
|
|
language_segments.append({ |
|
"language": language, |
|
"start": start_time, |
|
"end": end_time |
|
}) |
|
|
|
|
|
transcription = self.transcribe_chunk(chunk, language, models) |
|
|
|
segment = { |
|
"start": start_time, |
|
"end": end_time, |
|
"language": language, |
|
"text": transcription, |
|
"speaker": "Speaker" |
|
} |
|
|
|
if translate: |
|
translation = self.translate_text(transcription, models) |
|
segment["translated"] = translation |
|
|
|
segments.append(segment) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
merged_segments = self.merge_segments(segments) |
|
|
|
return language_segments, merged_segments |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing audio: {str(e)}") |
|
raise |
|
|
|
|
|
def merge_segments(self, segments, time_threshold=0.5, similarity_threshold=0.7): |
|
"""Merge similar nearby segments""" |
|
if not segments: |
|
return segments |
|
|
|
merged = [] |
|
current = segments[0] |
|
|
|
for next_segment in segments[1:]: |
|
if (next_segment['start'] - current['end'] <= time_threshold and |
|
current['language'] == next_segment['language']): |
|
|
|
|
|
matcher = SequenceMatcher(None, current['text'], next_segment['text']) |
|
similarity = matcher.ratio() |
|
|
|
if similarity > similarity_threshold: |
|
|
|
current['end'] = next_segment['end'] |
|
current['text'] = current['text'] + ' ' + next_segment['text'] |
|
if 'translated' in current and 'translated' in next_segment: |
|
current['translated'] = current['translated'] + ' ' + next_segment['translated'] |
|
else: |
|
merged.append(current) |
|
current = next_segment |
|
else: |
|
merged.append(current) |
|
current = next_segment |
|
|
|
merged.append(current) |
|
return merged |