metricv's picture
Update model
d06f65e verified
import nltk
from nltk.tag import PerceptronTagger
from stable_whisper.result import WordTiming
import numpy as np
import torch
additional_tags = {
"as": "`AS",
"and": "`AND",
"of": "`OF",
"how": "`HOW",
"but": "`BUT",
"the": "`THE",
"a": "`A",
"an": "`A",
"which": "`WHICH",
"what": "`WHAT",
"where": "`WHERE",
"that": "`THAT",
"who": "`WHO",
"when": "`WHEN",
}
def get_upenn_tags_dict():
# tagger = PerceptronTagger()
# tags = list(tagger.tagdict.values())
# # https://www.ling.upenn.edu/courses/Fall_2003/ling001/penn_treebank_pos.html
# tags.extend(["CC", "CD", "DT", "EX", "FW", "IN", "JJ", "JJR", "JJS", "LS", "MD", "NN", "NNS", "NNP", "NNPS", "PDT", "POS", "PRP", "PRP$", "RB", "RBR", "RBS", "RP", "SYM", "TO", "UH", "VB", "VBD", "VBG", "VBN", "VBP", "VBZ", "WDT", "WP", "WP$", "WRB"])
# tags = list(set(tags))
# tags.sort()
# tags.append("BREAK")
# tags_dict = dict()
# for index, tag in enumerate(tags):
# tags_dict[tag] = index
return {'#': 0, '$': 1, "''": 2,'(': 3,')': 4,',': 5,'.': 6,':': 7,'CC': 8,'CD': 9,'DT': 10,'EX': 11,'FW': 12,'IN': 13,'JJ': 14,'JJR': 15,'JJS': 16,'LS': 17,'MD': 18,'NN': 19,'NNP': 20,'NNPS': 21,'NNS': 22,'PDT': 23,'POS': 24,'PRP': 25,'PRP$': 26,'RB': 27,'RBR': 28,'RBS': 29,'RP': 30,'SYM': 31,'TO': 32,'UH': 33,'VB': 34,'VBD': 35,'VBG': 36,'VBN': 37,'VBP': 38,'VBZ': 39,'WDT': 40,'WP': 41,'WP$': 42,'WRB': 43,'``': 44,'BREAK': 45,
'`AS': 46,
'`AND': 47,
'`OF': 48,
'`HOW': 49,
'`BUT': 50,
'`THE': 51,
'`A': 52,
'`WHICH': 53,
'`WHAT': 54,
'`WHERE': 55,
'`THAT': 56,
'`WHO': 57,
'`WHEN': 58
}
def nltk_extend_tags(tagged_text: list[tuple[str, str]]):
result = []
for text, tag in tagged_text:
text_lower = text.lower().strip()
if text_lower in additional_tags:
yield (text, additional_tags[text_lower])
else:
yield (text, tag)
def bind_wordtimings_to_tags(wt: list[WordTiming]):
raw_words = [w.word for w in wt]
tokenized_raw_words = []
tokens_wordtiming_map = []
for word in raw_words:
tokens_word = nltk.word_tokenize(word)
tokenized_raw_words.extend(tokens_word)
tokens_wordtiming_map.append(len(tokens_word))
tagged_words = nltk.pos_tag(tokenized_raw_words)
tagged_words = list(nltk_extend_tags(tagged_words))
grouped_tags = []
for k in tokens_wordtiming_map:
grouped_tags.append(tagged_words[:k])
tagged_words = tagged_words[k:]
tags_only = [tuple([w[1] for w in t]) for t in grouped_tags]
wordtimings_with_tags = zip(wt, tags_only)
return list(wordtimings_with_tags)
def embed_tag_list(tags: list[str]):
tags_dict = get_upenn_tags_dict()
eye = np.eye(len(tags_dict))
return eye[np.array([tags_dict[tag] for tag in tags])]
def lookup_tag_list(tags: list[str]):
tags_dict = get_upenn_tags_dict()
return np.array([tags_dict[tag] for tag in tags], dtype=int)
def tag_training_data(filename: str):
with open(filename, "r") as f:
segmented_lines = f.readlines()
segmented_lines = [s.strip() for s in segmented_lines if s.strip() != ""]
# Regain the full text for more accurate tagging.
full_text = " ".join(segmented_lines)
tokenized_full_text = nltk.word_tokenize(full_text)
tagged_full_text = nltk.pos_tag(tokenized_full_text)
tagged_full_text = list(nltk_extend_tags(tagged_full_text))
tagged_full_text_copy = tagged_full_text
reconstructed_tags = []
for line in segmented_lines:
line_nospace = line.replace(r" ", "")
found = False
for i in range(len(tagged_full_text_copy)+1):
rejoined = "".join([x[0] for x in tagged_full_text_copy[:i]])
if line_nospace == rejoined:
found = True
reconstructed_tags.append(tagged_full_text_copy[:i])
tagged_full_text_copy = tagged_full_text_copy[i:]
continue;
if found == False:
print("Panic. Cannot match further.")
print(f"Was trying to match: {line}")
print(tagged_full_text_copy)
return reconstructed_tags
def parse_tags(reconstructed_tags):
"""
Parse reconstructed tags into input/tag datapoint.
In the original plan, this type of output is suitable for bidirectional LSTM.
Input:
reconstured_tags:
Tagged segments, from tag_training_data()
Example: [
[('You', 'PRP'), ("'re", 'VBP'), ('back', 'RB'), ('again', 'RB'), ('?', '.')],
[('You', 'PRP'),("'ve", 'VBP'), ('been', 'VBN'), ('consuming', 'VBG'), ('a', 'DT'), ('lot', 'NN'), ('of', 'IN'), ('tech', 'JJ'), ('news', 'NN'), ('lately', 'RB'), ('.', '.')]
...
]
Output:
(input_tokens, output_tag)
input_tokens:
A sequence of tokens, each number corresponds to a type of word.
Example: [25, 38, 27, 27, 6, 25, 38, 37, 36, 10, 19, 13, 14, 19, 27, 6]
output_tags:
A sequence of 0 and 1, indicating whether a break should be inserted AFTER each location.
Example: [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
"""
tags_dict = get_upenn_tags_dict()
all_tags_sequence = [[y[1] for y in segments] + ['BREAK'] for segments in reconstructed_tags]
all_tags_sequence = [tag for tags in all_tags_sequence for tag in tags]
input_tokens = []
output_tag = []
for token in all_tags_sequence:
if token != 'BREAK':
input_tokens.append(tags_dict[token])
output_tag.append(0)
else:
output_tag[-1] = 1
return input_tokens, output_tag
def embed_segments(tagged_segments):
tags, tags_dict = get_upenn_tags_dict()
for index, tag in enumerate(tags):
tags_dict[tag] = index
result_embedding = []
classes = len(tags)
eye = np.eye(classes)
for segment in tagged_segments:
targets = np.array([tags_dict[tag] for word, tag in segment])
segment_embedding = eye[targets]
result_embedding.append(segment_embedding)
result_embedding.append(np.array([eye[tags_dict["BREAK"]]]))
result_embedding = np.concatenate(result_embedding)
return result_embedding, tags_dict
def window_embedded_segments_rnn(embeddings, tags_dict):
datapoints = []
eye = np.eye(len(tags_dict))
break_vector = eye[tags_dict["BREAK"]]
for i in range(1, embeddings.shape[0]):
# Should we insert a break BEFORE token i?
if (embeddings[i] == break_vector).all():
continue
else:
prev_sequence = embeddings[:i]
if (prev_sequence[-1] == break_vector).all():
# It should break here. Remove the break and set tag as 1.
prev_sequence = prev_sequence[:-1]
tag = 1
else:
# It should not break here.
tag = 0
entire_sequence = np.concatenate((prev_sequence, np.array([embeddings[i]])))
datapoints.append((entire_sequence, tag))
return datapoints
def print_dataset(datapoints, tags_dict, tokenized_full_text):
eye = np.eye(len(tags_dict))
break_vector = eye[tags_dict["BREAK"]]
for input, tag in datapoints:
if tag == 1:
print("[1] ", end='')
else:
print("[0] ", end='')
count = 0
for v in input:
if not (v == break_vector).all():
count += 1
# print(input)
# count = np.count_nonzero(input != break_vector)
segment = tokenized_full_text[:count]
print(segment)
from stable_whisper.result import Segment # Just for typing
def get_indicies(segment: Segment, model, device, threshold):
word_list = segment.words
tagged_wordtiming = bind_wordtimings_to_tags(word_list)
tag_list = [tag for twt in tagged_wordtiming for tag in twt[1]]
tag_per_word = [len(twt[1]) for twt in tagged_wordtiming]
embedded_tags = embed_tag_list(tag_list)
embedded_tags = torch.from_numpy(embedded_tags).float()
output = model(embedded_tags[None, :].to(device))
list_output = output.detach().cpu().numpy().tolist()[0]
current_index = 0
cut_indicies = []
for index, tags_count in enumerate(tag_per_word):
tags = list_output[current_index:current_index+tags_count]
if max(tags) > threshold:
cut_indicies.append(index)
current_index += tags_count
return cut_indicies
def get_indicies_autoembed(segment: Segment, model, device, threshold):
word_list = segment.words
tagged_wordtiming = bind_wordtimings_to_tags(word_list)
tag_list = [tag for twt in tagged_wordtiming for tag in twt[1]]
tag_per_word = [len(twt[1]) for twt in tagged_wordtiming]
embedded_tags = lookup_tag_list(tag_list)
embedded_tags = torch.from_numpy(embedded_tags).int().to(device)
output = model(embedded_tags[None, :].to(device))
list_output = output.detach().cpu().numpy().tolist()[0]
current_index = 0
cut_indicies = []
for index, tags_count in enumerate(tag_per_word):
tags = list_output[current_index:current_index+tags_count]
if max(tags) > threshold:
cut_indicies.append(index)
current_index += tags_count
return cut_indicies