IdoMachlev commited on
Commit
9d67e70
·
1 Parent(s): 880d01b

Changed the parameters to be in env variables

Browse files
Files changed (2) hide show
  1. handler.py +26 -19
  2. requirements.txt +3 -1
handler.py CHANGED
@@ -1,7 +1,8 @@
1
- from typing import Dict, List, Any, Literal
2
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
 
3
  import torch
4
  import logging
 
5
 
6
  class EndpointHandler():
7
  def __init__(self, path=""):
@@ -36,32 +37,38 @@ class EndpointHandler():
36
  """
37
  # get inputs
38
  inputs = data.pop("inputs", data)
39
- logging.info(data.get("parameters"))
40
- parameters = WhisperKwargsHandler(data.get("parameters", {})) # Changed from pop to get
41
 
42
 
43
  # run normal prediction
44
  prediction = self.pipe(
45
  inputs,
46
- **parameters.to_kwargs()
47
  )
48
  return prediction
49
 
50
- class WhisperKwargsHandler:
51
- def __init__(self, kwargs: dict):
52
- self.language: str = kwargs.get('language')
53
- self.max_new_tokens: int = kwargs.get('max_new_tokens')
54
- self.num_beams: int = kwargs.get('num_beams')
55
- self.condition_on_prev_tokens: bool = kwargs.get('condition_on_prev_tokens')
56
- self.compression_ratio_threshold: float = kwargs.get('compression_ratio_threshold')
57
- self.temperature: tuple[float] = kwargs.get('temperature')
58
- self.logprob_threshold: float = kwargs.get('logprob_threshold')
59
- self.no_speech_threshold: float = kwargs.get('no_speech_threshold')
60
- self.return_timestamps: Literal["word", True] = kwargs.get('return_timestamps')
61
-
 
 
 
 
 
62
  def to_kwargs(self):
63
- """Convert object attributes to kwargs dict, excluding None values"""
64
  return {
65
- key: value for key, value in self.__dict__.items()
66
- if value is not None
 
67
  }
 
 
1
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
2
+ from typing import Dict, List, Any, Literal, Optional, Tuple
3
  import torch
4
  import logging
5
+ from pydantic_settings import BaseSettings
6
 
7
  class EndpointHandler():
8
  def __init__(self, path=""):
 
37
  """
38
  # get inputs
39
  inputs = data.pop("inputs", data)
40
+ whisper_parameter_handler = WhisperParameterHandler()
41
+ logging.info(whisper_parameter_handler.model_dump())
42
 
43
 
44
  # run normal prediction
45
  prediction = self.pipe(
46
  inputs,
47
+ **whisper_parameter_handler.to_kwargs()
48
  )
49
  return prediction
50
 
51
+
52
+ class WhisperParameterHandler(BaseSettings):
53
+ language: Optional[str] = None # Optional fields default to None
54
+ max_new_tokens: Optional[int] = None
55
+ num_beams: Optional[int] = None
56
+ condition_on_prev_tokens: Optional[bool] = None
57
+ compression_ratio_threshold: Optional[float] = None
58
+ temperature: Optional[Tuple[float, ...]] = None # Optional Tuple
59
+ logprob_threshold: Optional[float] = None
60
+ no_speech_threshold: Optional[float] = None
61
+ return_timestamps: Optional[Literal["word", True]] = None
62
+
63
+ model_config = {
64
+ "env_prefix": "WHISPER_KWARGS_",
65
+ "case_sensitive": False,
66
+ }
67
+
68
  def to_kwargs(self):
69
+ """Convert object attributes to kwargs dict, excluding None values."""
70
  return {
71
+ key: value
72
+ for key, value in self.model_dump().items() # Use model_dump for accurate representation
73
+ if value is not None # Exclude None values
74
  }
requirements.txt CHANGED
@@ -1 +1,3 @@
1
- torch
 
 
 
1
+ torch
2
+ pydantic-settings
3
+ transformers