# Copyright (c) OpenMMLab. All rights reserved. import os import time from io import StringIO from unittest.mock import patch import mmcv def reset_string_io(io): io.truncate(0) io.seek(0) class TestProgressBar: def test_start(self): out = StringIO() bar_width = 20 # without total task num prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out) assert out.getvalue() == 'completed: 0, elapsed: 0s' reset_string_io(out) prog_bar = mmcv.ProgressBar(bar_width=bar_width, start=False, file=out) assert out.getvalue() == '' reset_string_io(out) prog_bar.start() assert out.getvalue() == 'completed: 0, elapsed: 0s' # with total task num reset_string_io(out) prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out) assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:' reset_string_io(out) prog_bar = mmcv.ProgressBar( 10, bar_width=bar_width, start=False, file=out) assert out.getvalue() == '' reset_string_io(out) prog_bar.start() assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:' def test_update(self): out = StringIO() bar_width = 20 # without total task num prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out) time.sleep(1) reset_string_io(out) prog_bar.update() assert out.getvalue() == 'completed: 1, elapsed: 1s, 1.0 tasks/s' reset_string_io(out) # with total task num prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out) time.sleep(1) reset_string_io(out) prog_bar.update() assert out.getvalue() == f'\r[{">" * 2 + " " * 18}] 1/10, 1.0 ' \ 'task/s, elapsed: 1s, ETA: 9s' def test_adaptive_length(self): with patch.dict('os.environ', {'COLUMNS': '80'}): out = StringIO() bar_width = 20 prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out) time.sleep(1) reset_string_io(out) prog_bar.update() assert len(out.getvalue()) == 66 os.environ['COLUMNS'] = '30' reset_string_io(out) prog_bar.update() assert len(out.getvalue()) == 48 os.environ['COLUMNS'] = '60' reset_string_io(out) prog_bar.update() assert len(out.getvalue()) == 60 def sleep_1s(num): time.sleep(1) return num def test_track_progress_list(): out = StringIO() ret = mmcv.track_progress(sleep_1s, [1, 2, 3], bar_width=3, file=out) assert out.getvalue() == ( '[ ] 0/3, elapsed: 0s, ETA:' '\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s' '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s' '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n') assert ret == [1, 2, 3] def test_track_progress_iterator(): out = StringIO() ret = mmcv.track_progress( sleep_1s, ((i for i in [1, 2, 3]), 3), bar_width=3, file=out) assert out.getvalue() == ( '[ ] 0/3, elapsed: 0s, ETA:' '\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s' '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s' '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n') assert ret == [1, 2, 3] def test_track_iter_progress(): out = StringIO() ret = [] for num in mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out): ret.append(sleep_1s(num)) assert out.getvalue() == ( '[ ] 0/3, elapsed: 0s, ETA:' '\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s' '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s' '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n') assert ret == [1, 2, 3] def test_track_enum_progress(): out = StringIO() ret = [] count = [] for i, num in enumerate( mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out)): ret.append(sleep_1s(num)) count.append(i) assert out.getvalue() == ( '[ ] 0/3, elapsed: 0s, ETA:' '\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s' '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s' '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n') assert ret == [1, 2, 3] assert count == [0, 1, 2] def test_track_parallel_progress_list(): out = StringIO() results = mmcv.track_parallel_progress( sleep_1s, [1, 2, 3, 4], 2, bar_width=4, file=out) # The following cannot pass CI on Github Action # assert out.getvalue() == ( # '[ ] 0/4, elapsed: 0s, ETA:' # '\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s' # '\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s' # '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s' # '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n') assert results == [1, 2, 3, 4] def test_track_parallel_progress_iterator(): out = StringIO() results = mmcv.track_parallel_progress( sleep_1s, ((i for i in [1, 2, 3, 4]), 4), 2, bar_width=4, file=out) # The following cannot pass CI on Github Action # assert out.getvalue() == ( # '[ ] 0/4, elapsed: 0s, ETA:' # '\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s' # '\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s' # '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s' # '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n') assert results == [1, 2, 3, 4]