Spaces:
Running
Running
File size: 3,084 Bytes
5fd3fc6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import unittest
import pandas as pd
from davisinteractive.common import Path, patch
from davisinteractive.dataset import Davis
from davisinteractive.evaluation import EvaluationService
from davisinteractive.utils.scribbles import annotated_frames, is_empty
class TestEvaluationService(unittest.TestCase):
@patch.object(Davis, 'check_files', return_value=True)
def test_start(self, mock_davis):
assert mock_davis.call_count == 0
service = EvaluationService('train', davis_root='/tmp/DAVIS')
assert mock_davis.call_count == 1
samples, max_t, max_i = service.get_samples()
assert max_t is None
assert max_i is None
for seq in Davis.sets['train']:
nb_scribbles = Davis.dataset[seq]['num_scribbles']
for i in range(nb_scribbles):
assert (seq, i + 1) in samples
service = EvaluationService('val', davis_root='/tmp/DAVIS')
samples, max_t, max_i = service.get_samples()
assert mock_davis.call_count == 2
assert max_t is None
assert max_i is None
for seq in Davis.sets['val']:
nb_scribbles = Davis.dataset[seq]['num_scribbles']
for i in range(nb_scribbles):
assert (seq, i + 1) in samples
@patch.object(Davis, 'check_files', return_value=True)
def test_num_entries(self, mock_davis):
assert mock_davis.call_count == 0
service = EvaluationService('val', davis_root='/tmp/DAVIS')
assert mock_davis.call_count == 1
assert service.num_entries == 11952
service = EvaluationService('val', davis_root='/tmp/DAVIS', max_i=5)
assert mock_davis.call_count == 2
assert service.num_entries == 59760
@patch.object(Davis, 'check_files', return_value=True)
def test_starting_scribble(self, _):
dataset_dir = Path(__file__).parent.parent.joinpath(
'dataset', 'test_data', 'DAVIS')
service = EvaluationService('train', davis_root=dataset_dir)
service.get_samples()
scribble = service.get_scribble('bear', 1)
assert scribble['sequence'] == 'bear'
assert not is_empty(scribble)
assert annotated_frames(scribble) == [39]
@patch.object(Davis, 'check_files', return_value=True)
def test_exceptions(self, _):
dataset_dir = Path(__file__).parent.parent.joinpath(
'dataset', 'test_data', 'DAVIS')
service = EvaluationService('train', max_i=1, davis_root=dataset_dir)
service.get_samples()
with self.assertRaises(ValueError):
service.post_predicted_masks('bear', 1, None, 0, 2, None, None)
with self.assertRaises(ValueError):
service.post_predicted_masks('bear', 1, None, 0, 0, None, None)
with self.assertRaises(ValueError):
service.post_predicted_masks('novalidsequence', 1, None, 0, 1, None,
None)
with self.assertRaises(ValueError):
service.post_predicted_masks('bear', 4, None, 0, 1, None, None)
|