Update optimum_encoder.py
Browse files- optimum_encoder.py +10 -4
optimum_encoder.py
CHANGED
@@ -19,6 +19,7 @@ class OptimumEncoder(BaseEncoder):
|
|
19 |
_tokenizer: Any = PrivateAttr()
|
20 |
_model: Any = PrivateAttr()
|
21 |
_torch: Any = PrivateAttr()
|
|
|
22 |
|
23 |
def __init__(self, **data):
|
24 |
super().__init__(**data)
|
@@ -44,7 +45,7 @@ class OptimumEncoder(BaseEncoder):
|
|
44 |
"`pip install semantic-router[local]`"
|
45 |
)
|
46 |
try:
|
47 |
-
from transformers import AutoTokenizer
|
48 |
except ImportError:
|
49 |
raise ImportError(
|
50 |
"Please install transformers to use OptimumEncoder. "
|
@@ -59,6 +60,12 @@ class OptimumEncoder(BaseEncoder):
|
|
59 |
**self.tokenizer_kwargs,
|
60 |
)
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
provider_options = {
|
63 |
"trt_engine_cache_enable": True,
|
64 |
"trt_engine_cache_path": os.getenv('HF_HOME'),
|
@@ -69,6 +76,7 @@ class OptimumEncoder(BaseEncoder):
|
|
69 |
session_options.log_severity_level = 0
|
70 |
|
71 |
ort_model = ORTModelForFeatureExtraction.from_pretrained(
|
|
|
72 |
model_id=self.name,
|
73 |
file_name='model_fp16.onnx',
|
74 |
subfolder='onnx',
|
@@ -141,9 +149,7 @@ class OptimumEncoder(BaseEncoder):
|
|
141 |
else:
|
142 |
embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1).detach().cpu().tolist()
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
if original_dimensions > matryoshka_dim:
|
147 |
embeddings = embeddings[:, :matryoshka_dim]
|
148 |
|
149 |
all_embeddings.extend(embeddings)
|
|
|
19 |
_tokenizer: Any = PrivateAttr()
|
20 |
_model: Any = PrivateAttr()
|
21 |
_torch: Any = PrivateAttr()
|
22 |
+
_dim: int = 1024
|
23 |
|
24 |
def __init__(self, **data):
|
25 |
super().__init__(**data)
|
|
|
45 |
"`pip install semantic-router[local]`"
|
46 |
)
|
47 |
try:
|
48 |
+
from transformers import AutoTokenizer, AutoConfig
|
49 |
except ImportError:
|
50 |
raise ImportError(
|
51 |
"Please install transformers to use OptimumEncoder. "
|
|
|
60 |
**self.tokenizer_kwargs,
|
61 |
)
|
62 |
|
63 |
+
config = AutoConfig.from_pretrained(
|
64 |
+
self.name
|
65 |
+
)
|
66 |
+
|
67 |
+
self._dim = config['hidden_size']
|
68 |
+
|
69 |
provider_options = {
|
70 |
"trt_engine_cache_enable": True,
|
71 |
"trt_engine_cache_path": os.getenv('HF_HOME'),
|
|
|
76 |
session_options.log_severity_level = 0
|
77 |
|
78 |
ort_model = ORTModelForFeatureExtraction.from_pretrained(
|
79 |
+
config=config,
|
80 |
model_id=self.name,
|
81 |
file_name='model_fp16.onnx',
|
82 |
subfolder='onnx',
|
|
|
149 |
else:
|
150 |
embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1).detach().cpu().tolist()
|
151 |
|
152 |
+
if self._dim > matryoshka_dim:
|
|
|
|
|
153 |
embeddings = embeddings[:, :matryoshka_dim]
|
154 |
|
155 |
all_embeddings.extend(embeddings)
|