Spaces:
Build error
Build error
import re | |
import numpy as np | |
from weakly_supervised_parser.tree.helpers import Tree | |
def CKY(sent_all, prob_s, label_s, verbose=False): | |
r""" | |
choose tree with maximum expected number of constituents, | |
or max \sum_{(i,j) \in tree} p((i,j) is constituent) | |
""" | |
def backpt_to_tree(sent, backpt, label_table): | |
def to_tree(i, j): | |
if j - i == 1: | |
return Tree(sent[i], None, sent[i]) | |
else: | |
k = backpt[i][j] | |
return Tree(label_table[i][j], [to_tree(i, k), to_tree(k, j)], None) | |
return to_tree(0, len(sent)) | |
def to_table(value_s, i_s, j_s): | |
table = [[None for _ in range(np.max(j_s) + 1)] for _ in range(np.max(i_s) + 1)] | |
for value, i, j in zip(value_s, i_s, j_s): | |
table[i][j] = value | |
return table | |
# produce list of spans to pass to is_constituent, while keeping track of which sentence | |
sent_s, i_s, j_s = [], [], [] | |
idx_all = [] | |
for sent in sent_all: | |
start = len(sent_s) | |
for i in range(len(sent)): | |
for j in range(i + 1, len(sent) + 1): | |
sent_s.append(sent) | |
i_s.append(i) | |
j_s.append(j) | |
idx_all.append((start, len(sent_s))) | |
# feed spans to is_constituent | |
# prob_s, label_s = self.is_constituent(sent_s, i_s, j_s, verbose = verbose) | |
# given span probs, perform CKY to get best tree for each sentence. | |
tree_all, prob_all = [], [] | |
for sent, idx in zip(sent_all, idx_all): | |
# first, use tables to keep track of things | |
k, l = idx | |
prob, label = prob_s[k:l], label_s[k:l] | |
i, j = i_s[k:l], j_s[k:l] | |
prob_table = to_table(prob, i, j) | |
label_table = to_table(label, i, j) | |
# perform cky using scores and backpointers | |
score_table = [[None for _ in range(len(sent) + 1)] for _ in range(len(sent))] | |
backpt_table = [[None for _ in range(len(sent) + 1)] for _ in range(len(sent))] | |
for i in range(len(sent)): # base case: single words | |
score_table[i][i + 1] = 1 | |
for j in range(2, len(sent) + 1): | |
for i in range(j - 2, -1, -1): | |
best, argmax = -np.inf, None | |
for k in range(i + 1, j): # find splitpoint | |
score = score_table[i][k] + score_table[k][j] | |
if score > best: | |
best, argmax = score, k | |
score_table[i][j] = best + prob_table[i][j] | |
backpt_table[i][j] = argmax | |
tree = backpt_to_tree(sent, backpt_table, label_table) | |
tree_all.append(tree) | |
prob_all.append(prob_table) | |
return tree_all, prob_all | |
def get_best_parse(sentence, spans): | |
flattened_scores = [] | |
for i in range(spans.shape[0]): | |
for j in range(spans.shape[1]): | |
if i > j: | |
continue | |
else: | |
flattened_scores.append(spans[i, j]) | |
prob_s, label_s = flattened_scores, ["S"] * len(flattened_scores) | |
# print(prob_s, label_s) | |
trees, _ = CKY(sent_all=sentence, prob_s=prob_s, label_s=label_s) | |
s = str(trees[0]) | |
# Replace previous occurrence of string | |
out = re.sub(r"(?<![^\s()])([^\s()]+)(?=\s+\1(?![^\s()]))", "S", s) | |
# best_parse = "(ROOT " + out + ")" | |
return out # best_parse | |