devve1 commited on
Commit
81baed5
1 Parent(s): 1a57b91

Update optimum_encoder.py

Browse files
Files changed (1) hide show
  1. optimum_encoder.py +2 -13
optimum_encoder.py CHANGED
@@ -19,7 +19,6 @@ class OptimumEncoder(BaseEncoder):
19
  _tokenizer: Any = PrivateAttr()
20
  _model: Any = PrivateAttr()
21
  _torch: Any = PrivateAttr()
22
- _dim: int = 1024
23
 
24
  def __init__(self, **data):
25
  super().__init__(**data)
@@ -45,7 +44,7 @@ class OptimumEncoder(BaseEncoder):
45
  "`pip install semantic-router[local]`"
46
  )
47
  try:
48
- from transformers import AutoTokenizer, AutoConfig
49
  except ImportError:
50
  raise ImportError(
51
  "Please install transformers to use OptimumEncoder. "
@@ -59,12 +58,6 @@ class OptimumEncoder(BaseEncoder):
59
  self.name,
60
  **self.tokenizer_kwargs,
61
  )
62
-
63
- config = AutoConfig.from_pretrained(
64
- self.name
65
- )
66
-
67
- self._dim = config['hidden_size']
68
 
69
  provider_options = {
70
  "trt_engine_cache_enable": True,
@@ -116,7 +109,6 @@ class OptimumEncoder(BaseEncoder):
116
  batch_size: int = 32,
117
  normalize_embeddings: bool = True,
118
  pooling_strategy: str = "mean",
119
- matryoshka_dim: int = 1024,
120
  convert_to_numpy: bool = False
121
  ) -> List[List[float]] | List[np.ndarray]:
122
  all_embeddings = []
@@ -142,15 +134,12 @@ class OptimumEncoder(BaseEncoder):
142
  raise ValueError(
143
  "Invalid pooling_strategy. Please use 'mean' or 'max'."
144
  )
145
- print(f'Embeddings {embeddings}')
146
  if normalize_embeddings:
147
  if convert_to_numpy:
148
  embeddings = normalize(embeddings[:, 0]).astype(np.float32)
149
  else:
150
  embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1).detach().cpu().tolist()
151
-
152
- if self._dim > matryoshka_dim:
153
- embeddings = embeddings[:, :matryoshka_dim]
154
 
155
  all_embeddings.extend(embeddings)
156
 
 
19
  _tokenizer: Any = PrivateAttr()
20
  _model: Any = PrivateAttr()
21
  _torch: Any = PrivateAttr()
 
22
 
23
  def __init__(self, **data):
24
  super().__init__(**data)
 
44
  "`pip install semantic-router[local]`"
45
  )
46
  try:
47
+ from transformers import AutoTokenizer
48
  except ImportError:
49
  raise ImportError(
50
  "Please install transformers to use OptimumEncoder. "
 
58
  self.name,
59
  **self.tokenizer_kwargs,
60
  )
 
 
 
 
 
 
61
 
62
  provider_options = {
63
  "trt_engine_cache_enable": True,
 
109
  batch_size: int = 32,
110
  normalize_embeddings: bool = True,
111
  pooling_strategy: str = "mean",
 
112
  convert_to_numpy: bool = False
113
  ) -> List[List[float]] | List[np.ndarray]:
114
  all_embeddings = []
 
134
  raise ValueError(
135
  "Invalid pooling_strategy. Please use 'mean' or 'max'."
136
  )
137
+
138
  if normalize_embeddings:
139
  if convert_to_numpy:
140
  embeddings = normalize(embeddings[:, 0]).astype(np.float32)
141
  else:
142
  embeddings = self._torch.nn.functional.normalize(embeddings, p=2, dim=1).detach().cpu().tolist()
 
 
 
143
 
144
  all_embeddings.extend(embeddings)
145