Update optimum_encoder.py
Browse files- optimum_encoder.py +5 -4
optimum_encoder.py
CHANGED
@@ -10,7 +10,7 @@ from semantic_router.utils.logger import logger
|
|
10 |
|
11 |
|
12 |
class OptimumEncoder(BaseEncoder):
|
13 |
-
name: str = "mixedbread-ai/mxbai-embed-large-v1"
|
14 |
type: str = "huggingface"
|
15 |
score_threshold: float = 0.5
|
16 |
tokenizer_kwargs: Dict = {}
|
@@ -60,13 +60,14 @@ class OptimumEncoder(BaseEncoder):
|
|
60 |
|
61 |
provider_options = {
|
62 |
"trt_engine_cache_enable": True,
|
63 |
-
"trt_engine_cache_path": os.getenv('HF_HOME')
|
|
|
64 |
}
|
65 |
|
66 |
ort_model = ORTModelForFeatureExtraction.from_pretrained(
|
67 |
self.name,
|
68 |
-
|
69 |
-
|
70 |
provider_options=provider_options,
|
71 |
**self.model_kwargs
|
72 |
)
|
|
|
10 |
|
11 |
|
12 |
class OptimumEncoder(BaseEncoder):
|
13 |
+
name: str = "mixedbread-ai/mxbai-embed-large-v1/onnx/model_fp16.onnx"
|
14 |
type: str = "huggingface"
|
15 |
score_threshold: float = 0.5
|
16 |
tokenizer_kwargs: Dict = {}
|
|
|
60 |
|
61 |
provider_options = {
|
62 |
"trt_engine_cache_enable": True,
|
63 |
+
"trt_engine_cache_path": os.getenv('HF_HOME'),
|
64 |
+
"trt_fp16_enable": True
|
65 |
}
|
66 |
|
67 |
ort_model = ORTModelForFeatureExtraction.from_pretrained(
|
68 |
self.name,
|
69 |
+
model_save_dir=os.getenv('HF_HOME'),
|
70 |
+
providers=['TensorrtExecutionProvider'],
|
71 |
provider_options=provider_options,
|
72 |
**self.model_kwargs
|
73 |
)
|