unijoh commited on
Commit
981a951
·
verified ·
1 Parent(s): 45bdfde

Update asr.py

Browse files
Files changed (1) hide show
  1. asr.py +7 -19
asr.py CHANGED
@@ -1,6 +1,5 @@
1
  import librosa
2
- from transformers import Wav2Vec2ForCTC, AutoProcessor
3
- import torch
4
  import logging
5
 
6
  # Set up logging
@@ -8,14 +7,11 @@ logging.basicConfig(level=logging.DEBUG)
8
 
9
  ASR_SAMPLING_RATE = 16_000
10
 
11
- MODEL_ID = "facebook/wav2vec2-large-960h-lv60-self"
12
-
13
  try:
14
- processor = AutoProcessor.from_pretrained(MODEL_ID)
15
- model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
16
- logging.info("ASR model and processor loaded successfully.")
17
  except Exception as e:
18
- logging.error(f"Error loading ASR model or processor: {e}")
19
 
20
  def transcribe(audio):
21
  try:
@@ -25,17 +21,9 @@ def transcribe(audio):
25
 
26
  logging.info(f"Loading audio file: {audio}")
27
  audio_samples, _ = librosa.load(audio, sr=ASR_SAMPLING_RATE, mono=True)
28
- inputs = processor(audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt")
29
-
30
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
- model.to(device)
32
- inputs = inputs.to(device)
33
-
34
- with torch.no_grad():
35
- outputs = model(**inputs).logits
36
-
37
- ids = torch.argmax(outputs, dim=-1)[0]
38
- transcription = processor.decode(ids)
39
 
40
  logging.info("Transcription completed successfully.")
41
  return transcription
 
1
  import librosa
2
+ from transformers import pipeline
 
3
  import logging
4
 
5
  # Set up logging
 
7
 
8
  ASR_SAMPLING_RATE = 16_000
9
 
 
 
10
  try:
11
+ pipe = pipeline("automatic-speech-recognition", model="facebook/mms-1b-all")
12
+ logging.info("ASR pipeline loaded successfully.")
 
13
  except Exception as e:
14
+ logging.error(f"Error loading ASR pipeline: {e}")
15
 
16
  def transcribe(audio):
17
  try:
 
21
 
22
  logging.info(f"Loading audio file: {audio}")
23
  audio_samples, _ = librosa.load(audio, sr=ASR_SAMPLING_RATE, mono=True)
24
+
25
+ # Process the audio with the pipeline
26
+ transcription = pipe(audio_samples, sampling_rate=ASR_SAMPLING_RATE)["text"]
 
 
 
 
 
 
 
 
27
 
28
  logging.info("Transcription completed successfully.")
29
  return transcription