|
|
|
|
|
import random |
|
import unittest |
|
from typing import Any, Iterable, Iterator, Tuple |
|
|
|
from densepose.data import CombinedDataLoader |
|
|
|
|
|
def _grouper(iterable: Iterable[Any], n: int, fillvalue=None) -> Iterator[Tuple[Any]]: |
|
""" |
|
Group elements of an iterable by chunks of size `n`, e.g. |
|
grouper(range(9), 4) -> |
|
(0, 1, 2, 3), (4, 5, 6, 7), (8, None, None, None) |
|
""" |
|
it = iter(iterable) |
|
while True: |
|
values = [] |
|
for _ in range(n): |
|
try: |
|
value = next(it) |
|
except StopIteration: |
|
values.extend([fillvalue] * (n - len(values))) |
|
yield tuple(values) |
|
return |
|
values.append(value) |
|
yield tuple(values) |
|
|
|
|
|
class TestCombinedDataLoader(unittest.TestCase): |
|
def test_combine_loaders_1(self): |
|
loader1 = _grouper([f"1_{i}" for i in range(10)], 2) |
|
loader2 = _grouper([f"2_{i}" for i in range(11)], 3) |
|
batch_size = 4 |
|
ratios = (0.1, 0.9) |
|
random.seed(43) |
|
combined = CombinedDataLoader((loader1, loader2), batch_size, ratios) |
|
BATCHES_GT = [ |
|
["1_0", "1_1", "2_0", "2_1"], |
|
["2_2", "2_3", "2_4", "2_5"], |
|
["1_2", "1_3", "2_6", "2_7"], |
|
["2_8", "2_9", "2_10", None], |
|
] |
|
for i, batch in enumerate(combined): |
|
self.assertEqual(len(batch), batch_size) |
|
self.assertEqual(batch, BATCHES_GT[i]) |
|
|