# dataset.py from typing import List, Dict import torch from torch.utils.data import Dataset from utils import Vocab import numpy as np import re class SeqClsDataset(Dataset): def __init__( self, data: List[Dict], vocab: Vocab, label_mapping: Dict[str, int], max_len: int, ): self.data = data self.vocab = vocab self.label_mapping = label_mapping self._idx2label = {idx: intent for intent, idx in self.label_mapping.items()} self.max_len = max_len def __len__(self) -> int: return len(self.data) def __getitem__(self, index) -> Dict: instance = self.data[index] return instance @property def num_classes(self) -> int: return len(self.label_mapping) def label2idx(self, label: str): return self.label_mapping[label] def idx2label(self, idx: int): return self._idx2label[idx] class SeqTaggingClsDataset(SeqClsDataset): def collate_fn(self, samples: List[Dict]) -> Dict: batch_size = len(samples['tokens']) tokens = samples["tokens"] tags = samples["tags"] # list[str] batch_data = self.vocab.token_to_id("[PAD]") * np.ones((batch_size, self.max_len)) batch_labels = -1 * np.ones((batch_size, self.max_len)) # Copy the data to the numpy array for j in range(batch_size): tokens[j] = eval(tokens[j]) cur_len = len(tokens[j]) tags[j] = [self.label_mapping["O"]] * cur_len batch_data[j][:cur_len] = self.vocab.encode(tokens[j]) batch_labels[j][:cur_len] = tags[j] # Convert integer index sequences to PyTorch tensors batch_data = torch.LongTensor(batch_data) batch_labels = torch.LongTensor(batch_labels) # Create a batch data dictionary batch_data = { "encoded_tokens": batch_data, "encoded_tags": batch_labels } return batch_data