File size: 1,385 Bytes
f953fd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96a7d84
f953fd7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import os

if __package__ == None or __package__ == "":
    from model import BidirLSTMSegmenter, SegmentorDatasetDirectTag, train_bidirlstm_model
    from model import BidirLSTMSegmenterWithEmbedding, SegmentorDatasetNonEmbed, train_bidirlstm_embedding_model
    from utils import get_upenn_tags_dict
    from model_consts import input_size, embedding_size, hidden_size, num_layers
    data_path = "data"
else:
    from .model import BidirLSTMSegmenter, SegmentorDatasetDirectTag, train_bidirlstm_model
    from .model import BidirLSTMSegmenterWithEmbedding, SegmentorDatasetNonEmbed, train_bidirlstm_embedding_model
    from .utils import get_upenn_tags_dict
    from .model_consts import input_size, embedding_size, hidden_size, num_layers
    data_path = "segmenter/data"

device = "cuda"

if __name__ == "__main__":
    dataset = SegmentorDatasetNonEmbed(data_path)
    model = BidirLSTMSegmenterWithEmbedding(input_size, embedding_size, hidden_size, num_layers, device)

    if os.path.exists("segmenter.ckpt") and os.path.isfile("segmenter.ckpt"):
        print("Loading checkpoint. If you want to start from scratch, remove segmenter.ckpt.")
        model.load_state_dict(torch.load("segmenter.ckpt"))
    
    model.to(device)

    train_bidirlstm_embedding_model(model, dataset, num_epochs=100, batch_size=2)

    torch.save(model.state_dict(), "segmenter.ckpt")