File size: 6,824 Bytes
6eff5e7
 
580aef7
6eff5e7
 
 
 
963bf46
6eff5e7
 
963bf46
6eff5e7
 
 
 
 
963bf46
6eff5e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
963bf46
 
 
 
6eff5e7
 
 
 
580aef7
 
6eff5e7
 
 
 
580aef7
6eff5e7
 
 
 
 
 
580aef7
963bf46
 
 
 
580aef7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
963bf46
 
 
6eff5e7
 
 
 
 
 
 
 
 
580aef7
6eff5e7
 
580aef7
6eff5e7
580aef7
6eff5e7
 
 
 
 
 
580aef7
 
6eff5e7
580aef7
 
 
 
 
 
 
6eff5e7
 
 
 
 
 
 
 
 
580aef7
963bf46
 
 
6eff5e7
 
580aef7
 
6eff5e7
 
 
580aef7
 
6eff5e7
 
 
963bf46
 
6eff5e7
963bf46
6eff5e7
 
 
 
 
 
 
 
 
 
 
963bf46
6eff5e7
963bf46
6eff5e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
963bf46
6eff5e7
 
 
 
 
 
 
 
 
 
 
 
 
4bea31b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
from sentence_transformers import util
from nltk.tokenize import sent_tokenize
from nltk import word_tokenize, pos_tag
import torch
import numpy as np

def compute_sentencewise_scores(model, query_sents, candidate_sents):
    # TODO make this more general for different types of models
    # list of sentences from query and candidate
    q_v, c_v = get_embedding(model, query_sents, candidate_sents)
    
    return util.cos_sim(q_v, c_v)
    
def get_embedding(model, query_sents, candidate_sents):
    q_v = model.encode(query_sents)
    c_v = model.encode(candidate_sents)
   
    return q_v, c_v

def get_top_k(score_mat, K=3):
    """
    Pick top K sentences to show
    """
    idx = torch.argsort(-score_mat)
    picked_sent = idx[:,:K]
    picked_scores = torch.vstack(
        [score_mat[i,picked_sent[i]] for i in range(picked_sent.shape[0])]
    )
    
    return picked_sent, picked_scores

def get_words(sent):
    """
    Input: list of sentences
    Output: list of list of words per sentence, all words in, index of starting words for each sentence
    """
    words = []
    sent_start_id = [] # keep track of the word index where the new sentence starts
    counter = 0
    for x in sent:
        #w = x.split()
        w = word_tokenize(x)
        nw = len(w)
        counter += nw
        words.append(w)
        sent_start_id.append(counter)
    words = [word_tokenize(x) for x in sent]
    all_words = [item for sublist in words for item in sublist]
    sent_start_id.pop()
    sent_start_id = [0] + sent_start_id
    assert(len(sent_start_id) == len(sent))
    return words, all_words, sent_start_id

def get_match_phrase(w1, w2):
    """
    Input: list of words for query and candidate text
    Output: word list and binary mask of matching phrases between the inputs
    """
    # POS tags that should be considered for matching phrase
    include = [
        'JJ',
        'JJR',
        'JJS',
        'MD',
        'NN',
        'NNS',
        'NNP',
        'NNPS',
        'RB',
        'RBR',
        'RBS',
        'SYM',
        'VB', 
        'VBD',
        'VBG',
        'VBN',
        'FW'
    ]
    mask1 = np.zeros(len(w1))
    mask2 = np.zeros(len(w2))
    pos1 = pos_tag(w1)
    pos2 = pos_tag(w2)
    for i, (w, p) in enumerate(pos2):
        if w.lower() in w1 and p in include:
            mask2[i] = 1
    return mask2

