gagan3012 commited on
Commit
da13cde
·
1 Parent(s): f01fda5
Files changed (1) hide show
  1. src/models/model.py +2 -2
src/models/model.py CHANGED
@@ -15,7 +15,6 @@ from pytorch_lightning import LightningModule
15
  from datasets import load_metric
16
  from tqdm.auto import tqdm
17
 
18
-
19
  # from dagshub.pytorch_lightning import DAGsHubLogger
20
 
21
 
@@ -477,9 +476,10 @@ class Summarization:
477
  metric = load_metric(metrics)
478
  input_text = test_df['input_text'][:5]
479
  references = test_df['output_text'][:5]
 
480
 
481
  predictions = [self.predict(x) for x in input_text]
482
- print(type(predictions),type(references))
483
 
484
  results = metric.compute(predictions=predictions, references=references)
485
  '''
 
15
  from datasets import load_metric
16
  from tqdm.auto import tqdm
17
 
 
18
  # from dagshub.pytorch_lightning import DAGsHubLogger
19
 
20
 
 
476
  metric = load_metric(metrics)
477
  input_text = test_df['input_text'][:5]
478
  references = test_df['output_text'][:5]
479
+ references = references.to_list()
480
 
481
  predictions = [self.predict(x) for x in input_text]
482
+ print(type(predictions), type(references))
483
 
484
  results = metric.compute(predictions=predictions, references=references)
485
  '''