|
|
|
|
|
import io |
|
import tempfile |
|
import unittest |
|
from contextlib import ExitStack |
|
import torch |
|
import torch.distributed as dist |
|
import torch.multiprocessing as mp |
|
|
|
from detectron2.utils import comm |
|
|
|
from densepose.evaluation.tensor_storage import ( |
|
SingleProcessFileTensorStorage, |
|
SingleProcessRamTensorStorage, |
|
SizeData, |
|
storage_gather, |
|
) |
|
|
|
|
|
class TestSingleProcessRamTensorStorage(unittest.TestCase): |
|
def test_read_write_1(self): |
|
schema = { |
|
"tf": SizeData(dtype="float32", shape=(112, 112)), |
|
"ti": SizeData(dtype="int32", shape=(4, 64, 64)), |
|
} |
|
|
|
data_elts = [] |
|
torch.manual_seed(23) |
|
for _i in range(3): |
|
data_elt = { |
|
"tf": torch.rand((112, 112), dtype=torch.float32), |
|
"ti": (torch.rand(4, 64, 64) * 1000).to(dtype=torch.int32), |
|
} |
|
data_elts.append(data_elt) |
|
storage = SingleProcessRamTensorStorage(schema, io.BytesIO()) |
|
|
|
for i in range(3): |
|
record_id = storage.put(data_elts[i]) |
|
self.assertEqual(record_id, i) |
|
|
|
for i in range(3): |
|
record = storage.get(i) |
|
self.assertEqual(len(record), len(schema)) |
|
for field_name in schema: |
|
self.assertTrue(field_name in record) |
|
self.assertEqual(data_elts[i][field_name].shape, record[field_name].shape) |
|
self.assertEqual(data_elts[i][field_name].dtype, record[field_name].dtype) |
|
self.assertTrue(torch.allclose(data_elts[i][field_name], record[field_name])) |
|
|
|
|
|
class TestSingleProcessFileTensorStorage(unittest.TestCase): |
|
def test_read_write_1(self): |
|
schema = { |
|
"tf": SizeData(dtype="float32", shape=(112, 112)), |
|
"ti": SizeData(dtype="int32", shape=(4, 64, 64)), |
|
} |
|
|
|
data_elts = [] |
|
torch.manual_seed(23) |
|
for _i in range(3): |
|
data_elt = { |
|
"tf": torch.rand((112, 112), dtype=torch.float32), |
|
"ti": (torch.rand(4, 64, 64) * 1000).to(dtype=torch.int32), |
|
} |
|
data_elts.append(data_elt) |
|
|
|
with tempfile.NamedTemporaryFile() as hFile: |
|
storage = SingleProcessFileTensorStorage(schema, hFile.name, "wb") |
|
|
|
for i in range(3): |
|
record_id = storage.put(data_elts[i]) |
|
self.assertEqual(record_id, i) |
|
hFile.seek(0) |
|
storage = SingleProcessFileTensorStorage(schema, hFile.name, "rb") |
|
|
|
for i in range(3): |
|
record = storage.get(i) |
|
self.assertEqual(len(record), len(schema)) |
|
for field_name in schema: |
|
self.assertTrue(field_name in record) |
|
self.assertEqual(data_elts[i][field_name].shape, record[field_name].shape) |
|
self.assertEqual(data_elts[i][field_name].dtype, record[field_name].dtype) |
|
self.assertTrue(torch.allclose(data_elts[i][field_name], record[field_name])) |
|
|
|
|
|
def _find_free_port(): |
|
""" |
|
Copied from detectron2/engine/launch.py |
|
""" |
|
import socket |
|
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
|
|
|
sock.bind(("", 0)) |
|
port = sock.getsockname()[1] |
|
sock.close() |
|
|
|
return port |
|
|
|
|
|
def launch(main_func, nprocs, args=()): |
|
port = _find_free_port() |
|
dist_url = f"tcp://127.0.0.1:{port}" |
|
|
|
mp.spawn( |
|
distributed_worker, nprocs=nprocs, args=(main_func, nprocs, dist_url, args), daemon=False |
|
) |
|
|
|
|
|
def distributed_worker(local_rank, main_func, nprocs, dist_url, args): |
|
dist.init_process_group( |
|
backend="gloo", init_method=dist_url, world_size=nprocs, rank=local_rank |
|
) |
|
comm.synchronize() |
|
assert comm._LOCAL_PROCESS_GROUP is None |
|
pg = dist.new_group(list(range(nprocs))) |
|
comm._LOCAL_PROCESS_GROUP = pg |
|
main_func(*args) |
|
|
|
|
|
def ram_read_write_worker(): |
|
schema = { |
|
"tf": SizeData(dtype="float32", shape=(112, 112)), |
|
"ti": SizeData(dtype="int32", shape=(4, 64, 64)), |
|
} |
|
storage = SingleProcessRamTensorStorage(schema, io.BytesIO()) |
|
world_size = comm.get_world_size() |
|
rank = comm.get_rank() |
|
data_elts = [] |
|
|
|
for i in range(rank + 1): |
|
data_elt = { |
|
"tf": torch.ones((112, 112), dtype=torch.float32) * (rank + i * world_size), |
|
"ti": torch.ones((4, 64, 64), dtype=torch.int32) * (rank + i * world_size), |
|
} |
|
data_elts.append(data_elt) |
|
|
|
for i in range(rank + 1): |
|
record_id = storage.put(data_elts[i]) |
|
assert record_id == i, f"Process {rank}: record ID {record_id}, expected {i}" |
|
comm.synchronize() |
|
|
|
multi_storage = storage_gather(storage) |
|
if rank != 0: |
|
return |
|
|
|
for j in range(world_size): |
|
for i in range(j): |
|
record = multi_storage.get(j, i) |
|
record_gt = { |
|
"tf": torch.ones((112, 112), dtype=torch.float32) * (j + i * world_size), |
|
"ti": torch.ones((4, 64, 64), dtype=torch.int32) * (j + i * world_size), |
|
} |
|
assert len(record) == len(schema), ( |
|
f"Process {rank}: multi storage record, rank {j}, id {i}: " |
|
f"expected {len(schema)} fields in the record, got {len(record)}" |
|
) |
|
for field_name in schema: |
|
assert field_name in record, ( |
|
f"Process {rank}: multi storage record, rank {j}, id {i}: " |
|
f"field {field_name} not in the record" |
|
) |
|
|
|
assert record_gt[field_name].shape == record[field_name].shape, ( |
|
f"Process {rank}: multi storage record, rank {j}, id {i}: " |
|
f"field {field_name}, expected shape {record_gt[field_name].shape} " |
|
f"got {record[field_name].shape}" |
|
) |
|
assert record_gt[field_name].dtype == record[field_name].dtype, ( |
|
f"Process {rank}: multi storage record, rank {j}, id {i}: " |
|
f"field {field_name}, expected dtype {record_gt[field_name].dtype} " |
|
f"got {record[field_name].dtype}" |
|
) |
|
assert torch.allclose(record_gt[field_name], record[field_name]), ( |
|
f"Process {rank}: multi storage record, rank {j}, id {i}: " |
|
f"field {field_name}, tensors are not close enough:" |
|
f"L-inf {(record_gt[field_name]-record[field_name]).abs_().max()} " |
|
f"L1 {(record_gt[field_name]-record[field_name]).abs_().sum()} " |
|
) |
|
|
|
|
|
def file_read_write_worker(rank_to_fpath): |
|
schema = { |
|
"tf": SizeData(dtype="float32", shape=(112, 112)), |
|
"ti": SizeData(dtype="int32", shape=(4, 64, 64)), |
|
} |
|
world_size = comm.get_world_size() |
|
rank = comm.get_rank() |
|
storage = SingleProcessFileTensorStorage(schema, rank_to_fpath[rank], "wb") |
|
data_elts = [] |
|
|
|
for i in range(rank + 1): |
|
data_elt = { |
|
"tf": torch.ones((112, 112), dtype=torch.float32) * (rank + i * world_size), |
|
"ti": torch.ones((4, 64, 64), dtype=torch.int32) * (rank + i * world_size), |
|
} |
|
data_elts.append(data_elt) |
|
|
|
for i in range(rank + 1): |
|
record_id = storage.put(data_elts[i]) |
|
assert record_id == i, f"Process {rank}: record ID {record_id}, expected {i}" |
|
comm.synchronize() |
|
|
|
multi_storage = storage_gather(storage) |
|
if rank != 0: |
|
return |
|
|
|
for j in range(world_size): |
|
for i in range(j): |
|
record = multi_storage.get(j, i) |
|
record_gt = { |
|
"tf": torch.ones((112, 112), dtype=torch.float32) * (j + i * world_size), |
|
"ti": torch.ones((4, 64, 64), dtype=torch.int32) * (j + i * world_size), |
|
} |
|
assert len(record) == len(schema), ( |
|
f"Process {rank}: multi storage record, rank {j}, id {i}: " |
|
f"expected {len(schema)} fields in the record, got {len(record)}" |
|
) |
|
for field_name in schema: |
|
assert field_name in record, ( |
|
f"Process {rank}: multi storage record, rank {j}, id {i}: " |
|
f"field {field_name} not in the record" |
|
) |
|
|
|
assert record_gt[field_name].shape == record[field_name].shape, ( |
|
f"Process {rank}: multi storage record, rank {j}, id {i}: " |
|
f"field {field_name}, expected shape {record_gt[field_name].shape} " |
|
f"got {record[field_name].shape}" |
|
) |
|
assert record_gt[field_name].dtype == record[field_name].dtype, ( |
|
f"Process {rank}: multi storage record, rank {j}, id {i}: " |
|
f"field {field_name}, expected dtype {record_gt[field_name].dtype} " |
|
f"got {record[field_name].dtype}" |
|
) |
|
assert torch.allclose(record_gt[field_name], record[field_name]), ( |
|
f"Process {rank}: multi storage record, rank {j}, id {i}: " |
|
f"field {field_name}, tensors are not close enough:" |
|
f"L-inf {(record_gt[field_name]-record[field_name]).abs_().max()} " |
|
f"L1 {(record_gt[field_name]-record[field_name]).abs_().sum()} " |
|
) |
|
|
|
|
|
class TestMultiProcessRamTensorStorage(unittest.TestCase): |
|
def test_read_write_1(self): |
|
launch(ram_read_write_worker, 8) |
|
|
|
|
|
class TestMultiProcessFileTensorStorage(unittest.TestCase): |
|
def test_read_write_1(self): |
|
with ExitStack() as stack: |
|
|
|
rank_to_fpath = { |
|
i: stack.enter_context(tempfile.NamedTemporaryFile()).name for i in range(8) |
|
} |
|
launch(file_read_write_worker, 8, (rank_to_fpath,)) |
|
|