Spaces:
Runtime error
Runtime error
from sentence_transformers import util | |
from nltk.tokenize import sent_tokenize | |
from nltk import word_tokenize, pos_tag | |
import torch | |
import numpy as np | |
def compute_sentencewise_scores(model, query_sents, candidate_sents): | |
# TODO make this more general for different types of models | |
# list of sentences from query and candidate | |
q_v, c_v = get_embedding(model, query_sents, candidate_sents) | |
return util.cos_sim(q_v, c_v) | |
def get_embedding(model, query_sents, candidate_sents): | |
q_v = model.encode(query_sents) | |
c_v = model.encode(candidate_sents) | |
return q_v, c_v | |
def get_top_k(score_mat, K=3): | |
""" | |
Pick top K sentences to show | |
""" | |
idx = torch.argsort(-score_mat) | |
picked_sent = idx[:,:K] | |
picked_scores = torch.vstack( | |
[score_mat[i,picked_sent[i]] for i in range(picked_sent.shape[0])] | |
) | |
return picked_sent, picked_scores | |
def get_words(sent): | |
""" | |
Input: list of sentences | |
Output: list of list of words per sentence, all words in, index of starting words for each sentence | |
""" | |
words = [] | |
sent_start_id = [] # keep track of the word index where the new sentence starts | |
counter = 0 | |
for x in sent: | |
#w = x.split() | |
w = word_tokenize(x) | |
nw = len(w) | |
counter += nw | |
words.append(w) | |
sent_start_id.append(counter) | |
words = [word_tokenize(x) for x in sent] | |
all_words = [item for sublist in words for item in sublist] | |
sent_start_id.pop() | |
sent_start_id = [0] + sent_start_id | |
assert(len(sent_start_id) == len(sent)) | |
return words, all_words, sent_start_id | |
def get_match_phrase(w1, w2): | |
""" | |
Input: list of words for query and candidate text | |
Output: word list and binary mask of matching phrases between the inputs | |
""" | |
# POS tags that should be considered for matching phrase | |
include = [ | |
'JJ', | |
'JJR', | |
'JJS', | |
'MD', | |
'NN', | |
'NNS', | |
'NNP', | |
'NNPS', | |
'RB', | |
'RBR', | |
'RBS', | |
'SYM', | |
'VB', | |
'VBD', | |
'VBG', | |
'VBN', | |
'FW' | |
] | |
mask1 = np.zeros(len(w1)) | |
mask2 = np.zeros(len(w2)) | |
pos1 = pos_tag(w1) | |
pos2 = pos_tag(w2) | |
for i, (w, p) in enumerate(pos2): | |
if w.lower() in w1 and p in include: | |
mask2[i] = 1 | |
return mask2 | |
def mark_words(query_sents, words, all_words, sent_start_id, sent_ids, sent_scores): | |
""" | |
Mark the words that are highlighted, both by in terms of sentence and phrase | |
""" | |
num_query_sent = sent_ids.shape[0] | |
num_words = len(all_words) | |
output = dict() | |
output['all_words'] = all_words | |
output['words_by_sentence'] = words | |
# for each query sentence, mark the highlight information | |
for i in range(num_query_sent): | |
query_words = word_tokenize(query_sents[i]) | |
is_selected_sent = np.zeros(num_words) | |
is_selected_phrase = np.zeros(num_words) | |
word_scores = np.zeros(num_words) | |
# for each selected sentences from the candidate, compile information | |
for sid, sscore in zip(sent_ids[i], sent_scores[i]): | |
#print(len(sent_start_id), sid, sid+1) | |
if sid+1 < len(sent_start_id): | |
sent_range = (sent_start_id[sid], sent_start_id[sid+1]) | |
is_selected_sent[sent_range[0]:sent_range[1]] = 1 | |
word_scores[sent_range[0]:sent_range[1]] = sscore | |
is_selected_phrase[sent_range[0]:sent_range[1]] = \ | |
get_match_phrase(query_words, all_words[sent_range[0]:sent_range[1]]) | |
else: | |
is_selected_sent[sent_start_id[sid]:] = 1 | |
word_scores[sent_start_id[sid]:] = sscore | |
is_selected_phrase[sent_start_id[sid]:] = \ | |
get_match_phrase(query_words, all_words[sent_start_id[sid]:]) | |
# update selected phrase scores (-1 meaning a different color in gradio) | |
word_scores[is_selected_sent+is_selected_phrase==2] = -1 | |
output[i] = { | |
'is_selected_sent': is_selected_sent, | |
'is_selected_phrase': is_selected_phrase, | |
'scores': word_scores | |
} | |
return output | |
def get_highlight_info(model, text1, text2, K=None): | |
""" | |
Get highlight information from two texts | |
""" | |
sent1 = sent_tokenize(text1) # query | |
sent2 = sent_tokenize(text2) # candidate | |
if K is None: # if K is not set, select based on the length of the candidate | |
K = int(len(sent2) / 3) | |
score_mat = compute_sentencewise_scores(model, sent1, sent2) | |
sent_ids, sent_scores = get_top_k(score_mat, K=K) | |
words2, all_words2, sent_start_id2 = get_words(sent2) | |
info = mark_words(sent1, words2, all_words2, sent_start_id2, sent_ids, sent_scores) | |
return sent_ids, sent_scores, info | |
### Document-level operations | |
def predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=20): | |
# compute document scores for each papers | |
# concatenate title and abstract | |
title_abs = [] | |
for t, a in zip(titles, abstracts): | |
if t is not None and a is not None: | |
title_abs.append(t + ' [SEP] ' + a) | |
num_docs = len(title_abs) | |
no_iter = int(np.ceil(num_docs / batch)) | |
scores = [] | |
with torch.no_grad(): | |
# batch | |
for i in range(no_iter): | |
# preprocess the input | |
inputs = tokenizer( | |
[query] + title_abs[i*batch:(i+1)*batch], | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
max_length=512 | |
) | |
inputs.to(doc_model.device) | |
result = doc_model(**inputs) | |
# take the first token in the batch as the embedding | |
embeddings = result.last_hidden_state[:, 0, :].detach().cpu().numpy() | |
# compute cosine similarity | |
q_emb = embeddings[0,:] | |
p_emb = embeddings[1:,:] | |
nn = np.linalg.norm(q_emb) * np.linalg.norm(p_emb, axis=1) | |
scores += list(np.dot(p_emb, q_emb) / nn) | |
assert(len(scores) == num_docs) | |
return scores | |
def compute_document_score(doc_model, tokenizer, query, papers, batch=5): | |
scores = [] | |
titles = [] | |
abstracts = [] | |
for p in papers: | |
titles.append(p['title']) | |
abstracts.append(p['abstract']) | |
scores = predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=batch) | |
idx_sorted = np.argsort(scores)[::-1] | |
titles_sorted = [titles[x] for x in idx_sorted] | |
abstracts_sorted = [abstracts[x] for x in idx_sorted] | |
scores_sorted = [scores[x] for x in idx_sorted] | |
return titles_sorted, abstracts_sorted, scores_sorted | |