ajchri5 commited on
Commit
4570d8a
·
verified ·
1 Parent(s): d4d32cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -6
app.py CHANGED
@@ -3,7 +3,6 @@ import gradio as gr
3
  from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline
4
  import torch
5
  import numpy as np
6
- import librosa
7
 
8
  # Load Whisper model for transcription
9
  whisper_model_name = "openai/whisper-large"
@@ -15,16 +14,23 @@ lang_detect_model = pipeline("zero-shot-classification", model="facebook/bart-la
15
 
16
  # Function to transcribe audio to text using Whisper model
17
  def transcribe_audio(audio_file):
18
- # Ensure the audio is a numpy array (Gradio input type for audio is numpy)
19
- audio = np.array(audio_file)
 
 
 
20
 
21
- # Prepare input features for Whisper
 
 
 
 
22
  input_features = processor(audio, return_tensors="pt", sampling_rate=16000)
23
-
24
  # Generate transcription
25
  generated_ids = model.generate(input_features["input_features"])
26
  transcription = processor.decode(generated_ids[0], skip_special_tokens=True)
27
-
28
  return transcription
29
 
30
  # Function to detect the language of the transcription using zero-shot classification
 
3
  from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline
4
  import torch
5
  import numpy as np
 
6
 
7
  # Load Whisper model for transcription
8
  whisper_model_name = "openai/whisper-large"
 
14
 
15
  # Function to transcribe audio to text using Whisper model
16
  def transcribe_audio(audio_file):
17
+ # Check if audio_file is a list (Gradio returns a list when multiple clips are recorded)
18
+ if isinstance(audio_file, list):
19
+ audio = np.concatenate(audio_file) # Concatenate the list of arrays into a single 1D array
20
+ else:
21
+ audio = np.array(audio_file) # Ensure it's a 1D array
22
 
23
+ # Ensure the shape is 1D (if the shape is (2, N), we flatten it)
24
+ if len(audio.shape) > 1:
25
+ audio = audio.flatten()
26
+
27
+ # Prepare input features for Whisper (sampling rate should be 16000 for Whisper)
28
  input_features = processor(audio, return_tensors="pt", sampling_rate=16000)
29
+
30
  # Generate transcription
31
  generated_ids = model.generate(input_features["input_features"])
32
  transcription = processor.decode(generated_ids[0], skip_special_tokens=True)
33
+
34
  return transcription
35
 
36
  # Function to detect the language of the transcription using zero-shot classification