|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline |
|
from typing import Dict, List, Any, Literal, Optional, Tuple |
|
import torch |
|
import logging |
|
from pydantic_settings import BaseSettings |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
model_id = "openai/whisper-large-v3-turbo" |
|
|
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True |
|
) |
|
model.to(device) |
|
|
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
self.pipe = pipeline( |
|
"automatic-speech-recognition", |
|
model=model, |
|
tokenizer=processor.tokenizer, |
|
feature_extractor=processor.feature_extractor, |
|
torch_dtype=torch_dtype, |
|
device=device, |
|
) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str`) |
|
parameters (:obj: `Any`) |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
inputs = data.pop("inputs", data) |
|
whisper_parameter_handler = WhisperParameterHandler() |
|
logging.info(whisper_parameter_handler.model_dump()) |
|
|
|
|
|
|
|
prediction = self.pipe( |
|
inputs, |
|
**whisper_parameter_handler.to_kwargs() |
|
) |
|
return prediction |
|
|
|
|
|
class WhisperParameterHandler(BaseSettings): |
|
language: Optional[str] = None |
|
max_new_tokens: Optional[int] = None |
|
num_beams: Optional[int] = None |
|
condition_on_prev_tokens: Optional[bool] = None |
|
compression_ratio_threshold: Optional[float] = None |
|
temperature: Optional[Tuple[float, ...]] = None |
|
logprob_threshold: Optional[float] = None |
|
no_speech_threshold: Optional[float] = None |
|
return_timestamps: Optional[Literal["word", True]] = None |
|
|
|
model_config = { |
|
"env_prefix": "WHISPER_KWARGS_", |
|
"case_sensitive": False, |
|
} |
|
|
|
def to_kwargs(self): |
|
"""Convert object attributes to kwargs dict, excluding None values.""" |
|
return { |
|
key: value |
|
for key, value in self.model_dump().items() |
|
if value is not None |
|
} |
|
|