Spaces:
Runtime error
Runtime error
Pavankalyan
commited on
Commit
•
d42d3ef
1
Parent(s):
ecdcda6
Delete retrieval.py
Browse files- retrieval.py +0 -69
retrieval.py
DELETED
@@ -1,69 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import textwrap
|
3 |
-
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
4 |
-
import torch
|
5 |
-
#from tabulate import tabulate
|
6 |
-
import time
|
7 |
-
|
8 |
-
model_bi_encoder = "msmarco-distilbert-base-tas-b"
|
9 |
-
model_cross_encoder = "cross-encoder/ms-marco-MiniLM-L-12-v2"
|
10 |
-
|
11 |
-
bi_encoder = SentenceTransformer(model_bi_encoder)
|
12 |
-
bi_encoder.max_seq_length = 512
|
13 |
-
|
14 |
-
cross_encoder = CrossEncoder(model_cross_encoder)
|
15 |
-
|
16 |
-
top_k = 20
|
17 |
-
|
18 |
-
|
19 |
-
def get_corpus(passages):
|
20 |
-
|
21 |
-
if "corpus.pt" not in os.listdir(os.getcwd()):
|
22 |
-
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)
|
23 |
-
torch.save(corpus_embeddings, "corpus.pt")
|
24 |
-
else:
|
25 |
-
corpus_embeddings = torch.load("corpus.pt")
|
26 |
-
|
27 |
-
return corpus_embeddings
|
28 |
-
|
29 |
-
|
30 |
-
def search(query, passages):
|
31 |
-
|
32 |
-
corpus_embeddings = get_corpus(passages)
|
33 |
-
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
34 |
-
|
35 |
-
be = time.process_time()
|
36 |
-
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
|
37 |
-
#print("Time taken by Bi-encoder:" + str(time.process_time() - be))
|
38 |
-
|
39 |
-
hits = hits[0]
|
40 |
-
cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
|
41 |
-
|
42 |
-
ce = time.process_time()
|
43 |
-
cross_scores = cross_encoder.predict(cross_inp)
|
44 |
-
#print("Time taken by Cross-encoder:" + str(time.process_time() - ce))
|
45 |
-
|
46 |
-
# Sort results by the cross-encoder scores
|
47 |
-
for idx in range(len(cross_scores)):
|
48 |
-
hits[idx]['cross-score'] = cross_scores[idx]
|
49 |
-
|
50 |
-
|
51 |
-
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
52 |
-
result_table = list()
|
53 |
-
for hit in hits[0:5]:
|
54 |
-
ans = "{}".format(passages[hit['corpus_id']].replace("\n", " "))
|
55 |
-
#print(ans)
|
56 |
-
cs = "{}".format(hit['cross-score'])
|
57 |
-
#print(cs)
|
58 |
-
sc = "{}".format(hit['score'])
|
59 |
-
#print(sc)
|
60 |
-
wrapper = textwrap.TextWrapper(width=50)
|
61 |
-
ans = wrapper.fill(text=ans)
|
62 |
-
result_table.append([ans,str(cs),str(sc)])
|
63 |
-
|
64 |
-
return result_table
|
65 |
-
|
66 |
-
#print(tabulate(result_table, headers=["Answer", "Cross-encoder score", "Bi-encoder score"], tablefmt="fancy_grid", maxcolwidths=[None, None, None]))
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|