def mark_words(query_sents, words, all_words, sent_start_id, sent_ids, sent_scores):
    """
    Mark the words that are highlighted, both by in terms of sentence and phrase
    """
    num_query_sent = sent_ids.shape[0]
    num_words = len(all_words)
    
    output = dict()
    output['all_words'] = all_words
    output['words_by_sentence'] = words
    
    # for each query sentence, mark the highlight information
    for i in range(num_query_sent):
        query_words = word_tokenize(query_sents[i])
        is_selected_sent = np.zeros(num_words)
        is_selected_phrase = np.zeros(num_words)
        word_scores = np.zeros(num_words)
        
        # for each selected sentences from the candidate, compile information
        for sid, sscore in zip(sent_ids[i], sent_scores[i]):
            #print(len(sent_start_id), sid, sid+1)
            if sid+1 < len(sent_start_id):
                sent_range = (sent_start_id[sid], sent_start_id[sid+1])
                is_selected_sent[sent_range[0]:sent_range[1]] = 1
                word_scores[sent_range[0]:sent_range[1]] = sscore
                is_selected_phrase[sent_range[0]:sent_range[1]] = \
                    get_match_phrase(query_words, all_words[sent_range[0]:sent_range[1]])
            else:
                is_selected_sent[sent_start_id[sid]:] = 1
                word_scores[sent_start_id[sid]:] = sscore
                is_selected_phrase[sent_start_id[sid]:] = \
                    get_match_phrase(query_words, all_words[sent_start_id[sid]:])
                    
        # update selected phrase scores (-1 meaning a different color in gradio)
        word_scores[is_selected_sent+is_selected_phrase==2] = -1
            
        output[i] = {
            'is_selected_sent': is_selected_sent,
            'is_selected_phrase': is_selected_phrase,
            'scores': word_scores
        }

    return output

def get_highlight_info(model, text1, text2, K=None):
    """
    Get highlight information from two texts
    """
    sent1 = sent_tokenize(text1) # query
    sent2 = sent_tokenize(text2) # candidate
    if K is None: # if K is not set, select based on the length of the candidate
        K = int(len(sent2) / 3)
    score_mat = compute_sentencewise_scores(model, sent1, sent2)

    sent_ids, sent_scores = get_top_k(score_mat, K=K)
    words2, all_words2, sent_start_id2 = get_words(sent2)
    info = mark_words(sent1, words2, all_words2, sent_start_id2, sent_ids, sent_scores)
    
    return sent_ids, sent_scores, info

### Document-level operations

def predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=20):
    # compute document scores for each papers
    
    # concatenate title and abstract
    title_abs = []
    for t, a in zip(titles, abstracts):
        if t is not None and a is not None:
            title_abs.append(t + ' [SEP] ' + a)    
            
    num_docs = len(title_abs) 
    no_iter = int(np.ceil(num_docs / batch))
    scores = []
    with torch.no_grad():
        # batch
        for i in range(no_iter):
            # preprocess the input
            inputs = tokenizer(
                [query] + title_abs[i*batch:(i+1)*batch], 
                padding=True, 
                truncation=True, 
                return_tensors="pt", 
                max_length=512
            )
            inputs.to(doc_model.device)
            result = doc_model(**inputs)
        
            # take the first token in the batch as the embedding
            embeddings = result.last_hidden_state[:, 0, :].detach().cpu().numpy()
        
            # compute cosine similarity
            q_emb = embeddings[0,:]
            p_emb = embeddings[1:,:]
            nn = np.linalg.norm(q_emb) * np.linalg.norm(p_emb, axis=1)
            scores += list(np.dot(p_emb, q_emb) / nn)

    assert(len(scores) == num_docs)
    
    return scores

def compute_document_score(doc_model, tokenizer, query, papers, batch=5):
    scores = []
    titles = []
    abstracts = []
    for p in papers:
        titles.append(p['title'])
        abstracts.append(p['abstract'])
    scores = predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=batch)
    idx_sorted = np.argsort(scores)[::-1]
    
    titles_sorted = [titles[x] for x in idx_sorted]
    abstracts_sorted = [abstracts[x] for x in idx_sorted]
    scores_sorted = [scores[x] for x in idx_sorted]
    
    return titles_sorted, abstracts_sorted, scores_sorted