ASRfr / audio_processing.py
Kr08's picture
Update audio_processing.py
feed7c4 verified
raw
history blame
9.85 kB
import gc
import torch
import torchaudio
import numpy as np
from transformers import (
Wav2Vec2ForSequenceClassification,
AutoFeatureExtractor,
Wav2Vec2ForCTC,
AutoProcessor,
AutoTokenizer,
AutoModelForSeq2SeqLM
)
import spaces
import logging
from difflib import SequenceMatcher
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class AudioProcessor:
def __init__(self, chunk_size=5, overlap=1, sample_rate=16000):
self.chunk_size = chunk_size
self.overlap = overlap
self.sample_rate = sample_rate
self.previous_text = ""
self.previous_lang = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_models(self):
"""Load all required models"""
logger.info("Loading MMS models...")
# Language identification model
lid_processor = AutoFeatureExtractor.from_pretrained("facebook/mms-lid-256")
lid_model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/mms-lid-256")
# Transcription model
mms_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
mms_model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all")
# Translation model
translation_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
translation_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
return {
'lid': (lid_model, lid_processor),
'mms': (mms_model, mms_processor),
'translation': (translation_model, translation_tokenizer)
}
@spaces.GPU(duration=60)
def identify_language(self, audio_chunk, models):
"""Identify language of audio chunk"""
lid_model, lid_processor = models['lid']
inputs = lid_processor(audio_chunk, sampling_rate=16000, return_tensors="pt")
lid_model.to(self.device)
with torch.no_grad():
outputs = lid_model(inputs.input_values.to(self.device)).logits
lang_id = torch.argmax(outputs, dim=-1)[0].item()
detected_lang = lid_model.config.id2label[lang_id]
return detected_lang
@spaces.GPU(duration=60)
def transcribe_chunk(self, audio_chunk, language, models):
"""Transcribe audio chunk"""
mms_model, mms_processor = models['mms']
mms_processor.tokenizer.set_target_lang(language)
mms_model.load_adapter(language)
mms_model.to(self.device)
inputs = mms_processor(audio_chunk, sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
outputs = mms_model(inputs.input_values.to(self.device)).logits
ids = torch.argmax(outputs, dim=-1)[0]
transcription = mms_processor.decode(ids)
return transcription
@spaces.GPU(duration=60)
def translate_text(self, text, models):
"""Translate text to English"""
translation_model, translation_tokenizer = models['translation']
inputs = translation_tokenizer(text, return_tensors="pt")
inputs = inputs.to(self.device)
translation_model.to(self.device)
with torch.no_grad():
outputs = translation_model.generate(
**inputs,
forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"),
max_length=100
)
translation = translation_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return translation
def preprocess_audio(self, audio):
"""
Create overlapping chunks with improved timing logic
"""
chunk_samples = int(self.chunk_size * self.sample_rate)
overlap_samples = int(self.overlap * self.sample_rate)
chunks_with_times = []
start_idx = 0
while start_idx < len(audio):
end_idx = min(start_idx + chunk_samples, len(audio))
# Add padding for first chunk
if start_idx == 0:
chunk = audio[start_idx:end_idx]
padding = torch.zeros(int(1 * self.sample_rate))
chunk = torch.cat([padding, chunk])
else:
# Include overlap from previous chunk
actual_start = max(0, start_idx - overlap_samples)
chunk = audio[actual_start:end_idx]
# Pad if necessary
if len(chunk) < chunk_samples:
chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk)))
# Adjust time ranges to account for overlaps
chunk_start_time = max(0, (start_idx / self.sample_rate) - self.overlap)
chunk_end_time = min((end_idx / self.sample_rate) + self.overlap, len(audio) / self.sample_rate)
chunks_with_times.append({
'chunk': chunk,
'start_time': start_idx / self.sample_rate,
'end_time': end_idx / self.sample_rate,
'transcribe_start': chunk_start_time,
'transcribe_end': chunk_end_time
})
# Move to next chunk with smaller step size for better continuity
start_idx += (chunk_samples - overlap_samples)
return chunks_with_times
@spaces.GPU(duration=60)
def process_audio(self, audio_path, translate=False):
"""Main processing function"""
try:
# Load audio
waveform, sample_rate = torchaudio.load(audio_path)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0)
else:
waveform = waveform.squeeze(0)
# Resample if necessary
if sample_rate != self.sample_rate:
resampler = torchaudio.transforms.Resample(
orig_freq=sample_rate,
new_freq=self.sample_rate
)
waveform = resampler(waveform)
# if sample_rate != self.sample_rate:
# waveform = torchaudio.transforms.Resample(sample_rate, self.sample_rate)(waveform)
# Load models
models = self.load_models()
# Process in chunks
chunk_samples = int(self.chunk_size * self.sample_rate)
overlap_samples = int(self.overlap * self.sample_rate)
segments = []
language_segments = []
for i in range(0, len(waveform), chunk_samples - overlap_samples):
chunk = waveform[i:i + chunk_samples]
if len(chunk) < chunk_samples:
chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk)))
# Process chunk
start_time = i / self.sample_rate
end_time = (i + len(chunk)) / self.sample_rate
# Identify language
language = self.identify_language(chunk, models)
# Record language segment
language_segments.append({
"language": language,
"start": start_time,
"end": end_time
})
# Transcribe
transcription = self.transcribe_chunk(chunk, language, models)
segment = {
"start": start_time,
"end": end_time,
"language": language,
"text": transcription,
"speaker": "Speaker" # Simple speaker assignment
}
if translate:
translation = self.translate_text(transcription, models)
segment["translated"] = translation
segments.append(segment)
# Clean up GPU memory
torch.cuda.empty_cache()
gc.collect()
# Merge nearby segments
merged_segments = self.merge_segments(segments)
return language_segments, merged_segments
except Exception as e:
logger.error(f"Error processing audio: {str(e)}")
raise
def merge_segments(self, segments, time_threshold=0.5, similarity_threshold=0.7):
"""Merge similar nearby segments"""
if not segments:
return segments
merged = []
current = segments[0]
for next_segment in segments[1:]:
if (next_segment['start'] - current['end'] <= time_threshold and
current['language'] == next_segment['language']):
# Check text similarity
matcher = SequenceMatcher(None, current['text'], next_segment['text'])
similarity = matcher.ratio()
if similarity > similarity_threshold:
# Merge segments
current['end'] = next_segment['end']
current['text'] = current['text'] + ' ' + next_segment['text']
if 'translated' in current and 'translated' in next_segment:
current['translated'] = current['translated'] + ' ' + next_segment['translated']
else:
merged.append(current)
current = next_segment
else:
merged.append(current)
current = next_segment
merged.append(current)
return merged