devve1 commited on
Commit
ce7e039
1 Parent(s): 950d13d

Update optimum_encoder.py

Browse files
Files changed (1) hide show
  1. optimum_encoder.py +3 -0
optimum_encoder.py CHANGED
@@ -11,6 +11,7 @@ from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
11
  class OptimumEncoder(BaseModel, Embeddings):
12
  name: str = "mixedbread-ai/mxbai-embed-large-v1"
13
  device: Optional[str] = None
 
14
  _tokenizer: Any
15
  _model: Any
16
  _torch: Any
@@ -24,6 +25,7 @@ class OptimumEncoder(BaseModel, Embeddings):
24
  def validate_environment(cls, values: Dict) -> Dict:
25
  name = values.get('name')
26
  device = values.get('device')
 
27
 
28
  try:
29
  import onnxruntime as ort
@@ -66,6 +68,7 @@ class OptimumEncoder(BaseModel, Embeddings):
66
  file_name='model_fp16.onnx',
67
  subfolder='onnx',
68
  provider='CUDAExecutionProvider',
 
69
  use_io_binding=True,
70
  #provider_options=provider_options,
71
  session_options=session_options
 
11
  class OptimumEncoder(BaseModel, Embeddings):
12
  name: str = "mixedbread-ai/mxbai-embed-large-v1"
13
  device: Optional[str] = None
14
+ cache_dir: str = None
15
  _tokenizer: Any
16
  _model: Any
17
  _torch: Any
 
25
  def validate_environment(cls, values: Dict) -> Dict:
26
  name = values.get('name')
27
  device = values.get('device')
28
+ cache_dir = values.get('cache_dir')
29
 
30
  try:
31
  import onnxruntime as ort
 
68
  file_name='model_fp16.onnx',
69
  subfolder='onnx',
70
  provider='CUDAExecutionProvider',
71
+ cache_dir=cache_dir,
72
  use_io_binding=True,
73
  #provider_options=provider_options,
74
  session_options=session_options