sohomghosh commited on
Commit
ee4af77
1 Parent(s): 54ef9da

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +5 -7
README.md CHANGED
@@ -43,13 +43,13 @@ class Triage(Dataset):
43
  This is a subclass of torch packages Dataset class. It processes input to create ids, masks and targets required for model training.
44
  """
45
 
46
- def __init__(self, dataframe, tokenizer, max_len, text_col_name, category_col):
47
  self.len = len(dataframe)
48
  self.data = dataframe
49
  self.tokenizer = tokenizer
50
  self.max_len = max_len
51
  self.text_col_name = text_col_name
52
- self.category_col = category_col
53
 
54
  def __getitem__(self, index):
55
  title = str(self.data[self.text_col_name][index])
@@ -69,9 +69,7 @@ class Triage(Dataset):
69
  return {
70
  "ids": torch.tensor(ids, dtype=torch.long),
71
  "mask": torch.tensor(mask, dtype=torch.long),
72
- "targets": torch.tensor(
73
- self.data[self.category_col][index], dtype=torch.long
74
- ),
75
  }
76
 
77
  def __len__(self):
@@ -97,7 +95,7 @@ class BERTClass(torch.nn.Module):
97
  output = self.classifier(pooler)
98
  return output
99
 
100
- def do_predict(tokenizer):
101
  test_set = Triage(test_df, tokenizer, MAX_LEN, text_col_name)
102
  test_params = {'batch_size' : BATCH_SIZE, 'shuffle': False, 'num_workers':0}
103
  test_loader = DataLoader(test_set, **test_params)
@@ -121,7 +119,7 @@ model_read.to(device)
121
  model_read.load_stat_dict(torch.load('pytorch_model.bin', map_location=device)['model_state_dict'])
122
 
123
  tokenizer_read = BertTokenizer.from_pretrained('ProsusAI/finbert')
124
- actual_predictions_read = do_predict(tokenizer_read)
125
 
126
  test_df['readability'] = ['readable' if i==1 else 'not_reabale' for i in actual_predictions_read]
127
 
 
43
  This is a subclass of torch packages Dataset class. It processes input to create ids, masks and targets required for model training.
44
  """
45
 
46
+ def __init__(self, dataframe, tokenizer, max_len, text_col_name):
47
  self.len = len(dataframe)
48
  self.data = dataframe
49
  self.tokenizer = tokenizer
50
  self.max_len = max_len
51
  self.text_col_name = text_col_name
52
+
53
 
54
  def __getitem__(self, index):
55
  title = str(self.data[self.text_col_name][index])
 
69
  return {
70
  "ids": torch.tensor(ids, dtype=torch.long),
71
  "mask": torch.tensor(mask, dtype=torch.long),
72
+
 
 
73
  }
74
 
75
  def __len__(self):
 
95
  output = self.classifier(pooler)
96
  return output
97
 
98
+ def do_predict(model, tokenizer):
99
  test_set = Triage(test_df, tokenizer, MAX_LEN, text_col_name)
100
  test_params = {'batch_size' : BATCH_SIZE, 'shuffle': False, 'num_workers':0}
101
  test_loader = DataLoader(test_set, **test_params)
 
119
  model_read.load_stat_dict(torch.load('pytorch_model.bin', map_location=device)['model_state_dict'])
120
 
121
  tokenizer_read = BertTokenizer.from_pretrained('ProsusAI/finbert')
122
+ actual_predictions_read = do_predict(model_read, tokenizer_read)
123
 
124
  test_df['readability'] = ['readable' if i==1 else 'not_reabale' for i in actual_predictions_read]
125