* add trainer interface * add trainer script * add model init support for pipelineadd pipeline tutorial and fix bugs * add text classification evaluation to maas lib * add quickstart and prepare env doc * relax requirements for torch and sentencepiece * merge release/0.1 and fix conflict * modelhub support for model and pipeline Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8868339master
| @@ -0,0 +1,131 @@ | |||
| { | |||
| "framework": "pytorch", | |||
| "task": "image_classification", | |||
| "model": { | |||
| "type": "Resnet50ForImageClassification", | |||
| "pretrained": null, | |||
| "backbone": { | |||
| "type": "ResNet", | |||
| "depth": 50, | |||
| "out_indices": [ | |||
| 4 | |||
| ], | |||
| "norm_cfg": { | |||
| "type": "BN" | |||
| } | |||
| }, | |||
| "head": { | |||
| "type": "ClsHead", | |||
| "with_avg_pool": true, | |||
| "in_channels": 2048, | |||
| "loss_config": { | |||
| "type": "CrossEntropyLossWithLabelSmooth", | |||
| "label_smooth": 0 | |||
| }, | |||
| "num_classes": 1000 | |||
| } | |||
| }, | |||
| "dataset": { | |||
| "train": { | |||
| "type": "ClsDataset", | |||
| "data_source": { | |||
| "list_file": "data/imagenet_raw/meta/train_labeled.txt", | |||
| "root": "data/imagenet_raw/train/", | |||
| "type": "ClsSourceImageList" | |||
| } | |||
| }, | |||
| "val": { | |||
| "type": "ClsDataset", | |||
| "data_source": { | |||
| "list_file": "data/imagenet_raw/meta/val_labeled.txt", | |||
| "root": "data/imagenet_raw/validation/", | |||
| "type": "ClsSourceImageList" | |||
| } | |||
| } | |||
| }, | |||
| "preprocessor":{ | |||
| "train": [ | |||
| { | |||
| "type": "RandomResizedCrop", | |||
| "size": 224 | |||
| }, | |||
| { | |||
| "type": "RandomHorizontalFlip" | |||
| }, | |||
| { | |||
| "type": "ToTensor" | |||
| }, | |||
| { | |||
| "type": "Normalize", | |||
| "mean": [ | |||
| 0.485, | |||
| 0.456, | |||
| 0.406 | |||
| ], | |||
| "std": [ | |||
| 0.229, | |||
| 0.224, | |||
| 0.225 | |||
| ] | |||
| }, | |||
| { | |||
| "type": "Collect", | |||
| "keys": [ | |||
| "img", | |||
| "gt_labels" | |||
| ] | |||
| } | |||
| ], | |||
| "val": [ | |||
| { | |||
| "type": "Resize", | |||
| "size": 256 | |||
| }, | |||
| { | |||
| "type": "CenterCrop", | |||
| "size": 224 | |||
| }, | |||
| { | |||
| "type": "ToTensor" | |||
| }, | |||
| { | |||
| "type": "Normalize", | |||
| "mean": [ | |||
| 0.485, | |||
| 0.456, | |||
| 0.406 | |||
| ], | |||
| "std": [ | |||
| 0.229, | |||
| 0.224, | |||
| 0.225 | |||
| ] | |||
| }, | |||
| { | |||
| "type": "Collect", | |||
| "keys": [ | |||
| "img", | |||
| "gt_labels" | |||
| ] | |||
| } | |||
| ] | |||
| }, | |||
| "train": { | |||
| "batch_size": 32, | |||
| "learning_rate": 0.00001, | |||
| "lr_scheduler_type": "cosine", | |||
| "num_epochs": 20 | |||
| }, | |||
| "evaluation": { | |||
| "batch_size": 32, | |||
| "metrics": ["accuracy", "precision", "recall"] | |||
| } | |||
| } | |||
| @@ -0,0 +1,59 @@ | |||
| # In current version, many arguments are not used in pipelines, so, | |||
| # a tag `[being used]` will indicate which argument is being used | |||
| version: v0.1 | |||
| framework: pytorch | |||
| task: text-classification | |||
| model: | |||
| path: bert-base-sst2 | |||
| attention_probs_dropout_prob: 0.1 | |||
| bos_token_id: 0 | |||
| eos_token_id: 2 | |||
| hidden_act: elu | |||
| hidden_dropout_prob: 0.1 | |||
| hidden_size: 768 | |||
| initializer_range: 0.02 | |||
| intermediate_size: 3072 | |||
| layer_norm_eps: 1e-05 | |||
| max_position_embeddings: 514 | |||
| model_type: roberta | |||
| num_attention_heads: 12 | |||
| num_hidden_layers: 12 | |||
| pad_token_id: 1 | |||
| type_vocab_size: 1 | |||
| vocab_size: 50265 | |||
| num_classes: 5 | |||
| col_index: &col_indexs | |||
| text_col: 0 | |||
| label_col: 1 | |||
| dataset: | |||
| train: | |||
| <<: *col_indexs | |||
| file: ~ | |||
| valid: | |||
| <<: *col_indexs | |||
| file: glue/sst2 # [being used] | |||
| test: | |||
| <<: *col_indexs | |||
| file: ~ | |||
| preprocessor: | |||
| type: Tokenize | |||
| tokenizer_name: /workspace/bert-base-sst2 | |||
| train: | |||
| batch_size: 256 | |||
| learning_rate: 0.00001 | |||
| lr_scheduler_type: cosine | |||
| num_steps: 100000 | |||
| evaluation: # [being used] | |||
| model_path: .cache/easynlp/bert-base-sst2 | |||
| max_sequence_length: 128 | |||
| batch_size: 32 | |||
| metrics: | |||
| - accuracy | |||
| - f1 | |||
| @@ -11,6 +11,7 @@ MaasLib doc | |||
| :maxdepth: 2 | |||
| :caption: USER GUIDE | |||
| quick_start.md | |||
| develop.md | |||
| .. toctree:: | |||
| @@ -0,0 +1,64 @@ | |||
| # 快速开始 | |||
| ## 环境准备 | |||
| 方式一: whl包安装, 执行如下命令 | |||
| ```shell | |||
| pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/maas_lib-0.1.0-py3-none-any.whl | |||
| ``` | |||
| 方式二: 源码环境指定, 适合本地开发调试使用,修改源码后可以直接执行 | |||
| ```shell | |||
| git clone git@gitlab.alibaba-inc.com:Ali-MaaS/MaaS-lib.git maaslib | |||
| git fetch origin release/0.1 | |||
| git checkout release/0.1 | |||
| cd maaslib | |||
| #安装依赖 | |||
| pip install -r requirements.txt | |||
| # 设置PYTHONPATH | |||
| export PYTHONPATH=`pwd` | |||
| ``` | |||
| 备注: mac arm cpu暂时由于依赖包版本问题会导致requirements暂时无法安装,请使用mac intel cpu, linux cpu/gpu机器测试。 | |||
| ## 训练 | |||
| to be done | |||
| ## 评估 | |||
| to be done | |||
| ## 推理 | |||
| to be done | |||
| <!-- pipeline函数提供了简洁的推理接口,示例如下 | |||
| 注: 这里提供的接口是完成和modelhub打通后的接口,暂时不支持使用。pipeline使用示例请参考 [pipelien tutorial](tutorials/pipeline.md)给出的示例。 | |||
| ```python | |||
| import cv2 | |||
| from maas_lib.pipelines import pipeline | |||
| # 根据任务名创建pipeline | |||
| img_matting = pipeline('image-matting') | |||
| # 根据任务和模型名创建pipeline | |||
| img_matting = pipeline('image-matting', model='damo/image-matting-person') | |||
| # 自定义模型和预处理创建pipeline | |||
| model = Model.from_pretrained('damo/xxx') | |||
| preprocessor = Preprocessor.from_pretrained(cfg) | |||
| img_matting = pipeline('image-matting', model=model, preprocessor=preprocessor) | |||
| # 推理 | |||
| result = img_matting( | |||
| 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' | |||
| ) | |||
| # 保存结果图片 | |||
| cv2.imwrite('result.png', result['output_png']) | |||
| ``` --> | |||
| @@ -11,6 +11,9 @@ | |||
| * 指定特定预处理、特定模型进行推理 | |||
| * 不同场景推理任务示例 | |||
| ## 环境准备 | |||
| 详细步骤可以参考 [快速开始](../quick_start.md) | |||
| ## Pipeline基本用法 | |||
| 1. pipeline函数支持指定特定任务名称,加载任务默认模型,创建对应Pipeline对象 | |||
| @@ -21,7 +24,7 @@ | |||
| ```shell | |||
| wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/matting_person.pb | |||
| ``` | |||
| 执行python命令 | |||
| 执行如下python代码 | |||
| ```python | |||
| >>> from maas_lib.pipelines import pipeline | |||
| >>> img_matting = pipeline(task='image-matting', model_path='matting_person.pb') | |||
| @@ -36,7 +39,7 @@ | |||
| pipeline对象也支持传入一个列表输入,返回对应输出列表,每个元素对应输入样本的返回结果 | |||
| ```python | |||
| results = img_matting( | |||
| >>> results = img_matting( | |||
| [ | |||
| 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png', | |||
| 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png', | |||
| @@ -46,8 +49,8 @@ | |||
| 如果pipeline对应有一些后处理参数,也支持通过调用时候传入. | |||
| ```python | |||
| pipe = pipeline(task_name) | |||
| result = pipe(input, post_process_args) | |||
| >>> pipe = pipeline(task_name) | |||
| >>> result = pipe(input, post_process_args) | |||
| ``` | |||
| ## 指定预处理、模型进行推理 | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from .base import Model | |||
| from .builder import MODELS | |||
| from .builder import MODELS, build_model | |||
| @@ -1,15 +1,23 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| from abc import ABC, abstractmethod | |||
| from typing import Dict, List, Tuple, Union | |||
| from maas_hub.file_download import model_file_download | |||
| from maas_hub.snapshot_download import snapshot_download | |||
| from maas_lib.models.builder import build_model | |||
| from maas_lib.utils.config import Config | |||
| from maas_lib.utils.constant import CONFIGFILE | |||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
| class Model(ABC): | |||
| def __init__(self, *args, **kwargs): | |||
| pass | |||
| def __init__(self, model_dir, *args, **kwargs): | |||
| self.model_dir = model_dir | |||
| def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| return self.post_process(self.forward(input)) | |||
| @@ -26,4 +34,22 @@ class Model(ABC): | |||
| @classmethod | |||
| def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs): | |||
| raise NotImplementedError('from_pretrained has not been implemented') | |||
| """ Instantiate a model from local directory or remote model repo | |||
| """ | |||
| if osp.exists(model_name_or_path): | |||
| local_model_dir = model_name_or_path | |||
| else: | |||
| local_model_dir = snapshot_download(model_name_or_path) | |||
| # else: | |||
| # raise ValueError( | |||
| # 'Remote model repo {model_name_or_path} does not exists') | |||
| cfg = Config.from_file(osp.join(local_model_dir, CONFIGFILE)) | |||
| task_name = cfg.task | |||
| model_cfg = cfg.model | |||
| # TODO @wenmeng.zwm may should mannually initialize model after model building | |||
| if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): | |||
| model_cfg.type = model_cfg.model_type | |||
| model_cfg.model_dir = local_model_dir | |||
| return build_model(model_cfg, task_name) | |||
| @@ -14,30 +14,21 @@ __all__ = ['SequenceClassificationModel'] | |||
| Tasks.text_classification, module_name=r'bert-sentiment-analysis') | |||
| class SequenceClassificationModel(Model): | |||
| def __init__(self, | |||
| model_dir: str, | |||
| model_cls: Optional[Any] = None, | |||
| *args, | |||
| **kwargs): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| # Model.__init__(self, model_dir, model_cls, first_sequence, *args, **kwargs) | |||
| # Predictor.__init__(self, *args, **kwargs) | |||
| """initialize the sequence classification model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| model_cls (Optional[Any], optional): model loader, if None, use the | |||
| default loader to load model weights, by default None. | |||
| """ | |||
| super().__init__(model_dir, model_cls, *args, **kwargs) | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| from easynlp.appzoo import SequenceClassification | |||
| from easynlp.core.predictor import get_model_predictor | |||
| self.model_dir = model_dir | |||
| model_cls = SequenceClassification if not model_cls else model_cls | |||
| self.model = get_model_predictor( | |||
| model_dir=model_dir, | |||
| model_cls=model_cls, | |||
| model_dir=self.model_dir, | |||
| model_cls=SequenceClassification, | |||
| input_keys=[('input_ids', torch.LongTensor), | |||
| ('attention_mask', torch.LongTensor), | |||
| ('token_type_ids', torch.LongTensor)], | |||
| @@ -59,4 +50,3 @@ class SequenceClassificationModel(Model): | |||
| } | |||
| """ | |||
| return self.model.predict(input) | |||
| ... | |||
| @@ -1,10 +1,17 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| from abc import ABC, abstractmethod | |||
| from multiprocessing.sharedctypes import Value | |||
| from typing import Any, Dict, List, Tuple, Union | |||
| from maas_hub.snapshot_download import snapshot_download | |||
| from maas_lib.models import Model | |||
| from maas_lib.preprocessors import Preprocessor | |||
| from maas_lib.utils.config import Config | |||
| from maas_lib.utils.constant import CONFIGFILE | |||
| from .util import is_model_name | |||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
| Input = Union[str, 'PIL.Image.Image', 'numpy.ndarray'] | |||
| @@ -17,10 +24,38 @@ class Pipeline(ABC): | |||
| def __init__(self, | |||
| config_file: str = None, | |||
| model: Model = None, | |||
| model: Union[Model, str] = None, | |||
| preprocessor: Preprocessor = None, | |||
| **kwargs): | |||
| self.model = model | |||
| """ Base class for pipeline. | |||
| If config_file is provided, model and preprocessor will be | |||
| instantiated from corresponding config. Otherwise model | |||
| and preprocessor will be constructed separately. | |||
| Args: | |||
| config_file(str, optional): Filepath to configuration file. | |||
| model: Model name or model object | |||
| preprocessor: Preprocessor object | |||
| """ | |||
| if config_file is not None: | |||
| self.cfg = Config.from_file(config_file) | |||
| if isinstance(model, str): | |||
| if not osp.exists(model): | |||
| model = snapshot_download(model) | |||
| if is_model_name(model): | |||
| self.model = Model.from_pretrained(model) | |||
| else: | |||
| self.model = model | |||
| elif isinstance(model, Model): | |||
| self.model = model | |||
| else: | |||
| if model: | |||
| raise ValueError( | |||
| f'model type is either str or Model, but got type {type(model)}' | |||
| ) | |||
| self.preprocessor = preprocessor | |||
| def __call__(self, input: Union[Input, List[Input]], *args, | |||
| @@ -1,12 +1,17 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| from typing import Union | |||
| import json | |||
| from maas_hub.file_download import model_file_download | |||
| from maas_lib.models.base import Model | |||
| from maas_lib.utils.config import ConfigDict | |||
| from maas_lib.utils.constant import Tasks | |||
| from maas_lib.utils.config import Config, ConfigDict | |||
| from maas_lib.utils.constant import CONFIGFILE, Tasks | |||
| from maas_lib.utils.registry import Registry, build_from_cfg | |||
| from .base import Pipeline | |||
| from .util import is_model_name | |||
| PIPELINES = Registry('pipelines') | |||
| @@ -57,23 +62,26 @@ def pipeline(task: str = None, | |||
| >>> resnet = Model.from_pretrained('Resnet') | |||
| >>> p = pipeline('image-classification', model=resnet) | |||
| """ | |||
| if task is not None and pipeline_name is None: | |||
| if model is None or isinstance(model, Model): | |||
| # get default pipeline for this task | |||
| assert task in PIPELINES.modules, f'No pipeline is registerd for Task {task}' | |||
| pipeline_name = list(PIPELINES.modules[task].keys())[0] | |||
| cfg = dict(type=pipeline_name, **kwargs) | |||
| if model is not None: | |||
| cfg['model'] = model | |||
| if preprocessor is not None: | |||
| cfg['preprocessor'] = preprocessor | |||
| else: | |||
| assert isinstance(model, str), \ | |||
| f'model should be either str or Model, but got {type(model)}' | |||
| # TODO @wenmeng.zwm determine pipeline_name according to task and model | |||
| elif pipeline_name is not None: | |||
| cfg = dict(type=pipeline_name) | |||
| else: | |||
| if task is None and pipeline_name is None: | |||
| raise ValueError('task or pipeline_name is required') | |||
| if pipeline_name is None: | |||
| # get default pipeline for this task | |||
| assert task in PIPELINES.modules, f'No pipeline is registerd for Task {task}' | |||
| pipeline_name = get_default_pipeline(task) | |||
| cfg = ConfigDict(type=pipeline_name) | |||
| if model: | |||
| assert isinstance(model, (str, Model)), \ | |||
| f'model should be either str or Model, but got {type(model)}' | |||
| cfg.model = model | |||
| if preprocessor is not None: | |||
| cfg.preprocessor = preprocessor | |||
| return build_pipeline(cfg, task_name=task) | |||
| def get_default_pipeline(task): | |||
| return list(PIPELINES.modules[task].keys())[0] | |||
| @@ -1,3 +1,4 @@ | |||
| import os.path as osp | |||
| from typing import Any, Dict, List, Tuple, Union | |||
| import cv2 | |||
| @@ -23,8 +24,9 @@ logger = get_logger() | |||
| Tasks.image_matting, module_name=Tasks.image_matting) | |||
| class ImageMatting(Pipeline): | |||
| def __init__(self, model_path: str): | |||
| super().__init__() | |||
| def __init__(self, model: str): | |||
| super().__init__(model=model) | |||
| model_path = osp.join(self.model, 'matting_person.pb') | |||
| config = tf.ConfigProto(allow_soft_placement=True) | |||
| config.gpu_options.allow_growth = True | |||
| @@ -0,0 +1,29 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| import json | |||
| from maas_hub.file_download import model_file_download | |||
| from maas_lib.utils.constant import CONFIGFILE | |||
| def is_model_name(model): | |||
| if osp.exists(model): | |||
| if osp.exists(osp.join(model, CONFIGFILE)): | |||
| return True | |||
| else: | |||
| return False | |||
| else: | |||
| # try: | |||
| # cfg_file = model_file_download(model, CONFIGFILE) | |||
| # except Exception: | |||
| # cfg_file = None | |||
| # TODO @wenmeng.zwm use exception instead of | |||
| # following tricky logic | |||
| cfg_file = model_file_download(model, CONFIGFILE) | |||
| with open(cfg_file, 'r') as infile: | |||
| cfg = json.load(infile) | |||
| if 'Code' in cfg: | |||
| return False | |||
| else: | |||
| return True | |||
| @@ -0,0 +1,30 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import argparse | |||
| from maas_lib.trainers import build_trainer | |||
| def parse_args(): | |||
| parser = argparse.ArgumentParser(description='evaluate a model') | |||
| parser.add_argument('config', help='config file path', type=str) | |||
| parser.add_argument( | |||
| '--trainer_name', help='name for trainer', type=str, default=None) | |||
| parser.add_argument( | |||
| '--checkpoint_path', | |||
| help='checkpoint to be evaluated', | |||
| type=str, | |||
| default=None) | |||
| args = parser.parse_args() | |||
| return args | |||
| def main(): | |||
| args = parse_args() | |||
| kwargs = dict(cfg_file=args.config) | |||
| trainer = build_trainer(args.trainer_name, kwargs) | |||
| trainer.evaluate(args.checkpoint_path) | |||
| if __name__ == '__main__': | |||
| main() | |||
| @@ -0,0 +1,25 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import argparse | |||
| from maas_lib.trainers import build_trainer | |||
| def parse_args(): | |||
| parser = argparse.ArgumentParser(description='Train a model') | |||
| parser.add_argument('config', help='config file path', type=str) | |||
| parser.add_argument( | |||
| 'trainer_name', help='name for trainer', type=str, default=None) | |||
| args = parser.parse_args() | |||
| return args | |||
| def main(): | |||
| args = parse_args() | |||
| kwargs = dict(cfg_file=args.config) | |||
| trainer = build_trainer(args.trainer_name, kwargs) | |||
| trainer.train() | |||
| if __name__ == '__main__': | |||
| main() | |||
| @@ -0,0 +1,3 @@ | |||
| from .base import DummyTrainer | |||
| from .builder import build_trainer | |||
| from .nlp import SequenceClassificationTrainer | |||
| @@ -0,0 +1,86 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from abc import ABC, abstractmethod | |||
| from typing import Callable, Dict, List, Optional, Tuple, Union | |||
| from maas_lib.trainers.builder import TRAINERS | |||
| from maas_lib.utils.config import Config | |||
| class BaseTrainer(ABC): | |||
| """ Base class for trainer which can not be instantiated. | |||
| BaseTrainer defines necessary interface | |||
| and provide default implementation for basic initialization | |||
| such as parsing config file and parsing commandline args. | |||
| """ | |||
| def __init__(self, cfg_file: str, arg_parse_fn: Optional[Callable] = None): | |||
| """ Trainer basic init, should be called in derived class | |||
| Args: | |||
| cfg_file: Path to configuration file. | |||
| arg_parse_fn: Same as ``parse_fn`` in :obj:`Config.to_args`. | |||
| """ | |||
| self.cfg = Config.from_file(cfg_file) | |||
| if arg_parse_fn: | |||
| self.args = self.cfg.to_args(arg_parse_fn) | |||
| else: | |||
| self.args = None | |||
| @abstractmethod | |||
| def train(self, *args, **kwargs): | |||
| """ Train (and evaluate) process | |||
| Train process should be implemented for specific task or | |||
| model, releated paramters have been intialized in | |||
| ``BaseTrainer.__init__`` and should be used in this function | |||
| """ | |||
| pass | |||
| @abstractmethod | |||
| def evaluate(self, checkpoint_path: str, *args, | |||
| **kwargs) -> Dict[str, float]: | |||
| """ Evaluation process | |||
| Evaluation process should be implemented for specific task or | |||
| model, releated paramters have been intialized in | |||
| ``BaseTrainer.__init__`` and should be used in this function | |||
| """ | |||
| pass | |||
| @TRAINERS.register_module(module_name='dummy') | |||
| class DummyTrainer(BaseTrainer): | |||
| def __init__(self, cfg_file: str, *args, **kwargs): | |||
| """ Dummy Trainer. | |||
| Args: | |||
| cfg_file: Path to configuration file. | |||
| """ | |||
| super().__init__(cfg_file) | |||
| def train(self, *args, **kwargs): | |||
| """ Train (and evaluate) process | |||
| Train process should be implemented for specific task or | |||
| model, releated paramters have been intialized in | |||
| ``BaseTrainer.__init__`` and should be used in this function | |||
| """ | |||
| cfg = self.cfg.train | |||
| print(f'train cfg {cfg}') | |||
| def evaluate(self, | |||
| checkpoint_path: str = None, | |||
| *args, | |||
| **kwargs) -> Dict[str, float]: | |||
| """ Evaluation process | |||
| Evaluation process should be implemented for specific task or | |||
| model, releated paramters have been intialized in | |||
| ``BaseTrainer.__init__`` and should be used in this function | |||
| """ | |||
| cfg = self.cfg.evaluation | |||
| print(f'eval cfg {cfg}') | |||
| print(f'checkpoint_path {checkpoint_path}') | |||
| @@ -0,0 +1,21 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from maas_lib.utils.config import ConfigDict | |||
| from maas_lib.utils.constant import Tasks | |||
| from maas_lib.utils.registry import Registry, build_from_cfg | |||
| TRAINERS = Registry('trainers') | |||
| def build_trainer(name: str = None, default_args: dict = None): | |||
| """ build trainer given a trainer name | |||
| Args: | |||
| name (str, optional): Trainer name, if None, default trainer | |||
| will be used. | |||
| default_args (dict, optional): Default initialization arguments. | |||
| """ | |||
| if name is None: | |||
| name = 'Trainer' | |||
| cfg = dict(type=name) | |||
| return build_from_cfg(cfg, TRAINERS, default_args=default_args) | |||
| @@ -0,0 +1 @@ | |||
| from .sequence_classification_trainer import SequenceClassificationTrainer | |||
| @@ -0,0 +1,226 @@ | |||
| import time | |||
| from typing import Callable, Dict, List, Optional, Tuple, Union | |||
| import numpy as np | |||
| from maas_lib.utils.constant import Tasks | |||
| from maas_lib.utils.logger import get_logger | |||
| from ..base import BaseTrainer | |||
| from ..builder import TRAINERS | |||
| # __all__ = ["SequenceClassificationTrainer"] | |||
| PATH = None | |||
| logger = get_logger(PATH) | |||
| @TRAINERS.register_module( | |||
| Tasks.text_classification, module_name=r'bert-sentiment-analysis') | |||
| class SequenceClassificationTrainer(BaseTrainer): | |||
| def __init__(self, cfg_file: str, *args, **kwargs): | |||
| """ A trainer is used for Sequence Classification | |||
| Based on Config file (*.yaml or *.json), the trainer trains or evaluates on a dataset | |||
| Args: | |||
| cfg_file (str): the path of config file | |||
| Raises: | |||
| ValueError: _description_ | |||
| """ | |||
| super().__init__(cfg_file) | |||
| def train(self, *args, **kwargs): | |||
| logger.info('Train') | |||
| ... | |||
| def __attr_is_exist(self, attr: str) -> Tuple[Union[str, bool]]: | |||
| """get attribute from config, if the attribute does exist, return false | |||
| Example: | |||
| >>> self.__attr_is_exist("model path") | |||
| out: (model-path, "/workspace/bert-base-sst2") | |||
| >>> self.__attr_is_exist("model weights") | |||
| out: (model-weights, False) | |||
| Args: | |||
| attr (str): attribute str, "model path" -> config["model"][path] | |||
| Returns: | |||
| Tuple[Union[str, bool]]:[target attribute name, the target attribute or False] | |||
| """ | |||
| paths = attr.split(' ') | |||
| attr_str: str = '-'.join(paths) | |||
| target = self.cfg[paths[0]] if hasattr(self.cfg, paths[0]) else None | |||
| for path_ in paths[1:]: | |||
| if not hasattr(target, path_): | |||
| return attr_str, False | |||
| target = target[path_] | |||
| if target and target != '': | |||
| return attr_str, target | |||
| return attr_str, False | |||
| def evaluate(self, | |||
| checkpoint_path: Optional[str] = None, | |||
| *args, | |||
| **kwargs) -> Dict[str, float]: | |||
| """evaluate a dataset | |||
| evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path` | |||
| does not exist, read from the config file. | |||
| Args: | |||
| checkpoint_path (Optional[str], optional): the model path. Defaults to None. | |||
| Returns: | |||
| Dict[str, float]: the results about the evaluation | |||
| Example: | |||
| {"accuracy": 0.5091743119266054, "f1": 0.673780487804878} | |||
| """ | |||
| import torch | |||
| from easynlp.appzoo import load_dataset | |||
| from easynlp.appzoo.dataset import GeneralDataset | |||
| from easynlp.appzoo.sequence_classification.model import SequenceClassification | |||
| from easynlp.utils import losses | |||
| from sklearn.metrics import f1_score | |||
| from torch.utils.data import DataLoader | |||
| raise_str = 'Attribute {} is not given in config file!' | |||
| metrics = self.__attr_is_exist('evaluation metrics') | |||
| eval_batch_size = self.__attr_is_exist('evaluation batch_size') | |||
| test_dataset_path = self.__attr_is_exist('dataset valid file') | |||
| attrs = [metrics, eval_batch_size, test_dataset_path] | |||
| for attr_ in attrs: | |||
| if not attr_[-1]: | |||
| raise AttributeError(raise_str.format(attr_[0])) | |||
| if not checkpoint_path: | |||
| checkpoint_path = self.__attr_is_exist('evaluation model_path')[-1] | |||
| if not checkpoint_path: | |||
| raise ValueError( | |||
| 'Argument checkout_path must be passed if the evaluation-model_path is not given in config file!' | |||
| ) | |||
| max_sequence_length = kwargs.get( | |||
| 'max_sequence_length', | |||
| self.__attr_is_exist('evaluation max_sequence_length')[-1]) | |||
| if not max_sequence_length: | |||
| raise ValueError( | |||
| 'Argument max_sequence_length must be passed ' | |||
| 'if the evaluation-max_sequence_length does not exist in config file!' | |||
| ) | |||
| # get the raw online dataset | |||
| raw_dataset = load_dataset(*test_dataset_path[-1].split('/')) | |||
| valid_dataset = raw_dataset['validation'] | |||
| # generate a standard dataloader | |||
| pre_dataset = GeneralDataset(valid_dataset, checkpoint_path, | |||
| max_sequence_length) | |||
| valid_dataloader = DataLoader( | |||
| pre_dataset, | |||
| batch_size=eval_batch_size[-1], | |||
| shuffle=False, | |||
| collate_fn=pre_dataset.batch_fn) | |||
| # generate a model | |||
| model = SequenceClassification(checkpoint_path) | |||
| # copy from easynlp (start) | |||
| model.eval() | |||
| total_loss = 0 | |||
| total_steps = 0 | |||
| total_samples = 0 | |||
| hit_num = 0 | |||
| total_num = 0 | |||
| logits_list = list() | |||
| y_trues = list() | |||
| total_spent_time = 0.0 | |||
| device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |||
| model.to(device) | |||
| for _step, batch in enumerate(valid_dataloader): | |||
| try: | |||
| batch = { | |||
| # key: val.cuda() if isinstance(val, torch.Tensor) else val | |||
| # for key, val in batch.items() | |||
| key: | |||
| val.to(device) if isinstance(val, torch.Tensor) else val | |||
| for key, val in batch.items() | |||
| } | |||
| except RuntimeError: | |||
| batch = {key: val for key, val in batch.items()} | |||
| infer_start_time = time.time() | |||
| with torch.no_grad(): | |||
| label_ids = batch.pop('label_ids') | |||
| outputs = model(batch) | |||
| infer_end_time = time.time() | |||
| total_spent_time += infer_end_time - infer_start_time | |||
| assert 'logits' in outputs | |||
| logits = outputs['logits'] | |||
| y_trues.extend(label_ids.tolist()) | |||
| logits_list.extend(logits.tolist()) | |||
| hit_num += torch.sum( | |||
| torch.argmax(logits, dim=-1) == label_ids).item() | |||
| total_num += label_ids.shape[0] | |||
| if len(logits.shape) == 1 or logits.shape[-1] == 1: | |||
| tmp_loss = losses.mse_loss(logits, label_ids) | |||
| elif len(logits.shape) == 2: | |||
| tmp_loss = losses.cross_entropy(logits, label_ids) | |||
| else: | |||
| raise RuntimeError | |||
| total_loss += tmp_loss.mean().item() | |||
| total_steps += 1 | |||
| total_samples += valid_dataloader.batch_size | |||
| if (_step + 1) % 100 == 0: | |||
| total_step = len( | |||
| valid_dataloader.dataset) // valid_dataloader.batch_size | |||
| logger.info('Eval: {}/{} steps finished'.format( | |||
| _step + 1, total_step)) | |||
| logger.info('Inference time = {:.2f}s, [{:.4f} ms / sample] '.format( | |||
| total_spent_time, total_spent_time * 1000 / total_samples)) | |||
| eval_loss = total_loss / total_steps | |||
| logger.info('Eval loss: {}'.format(eval_loss)) | |||
| logits_list = np.array(logits_list) | |||
| eval_outputs = list() | |||
| for metric in metrics[-1]: | |||
| if metric.endswith('accuracy'): | |||
| acc = hit_num / total_num | |||
| logger.info('Accuracy: {}'.format(acc)) | |||
| eval_outputs.append(('accuracy', acc)) | |||
| elif metric == 'f1': | |||
| if model.config.num_labels == 2: | |||
| f1 = f1_score(y_trues, np.argmax(logits_list, axis=-1)) | |||
| logger.info('F1: {}'.format(f1)) | |||
| eval_outputs.append(('f1', f1)) | |||
| else: | |||
| f1 = f1_score( | |||
| y_trues, | |||
| np.argmax(logits_list, axis=-1), | |||
| average='macro') | |||
| logger.info('Macro F1: {}'.format(f1)) | |||
| eval_outputs.append(('macro-f1', f1)) | |||
| f1 = f1_score( | |||
| y_trues, | |||
| np.argmax(logits_list, axis=-1), | |||
| average='micro') | |||
| logger.info('Micro F1: {}'.format(f1)) | |||
| eval_outputs.append(('micro-f1', f1)) | |||
| else: | |||
| raise NotImplementedError('Metric %s not implemented' % metric) | |||
| # copy from easynlp (end) | |||
| return dict(eval_outputs) | |||
| @@ -62,3 +62,9 @@ class InputFields(object): | |||
| img = 'img' | |||
| text = 'text' | |||
| audio = 'audio' | |||
| # configuration filename | |||
| # in order to avoid conflict with huggingface | |||
| # config file we use maas_config instead | |||
| CONFIGFILE = 'maas_config.json' | |||
| @@ -1,5 +1,6 @@ | |||
| http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/package/whl/easynlp-0.0.3-py2.py3-none-any.whl | |||
| #https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/package/whl/easynlp-0.0.4-py2.py3-none-any.whl | |||
| tensorflow | |||
| torch==1.9.1 | |||
| torchaudio==0.9.1 | |||
| torchvision==0.10.1 | |||
| #--find-links https://download.pytorch.org/whl/torch_stable.html | |||
| torch<1.10,>=1.8.0 | |||
| torchaudio | |||
| torchvision | |||
| @@ -1,4 +1,5 @@ | |||
| addict | |||
| https://maashub.oss-cn-hangzhou.aliyuncs.com/releases/maas_hub-0.1.0.dev0-py2.py3-none-any.whl | |||
| numpy | |||
| opencv-python-headless | |||
| Pillow | |||
| @@ -113,26 +113,39 @@ def parse_requirements(fname='requirements.txt', with_version=True): | |||
| if line.startswith('http'): | |||
| print('skip http requirements %s' % line) | |||
| continue | |||
| if line and not line.startswith('#'): | |||
| if line and not line.startswith('#') and not line.startswith( | |||
| '--'): | |||
| for info in parse_line(line): | |||
| yield info | |||
| elif line and line.startswith('--find-links'): | |||
| eles = line.split() | |||
| for e in eles: | |||
| e = e.strip() | |||
| if 'http' in e: | |||
| info = dict(dependency_links=e) | |||
| yield info | |||
| def gen_packages_items(): | |||
| items = [] | |||
| deps_link = [] | |||
| if exists(require_fpath): | |||
| for info in parse_require_file(require_fpath): | |||
| parts = [info['package']] | |||
| if with_version and 'version' in info: | |||
| parts.extend(info['version']) | |||
| if not sys.version.startswith('3.4'): | |||
| # apparently package_deps are broken in 3.4 | |||
| platform_deps = info.get('platform_deps') | |||
| if platform_deps is not None: | |||
| parts.append(';' + platform_deps) | |||
| item = ''.join(parts) | |||
| yield item | |||
| packages = list(gen_packages_items()) | |||
| return packages | |||
| if 'dependency_links' not in info: | |||
| parts = [info['package']] | |||
| if with_version and 'version' in info: | |||
| parts.extend(info['version']) | |||
| if not sys.version.startswith('3.4'): | |||
| # apparently package_deps are broken in 3.4 | |||
| platform_deps = info.get('platform_deps') | |||
| if platform_deps is not None: | |||
| parts.append(';' + platform_deps) | |||
| item = ''.join(parts) | |||
| items.append(item) | |||
| else: | |||
| deps_link.append(info['dependency_links']) | |||
| return items, deps_link | |||
| return gen_packages_items() | |||
| def pack_resource(): | |||
| @@ -155,7 +168,7 @@ if __name__ == '__main__': | |||
| # write_version_py() | |||
| pack_resource() | |||
| os.chdir('package') | |||
| install_requires = parse_requirements('requirements.txt') | |||
| install_requires, deps_link = parse_requirements('requirements.txt') | |||
| setup( | |||
| name='maas-lib', | |||
| version=get_version(), | |||
| @@ -180,4 +193,5 @@ if __name__ == '__main__': | |||
| license='Apache License 2.0', | |||
| tests_require=parse_requirements('requirements/tests.txt'), | |||
| install_requires=install_requires, | |||
| dependency_links=deps_link, | |||
| zip_safe=False) | |||
| @@ -1,5 +1,6 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| import tempfile | |||
| import unittest | |||
| from typing import Any, Dict, List, Tuple, Union | |||
| @@ -18,15 +19,26 @@ class ImageMattingTest(unittest.TestCase): | |||
| def test_run(self): | |||
| model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | |||
| '.com/data/test/maas/image_matting/matting_person.pb' | |||
| with tempfile.NamedTemporaryFile('wb', suffix='.pb') as ofile: | |||
| ofile.write(File.read(model_path)) | |||
| img_matting = pipeline(Tasks.image_matting, model_path=ofile.name) | |||
| with tempfile.TemporaryDirectory() as tmp_dir: | |||
| model_file = osp.join(tmp_dir, 'matting_person.pb') | |||
| with open(model_file, 'wb') as ofile: | |||
| ofile.write(File.read(model_path)) | |||
| img_matting = pipeline(Tasks.image_matting, model=tmp_dir) | |||
| result = img_matting( | |||
| 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' | |||
| ) | |||
| cv2.imwrite('result.png', result['output_png']) | |||
| def test_run_modelhub(self): | |||
| img_matting = pipeline( | |||
| Tasks.image_matting, model='damo/image-matting-person') | |||
| result = img_matting( | |||
| 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' | |||
| ) | |||
| cv2.imwrite('result.png', result['output_png']) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -1,11 +1,11 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| import tempfile | |||
| import unittest | |||
| import zipfile | |||
| from pathlib import Path | |||
| from maas_lib.fileio import File | |||
| from maas_lib.models import Model | |||
| from maas_lib.models.nlp import SequenceClassificationModel | |||
| from maas_lib.pipelines import SequenceClassificationPipeline, pipeline | |||
| from maas_lib.preprocessors import SequenceClassificationPreprocessor | |||
| @@ -13,15 +13,15 @@ from maas_lib.preprocessors import SequenceClassificationPreprocessor | |||
| class SequenceClassificationTest(unittest.TestCase): | |||
| def predict(self, pipeline: SequenceClassificationPipeline): | |||
| def predict(self, pipeline_ins: SequenceClassificationPipeline): | |||
| from easynlp.appzoo import load_dataset | |||
| set = load_dataset('glue', 'sst2') | |||
| data = set['test']['sentence'][:3] | |||
| results = pipeline(data[0]) | |||
| results = pipeline_ins(data[0]) | |||
| print(results) | |||
| results = pipeline(data[1]) | |||
| results = pipeline_ins(data[1]) | |||
| print(results) | |||
| print(data) | |||
| @@ -29,22 +29,34 @@ class SequenceClassificationTest(unittest.TestCase): | |||
| def test_run(self): | |||
| model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ | |||
| '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' | |||
| with tempfile.TemporaryDirectory() as tmp_dir: | |||
| tmp_file = osp.join(tmp_dir, 'bert-base-sst2.zip') | |||
| with open(tmp_file, 'wb') as ofile: | |||
| cache_path_str = r'.cache/easynlp/bert-base-sst2.zip' | |||
| cache_path = Path(cache_path_str) | |||
| if not cache_path.exists(): | |||
| cache_path.parent.mkdir(parents=True, exist_ok=True) | |||
| cache_path.touch(exist_ok=True) | |||
| with cache_path.open('wb') as ofile: | |||
| ofile.write(File.read(model_url)) | |||
| with zipfile.ZipFile(tmp_file, 'r') as zipf: | |||
| zipf.extractall(tmp_dir) | |||
| path = osp.join(tmp_dir, 'bert-base-sst2') | |||
| print(path) | |||
| model = SequenceClassificationModel(path) | |||
| preprocessor = SequenceClassificationPreprocessor( | |||
| path, first_sequence='sentence', second_sequence=None) | |||
| pipeline1 = SequenceClassificationPipeline(model, preprocessor) | |||
| self.predict(pipeline1) | |||
| pipeline2 = pipeline( | |||
| 'text-classification', model=model, preprocessor=preprocessor) | |||
| print(pipeline2('Hello world!')) | |||
| with zipfile.ZipFile(cache_path_str, 'r') as zipf: | |||
| zipf.extractall(cache_path.parent) | |||
| path = r'.cache/easynlp/bert-base-sst2' | |||
| model = SequenceClassificationModel(path) | |||
| preprocessor = SequenceClassificationPreprocessor( | |||
| path, first_sequence='sentence', second_sequence=None) | |||
| pipeline1 = SequenceClassificationPipeline(model, preprocessor) | |||
| self.predict(pipeline1) | |||
| pipeline2 = pipeline( | |||
| 'text-classification', model=model, preprocessor=preprocessor) | |||
| print(pipeline2('Hello world!')) | |||
| def test_run_modelhub(self): | |||
| model = Model.from_pretrained('damo/bert-base-sst2') | |||
| preprocessor = SequenceClassificationPreprocessor( | |||
| model.model_dir, first_sequence='sentence', second_sequence=None) | |||
| pipeline_ins = pipeline( | |||
| task='text-classification', model=model, preprocessor=preprocessor) | |||
| self.predict(pipeline_ins) | |||
| if __name__ == '__main__': | |||
| @@ -0,0 +1,38 @@ | |||
| import unittest | |||
| import zipfile | |||
| from pathlib import Path | |||
| from maas_lib.fileio import File | |||
| from maas_lib.trainers import build_trainer | |||
| from maas_lib.utils.logger import get_logger | |||
| logger = get_logger() | |||
| class SequenceClassificationTrainerTest(unittest.TestCase): | |||
| def test_sequence_classification(self): | |||
| model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ | |||
| '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' | |||
| cache_path_str = r'.cache/easynlp/bert-base-sst2.zip' | |||
| cache_path = Path(cache_path_str) | |||
| if not cache_path.exists(): | |||
| cache_path.parent.mkdir(parents=True, exist_ok=True) | |||
| cache_path.touch(exist_ok=True) | |||
| with cache_path.open('wb') as ofile: | |||
| ofile.write(File.read(model_url)) | |||
| with zipfile.ZipFile(cache_path_str, 'r') as zipf: | |||
| zipf.extractall(cache_path.parent) | |||
| path: str = './configs/nlp/sequence_classification_trainer.yaml' | |||
| default_args = dict(cfg_file=path) | |||
| trainer = build_trainer('bert-sentiment-analysis', default_args) | |||
| trainer.train() | |||
| trainer.evaluate() | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| ... | |||
| @@ -0,0 +1,19 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from maas_lib.trainers import build_trainer | |||
| class DummyTrainerTest(unittest.TestCase): | |||
| def test_dummy(self): | |||
| default_args = dict(cfg_file='configs/examples/train.json') | |||
| trainer = build_trainer('dummy', default_args) | |||
| trainer.train() | |||
| trainer.evaluate() | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||