IdoMachlev
commited on
Commit
·
9d67e70
1
Parent(s):
880d01b
Changed the parameters to be in env variables
Browse files- handler.py +26 -19
- requirements.txt +3 -1
handler.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
-
from typing import Dict, List, Any, Literal
|
2 |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
|
|
3 |
import torch
|
4 |
import logging
|
|
|
5 |
|
6 |
class EndpointHandler():
|
7 |
def __init__(self, path=""):
|
@@ -36,32 +37,38 @@ class EndpointHandler():
|
|
36 |
"""
|
37 |
# get inputs
|
38 |
inputs = data.pop("inputs", data)
|
39 |
-
|
40 |
-
|
41 |
|
42 |
|
43 |
# run normal prediction
|
44 |
prediction = self.pipe(
|
45 |
inputs,
|
46 |
-
**
|
47 |
)
|
48 |
return prediction
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
62 |
def to_kwargs(self):
|
63 |
-
"""Convert object attributes to kwargs dict, excluding None values"""
|
64 |
return {
|
65 |
-
key: value
|
66 |
-
|
|
|
67 |
}
|
|
|
|
|
1 |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
2 |
+
from typing import Dict, List, Any, Literal, Optional, Tuple
|
3 |
import torch
|
4 |
import logging
|
5 |
+
from pydantic_settings import BaseSettings
|
6 |
|
7 |
class EndpointHandler():
|
8 |
def __init__(self, path=""):
|
|
|
37 |
"""
|
38 |
# get inputs
|
39 |
inputs = data.pop("inputs", data)
|
40 |
+
whisper_parameter_handler = WhisperParameterHandler()
|
41 |
+
logging.info(whisper_parameter_handler.model_dump())
|
42 |
|
43 |
|
44 |
# run normal prediction
|
45 |
prediction = self.pipe(
|
46 |
inputs,
|
47 |
+
**whisper_parameter_handler.to_kwargs()
|
48 |
)
|
49 |
return prediction
|
50 |
|
51 |
+
|
52 |
+
class WhisperParameterHandler(BaseSettings):
|
53 |
+
language: Optional[str] = None # Optional fields default to None
|
54 |
+
max_new_tokens: Optional[int] = None
|
55 |
+
num_beams: Optional[int] = None
|
56 |
+
condition_on_prev_tokens: Optional[bool] = None
|
57 |
+
compression_ratio_threshold: Optional[float] = None
|
58 |
+
temperature: Optional[Tuple[float, ...]] = None # Optional Tuple
|
59 |
+
logprob_threshold: Optional[float] = None
|
60 |
+
no_speech_threshold: Optional[float] = None
|
61 |
+
return_timestamps: Optional[Literal["word", True]] = None
|
62 |
+
|
63 |
+
model_config = {
|
64 |
+
"env_prefix": "WHISPER_KWARGS_",
|
65 |
+
"case_sensitive": False,
|
66 |
+
}
|
67 |
+
|
68 |
def to_kwargs(self):
|
69 |
+
"""Convert object attributes to kwargs dict, excluding None values."""
|
70 |
return {
|
71 |
+
key: value
|
72 |
+
for key, value in self.model_dump().items() # Use model_dump for accurate representation
|
73 |
+
if value is not None # Exclude None values
|
74 |
}
|
requirements.txt
CHANGED
@@ -1 +1,3 @@
|
|
1 |
-
torch
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
pydantic-settings
|
3 |
+
transformers
|