jpohhhh commited on
Commit
b9a2f5d
1 Parent(s): fce2850

Use SentenceTransformers instead

Browse files

Per https://huggingface.co./sentence-transformers/multi-qa-MiniLM-L6-cos-v1/blob/main/README.md

Files changed (1) hide show
  1. handler.py +4 -19
handler.py CHANGED
@@ -1,17 +1,8 @@
1
- from typing import Dict, List, Any
2
- from transformers import AutoTokenizer, AutoModel
3
- import torch
4
-
5
- #Mean Pooling - Take attention mask into account for correct averaging
6
- def mean_pooling(model_output, attention_mask):
7
- token_embeddings = model_output[0] #First element of model_output contains all token embeddings
8
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
9
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
10
 
11
  class EndpointHandler():
12
  def __init__(self, path=""):
13
- self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/multi-qa-MiniLM-L6-cos-v1')
14
- self.model = AutoModel.from_pretrained('sentence-transformers/multi-qa-MiniLM-L6-cos-v1')
15
 
16
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
17
  """
@@ -22,11 +13,5 @@ class EndpointHandler():
22
  A :obj:`list` | `dict`: will be serialized and returned
23
  """
24
  sentences = data.pop("inputs",data)
25
- encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
26
- # Compute token embeddings
27
- with torch.no_grad():
28
- model_output = self.model(**encoded_input)
29
-
30
- # Perform pooling. In this case, max pooling.
31
- sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
32
- return sentence_embeddings.tolist()
 
1
+ from sentence_transformers import SentenceTransformer, util
 
 
 
 
 
 
 
 
2
 
3
  class EndpointHandler():
4
  def __init__(self, path=""):
5
+ self.model = SentenceTransformer('sentence-transformers/multi-qa-MiniLM-L6-cos-v1')
 
6
 
7
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
8
  """
 
13
  A :obj:`list` | `dict`: will be serialized and returned
14
  """
15
  sentences = data.pop("inputs",data)
16
+ embeddings = model.encode(sentences)
17
+ return embeddings.tolist()