Spaces:
Runtime error
Runtime error
Pavankalyan
commited on
Commit
•
17283b0
1
Parent(s):
4681ada
Upload 7 files
Browse files- Responses.csv +0 -0
- app.py +19 -0
- corpus.pt +3 -0
- data_process.py +44 -0
- main.py +22 -0
- requirements.txt +2 -0
- retrieval.py +69 -0
Responses.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
app.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from data_process import *
|
3 |
+
from retrieval import *
|
4 |
+
|
5 |
+
df = pd.read_csv("Responses.csv")
|
6 |
+
text = list(df["text"].values)
|
7 |
+
|
8 |
+
|
9 |
+
def chitti(query):
|
10 |
+
re_table = search(query, text)
|
11 |
+
return re_table[0][0]
|
12 |
+
|
13 |
+
demo = gr.Interface(
|
14 |
+
fn=chitti,
|
15 |
+
inputs=["text"],
|
16 |
+
outputs=["text"],
|
17 |
+
)
|
18 |
+
demo.launch(share=True)
|
19 |
+
|
corpus.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:90d8781fef8d3a3b5a5130ce095c186c076a05ee25e3980cc3cf2577910302b2
|
3 |
+
size 5803755
|
data_process.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
|
6 |
+
def merge_text(text_list):
|
7 |
+
i = 0
|
8 |
+
j = 1
|
9 |
+
|
10 |
+
k = len(text_list)
|
11 |
+
|
12 |
+
while j < k:
|
13 |
+
if len(text_list[i].split()) <= 30:
|
14 |
+
text_list[j] = text_list[i] + " " + text_list[j]
|
15 |
+
text_list[i] = " "
|
16 |
+
i += 1
|
17 |
+
j += 1
|
18 |
+
|
19 |
+
return [accepted for accepted in text_list if accepted is not " "]
|
20 |
+
|
21 |
+
|
22 |
+
def get_text(path):
|
23 |
+
doc_list = sorted(os.listdir(path))
|
24 |
+
text = []
|
25 |
+
for doc in doc_list:
|
26 |
+
sub_text = []
|
27 |
+
with open(os.path.join(path, doc), encoding='utf-8') as f:
|
28 |
+
for line in f.readlines():
|
29 |
+
temp_text = re.sub("\\n", "", line)
|
30 |
+
if temp_text != "":
|
31 |
+
sub_text.append(temp_text)
|
32 |
+
|
33 |
+
sub_text = merge_text(sub_text)
|
34 |
+
text.extend(sub_text)
|
35 |
+
return text
|
36 |
+
|
37 |
+
|
38 |
+
def dataframe(path):
|
39 |
+
text = get_text(path)
|
40 |
+
df = {
|
41 |
+
"text": text
|
42 |
+
}
|
43 |
+
df = pd.DataFrame(df)
|
44 |
+
df.to_csv("Responses.csv")
|
main.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from data_process import *
|
2 |
+
from retrieval import *
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
|
6 |
+
parser = argparse.ArgumentParser(description="Run the query for the bot")
|
7 |
+
parser.add_argument('--query', help="Question to the bot", type=str, required=True)
|
8 |
+
parser.add_argument('--data_path', help="Path for the stored dataset", type=str, required=True)
|
9 |
+
|
10 |
+
args = parser.parse_args()
|
11 |
+
path = args.data_path
|
12 |
+
query = args.query
|
13 |
+
|
14 |
+
if "Responses.csv" not in os.listdir(os.getcwd()):
|
15 |
+
dataframe(path)
|
16 |
+
|
17 |
+
df = pd.read_csv("Responses.csv")
|
18 |
+
text = list(df["text"].values)
|
19 |
+
|
20 |
+
|
21 |
+
search(query, text)
|
22 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
sentence-transformers
|
2 |
+
pandas
|
retrieval.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import textwrap
|
3 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
4 |
+
import torch
|
5 |
+
from tabulate import tabulate
|
6 |
+
import time
|
7 |
+
|
8 |
+
model_bi_encoder = "msmarco-distilbert-base-tas-b"
|
9 |
+
model_cross_encoder = "cross-encoder/ms-marco-MiniLM-L-12-v2"
|
10 |
+
|
11 |
+
bi_encoder = SentenceTransformer(model_bi_encoder)
|
12 |
+
bi_encoder.max_seq_length = 512
|
13 |
+
|
14 |
+
cross_encoder = CrossEncoder(model_cross_encoder)
|
15 |
+
|
16 |
+
top_k = 20
|
17 |
+
|
18 |
+
|
19 |
+
def get_corpus(passages):
|
20 |
+
|
21 |
+
if "corpus.pt" not in os.listdir(os.getcwd()):
|
22 |
+
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)
|
23 |
+
torch.save(corpus_embeddings, "corpus.pt")
|
24 |
+
else:
|
25 |
+
corpus_embeddings = torch.load("corpus.pt")
|
26 |
+
|
27 |
+
return corpus_embeddings
|
28 |
+
|
29 |
+
|
30 |
+
def search(query, passages):
|
31 |
+
|
32 |
+
corpus_embeddings = get_corpus(passages)
|
33 |
+
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
34 |
+
|
35 |
+
be = time.process_time()
|
36 |
+
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
|
37 |
+
#print("Time taken by Bi-encoder:" + str(time.process_time() - be))
|
38 |
+
|
39 |
+
hits = hits[0]
|
40 |
+
cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
|
41 |
+
|
42 |
+
ce = time.process_time()
|
43 |
+
cross_scores = cross_encoder.predict(cross_inp)
|
44 |
+
#print("Time taken by Cross-encoder:" + str(time.process_time() - ce))
|
45 |
+
|
46 |
+
# Sort results by the cross-encoder scores
|
47 |
+
for idx in range(len(cross_scores)):
|
48 |
+
hits[idx]['cross-score'] = cross_scores[idx]
|
49 |
+
|
50 |
+
|
51 |
+
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
52 |
+
result_table = list()
|
53 |
+
for hit in hits[0:5]:
|
54 |
+
ans = "{}".format(passages[hit['corpus_id']].replace("\n", " "))
|
55 |
+
#print(ans)
|
56 |
+
cs = "{}".format(hit['cross-score'])
|
57 |
+
#print(cs)
|
58 |
+
sc = "{}".format(hit['score'])
|
59 |
+
#print(sc)
|
60 |
+
wrapper = textwrap.TextWrapper(width=50)
|
61 |
+
ans = wrapper.fill(text=ans)
|
62 |
+
result_table.append([ans,str(cs),str(sc)])
|
63 |
+
|
64 |
+
return result_table
|
65 |
+
|
66 |
+
#print(tabulate(result_table, headers=["Answer", "Cross-encoder score", "Bi-encoder score"], tablefmt="fancy_grid", maxcolwidths=[None, None, None]))
|
67 |
+
|
68 |
+
|
69 |
+
|