devve1 commited on
Commit
1a57b91
1 Parent(s): a8e91c9

Update optimum_encoder.py

Browse files
Files changed (1) hide show
  1. optimum_encoder.py +10 -4
optimum_encoder.py CHANGED
@@ -19,6 +19,7 @@ class OptimumEncoder(BaseEncoder):
19
  _tokenizer: Any = PrivateAttr()
20
  _model: Any = PrivateAttr()
21
  _torch: Any = PrivateAttr()
 
22
 
23
  def __init__(self, **data):
24
  super().__init__(**data)
@@ -44,7 +45,7 @@ class OptimumEncoder(BaseEncoder):
44
  "`pip install semantic-router[local]`"
45
  )
46
  try:
47
- from transformers import AutoTokenizer
48
  except ImportError:
49
  raise ImportError(
50
  "Please install transformers to use OptimumEncoder. "
@@ -59,6 +60,12 @@ class OptimumEncoder(BaseEncoder):
59
  **self.tokenizer_kwargs,
60
  )
61
 
 
 
 
 
 
 
62
  provider_options = {
63
  "trt_engine_cache_enable": True,
64
  "trt_engine_cache_path": os.getenv('HF_HOME'),
@@ -69,6 +76,7 @@ class OptimumEncoder(BaseEncoder):
69
  session_options.log_severity_level = 0
70
 
71
  ort_model = ORTModelForFeatureExtraction.from_pretrained(
 
72
  model_id=self.name,
73
  file_name='model_fp16.onnx',
74
  subfolder='onnx',
@@ -141,9 +149,7 @@ class OptimumEncoder(BaseEncoder):
141
  else:
142
  embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1).detach().cpu().tolist()
143
 
144
- original_dimensions = embeddings.size()
145
-
146
- if original_dimensions > matryoshka_dim:
147
  embeddings = embeddings[:, :matryoshka_dim]
148
 
149
  all_embeddings.extend(embeddings)
 
19
  _tokenizer: Any = PrivateAttr()
20
  _model: Any = PrivateAttr()
21
  _torch: Any = PrivateAttr()
22
+ _dim: int = 1024
23
 
24
  def __init__(self, **data):
25
  super().__init__(**data)
 
45
  "`pip install semantic-router[local]`"
46
  )
47
  try:
48
+ from transformers import AutoTokenizer, AutoConfig
49
  except ImportError:
50
  raise ImportError(
51
  "Please install transformers to use OptimumEncoder. "
 
60
  **self.tokenizer_kwargs,
61
  )
62
 
63
+ config = AutoConfig.from_pretrained(
64
+ self.name
65
+ )
66
+
67
+ self._dim = config['hidden_size']
68
+
69
  provider_options = {
70
  "trt_engine_cache_enable": True,
71
  "trt_engine_cache_path": os.getenv('HF_HOME'),
 
76
  session_options.log_severity_level = 0
77
 
78
  ort_model = ORTModelForFeatureExtraction.from_pretrained(
79
+ config=config,
80
  model_id=self.name,
81
  file_name='model_fp16.onnx',
82
  subfolder='onnx',
 
149
  else:
150
  embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1).detach().cpu().tolist()
151
 
152
+ if self._dim > matryoshka_dim:
 
 
153
  embeddings = embeddings[:, :matryoshka_dim]
154
 
155
  all_embeddings.extend(embeddings)