sohomghosh
commited on
Commit
•
9f31d25
1
Parent(s):
ee4af77
Update README.md
Browse files
README.md
CHANGED
@@ -95,7 +95,7 @@ class BERTClass(torch.nn.Module):
|
|
95 |
output = self.classifier(pooler)
|
96 |
return output
|
97 |
|
98 |
-
def do_predict(model, tokenizer):
|
99 |
test_set = Triage(test_df, tokenizer, MAX_LEN, text_col_name)
|
100 |
test_params = {'batch_size' : BATCH_SIZE, 'shuffle': False, 'num_workers':0}
|
101 |
test_loader = DataLoader(test_set, **test_params)
|
@@ -119,7 +119,7 @@ model_read.to(device)
|
|
119 |
model_read.load_stat_dict(torch.load('pytorch_model.bin', map_location=device)['model_state_dict'])
|
120 |
|
121 |
tokenizer_read = BertTokenizer.from_pretrained('ProsusAI/finbert')
|
122 |
-
actual_predictions_read = do_predict(model_read, tokenizer_read)
|
123 |
|
124 |
test_df['readability'] = ['readable' if i==1 else 'not_reabale' for i in actual_predictions_read]
|
125 |
|
|
|
95 |
output = self.classifier(pooler)
|
96 |
return output
|
97 |
|
98 |
+
def do_predict(model, tokenizer, test_df):
|
99 |
test_set = Triage(test_df, tokenizer, MAX_LEN, text_col_name)
|
100 |
test_params = {'batch_size' : BATCH_SIZE, 'shuffle': False, 'num_workers':0}
|
101 |
test_loader = DataLoader(test_set, **test_params)
|
|
|
119 |
model_read.load_stat_dict(torch.load('pytorch_model.bin', map_location=device)['model_state_dict'])
|
120 |
|
121 |
tokenizer_read = BertTokenizer.from_pretrained('ProsusAI/finbert')
|
122 |
+
actual_predictions_read = do_predict(model_read, tokenizer_read, test_df)
|
123 |
|
124 |
test_df['readability'] = ['readable' if i==1 else 'not_reabale' for i in actual_predictions_read]
|
125 |
|