You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hand_2d_keypoints.py 1.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from modelscope.outputs import OutputKeys
  4. from modelscope.pipelines import pipeline
  5. from modelscope.utils.constant import Tasks
  6. from modelscope.utils.test_utils import test_level
  7. class Hand2DKeypointsPipelineTest(unittest.TestCase):
  8. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  9. def test_hand_2d_keypoints(self):
  10. img_path = 'data/test/images/hand_keypoints.jpg'
  11. model_id = 'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody'
  12. hand_keypoint = pipeline(task=Tasks.hand_2d_keypoints, model=model_id)
  13. outputs = hand_keypoint(img_path)
  14. self.assertEqual(len(outputs), 1)
  15. results = outputs[0]
  16. self.assertIn(OutputKeys.KEYPOINTS, results.keys())
  17. self.assertIn(OutputKeys.BOXES, results.keys())
  18. self.assertEqual(results[OutputKeys.KEYPOINTS].shape[1], 21)
  19. self.assertEqual(results[OutputKeys.KEYPOINTS].shape[2], 3)
  20. self.assertEqual(results[OutputKeys.BOXES].shape[1], 4)
  21. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  22. def test_hand_2d_keypoints_with_default_model(self):
  23. img_path = 'data/test/images/hand_keypoints.jpg'
  24. hand_keypoint = pipeline(task=Tasks.hand_2d_keypoints)
  25. outputs = hand_keypoint(img_path)
  26. self.assertEqual(len(outputs), 1)
  27. results = outputs[0]
  28. self.assertIn(OutputKeys.KEYPOINTS, results.keys())
  29. self.assertIn(OutputKeys.BOXES, results.keys())
  30. self.assertEqual(results[OutputKeys.KEYPOINTS].shape[1], 21)
  31. self.assertEqual(results[OutputKeys.KEYPOINTS].shape[2], 3)
  32. self.assertEqual(results[OutputKeys.BOXES].shape[1], 4)
  33. if __name__ == '__main__':
  34. unittest.main()