Spaces:
Running
on
T4
Running
on
T4
Update optimum_encoder.py
Browse files- optimum_encoder.py +3 -0
optimum_encoder.py
CHANGED
@@ -11,6 +11,7 @@ from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
|
11 |
class OptimumEncoder(BaseModel, Embeddings):
|
12 |
name: str = "mixedbread-ai/mxbai-embed-large-v1"
|
13 |
device: Optional[str] = None
|
|
|
14 |
_tokenizer: Any
|
15 |
_model: Any
|
16 |
_torch: Any
|
@@ -24,6 +25,7 @@ class OptimumEncoder(BaseModel, Embeddings):
|
|
24 |
def validate_environment(cls, values: Dict) -> Dict:
|
25 |
name = values.get('name')
|
26 |
device = values.get('device')
|
|
|
27 |
|
28 |
try:
|
29 |
import onnxruntime as ort
|
@@ -66,6 +68,7 @@ class OptimumEncoder(BaseModel, Embeddings):
|
|
66 |
file_name='model_fp16.onnx',
|
67 |
subfolder='onnx',
|
68 |
provider='CUDAExecutionProvider',
|
|
|
69 |
use_io_binding=True,
|
70 |
#provider_options=provider_options,
|
71 |
session_options=session_options
|
|
|
11 |
class OptimumEncoder(BaseModel, Embeddings):
|
12 |
name: str = "mixedbread-ai/mxbai-embed-large-v1"
|
13 |
device: Optional[str] = None
|
14 |
+
cache_dir: str = None
|
15 |
_tokenizer: Any
|
16 |
_model: Any
|
17 |
_torch: Any
|
|
|
25 |
def validate_environment(cls, values: Dict) -> Dict:
|
26 |
name = values.get('name')
|
27 |
device = values.get('device')
|
28 |
+
cache_dir = values.get('cache_dir')
|
29 |
|
30 |
try:
|
31 |
import onnxruntime as ort
|
|
|
68 |
file_name='model_fp16.onnx',
|
69 |
subfolder='onnx',
|
70 |
provider='CUDAExecutionProvider',
|
71 |
+
cache_dir=cache_dir,
|
72 |
use_io_binding=True,
|
73 |
#provider_options=provider_options,
|
74 |
session_options=session_options
|