You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_builder.py 3.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import unittest
  4. from asyncio import Task
  5. from typing import Any, Dict, List, Tuple, Union
  6. import numpy as np
  7. import PIL
  8. from modelscope.fileio import io
  9. from modelscope.models.base import Model
  10. from modelscope.pipelines import Pipeline, pipeline
  11. from modelscope.pipelines.builder import PIPELINES, add_default_pipeline_info
  12. from modelscope.utils.constant import (ConfigFields, Frameworks, ModelFile,
  13. Tasks)
  14. from modelscope.utils.logger import get_logger
  15. from modelscope.utils.registry import default_group
  16. logger = get_logger()
  17. @PIPELINES.register_module(
  18. group_key=Tasks.image_classification, module_name='custom_single_model')
  19. class CustomSingleModelPipeline(Pipeline):
  20. def __init__(self,
  21. config_file: str = None,
  22. model: List[Union[str, Model]] = None,
  23. preprocessor=None,
  24. **kwargs):
  25. super().__init__(config_file, model, preprocessor, **kwargs)
  26. assert isinstance(model, str), 'model is not str'
  27. print(model)
  28. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  29. return super().postprocess(inputs)
  30. @PIPELINES.register_module(
  31. group_key=Tasks.image_classification, module_name='model1_model2')
  32. class CustomMultiModelPipeline(Pipeline):
  33. def __init__(self,
  34. config_file: str = None,
  35. model: List[Union[str, Model]] = None,
  36. preprocessor=None,
  37. **kwargs):
  38. super().__init__(config_file, model, preprocessor, **kwargs)
  39. assert isinstance(model, list), 'model is not list'
  40. for m in model:
  41. assert isinstance(m, str), 'submodel is not str'
  42. print(m)
  43. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  44. return super().postprocess(inputs)
  45. class PipelineInterfaceTest(unittest.TestCase):
  46. def prepare_dir(self, dirname, pipeline_name):
  47. if not os.path.exists(dirname):
  48. os.makedirs(dirname)
  49. cfg_file = os.path.join(dirname, ModelFile.CONFIGURATION)
  50. cfg = {
  51. ConfigFields.framework: Frameworks.torch,
  52. ConfigFields.task: Tasks.image_classification,
  53. ConfigFields.pipeline: {
  54. 'type': pipeline_name,
  55. }
  56. }
  57. io.dump(cfg, cfg_file)
  58. def setUp(self) -> None:
  59. self.prepare_dir('/tmp/custom_single_model', 'custom_single_model')
  60. self.prepare_dir('/tmp/model1', 'model1_model2')
  61. self.prepare_dir('/tmp/model2', 'model1_model2')
  62. def test_single_model(self):
  63. pipe = pipeline(
  64. Tasks.image_classification, model='/tmp/custom_single_model')
  65. assert isinstance(pipe, CustomSingleModelPipeline)
  66. def test_multi_model(self):
  67. pipe = pipeline(
  68. Tasks.image_classification, model=['/tmp/model1', '/tmp/model2'])
  69. assert isinstance(pipe, CustomMultiModelPipeline)
  70. if __name__ == '__main__':
  71. unittest.main()