You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

builder.py 3.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os.path as osp
  3. from typing import List, Union
  4. import json
  5. from maas_hub.file_download import model_file_download
  6. from maas_lib.models.base import Model
  7. from maas_lib.utils.config import Config, ConfigDict
  8. from maas_lib.utils.constant import CONFIGFILE, Tasks
  9. from maas_lib.utils.registry import Registry, build_from_cfg
  10. from .base import InputModel, Pipeline
  11. from .util import is_model_name
  12. PIPELINES = Registry('pipelines')
  13. def build_pipeline(cfg: ConfigDict,
  14. task_name: str = None,
  15. default_args: dict = None):
  16. """ build pipeline given model config dict.
  17. Args:
  18. cfg (:obj:`ConfigDict`): config dict for model object.
  19. task_name (str, optional): task name, refer to
  20. :obj:`Tasks` for more details.
  21. default_args (dict, optional): Default initialization arguments.
  22. """
  23. return build_from_cfg(
  24. cfg, PIPELINES, group_key=task_name, default_args=default_args)
  25. def pipeline(task: str = None,
  26. model: Union[InputModel, List[InputModel]] = None,
  27. preprocessor=None,
  28. config_file: str = None,
  29. pipeline_name: str = None,
  30. framework: str = None,
  31. device: int = -1,
  32. **kwargs) -> Pipeline:
  33. """ Factory method to build a obj:`Pipeline`.
  34. Args:
  35. task (str): Task name defining which pipeline will be returned.
  36. model (str or obj:`Model`): model name or model object.
  37. preprocessor: preprocessor object.
  38. config_file (str, optional): path to config file.
  39. pipeline_name (str, optional): pipeline class name or alias name.
  40. framework (str, optional): framework type.
  41. device (int, optional): which device is used to do inference.
  42. Return:
  43. pipeline (obj:`Pipeline`): pipeline object for certain task.
  44. Examples:
  45. ```python
  46. >>> p = pipeline('image-classification')
  47. >>> p = pipeline('text-classification', model='distilbert-base-uncased')
  48. >>> # Using model object
  49. >>> resnet = Model.from_pretrained('Resnet')
  50. >>> p = pipeline('image-classification', model=resnet)
  51. """
  52. if task is None and pipeline_name is None:
  53. raise ValueError('task or pipeline_name is required')
  54. if pipeline_name is None:
  55. # get default pipeline for this task
  56. assert task in PIPELINES.modules, f'No pipeline is registered for Task {task}'
  57. pipeline_name = get_default_pipeline(task)
  58. cfg = ConfigDict(type=pipeline_name)
  59. if kwargs:
  60. cfg.update(kwargs)
  61. if model:
  62. assert isinstance(model, (str, Model, List)), \
  63. f'model should be either (list of) str or Model, but got {type(model)}'
  64. cfg.model = model
  65. if preprocessor is not None:
  66. cfg.preprocessor = preprocessor
  67. return build_pipeline(cfg, task_name=task)
  68. def get_default_pipeline(task):
  69. return list(PIPELINES.modules[task].keys())[0]

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展