You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_object_detection.py 2.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from modelscope.pipelines import pipeline
  4. from modelscope.utils.constant import Tasks
  5. from modelscope.utils.logger import get_logger
  6. from modelscope.utils.test_utils import test_level
  7. class ObjectDetectionTest(unittest.TestCase):
  8. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  9. def test_object_detection(self):
  10. input_location = 'data/test/images/image_detection.jpg'
  11. model_id = 'damo/cv_vit_object-detection_coco'
  12. object_detect = pipeline(Tasks.image_object_detection, model=model_id)
  13. result = object_detect(input_location)
  14. if result:
  15. print(result)
  16. else:
  17. raise ValueError('process error')
  18. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  19. def test_object_detection_with_default_task(self):
  20. input_location = 'data/test/images/image_detection.jpg'
  21. object_detect = pipeline(Tasks.image_object_detection)
  22. result = object_detect(input_location)
  23. if result:
  24. print(result)
  25. else:
  26. raise ValueError('process error')
  27. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  28. def test_human_detection(self):
  29. input_location = 'data/test/images/image_detection.jpg'
  30. model_id = 'damo/cv_resnet18_human-detection'
  31. human_detect = pipeline(Tasks.human_detection, model=model_id)
  32. result = human_detect(input_location)
  33. if result:
  34. print(result)
  35. else:
  36. raise ValueError('process error')
  37. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  38. def test_human_detection_with_default_task(self):
  39. input_location = 'data/test/images/image_detection.jpg'
  40. human_detect = pipeline(Tasks.human_detection)
  41. result = human_detect(input_location)
  42. if result:
  43. print(result)
  44. else:
  45. raise ValueError('process error')
  46. if __name__ == '__main__':
  47. unittest.main()