Spaces:
Runtime error
Runtime error
fixes
Browse files- src/models/model.py +2 -2
src/models/model.py
CHANGED
@@ -15,7 +15,6 @@ from pytorch_lightning import LightningModule
|
|
15 |
from datasets import load_metric
|
16 |
from tqdm.auto import tqdm
|
17 |
|
18 |
-
|
19 |
# from dagshub.pytorch_lightning import DAGsHubLogger
|
20 |
|
21 |
|
@@ -477,9 +476,10 @@ class Summarization:
|
|
477 |
metric = load_metric(metrics)
|
478 |
input_text = test_df['input_text'][:5]
|
479 |
references = test_df['output_text'][:5]
|
|
|
480 |
|
481 |
predictions = [self.predict(x) for x in input_text]
|
482 |
-
print(type(predictions),type(references))
|
483 |
|
484 |
results = metric.compute(predictions=predictions, references=references)
|
485 |
'''
|
|
|
15 |
from datasets import load_metric
|
16 |
from tqdm.auto import tqdm
|
17 |
|
|
|
18 |
# from dagshub.pytorch_lightning import DAGsHubLogger
|
19 |
|
20 |
|
|
|
476 |
metric = load_metric(metrics)
|
477 |
input_text = test_df['input_text'][:5]
|
478 |
references = test_df['output_text'][:5]
|
479 |
+
references = references.to_list()
|
480 |
|
481 |
predictions = [self.predict(x) for x in input_text]
|
482 |
+
print(type(predictions), type(references))
|
483 |
|
484 |
results = metric.compute(predictions=predictions, references=references)
|
485 |
'''
|