Update README.md
Browse files
README.md
CHANGED
@@ -51,18 +51,22 @@ from transformers import pipeline
|
|
51 |
from scipy.spatial.distance import cdist
|
52 |
|
53 |
retriever = pipeline('feature-extraction', 'cmarkea/bloomz-3b-retriever')
|
54 |
-
|
|
|
|
|
55 |
|
56 |
list_of_contexts = [...]
|
57 |
emb_contexts = np.concatenate(infer(list_of_contexts), axis=0)
|
58 |
list_of_queries = [...]
|
59 |
emb_queries = np.concatenate(infer(list_of_queries), axis=0)
|
60 |
|
61 |
-
|
|
|
62 |
top_k = lambda x: [
|
63 |
[list_of_contexts[qq] for qq in ii]
|
64 |
for ii in dist.argsort(axis=-1)[:,:x]
|
65 |
]
|
|
|
66 |
# top 5 nearest contexts for each queries
|
67 |
top_contexts = top_k(5)
|
68 |
```
|
|
|
51 |
from scipy.spatial.distance import cdist
|
52 |
|
53 |
retriever = pipeline('feature-extraction', 'cmarkea/bloomz-3b-retriever')
|
54 |
+
|
55 |
+
# Inportant: take only last token!
|
56 |
+
infer = lambda x: [ii[0][-1] for ii in retriever(x)]
|
57 |
|
58 |
list_of_contexts = [...]
|
59 |
emb_contexts = np.concatenate(infer(list_of_contexts), axis=0)
|
60 |
list_of_queries = [...]
|
61 |
emb_queries = np.concatenate(infer(list_of_queries), axis=0)
|
62 |
|
63 |
+
# Important: take l2 distance!
|
64 |
+
dist = cdist(emb_queries, emb_contexts, 'euclidean')
|
65 |
top_k = lambda x: [
|
66 |
[list_of_contexts[qq] for qq in ii]
|
67 |
for ii in dist.argsort(axis=-1)[:,:x]
|
68 |
]
|
69 |
+
|
70 |
# top 5 nearest contexts for each queries
|
71 |
top_contexts = top_k(5)
|
72 |
```
|