|
import nltk |
|
from nltk.tag import PerceptronTagger |
|
from stable_whisper.result import WordTiming |
|
import numpy as np |
|
import torch |
|
|
|
additional_tags = { |
|
"as": "`AS", |
|
"and": "`AND", |
|
"of": "`OF", |
|
"how": "`HOW", |
|
"but": "`BUT", |
|
"the": "`THE", |
|
"a": "`A", |
|
"an": "`A", |
|
"which": "`WHICH", |
|
"what": "`WHAT", |
|
"where": "`WHERE", |
|
"that": "`THAT", |
|
"who": "`WHO", |
|
"when": "`WHEN", |
|
} |
|
|
|
def get_upenn_tags_dict(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {'#': 0, '$': 1, "''": 2,'(': 3,')': 4,',': 5,'.': 6,':': 7,'CC': 8,'CD': 9,'DT': 10,'EX': 11,'FW': 12,'IN': 13,'JJ': 14,'JJR': 15,'JJS': 16,'LS': 17,'MD': 18,'NN': 19,'NNP': 20,'NNPS': 21,'NNS': 22,'PDT': 23,'POS': 24,'PRP': 25,'PRP$': 26,'RB': 27,'RBR': 28,'RBS': 29,'RP': 30,'SYM': 31,'TO': 32,'UH': 33,'VB': 34,'VBD': 35,'VBG': 36,'VBN': 37,'VBP': 38,'VBZ': 39,'WDT': 40,'WP': 41,'WP$': 42,'WRB': 43,'``': 44,'BREAK': 45, |
|
'`AS': 46, |
|
'`AND': 47, |
|
'`OF': 48, |
|
'`HOW': 49, |
|
'`BUT': 50, |
|
'`THE': 51, |
|
'`A': 52, |
|
'`WHICH': 53, |
|
'`WHAT': 54, |
|
'`WHERE': 55, |
|
'`THAT': 56, |
|
'`WHO': 57, |
|
'`WHEN': 58 |
|
} |
|
|
|
def nltk_extend_tags(tagged_text: list[tuple[str, str]]): |
|
result = [] |
|
for text, tag in tagged_text: |
|
text_lower = text.lower().strip() |
|
if text_lower in additional_tags: |
|
yield (text, additional_tags[text_lower]) |
|
else: |
|
yield (text, tag) |
|
|
|
def bind_wordtimings_to_tags(wt: list[WordTiming]): |
|
raw_words = [w.word for w in wt] |
|
|
|
tokenized_raw_words = [] |
|
tokens_wordtiming_map = [] |
|
|
|
for word in raw_words: |
|
tokens_word = nltk.word_tokenize(word) |
|
tokenized_raw_words.extend(tokens_word) |
|
tokens_wordtiming_map.append(len(tokens_word)) |
|
|
|
tagged_words = nltk.pos_tag(tokenized_raw_words) |
|
tagged_words = list(nltk_extend_tags(tagged_words)) |
|
|
|
grouped_tags = [] |
|
|
|
for k in tokens_wordtiming_map: |
|
grouped_tags.append(tagged_words[:k]) |
|
tagged_words = tagged_words[k:] |
|
|
|
tags_only = [tuple([w[1] for w in t]) for t in grouped_tags] |
|
|
|
wordtimings_with_tags = zip(wt, tags_only) |
|
|
|
return list(wordtimings_with_tags) |
|
|
|
def embed_tag_list(tags: list[str]): |
|
tags_dict = get_upenn_tags_dict() |
|
eye = np.eye(len(tags_dict)) |
|
return eye[np.array([tags_dict[tag] for tag in tags])] |
|
|
|
def lookup_tag_list(tags: list[str]): |
|
tags_dict = get_upenn_tags_dict() |
|
return np.array([tags_dict[tag] for tag in tags], dtype=int) |
|
|
|
def tag_training_data(filename: str): |
|
with open(filename, "r") as f: |
|
segmented_lines = f.readlines() |
|
|
|
segmented_lines = [s.strip() for s in segmented_lines if s.strip() != ""] |
|
|
|
|
|
full_text = " ".join(segmented_lines) |
|
|
|
tokenized_full_text = nltk.word_tokenize(full_text) |
|
tagged_full_text = nltk.pos_tag(tokenized_full_text) |
|
tagged_full_text = list(nltk_extend_tags(tagged_full_text)) |
|
|
|
tagged_full_text_copy = tagged_full_text |
|
|
|
reconstructed_tags = [] |
|
|
|
for line in segmented_lines: |
|
line_nospace = line.replace(r" ", "") |
|
|
|
found = False |
|
|
|
for i in range(len(tagged_full_text_copy)+1): |
|
rejoined = "".join([x[0] for x in tagged_full_text_copy[:i]]) |
|
|
|
if line_nospace == rejoined: |
|
found = True |
|
reconstructed_tags.append(tagged_full_text_copy[:i]) |
|
tagged_full_text_copy = tagged_full_text_copy[i:] |
|
continue; |
|
|
|
if found == False: |
|
print("Panic. Cannot match further.") |
|
print(f"Was trying to match: {line}") |
|
print(tagged_full_text_copy) |
|
|
|
return reconstructed_tags |
|
|
|
def parse_tags(reconstructed_tags): |
|
""" |
|
Parse reconstructed tags into input/tag datapoint. |
|
In the original plan, this type of output is suitable for bidirectional LSTM. |
|
|
|
Input: |
|
reconstured_tags: |
|
Tagged segments, from tag_training_data() |
|
Example: [ |
|
[('You', 'PRP'), ("'re", 'VBP'), ('back', 'RB'), ('again', 'RB'), ('?', '.')], |
|
[('You', 'PRP'),("'ve", 'VBP'), ('been', 'VBN'), ('consuming', 'VBG'), ('a', 'DT'), ('lot', 'NN'), ('of', 'IN'), ('tech', 'JJ'), ('news', 'NN'), ('lately', 'RB'), ('.', '.')] |
|
... |
|
] |
|
|
|
Output: |
|
(input_tokens, output_tag) |
|
input_tokens: |
|
A sequence of tokens, each number corresponds to a type of word. |
|
Example: [25, 38, 27, 27, 6, 25, 38, 37, 36, 10, 19, 13, 14, 19, 27, 6] |
|
output_tags: |
|
A sequence of 0 and 1, indicating whether a break should be inserted AFTER each location. |
|
Example: [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1] |
|
""" |
|
tags_dict = get_upenn_tags_dict() |
|
|
|
all_tags_sequence = [[y[1] for y in segments] + ['BREAK'] for segments in reconstructed_tags] |
|
all_tags_sequence = [tag for tags in all_tags_sequence for tag in tags] |
|
|
|
input_tokens = [] |
|
output_tag = [] |
|
for token in all_tags_sequence: |
|
if token != 'BREAK': |
|
input_tokens.append(tags_dict[token]) |
|
output_tag.append(0) |
|
else: |
|
output_tag[-1] = 1 |
|
|
|
return input_tokens, output_tag |
|
|
|
def embed_segments(tagged_segments): |
|
tags, tags_dict = get_upenn_tags_dict() |
|
|
|
for index, tag in enumerate(tags): |
|
tags_dict[tag] = index |
|
|
|
result_embedding = [] |
|
|
|
classes = len(tags) |
|
eye = np.eye(classes) |
|
|
|
for segment in tagged_segments: |
|
targets = np.array([tags_dict[tag] for word, tag in segment]) |
|
segment_embedding = eye[targets] |
|
|
|
result_embedding.append(segment_embedding) |
|
result_embedding.append(np.array([eye[tags_dict["BREAK"]]])) |
|
|
|
result_embedding = np.concatenate(result_embedding) |
|
|
|
return result_embedding, tags_dict |
|
|
|
def window_embedded_segments_rnn(embeddings, tags_dict): |
|
datapoints = [] |
|
eye = np.eye(len(tags_dict)) |
|
|
|
break_vector = eye[tags_dict["BREAK"]] |
|
|
|
for i in range(1, embeddings.shape[0]): |
|
|
|
if (embeddings[i] == break_vector).all(): |
|
continue |
|
else: |
|
prev_sequence = embeddings[:i] |
|
|
|
if (prev_sequence[-1] == break_vector).all(): |
|
|
|
prev_sequence = prev_sequence[:-1] |
|
tag = 1 |
|
else: |
|
|
|
tag = 0 |
|
|
|
entire_sequence = np.concatenate((prev_sequence, np.array([embeddings[i]]))) |
|
|
|
datapoints.append((entire_sequence, tag)) |
|
return datapoints |
|
|
|
def print_dataset(datapoints, tags_dict, tokenized_full_text): |
|
eye = np.eye(len(tags_dict)) |
|
|
|
break_vector = eye[tags_dict["BREAK"]] |
|
|
|
for input, tag in datapoints: |
|
if tag == 1: |
|
print("[1] ", end='') |
|
else: |
|
print("[0] ", end='') |
|
|
|
count = 0 |
|
for v in input: |
|
if not (v == break_vector).all(): |
|
count += 1 |
|
|
|
|
|
segment = tokenized_full_text[:count] |
|
print(segment) |
|
|
|
from stable_whisper.result import Segment |
|
|
|
def get_indicies(segment: Segment, model, device, threshold): |
|
word_list = segment.words |
|
tagged_wordtiming = bind_wordtimings_to_tags(word_list) |
|
|
|
tag_list = [tag for twt in tagged_wordtiming for tag in twt[1]] |
|
|
|
tag_per_word = [len(twt[1]) for twt in tagged_wordtiming] |
|
|
|
embedded_tags = embed_tag_list(tag_list) |
|
embedded_tags = torch.from_numpy(embedded_tags).float() |
|
|
|
output = model(embedded_tags[None, :].to(device)) |
|
|
|
list_output = output.detach().cpu().numpy().tolist()[0] |
|
|
|
current_index = 0 |
|
cut_indicies = [] |
|
for index, tags_count in enumerate(tag_per_word): |
|
tags = list_output[current_index:current_index+tags_count] |
|
if max(tags) > threshold: |
|
cut_indicies.append(index) |
|
current_index += tags_count |
|
|
|
return cut_indicies |
|
|
|
def get_indicies_autoembed(segment: Segment, model, device, threshold): |
|
word_list = segment.words |
|
tagged_wordtiming = bind_wordtimings_to_tags(word_list) |
|
|
|
tag_list = [tag for twt in tagged_wordtiming for tag in twt[1]] |
|
|
|
tag_per_word = [len(twt[1]) for twt in tagged_wordtiming] |
|
|
|
embedded_tags = lookup_tag_list(tag_list) |
|
embedded_tags = torch.from_numpy(embedded_tags).int().to(device) |
|
|
|
output = model(embedded_tags[None, :].to(device)) |
|
|
|
list_output = output.detach().cpu().numpy().tolist()[0] |
|
|
|
current_index = 0 |
|
cut_indicies = [] |
|
for index, tags_count in enumerate(tag_per_word): |
|
tags = list_output[current_index:current_index+tags_count] |
|
if max(tags) > threshold: |
|
cut_indicies.append(index) |
|
current_index += tags_count |
|
|
|
return cut_indicies |
|
|