Update segmenter
Browse files
data
CHANGED
@@ -1 +1 @@
|
|
1 |
-
Subproject commit
|
|
|
1 |
+
Subproject commit dd266799aedd72e6381b368eacbe2767b6174aad
|
model.py
CHANGED
@@ -5,6 +5,8 @@ from torch.utils.data import Dataset, DataLoader
|
|
5 |
import numpy as np
|
6 |
from os import listdir
|
7 |
from os.path import isfile, join
|
|
|
|
|
8 |
|
9 |
if __package__ == None or __package__ == "":
|
10 |
from utils import tag_training_data, get_upenn_tags_dict, parse_tags
|
@@ -79,20 +81,38 @@ class SegmentorDatasetDirectTag(Dataset):
|
|
79 |
|
80 |
# The same dataset without one-hot embedding of the input.
|
81 |
class SegmentorDatasetNonEmbed(Dataset):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
def __init__(self, document_root: str):
|
83 |
self.datapoints = []
|
84 |
|
85 |
files = listdir(document_root)
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
96 |
|
97 |
def __len__(self):
|
98 |
return len(self.datapoints)
|
|
|
5 |
import numpy as np
|
6 |
from os import listdir
|
7 |
from os.path import isfile, join
|
8 |
+
import concurrent
|
9 |
+
import itertools
|
10 |
|
11 |
if __package__ == None or __package__ == "":
|
12 |
from utils import tag_training_data, get_upenn_tags_dict, parse_tags
|
|
|
81 |
|
82 |
# The same dataset without one-hot embedding of the input.
|
83 |
class SegmentorDatasetNonEmbed(Dataset):
|
84 |
+
@staticmethod
|
85 |
+
def read_file(f: str, document_root: str):
|
86 |
+
if f.endswith(".txt"):
|
87 |
+
fname = join(document_root, f)
|
88 |
+
print(f"Loaded datafile: {fname}")
|
89 |
+
reconstructed_tags = tag_training_data(fname)
|
90 |
+
input, tag = parse_tags(reconstructed_tags)
|
91 |
+
return [(
|
92 |
+
np.array(input),
|
93 |
+
np.array(tag)
|
94 |
+
)]
|
95 |
+
else:
|
96 |
+
return []
|
97 |
+
|
98 |
def __init__(self, document_root: str):
|
99 |
self.datapoints = []
|
100 |
|
101 |
files = listdir(document_root)
|
102 |
+
with concurrent.futures.ProcessPoolExecutor() as pool:
|
103 |
+
out = pool.map(SegmentorDatasetNonEmbed.read_file, files, itertools.repeat(document_root))
|
104 |
+
|
105 |
+
self.datapoints = list(itertools.chain.from_iterable(out))
|
106 |
+
# for f in files:
|
107 |
+
# if f.endswith(".txt"):
|
108 |
+
# fname = join(document_root, f)
|
109 |
+
# print(f"Loaded datafile: {fname}")
|
110 |
+
# reconstructed_tags = tag_training_data(fname)
|
111 |
+
# input, tag = parse_tags(reconstructed_tags)
|
112 |
+
# self.datapoints.append((
|
113 |
+
# np.array(input),
|
114 |
+
# np.array(tag)
|
115 |
+
# ))
|
116 |
|
117 |
def __len__(self):
|
118 |
return len(self.datapoints)
|
segmenter.ckpt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 10584544
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:005053e2036ac4a30364cdb81501140ef2ca238bee0f9a1a28fc5a4603d725f6
|
3 |
size 10584544
|
train.py
CHANGED
@@ -26,6 +26,6 @@ if __name__ == "__main__":
|
|
26 |
|
27 |
model.to(device)
|
28 |
|
29 |
-
train_bidirlstm_embedding_model(model, dataset, num_epochs=
|
30 |
|
31 |
torch.save(model.state_dict(), "segmenter.ckpt")
|
|
|
26 |
|
27 |
model.to(device)
|
28 |
|
29 |
+
train_bidirlstm_embedding_model(model, dataset, num_epochs=100, batch_size=2)
|
30 |
|
31 |
torch.save(model.state_dict(), "segmenter.ckpt")
|