Cyrile commited on
Commit
f8bd3d5
1 Parent(s): 627d6d6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +2 -2
README.md CHANGED
@@ -73,7 +73,7 @@ model = AutoModel.from_pretrained('cmarkea/bloomz-3b-retriever-v2')
73
 
74
  def infer(txt: Union[str, List[str]]):
75
  tok = tokenizer(txt, padding=True, return_tensors='pt')
76
- with torch.no_grad():
77
  embedding = model(**tok)
78
  # Inportant: take only last token!
79
  return embedding.get('last_hidden_state')[:,-1,:].numpy()
@@ -104,7 +104,7 @@ from scipy.spatial.distance import cdist
104
  retriever = pipeline('feature-extraction', 'cmarkea/bloomz-3b-retriever-v2')
105
 
106
  # Inportant: take only last token!
107
- infer = lambda x: [ii[0][-1] for ii in retriever(x)]
108
 
109
  list_of_contexts: List[str] = [...]
110
  emb_contexts = np.concatenate(infer(list_of_contexts), axis=0)
 
73
 
74
  def infer(txt: Union[str, List[str]]):
75
  tok = tokenizer(txt, padding=True, return_tensors='pt')
76
+ with torch.inference_mode():
77
  embedding = model(**tok)
78
  # Inportant: take only last token!
79
  return embedding.get('last_hidden_state')[:,-1,:].numpy()
 
104
  retriever = pipeline('feature-extraction', 'cmarkea/bloomz-3b-retriever-v2')
105
 
106
  # Inportant: take only last token!
107
+ infer = lambda x: [np.array(ii[0][-1]).reshape(1,-1) for ii in retriever(x)]
108
 
109
  list_of_contexts: List[str] = [...]
110
  emb_contexts = np.concatenate(infer(list_of_contexts), axis=0)