Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -3,7 +3,6 @@ import gradio as gr
|
|
3 |
from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline
|
4 |
import torch
|
5 |
import numpy as np
|
6 |
-
import librosa
|
7 |
|
8 |
# Load Whisper model for transcription
|
9 |
whisper_model_name = "openai/whisper-large"
|
@@ -15,16 +14,23 @@ lang_detect_model = pipeline("zero-shot-classification", model="facebook/bart-la
|
|
15 |
|
16 |
# Function to transcribe audio to text using Whisper model
|
17 |
def transcribe_audio(audio_file):
|
18 |
-
#
|
19 |
-
|
|
|
|
|
|
|
20 |
|
21 |
-
#
|
|
|
|
|
|
|
|
|
22 |
input_features = processor(audio, return_tensors="pt", sampling_rate=16000)
|
23 |
-
|
24 |
# Generate transcription
|
25 |
generated_ids = model.generate(input_features["input_features"])
|
26 |
transcription = processor.decode(generated_ids[0], skip_special_tokens=True)
|
27 |
-
|
28 |
return transcription
|
29 |
|
30 |
# Function to detect the language of the transcription using zero-shot classification
|
|
|
3 |
from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline
|
4 |
import torch
|
5 |
import numpy as np
|
|
|
6 |
|
7 |
# Load Whisper model for transcription
|
8 |
whisper_model_name = "openai/whisper-large"
|
|
|
14 |
|
15 |
# Function to transcribe audio to text using Whisper model
|
16 |
def transcribe_audio(audio_file):
|
17 |
+
# Check if audio_file is a list (Gradio returns a list when multiple clips are recorded)
|
18 |
+
if isinstance(audio_file, list):
|
19 |
+
audio = np.concatenate(audio_file) # Concatenate the list of arrays into a single 1D array
|
20 |
+
else:
|
21 |
+
audio = np.array(audio_file) # Ensure it's a 1D array
|
22 |
|
23 |
+
# Ensure the shape is 1D (if the shape is (2, N), we flatten it)
|
24 |
+
if len(audio.shape) > 1:
|
25 |
+
audio = audio.flatten()
|
26 |
+
|
27 |
+
# Prepare input features for Whisper (sampling rate should be 16000 for Whisper)
|
28 |
input_features = processor(audio, return_tensors="pt", sampling_rate=16000)
|
29 |
+
|
30 |
# Generate transcription
|
31 |
generated_ids = model.generate(input_features["input_features"])
|
32 |
transcription = processor.decode(generated_ids[0], skip_special_tokens=True)
|
33 |
+
|
34 |
return transcription
|
35 |
|
36 |
# Function to detect the language of the transcription using zero-shot classification
|