from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline from typing import Dict, List, Any, Literal, Optional, Tuple import torch import logging from pydantic_settings import BaseSettings from pydantic import field_validator class EndpointHandler(): def __init__(self, path=""): device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model_id = "openai/whisper-large-v3-turbo" model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="sdpa" ) model.to(device) processor = AutoProcessor.from_pretrained(model_id) self.pipe = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch_dtype, device=device, ) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str`) parameters (:obj: `Any`) Return: A :obj:`list` | `dict`: will be serialized and returned """ # get inputs inputs = data.pop("inputs", data) whisper_parameter_handler = WhisperParameterHandler() logging.info(whisper_parameter_handler.model_dump(exclude_none=True, exclude=["return_timestamps"])) # run normal prediction prediction = self.pipe( inputs, return_timestamps=whisper_parameter_handler.return_timestamps, generate_kwargs=whisper_parameter_handler.model_dump(exclude_none=True, exclude=["return_timestamps"]) ) logging.info(prediction) logging.info(prediction['chunks']) return prediction class WhisperParameterHandler(BaseSettings): language: Optional[str] = None # Optional fields default to None max_new_tokens: Optional[int] = None num_beams: Optional[int] = None condition_on_prev_tokens: Optional[bool] = None compression_ratio_threshold: Optional[float] = None temperature: Optional[Tuple[float, ...]] = None # Optional Tuple logprob_threshold: Optional[float] = None no_speech_threshold: Optional[float] = None return_timestamps: Optional[Literal["word", True]] = None @field_validator("return_timestamps", mode="before") def cannonize_timestamps(cls, value: Optional[str]): if value is None: return None if value.lower() == "true": logging.info("return_timestamps == 'True'") return True return value model_config = { "env_prefix": "WHISPER_KWARGS_", "case_sensitive": False, } def to_kwargs(self): """Convert object attributes to kwargs dict, excluding None values.""" return { key: value for key, value in self.model_dump().items() # Use model_dump for accurate representation if value is not None # Exclude None values }