devve1 commited on
Commit
a9ba889
1 Parent(s): 49ec730

Update optimum_encoder.py

Browse files
Files changed (1) hide show
  1. optimum_encoder.py +5 -4
optimum_encoder.py CHANGED
@@ -10,7 +10,7 @@ from semantic_router.utils.logger import logger
10
 
11
 
12
  class OptimumEncoder(BaseEncoder):
13
- name: str = "mixedbread-ai/mxbai-embed-large-v1"
14
  type: str = "huggingface"
15
  score_threshold: float = 0.5
16
  tokenizer_kwargs: Dict = {}
@@ -60,13 +60,14 @@ class OptimumEncoder(BaseEncoder):
60
 
61
  provider_options = {
62
  "trt_engine_cache_enable": True,
63
- "trt_engine_cache_path": os.getenv('HF_HOME')
 
64
  }
65
 
66
  ort_model = ORTModelForFeatureExtraction.from_pretrained(
67
  self.name,
68
- use_io_binding=True,
69
- provider=['TensorrtExecutionProvider'],
70
  provider_options=provider_options,
71
  **self.model_kwargs
72
  )
 
10
 
11
 
12
  class OptimumEncoder(BaseEncoder):
13
+ name: str = "mixedbread-ai/mxbai-embed-large-v1/onnx/model_fp16.onnx"
14
  type: str = "huggingface"
15
  score_threshold: float = 0.5
16
  tokenizer_kwargs: Dict = {}
 
60
 
61
  provider_options = {
62
  "trt_engine_cache_enable": True,
63
+ "trt_engine_cache_path": os.getenv('HF_HOME'),
64
+ "trt_fp16_enable": True
65
  }
66
 
67
  ort_model = ORTModelForFeatureExtraction.from_pretrained(
68
  self.name,
69
+ model_save_dir=os.getenv('HF_HOME'),
70
+ providers=['TensorrtExecutionProvider'],
71
  provider_options=provider_options,
72
  **self.model_kwargs
73
  )