Spaces:
Runtime error
Runtime error
import pandas as pd | |
import os | |
import json | |
import re | |
from sentence_transformers import SentenceTransformer, CrossEncoder, util | |
import torch | |
import time | |
import textwrap | |
model_bi_encoder = "msmarco-distilbert-base-tas-b" | |
model_cross_encoder = "cross-encoder/ms-marco-MiniLM-L-12-v2" | |
bi_encoder = SentenceTransformer(model_bi_encoder) | |
bi_encoder.max_seq_length = 512 | |
cross_encoder = CrossEncoder(model_cross_encoder) | |
def collect_data(data_lis,meta_count): | |
new_files = data_lis['file_name'][meta_count:] | |
new_links = data_lis['link'][meta_count:] | |
return new_files,new_links | |
def merge_text(text_list): | |
i = 0;j = 1 | |
k = len(text_list) | |
while j < k: | |
if len(text_list[i].split()) <= 30: | |
text_list[j] = text_list[i] + " " + text_list[j] | |
text_list[i] = " " | |
i += 1;j += 1 | |
return [accepted for accepted in text_list if accepted != " "] | |
def make_data(new_files,new_links,local_path): | |
text = [];links = [] | |
for doc in range(len(new_files)): | |
sub_text = [];sub_link = [] | |
with open(os.path.join(local_path, new_files[doc]), encoding='utf-8') as f: | |
for line in f.readlines(): | |
temp_text = re.sub("\\n", "", line) | |
if temp_text != "": | |
sub_text.append(temp_text) | |
sub_text = merge_text(sub_text) | |
sub_link = [new_links[doc] for i in range(len(sub_text))] | |
text.extend(sub_text) | |
links.extend(sub_link) | |
return text,links | |
def get_final_data(): | |
#Define all the paths | |
meta_path = "meta_data.json" | |
data_lis_path = "data_url.csv" | |
local_path = "Data_final" | |
data_path = "Responses.csv" | |
corpus_path = "corpus.pt" | |
# Load the list of data files | |
data_lis = pd.read_csv(data_lis_path) | |
# Load the responses.csv file | |
if not(os.path.exists(data_path)): | |
fresh_text = [] | |
fresh_link = [] | |
fresh_data = { | |
"text": fresh_text, | |
"links": fresh_link | |
} | |
fresh_data = pd.DataFrame(fresh_data) | |
fresh_data.to_csv(data_path) | |
data = pd.read_csv(data_path) | |
# Check for any new files; If present add those to responses.csv file | |
# Make changes to corpus.pt accordingly | |
act_count = len(data_lis['file_name']) | |
with open(meta_path, "r") as jsonFile: | |
meta_data = json.load(jsonFile) | |
meta_count = meta_data["data"]["count"] | |
if meta_count!=act_count: | |
meta_data["data"]["count"] = act_count | |
with open(meta_path, "w") as jsonFile: | |
json.dump(meta_data, jsonFile) | |
new_files,new_links = collect_data(data_lis,meta_count) | |
text,links = make_data(new_files,new_links,local_path) | |
df = { | |
"text": text, | |
"links":links | |
} | |
df = pd.DataFrame(df) | |
data = pd.concat([data,df]) | |
data.to_csv("Responses.csv") | |
if not(os.path.exists(corpus_path)): | |
corpus_embeddings = bi_encoder.encode(data["text"], convert_to_tensor=True, show_progress_bar=True) | |
torch.save(corpus_embeddings, corpus_path) | |
else: | |
corpus_embeddings = torch.load(corpus_path) | |
new_embeddings = bi_encoder.encode(df["text"], convert_to_tensor=True, show_progress_bar=True) | |
corpus_embeddings = torch.cat((corpus_embeddings,new_embeddings),0) | |
torch.save(corpus_embeddings, corpus_path) | |
corpus_embeddings = torch.load(corpus_path) | |
return corpus_embeddings,data | |
def search(query): | |
corpus_embeddings,data = get_final_data() | |
question_embedding = bi_encoder.encode(query, convert_to_tensor=True) | |
top_k = 20 | |
#be = time.process_time() | |
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k) | |
#print("Time taken by Bi-encoder:" + str(time.process_time() - be)) | |
hits = hits[0] | |
cross_inp = [[query, data['text'][hit['corpus_id']]] for hit in hits] | |
#ce = time.process_time() | |
cross_scores = cross_encoder.predict(cross_inp) | |
#print("Time taken by Cross-encoder:" + str(time.process_time() - ce)) | |
# Sort results by the cross-encoder scores | |
for idx in range(len(cross_scores)): | |
hits[idx]['cross-score'] = cross_scores[idx] | |
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) | |
result_table = list() | |
for hit in hits[0:5]: | |
ans = "{}".format(data['text'][hit['corpus_id']].replace("\n", " ")) | |
#print(ans) | |
cs = "{}".format(hit['cross-score']) | |
#print(cs) | |
sc = "{}".format(hit['score']) | |
#print(sc) | |
corr_link = "{}".format(data['links'][hit['corpus_id']]) | |
wrapper = textwrap.TextWrapper(width=50) | |
ans = wrapper.fill(text=ans) | |
result_table.append([ans,str(cs),str(sc),str(corr_link)]) | |
return result_table |