yabramuvdi's picture
Update app.py
24c2c97 verified
import spaces
import torch
import gradio as gr
from transformers import pipeline
from transformers.pipelines.audio_utils import ffmpeg_read
import tempfile
import os
import json
from pydub import AudioSegment
import math
#===============
# Define main parameters
#===============
MODEL_NAME = "openai/whisper-large-v3-turbo"
BATCH_SIZE = 8
FILE_LIMIT_MB = 1000
YT_LENGTH_LIMIT_S = 3600 # limit to 1 hour YouTube files
device = 0 if torch.cuda.is_available() else "cpu"
pipe = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME,
chunk_length_s=30,
device=device,
)
#===============
# Main functions
#===============
# Function to split the audio into chunks
def split_audio(audio_file, chunk_length_ms):
audio = AudioSegment.from_file(audio_file)
duration_ms = len(audio)
num_chunks = math.ceil(duration_ms / chunk_length_ms)
chunks = []
for i in range(num_chunks):
start_time = i * chunk_length_ms
end_time = min((i + 1) * chunk_length_ms, duration_ms)
chunk = audio[start_time:end_time]
chunk_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
chunk.export(chunk_file.name, format="wav")
chunks.append((chunk_file.name, start_time)) # Save the chunk and its start time
return chunks
@spaces.GPU
def transcribe(audio_file, task, language, keywords, chunk_length_s=30):
if audio_file is None:
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
# Load the audio file using pydub to get its length
audio = AudioSegment.from_file(audio_file)
audio_length_ms = len(audio) # Length in milliseconds
# Set threshold for chunking (40 minutes = 2,400,000 milliseconds)
chunk_threshold_ms = 40 * 60 * 1000
# Decide whether to chunk or process the entire file
if audio_length_ms > chunk_threshold_ms:
# Audio is longer than 40 minutes, apply chunking
chunk_length_ms = chunk_length_s * 1000
audio_chunks = split_audio(audio_file, chunk_length_ms)
all_text = ""
all_timestamps = []
for chunk_file, chunk_start_time in audio_chunks:
result = pipe(chunk_file, batch_size=BATCH_SIZE, generate_kwargs={"task": task, "language": language}, return_timestamps=True)
all_text += result["text"] + " "
# Adjust the timestamps to account for the chunk's position in the full audio
for chunk_timestamp in result["chunks"]:
start_time_chunk, end_time_chunk = chunk_timestamp["timestamp"]
adjusted_timestamp = {
"start": start_time_chunk + chunk_start_time / 1000,
"end": end_time_chunk + chunk_start_time / 1000,
"text": chunk_timestamp["text"]
}
all_timestamps.append(adjusted_timestamp)
else:
# Audio is shorter than 40 minutes, process the whole file at once
result = pipe(audio_file, batch_size=BATCH_SIZE, generate_kwargs={"task": task, "language": language}, return_timestamps=True)
all_text = result["text"]
all_timestamps = []
for chunk_timestamp in result["chunks"]:
start_time_chunk, end_time_chunk = chunk_timestamp["timestamp"]
adjusted_timestamp = {
"start": start_time_chunk,
"end": end_time_chunk,
"text": chunk_timestamp["text"]
}
all_timestamps.append(adjusted_timestamp)
# First 200 characters for display
preview_text = all_text[:200] + "..." if len(all_text) > 200 else all_text
# Full transcription with timestamps in JSON
full_transcription = {
"text": all_text,
"timestamps": all_timestamps
}
# Save the full transcription (with timestamps) as JSON
json_file_path = os.path.join(tempfile.gettempdir(), f"{os.path.splitext(os.path.basename(audio_file))[0]}_transcription.json")
with open(json_file_path, "w") as json_file:
json.dump(full_transcription, json_file)
# Save the plain text transcription as TXT
txt_file_path = os.path.join(tempfile.gettempdir(), f"{os.path.splitext(os.path.basename(audio_file))[0]}_transcription.txt")
with open(txt_file_path, "w") as txt_file:
txt_file.write(all_text)
return preview_text, json_file_path, txt_file_path
#===============
# Build the frontend
#===============
file_transcribe = gr.Interface(
fn=transcribe,
inputs=[
gr.Audio(sources="upload", type="filepath", label="Audio file"),
gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
gr.Dropdown(["spanish", "english"], label="Language", info="Will add more later!", value="spanish"),
gr.Textbox(lines=10, label="Keywords"),
],
outputs=[
gr.Textbox(label="Preview (first 200 characters)"),
gr.File(label="Download full transcription as JSON"),
gr.File(label="Download transcription as TXT")
],
title="Transcribe Audio",
description=(
"Transcribe audio inputs with the click of a button! Demo uses the"
f" checkpoint [{MODEL_NAME}](https://huggingface.co./{MODEL_NAME}) and 🤗 Transformers to transcribe audio files"
" of **at most 40 minutes**. Support for longer audio files will soon come."
),
allow_flagging="never",
)
#===============
# Launch
#===============
demo = gr.Blocks(theme=gr.themes.Ocean())
with demo:
gr.TabbedInterface([file_transcribe], ["Audio file"])
demo.queue().launch(ssr_mode=False)