Spaces:
Running
on
L40S
Running
on
L40S
# Copyright (c) OpenMMLab. All rights reserved. | |
import pytest | |
import torch | |
from mmcv.device._functions import Scatter, scatter | |
from mmcv.utils import IS_MLU_AVAILABLE, IS_MPS_AVAILABLE | |
def test_scatter(): | |
# if the device is CPU, just return the input | |
input = torch.zeros([1, 3, 3, 3]) | |
output = scatter(input=input, devices=[-1]) | |
assert torch.allclose(input, output) | |
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | |
outputs = scatter(input=inputs, devices=[-1]) | |
for input, output in zip(inputs, outputs): | |
assert torch.allclose(input, output) | |
# if the device is MLU, copy the input from CPU to MLU | |
if IS_MLU_AVAILABLE: | |
input = torch.zeros([1, 3, 3, 3]) | |
output = scatter(input=input, devices=[0]) | |
assert torch.allclose(input.to('mlu'), output) | |
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | |
outputs = scatter(input=inputs, devices=[0]) | |
for input, output in zip(inputs, outputs): | |
assert torch.allclose(input.to('mlu'), output) | |
# if the device is MPS, copy the input from CPU to MPS | |
if IS_MPS_AVAILABLE: | |
input = torch.zeros([1, 3, 3, 3]) | |
output = scatter(input=input, devices=[0]) | |
assert torch.allclose(input.to('mps'), output) | |
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | |
outputs = scatter(input=inputs, devices=[0]) | |
for input, output in zip(inputs, outputs): | |
assert torch.allclose(input.to('mps'), output) | |
# input should be a tensor or list of tensor | |
with pytest.raises(Exception): | |
scatter(5, [-1]) | |
def test_Scatter(): | |
# if the device is CPU, just return the input | |
target_devices = [-1] | |
input = torch.zeros([1, 3, 3, 3]) | |
outputs = Scatter.forward(target_devices, input) | |
assert isinstance(outputs, tuple) | |
assert torch.allclose(input, outputs[0]) | |
target_devices = [-1] | |
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | |
outputs = Scatter.forward(target_devices, inputs) | |
assert isinstance(outputs, tuple) | |
for input, output in zip(inputs, outputs): | |
assert torch.allclose(input, output) | |
# if the device is MLU, copy the input from CPU to MLU | |
if IS_MLU_AVAILABLE: | |
target_devices = [0] | |
input = torch.zeros([1, 3, 3, 3]) | |
outputs = Scatter.forward(target_devices, input) | |
assert isinstance(outputs, tuple) | |
assert torch.allclose(input.to('mlu'), outputs[0]) | |
target_devices = [0] | |
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | |
outputs = Scatter.forward(target_devices, inputs) | |
assert isinstance(outputs, tuple) | |
for input, output in zip(inputs, outputs): | |
assert torch.allclose(input.to('mlu'), output[0]) | |
# if the device is MPS, copy the input from CPU to MPS | |
if IS_MPS_AVAILABLE: | |
target_devices = [0] | |
input = torch.zeros([1, 3, 3, 3]) | |
outputs = Scatter.forward(target_devices, input) | |
assert isinstance(outputs, tuple) | |
assert torch.allclose(input.to('mps'), outputs[0]) | |
target_devices = [0] | |
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | |
outputs = Scatter.forward(target_devices, inputs) | |
assert isinstance(outputs, tuple) | |
for input, output in zip(inputs, outputs): | |
assert torch.allclose(input.to('mps'), output[0]) | |