sohomghosh
commited on
Commit
•
ed4339d
1
Parent(s):
ed29696
Update README.md
Browse files
README.md
CHANGED
@@ -95,7 +95,7 @@ class BERTClass(torch.nn.Module):
|
|
95 |
output = self.classifier(pooler)
|
96 |
return output
|
97 |
|
98 |
-
def do_predict(model, tokenizer):
|
99 |
test_set = Triage(test_df, tokenizer, MAX_LEN, text_col_name)
|
100 |
test_params = {'batch_size' : BATCH_SIZE, 'shuffle': False, 'num_workers':0}
|
101 |
test_loader = DataLoader(test_set, **test_params)
|
@@ -119,7 +119,7 @@ model_sustain.to(device)
|
|
119 |
model_sustain.load_state_dict(torch.load('pytorch_model.bin', map_location=device)['model_state_dict'])
|
120 |
|
121 |
tokenizer_sus = BertTokenizer.from_pretrained('roberta-base')
|
122 |
-
actual_predictions_sus = do_predict(model_sustain, tokenizer_sus)
|
123 |
|
124 |
test_df['sustainability'] = ['sustainable' if i==0 else 'unsustainable' for i in actual_predictions_read]
|
125 |
```
|
|
|
95 |
output = self.classifier(pooler)
|
96 |
return output
|
97 |
|
98 |
+
def do_predict(model, tokenizer, test_df):
|
99 |
test_set = Triage(test_df, tokenizer, MAX_LEN, text_col_name)
|
100 |
test_params = {'batch_size' : BATCH_SIZE, 'shuffle': False, 'num_workers':0}
|
101 |
test_loader = DataLoader(test_set, **test_params)
|
|
|
119 |
model_sustain.load_state_dict(torch.load('pytorch_model.bin', map_location=device)['model_state_dict'])
|
120 |
|
121 |
tokenizer_sus = BertTokenizer.from_pretrained('roberta-base')
|
122 |
+
actual_predictions_sus = do_predict(model_sustain, tokenizer_sus, test_df)
|
123 |
|
124 |
test_df['sustainability'] = ['sustainable' if i==0 else 'unsustainable' for i in actual_predictions_read]
|
125 |
```
|