Pavankalyan commited on
Commit
d42d3ef
1 Parent(s): ecdcda6

Delete retrieval.py

Browse files
Files changed (1) hide show
  1. retrieval.py +0 -69
retrieval.py DELETED
@@ -1,69 +0,0 @@
1
- import os
2
- import textwrap
3
- from sentence_transformers import SentenceTransformer, CrossEncoder, util
4
- import torch
5
- #from tabulate import tabulate
6
- import time
7
-
8
- model_bi_encoder = "msmarco-distilbert-base-tas-b"
9
- model_cross_encoder = "cross-encoder/ms-marco-MiniLM-L-12-v2"
10
-
11
- bi_encoder = SentenceTransformer(model_bi_encoder)
12
- bi_encoder.max_seq_length = 512
13
-
14
- cross_encoder = CrossEncoder(model_cross_encoder)
15
-
16
- top_k = 20
17
-
18
-
19
- def get_corpus(passages):
20
-
21
- if "corpus.pt" not in os.listdir(os.getcwd()):
22
- corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)
23
- torch.save(corpus_embeddings, "corpus.pt")
24
- else:
25
- corpus_embeddings = torch.load("corpus.pt")
26
-
27
- return corpus_embeddings
28
-
29
-
30
- def search(query, passages):
31
-
32
- corpus_embeddings = get_corpus(passages)
33
- question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
34
-
35
- be = time.process_time()
36
- hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
37
- #print("Time taken by Bi-encoder:" + str(time.process_time() - be))
38
-
39
- hits = hits[0]
40
- cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
41
-
42
- ce = time.process_time()
43
- cross_scores = cross_encoder.predict(cross_inp)
44
- #print("Time taken by Cross-encoder:" + str(time.process_time() - ce))
45
-
46
- # Sort results by the cross-encoder scores
47
- for idx in range(len(cross_scores)):
48
- hits[idx]['cross-score'] = cross_scores[idx]
49
-
50
-
51
- hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
52
- result_table = list()
53
- for hit in hits[0:5]:
54
- ans = "{}".format(passages[hit['corpus_id']].replace("\n", " "))
55
- #print(ans)
56
- cs = "{}".format(hit['cross-score'])
57
- #print(cs)
58
- sc = "{}".format(hit['score'])
59
- #print(sc)
60
- wrapper = textwrap.TextWrapper(width=50)
61
- ans = wrapper.fill(text=ans)
62
- result_table.append([ans,str(cs),str(sc)])
63
-
64
- return result_table
65
-
66
- #print(tabulate(result_table, headers=["Answer", "Cross-encoder score", "Bi-encoder score"], tablefmt="fancy_grid", maxcolwidths=[None, None, None]))
67
-
68
-
69
-