Update model
Browse files- model_consts.py +2 -2
- segmenter.ckpt +2 -2
- train.py +1 -1
- utils.py +60 -64
model_consts.py
CHANGED
@@ -4,6 +4,6 @@ else:
|
|
4 |
from .utils import get_upenn_tags_dict
|
5 |
|
6 |
input_size = len(get_upenn_tags_dict())
|
7 |
-
embedding_size =
|
8 |
-
hidden_size =
|
9 |
num_layers = 2
|
|
|
4 |
from .utils import get_upenn_tags_dict
|
5 |
|
6 |
input_size = len(get_upenn_tags_dict())
|
7 |
+
embedding_size = 256
|
8 |
+
hidden_size = 256
|
9 |
num_layers = 2
|
segmenter.ckpt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a8e6209584d0021684bb3a09ec1b717843f3086dfcc6411c57276f743f8e62fa
|
3 |
+
size 10584544
|
train.py
CHANGED
@@ -26,6 +26,6 @@ if __name__ == "__main__":
|
|
26 |
|
27 |
model.to(device)
|
28 |
|
29 |
-
train_bidirlstm_embedding_model(model, dataset, num_epochs=
|
30 |
|
31 |
torch.save(model.state_dict(), "segmenter.ckpt")
|
|
|
26 |
|
27 |
model.to(device)
|
28 |
|
29 |
+
train_bidirlstm_embedding_model(model, dataset, num_epochs=150, batch_size=2)
|
30 |
|
31 |
torch.save(model.state_dict(), "segmenter.ckpt")
|
utils.py
CHANGED
@@ -4,6 +4,64 @@ from stable_whisper.result import WordTiming
|
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
def bind_wordtimings_to_tags(wt: list[WordTiming]):
|
8 |
raw_words = [w.word for w in wt]
|
9 |
|
@@ -16,6 +74,7 @@ def bind_wordtimings_to_tags(wt: list[WordTiming]):
|
|
16 |
tokens_wordtiming_map.append(len(tokens_word))
|
17 |
|
18 |
tagged_words = nltk.pos_tag(tokenized_raw_words)
|
|
|
19 |
|
20 |
grouped_tags = []
|
21 |
|
@@ -49,6 +108,7 @@ def tag_training_data(filename: str):
|
|
49 |
|
50 |
tokenized_full_text = nltk.word_tokenize(full_text)
|
51 |
tagged_full_text = nltk.pos_tag(tokenized_full_text)
|
|
|
52 |
|
53 |
tagged_full_text_copy = tagged_full_text
|
54 |
|
@@ -75,70 +135,6 @@ def tag_training_data(filename: str):
|
|
75 |
|
76 |
return reconstructed_tags
|
77 |
|
78 |
-
def get_upenn_tags_dict():
|
79 |
-
# tagger = PerceptronTagger()
|
80 |
-
|
81 |
-
# tags = list(tagger.tagdict.values())
|
82 |
-
|
83 |
-
# # https://www.ling.upenn.edu/courses/Fall_2003/ling001/penn_treebank_pos.html
|
84 |
-
# tags.extend(["CC", "CD", "DT", "EX", "FW", "IN", "JJ", "JJR", "JJS", "LS", "MD", "NN", "NNS", "NNP", "NNPS", "PDT", "POS", "PRP", "PRP$", "RB", "RBR", "RBS", "RP", "SYM", "TO", "UH", "VB", "VBD", "VBG", "VBN", "VBP", "VBZ", "WDT", "WP", "WP$", "WRB"])
|
85 |
-
# tags = list(set(tags))
|
86 |
-
# tags.sort()
|
87 |
-
# tags.append("BREAK")
|
88 |
-
|
89 |
-
# tags_dict = dict()
|
90 |
-
|
91 |
-
# for index, tag in enumerate(tags):
|
92 |
-
# tags_dict[tag] = index
|
93 |
-
|
94 |
-
return {'#': 0,
|
95 |
-
'$': 1,
|
96 |
-
"''": 2,
|
97 |
-
'(': 3,
|
98 |
-
')': 4,
|
99 |
-
',': 5,
|
100 |
-
'.': 6,
|
101 |
-
':': 7,
|
102 |
-
'CC': 8,
|
103 |
-
'CD': 9,
|
104 |
-
'DT': 10,
|
105 |
-
'EX': 11,
|
106 |
-
'FW': 12,
|
107 |
-
'IN': 13,
|
108 |
-
'JJ': 14,
|
109 |
-
'JJR': 15,
|
110 |
-
'JJS': 16,
|
111 |
-
'LS': 17,
|
112 |
-
'MD': 18,
|
113 |
-
'NN': 19,
|
114 |
-
'NNP': 20,
|
115 |
-
'NNPS': 21,
|
116 |
-
'NNS': 22,
|
117 |
-
'PDT': 23,
|
118 |
-
'POS': 24,
|
119 |
-
'PRP': 25,
|
120 |
-
'PRP$': 26,
|
121 |
-
'RB': 27,
|
122 |
-
'RBR': 28,
|
123 |
-
'RBS': 29,
|
124 |
-
'RP': 30,
|
125 |
-
'SYM': 31,
|
126 |
-
'TO': 32,
|
127 |
-
'UH': 33,
|
128 |
-
'VB': 34,
|
129 |
-
'VBD': 35,
|
130 |
-
'VBG': 36,
|
131 |
-
'VBN': 37,
|
132 |
-
'VBP': 38,
|
133 |
-
'VBZ': 39,
|
134 |
-
'WDT': 40,
|
135 |
-
'WP': 41,
|
136 |
-
'WP$': 42,
|
137 |
-
'WRB': 43,
|
138 |
-
'``': 44,
|
139 |
-
'BREAK': 45}
|
140 |
-
|
141 |
-
|
142 |
def parse_tags(reconstructed_tags):
|
143 |
"""
|
144 |
Parse reconstructed tags into input/tag datapoint.
|
|
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
|
7 |
+
additional_tags = {
|
8 |
+
"as": "`AS",
|
9 |
+
"and": "`AND",
|
10 |
+
"of": "`OF",
|
11 |
+
"how": "`HOW",
|
12 |
+
"but": "`BUT",
|
13 |
+
"the": "`THE",
|
14 |
+
"a": "`A",
|
15 |
+
"an": "`A",
|
16 |
+
"which": "`WHICH",
|
17 |
+
"what": "`WHAT",
|
18 |
+
"where": "`WHERE",
|
19 |
+
"that": "`THAT",
|
20 |
+
"who": "`WHO",
|
21 |
+
"when": "`WHEN",
|
22 |
+
}
|
23 |
+
|
24 |
+
def get_upenn_tags_dict():
|
25 |
+
# tagger = PerceptronTagger()
|
26 |
+
|
27 |
+
# tags = list(tagger.tagdict.values())
|
28 |
+
|
29 |
+
# # https://www.ling.upenn.edu/courses/Fall_2003/ling001/penn_treebank_pos.html
|
30 |
+
# tags.extend(["CC", "CD", "DT", "EX", "FW", "IN", "JJ", "JJR", "JJS", "LS", "MD", "NN", "NNS", "NNP", "NNPS", "PDT", "POS", "PRP", "PRP$", "RB", "RBR", "RBS", "RP", "SYM", "TO", "UH", "VB", "VBD", "VBG", "VBN", "VBP", "VBZ", "WDT", "WP", "WP$", "WRB"])
|
31 |
+
# tags = list(set(tags))
|
32 |
+
# tags.sort()
|
33 |
+
# tags.append("BREAK")
|
34 |
+
|
35 |
+
# tags_dict = dict()
|
36 |
+
|
37 |
+
# for index, tag in enumerate(tags):
|
38 |
+
# tags_dict[tag] = index
|
39 |
+
|
40 |
+
return {'#': 0, '$': 1, "''": 2,'(': 3,')': 4,',': 5,'.': 6,':': 7,'CC': 8,'CD': 9,'DT': 10,'EX': 11,'FW': 12,'IN': 13,'JJ': 14,'JJR': 15,'JJS': 16,'LS': 17,'MD': 18,'NN': 19,'NNP': 20,'NNPS': 21,'NNS': 22,'PDT': 23,'POS': 24,'PRP': 25,'PRP$': 26,'RB': 27,'RBR': 28,'RBS': 29,'RP': 30,'SYM': 31,'TO': 32,'UH': 33,'VB': 34,'VBD': 35,'VBG': 36,'VBN': 37,'VBP': 38,'VBZ': 39,'WDT': 40,'WP': 41,'WP$': 42,'WRB': 43,'``': 44,'BREAK': 45,
|
41 |
+
'`AS': 46,
|
42 |
+
'`AND': 47,
|
43 |
+
'`OF': 48,
|
44 |
+
'`HOW': 49,
|
45 |
+
'`BUT': 50,
|
46 |
+
'`THE': 51,
|
47 |
+
'`A': 52,
|
48 |
+
'`WHICH': 53,
|
49 |
+
'`WHAT': 54,
|
50 |
+
'`WHERE': 55,
|
51 |
+
'`THAT': 56,
|
52 |
+
'`WHO': 57,
|
53 |
+
'`WHEN': 58
|
54 |
+
}
|
55 |
+
|
56 |
+
def nltk_extend_tags(tagged_text: list[tuple[str, str]]):
|
57 |
+
result = []
|
58 |
+
for text, tag in tagged_text:
|
59 |
+
text_lower = text.lower().strip()
|
60 |
+
if text_lower in additional_tags:
|
61 |
+
yield (text, additional_tags[text_lower])
|
62 |
+
else:
|
63 |
+
yield (text, tag)
|
64 |
+
|
65 |
def bind_wordtimings_to_tags(wt: list[WordTiming]):
|
66 |
raw_words = [w.word for w in wt]
|
67 |
|
|
|
74 |
tokens_wordtiming_map.append(len(tokens_word))
|
75 |
|
76 |
tagged_words = nltk.pos_tag(tokenized_raw_words)
|
77 |
+
tagged_words = list(nltk_extend_tags(tagged_words))
|
78 |
|
79 |
grouped_tags = []
|
80 |
|
|
|
108 |
|
109 |
tokenized_full_text = nltk.word_tokenize(full_text)
|
110 |
tagged_full_text = nltk.pos_tag(tokenized_full_text)
|
111 |
+
tagged_full_text = list(nltk_extend_tags(tagged_full_text))
|
112 |
|
113 |
tagged_full_text_copy = tagged_full_text
|
114 |
|
|
|
135 |
|
136 |
return reconstructed_tags
|
137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
def parse_tags(reconstructed_tags):
|
139 |
"""
|
140 |
Parse reconstructed tags into input/tag datapoint.
|