|
|
|
import unittest |
|
import torch |
|
|
|
from detectron2.structures.keypoints import Keypoints |
|
|
|
|
|
class TestKeypoints(unittest.TestCase): |
|
def test_cat_keypoints(self): |
|
keypoints1 = Keypoints(torch.rand(2, 21, 3)) |
|
keypoints2 = Keypoints(torch.rand(4, 21, 3)) |
|
|
|
cat_keypoints = keypoints1.cat([keypoints1, keypoints2]) |
|
self.assertTrue(torch.all(cat_keypoints.tensor[:2] == keypoints1.tensor).item()) |
|
self.assertTrue(torch.all(cat_keypoints.tensor[2:] == keypoints2.tensor).item()) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|