import torch import torch.nn as nn import numpy as np class Hack_no_grad(nn.Module): def __init__(self, module): super().__init__() self.module = module def forward(self, *inputs, **kwargs): with torch.no_grad(): return self.module(*inputs, **kwargs) def find_max_subspans(sequence, n_spans, max_length): length = len(sequence) inner_scores = np.zeros((length, n_spans + 1, max_length + 1, 2)) trace = np.zeros((length, n_spans + 1, max_length + 1, 2, 3), dtype=int) # trace[:, n_spans, max_length, 0] = (n_spans, max_length, 0) inner_scores[-1, :, :, 1] = -1e5 for _i in range(length): for _j in range(n_spans+1): for _k in range(max_length+1): trace[_i, _j, _k, 0] = (_j, max_length, 0) for _i in range(length): for _j in range(n_spans): for _k in range(max_length+1): inner_scores[_i, _j, _k, 0], trace[_i, _j, _k, 0] = ( inner_scores[_i-1, _j, max_length, 0], (_j, max_length, 0) ) max_taken = inner_scores[_i-1, _j, :, 1].max() if max_taken > inner_scores[_i, _j, _k, 0]: inner_scores[_i, _j, _k, 0] = max_taken trace[_i, _j, _k, 0] = ( _j, inner_scores[_i-1, _j, :, 1].argmax(), 1) if _k < max_length: inner_scores[_i, _j, _k, 1], trace[_i, _j, _k, 1] = ( ( inner_scores[_i-1, _j, _k+1, 1] + sequence[_i], (_j, _k+1, 1) ) if (inner_scores[_i-1, _j, _k+1, 1] > inner_scores[_i-1, _j+1, max_length, 0]) else ( inner_scores[_i-1, _j+1, max_length, 0] + sequence[_i], (_j+1, max_length, 0) ) ) max_score = 0 argmax = (0, 0, 0) for _j in reversed(range(n_spans + 1)): for _k in reversed(range(max_length)): if inner_scores[-1, _j, _k, 0] > max_score: max_score = inner_scores[-1, _j, _k, 0] argmax = (_j, _k, 0) if inner_scores[-1, _j, _k, 1] > max_score: max_score = inner_scores[-1, _j, _k, 1] argmax = (_j, _k, 1) trace_back = argmax tags = [] for _i in reversed(range(length)): tags.append(trace_back[2]) trace_back = trace[_i, trace_back[0], trace_back[1], trace_back[2]] tags.reverse() segments = [] start = None for _i in range(length + 1): if _i < length and tags[_i] == 1 and start is None: start = _i elif (_i == length or tags[_i] == 0) and start is not None: segments.append((start, _i)) start = None return segments, max_score, tags # , inner_scores, trace