import torch from torch.utils.data import Dataset class SpamMessageDataset(Dataset): def __init__(self, text, labels, tokenizer, max_length): self.text = text labels = [1 if label == 'spam' else 0 for label in labels] self.labels = torch.tensor(labels, dtype=torch.long) self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.text) def __getitem__(self, idx): text = str(self.text[idx]) label = self.labels[idx].clone().detach() encoding = self.tokenizer.encode_plus( text, add_special_tokens=True, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt' ) input_ids = encoding['input_ids'].squeeze() attention_mask = encoding['attention_mask'].squeeze() return { 'input_ids': input_ids, 'attention_mask': attention_mask, 'label': label }