from typing import Any import torch from torch import nn from torch.utils.data import Dataset, DataLoader import numpy as np from os import listdir from os.path import isfile, join import concurrent import itertools if __package__ == None or __package__ == "": from utils import tag_training_data, get_upenn_tags_dict, parse_tags else: from .utils import tag_training_data, get_upenn_tags_dict, parse_tags # Model Type 1: LSTM with 1-logit lookahead. class SegmentorDataset(Dataset): def __init__(self, datapoints): self.datapoints = [(torch.from_numpy(k).float(), torch.tensor([t]).float()) for k, t in datapoints] def __len__(self): return len(self.datapoints) def __getitem__(self, idx): return self.datapoints[idx][0], self.datapoints[idx][1] class RNN(nn.Module): def __init__(self, input_size, hidden_size, num_layers, device=None): super(RNN, self).__init__() if device == None: if torch.cuda.is_available(): self.device = "cuda" else: self.device = "cpu" else: self.device = device self.num_layers = num_layers self.hidden_size = hidden_size self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) self.fc = nn.Linear(hidden_size, 1) def forward(self, x): h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, device=self.device) c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, device=self.device) out, _ = self.rnn(x, (h0, c0)) out = out[:, -1, :] out = self.fc(out) return out # Model 2: Bidirectional LSTM with entire sequence context (hopefully) class SegmentorDatasetDirectTag(Dataset): def __init__(self, document_root: str): self.tags_dict = get_upenn_tags_dict() self.datapoints = [] self.eye = np.eye(len(self.tags_dict)) files = listdir(document_root) for f in files: if f.endswith(".txt"): fname = join(document_root, f) print(f"Loaded datafile: {fname}") reconstructed_tags = tag_training_data(fname) input, tag = parse_tags(reconstructed_tags) self.datapoints.append(( np.array(input), np.array(tag) )) def __len__(self): return len(self.datapoints) def __getitem__(self, idx): item = self.datapoints[idx] return torch.from_numpy(self.eye[item[0]]).float(), torch.from_numpy(item[1]).float() # The same dataset without one-hot embedding of the input. class SegmentorDatasetNonEmbed(Dataset): @staticmethod def read_file(f: str, document_root: str): if f.endswith(".txt"): fname = join(document_root, f) print(f"Loaded datafile: {fname}") reconstructed_tags = tag_training_data(fname) input, tag = parse_tags(reconstructed_tags) return [( np.array(input), np.array(tag) )] else: return [] def __init__(self, document_root: str): self.datapoints = [] files = listdir(document_root) with concurrent.futures.ProcessPoolExecutor() as pool: out = pool.map(SegmentorDatasetNonEmbed.read_file, files, itertools.repeat(document_root)) self.datapoints = list(itertools.chain.from_iterable(out)) # for f in files: # if f.endswith(".txt"): # fname = join(document_root, f) # print(f"Loaded datafile: {fname}") # reconstructed_tags = tag_training_data(fname) # input, tag = parse_tags(reconstructed_tags) # self.datapoints.append(( # np.array(input), # np.array(tag) # )) def __len__(self): return len(self.datapoints) def __getitem__(self, idx): item = self.datapoints[idx] return torch.from_numpy(item[0]).int(), torch.from_numpy(item[1]).float() class BidirLSTMSegmenter(nn.Module): def __init__(self, input_size, hidden_size, num_layers, device = None): super(BidirLSTMSegmenter, self).__init__() if device == None: if torch.cuda.is_available(): self.device = "cuda" else: self.device = "cpu" else: self.device = device self.num_layers = num_layers self.hidden_size = hidden_size self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True, device = self.device) self.fc = nn.Linear(2*hidden_size, 1, device = self.device) self.final = nn.Sigmoid() def forward(self, x): h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size, device=self.device) c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size, device=self.device) out, _ = self.rnn(x, (h0, c0)) # out_fced = [torch.zeros((out.shape[0], out.shape[1]), device=device)] # # Shape of out: [batch, seq_length, 256 (num_directions * hidden_size)] # for i in range(out.shape[1]): # out_fced[:, i] = self.fc(out[:, i, :])[0] out_fced = self.fc(out)[:, :, 0] # Shape of out: return self.final(out_fced) class BidirLSTMSegmenterWithEmbedding(nn.Module): def __init__(self, input_size, embedding_size, hidden_size, num_layers, device = None): super(BidirLSTMSegmenterWithEmbedding, self).__init__() if device == None: if torch.cuda.is_available(): self.device = "cuda" else: self.device = "cpu" else: self.device = device self.num_layers = num_layers self.hidden_size = hidden_size self.embedding_size = embedding_size self.embedding = nn.Embedding(input_size, embedding_dim=embedding_size, device = self.device) self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers, batch_first=True, bidirectional=True, device = self.device) self.fc = nn.Linear(2*hidden_size, 1, device = self.device) self.final = nn.Sigmoid() def forward(self, x): h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size, device=self.device) c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size, device=self.device) embedded = self.embedding(x) out, _ = self.rnn(embedded, (h0, c0)) # out_fced = [torch.zeros((out.shape[0], out.shape[1]), device=device)] # # Shape of out: [batch, seq_length, 256 (num_directions * hidden_size)] # for i in range(out.shape[1]): # out_fced[:, i] = self.fc(out[:, i, :])[0] out_fced = self.fc(out)[:, :, 0] # Shape of out: return self.final(out_fced) def collate_fn_padd(batch): ''' Padds batch of variable length note: it converts things ToTensor manually here since the ToTensor transform assume it takes in images rather than arbitrary tensors. ''' ## get sequence lengths inputs = [i[0] for i in batch] tags = [i[1] for i in batch] padded_input = torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True) combined_outputs = torch.nn.utils.rnn.pad_sequence(tags, batch_first=True) ## compute mask return (padded_input, combined_outputs) def get_dataloader(dataset: SegmentorDataset, batch_size): return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_padd) def train_model(model: RNN, dataset, lr = 1e-3, num_epochs = 3, batch_size = 100, ): train_loader = get_dataloader(dataset, batch_size=batch_size) n_total_steps = len(train_loader) criterion = nn.MSELoss() optimizer = torch.optim.AdamW(model.parameters(), lr=lr) device = model.device for epoch in range(num_epochs): for i, (input, tags) in enumerate(train_loader): input = input.to(device) tags = tags.to(device) outputs = model(input) loss = criterion(outputs, tags) optimizer.zero_grad() loss.backward() optimizer.step() if i%100 == 0: print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss [{loss.item():.4f}]") def train_bidirlstm_model(model: BidirLSTMSegmenter, dataset: SegmentorDatasetDirectTag, lr = 1e-3, num_epochs = 3, batch_size = 1, ): train_loader = get_dataloader(dataset, batch_size=batch_size) n_total_steps = len(train_loader) criterion = nn.BCELoss() optimizer = torch.optim.AdamW(model.parameters(), lr=lr) device = model.device for epoch in range(num_epochs): for i, (input, tags) in enumerate(train_loader): input = input.to(device) tags = tags.to(device) optimizer.zero_grad() outputs = model(input) loss = criterion(outputs, tags) loss.backward() optimizer.step() if i%10 == 0: print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss [{loss.item():.4f}]") def train_bidirlstm_embedding_model(model: BidirLSTMSegmenterWithEmbedding, dataset: SegmentorDatasetNonEmbed, lr = 1e-3, num_epochs = 3, batch_size = 1, ): train_loader = get_dataloader(dataset, batch_size=batch_size) n_total_steps = len(train_loader) criterion = nn.BCELoss() optimizer = torch.optim.AdamW(model.parameters(), lr=lr) device = model.device for epoch in range(num_epochs): for i, (input, tags) in enumerate(train_loader): input = input.to(device) tags = tags.to(device) optimizer.zero_grad() outputs = model(input) loss = criterion(outputs, tags) loss.backward() optimizer.step() if i%10 == 0: print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss [{loss.item():.4f}]")