diff --git a/configs/examples/train.json b/configs/examples/train.json new file mode 100644 index 00000000..fbfde923 --- /dev/null +++ b/configs/examples/train.json @@ -0,0 +1,131 @@ +{ + "framework": "pytorch", + + "task": "image_classification", + + "model": { + "type": "Resnet50ForImageClassification", + "pretrained": null, + "backbone": { + "type": "ResNet", + "depth": 50, + "out_indices": [ + 4 + ], + "norm_cfg": { + "type": "BN" + } + }, + "head": { + "type": "ClsHead", + "with_avg_pool": true, + "in_channels": 2048, + "loss_config": { + "type": "CrossEntropyLossWithLabelSmooth", + "label_smooth": 0 + }, + "num_classes": 1000 + } + }, + + "dataset": { + "train": { + "type": "ClsDataset", + "data_source": { + "list_file": "data/imagenet_raw/meta/train_labeled.txt", + "root": "data/imagenet_raw/train/", + "type": "ClsSourceImageList" + } + }, + "val": { + "type": "ClsDataset", + "data_source": { + "list_file": "data/imagenet_raw/meta/val_labeled.txt", + "root": "data/imagenet_raw/validation/", + "type": "ClsSourceImageList" + } + } + }, + + + "preprocessor":{ + "train": [ + { + "type": "RandomResizedCrop", + "size": 224 + }, + { + "type": "RandomHorizontalFlip" + }, + { + "type": "ToTensor" + }, + { + "type": "Normalize", + "mean": [ + 0.485, + 0.456, + 0.406 + ], + "std": [ + 0.229, + 0.224, + 0.225 + ] + }, + { + "type": "Collect", + "keys": [ + "img", + "gt_labels" + ] + } + ], + "val": [ + { + "type": "Resize", + "size": 256 + }, + { + "type": "CenterCrop", + "size": 224 + }, + { + "type": "ToTensor" + }, + { + "type": "Normalize", + "mean": [ + 0.485, + 0.456, + 0.406 + ], + "std": [ + 0.229, + 0.224, + 0.225 + ] + }, + { + "type": "Collect", + "keys": [ + "img", + "gt_labels" + ] + } + ] + }, + + "train": { + "batch_size": 32, + "learning_rate": 0.00001, + "lr_scheduler_type": "cosine", + "num_epochs": 20 + }, + + "evaluation": { + "batch_size": 32, + "metrics": ["accuracy", "precision", "recall"] + } + +} diff --git a/configs/nlp/sequence_classification_trainer.yaml b/configs/nlp/sequence_classification_trainer.yaml new file mode 100644 index 00000000..17e9028d --- /dev/null +++ b/configs/nlp/sequence_classification_trainer.yaml @@ -0,0 +1,59 @@ +# In current version, many arguments are not used in pipelines, so, +# a tag `[being used]` will indicate which argument is being used +version: v0.1 +framework: pytorch +task: text-classification + +model: + path: bert-base-sst2 + attention_probs_dropout_prob: 0.1 + bos_token_id: 0 + eos_token_id: 2 + hidden_act: elu + hidden_dropout_prob: 0.1 + hidden_size: 768 + initializer_range: 0.02 + intermediate_size: 3072 + layer_norm_eps: 1e-05 + max_position_embeddings: 514 + model_type: roberta + num_attention_heads: 12 + num_hidden_layers: 12 + pad_token_id: 1 + type_vocab_size: 1 + vocab_size: 50265 + num_classes: 5 + + +col_index: &col_indexs + text_col: 0 + label_col: 1 + +dataset: + train: + <<: *col_indexs + file: ~ + valid: + <<: *col_indexs + file: glue/sst2 # [being used] + test: + <<: *col_indexs + file: ~ + +preprocessor: + type: Tokenize + tokenizer_name: /workspace/bert-base-sst2 + +train: + batch_size: 256 + learning_rate: 0.00001 + lr_scheduler_type: cosine + num_steps: 100000 + +evaluation: # [being used] + model_path: .cache/easynlp/bert-base-sst2 + max_sequence_length: 128 + batch_size: 32 + metrics: + - accuracy + - f1 diff --git a/docs/source/index.rst b/docs/source/index.rst index 08e85d0b..0ca63b41 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -11,6 +11,7 @@ MaasLib doc :maxdepth: 2 :caption: USER GUIDE + quick_start.md develop.md .. toctree:: diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md new file mode 100644 index 00000000..3b483081 --- /dev/null +++ b/docs/source/quick_start.md @@ -0,0 +1,64 @@ +# 快速开始 + +## 环境准备 + +方式一: whl包安装, 执行如下命令 +```shell +pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/maas_lib-0.1.0-py3-none-any.whl +``` + +方式二: 源码环境指定, 适合本地开发调试使用,修改源码后可以直接执行 +```shell +git clone git@gitlab.alibaba-inc.com:Ali-MaaS/MaaS-lib.git maaslib +git fetch origin release/0.1 +git checkout release/0.1 + +cd maaslib + +#安装依赖 +pip install -r requirements.txt + +# 设置PYTHONPATH +export PYTHONPATH=`pwd` +``` + +备注: mac arm cpu暂时由于依赖包版本问题会导致requirements暂时无法安装,请使用mac intel cpu, linux cpu/gpu机器测试。 + + +## 训练 + +to be done + +## 评估 + +to be done + +## 推理 +to be done + diff --git a/docs/source/tutorials/pipeline.md b/docs/source/tutorials/pipeline.md index 18e100c8..ad73c773 100644 --- a/docs/source/tutorials/pipeline.md +++ b/docs/source/tutorials/pipeline.md @@ -11,6 +11,9 @@ * 指定特定预处理、特定模型进行推理 * 不同场景推理任务示例 +## 环境准备 +详细步骤可以参考 [快速开始](../quick_start.md) + ## Pipeline基本用法 1. pipeline函数支持指定特定任务名称,加载任务默认模型,创建对应Pipeline对象 @@ -21,7 +24,7 @@ ```shell wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/matting_person.pb ``` - 执行python命令 + 执行如下python代码 ```python >>> from maas_lib.pipelines import pipeline >>> img_matting = pipeline(task='image-matting', model_path='matting_person.pb') @@ -36,7 +39,7 @@ pipeline对象也支持传入一个列表输入,返回对应输出列表,每个元素对应输入样本的返回结果 ```python - results = img_matting( + >>> results = img_matting( [ 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png', 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png', @@ -46,8 +49,8 @@ 如果pipeline对应有一些后处理参数,也支持通过调用时候传入. ```python - pipe = pipeline(task_name) - result = pipe(input, post_process_args) + >>> pipe = pipeline(task_name) + >>> result = pipe(input, post_process_args) ``` ## 指定预处理、模型进行推理 diff --git a/maas_lib/models/__init__.py b/maas_lib/models/__init__.py index eeeadd3c..f1ba8980 100644 --- a/maas_lib/models/__init__.py +++ b/maas_lib/models/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .base import Model -from .builder import MODELS +from .builder import MODELS, build_model diff --git a/maas_lib/models/base.py b/maas_lib/models/base.py index 2781f8a4..cc6c4ec8 100644 --- a/maas_lib/models/base.py +++ b/maas_lib/models/base.py @@ -1,15 +1,23 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp from abc import ABC, abstractmethod from typing import Dict, List, Tuple, Union +from maas_hub.file_download import model_file_download +from maas_hub.snapshot_download import snapshot_download + +from maas_lib.models.builder import build_model +from maas_lib.utils.config import Config +from maas_lib.utils.constant import CONFIGFILE + Tensor = Union['torch.Tensor', 'tf.Tensor'] class Model(ABC): - def __init__(self, *args, **kwargs): - pass + def __init__(self, model_dir, *args, **kwargs): + self.model_dir = model_dir def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: return self.post_process(self.forward(input)) @@ -26,4 +34,22 @@ class Model(ABC): @classmethod def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs): - raise NotImplementedError('from_pretrained has not been implemented') + """ Instantiate a model from local directory or remote model repo + """ + if osp.exists(model_name_or_path): + local_model_dir = model_name_or_path + else: + + local_model_dir = snapshot_download(model_name_or_path) + # else: + # raise ValueError( + # 'Remote model repo {model_name_or_path} does not exists') + + cfg = Config.from_file(osp.join(local_model_dir, CONFIGFILE)) + task_name = cfg.task + model_cfg = cfg.model + # TODO @wenmeng.zwm may should mannually initialize model after model building + if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): + model_cfg.type = model_cfg.model_type + model_cfg.model_dir = local_model_dir + return build_model(model_cfg, task_name) diff --git a/maas_lib/models/nlp/sequence_classification_model.py b/maas_lib/models/nlp/sequence_classification_model.py index e7ec69f7..dbb86105 100644 --- a/maas_lib/models/nlp/sequence_classification_model.py +++ b/maas_lib/models/nlp/sequence_classification_model.py @@ -14,30 +14,21 @@ __all__ = ['SequenceClassificationModel'] Tasks.text_classification, module_name=r'bert-sentiment-analysis') class SequenceClassificationModel(Model): - def __init__(self, - model_dir: str, - model_cls: Optional[Any] = None, - *args, - **kwargs): + def __init__(self, model_dir: str, *args, **kwargs): # Model.__init__(self, model_dir, model_cls, first_sequence, *args, **kwargs) # Predictor.__init__(self, *args, **kwargs) """initialize the sequence classification model from the `model_dir` path. Args: model_dir (str): the model path. - model_cls (Optional[Any], optional): model loader, if None, use the - default loader to load model weights, by default None. """ - super().__init__(model_dir, model_cls, *args, **kwargs) - + super().__init__(model_dir, *args, **kwargs) from easynlp.appzoo import SequenceClassification from easynlp.core.predictor import get_model_predictor - self.model_dir = model_dir - model_cls = SequenceClassification if not model_cls else model_cls self.model = get_model_predictor( - model_dir=model_dir, - model_cls=model_cls, + model_dir=self.model_dir, + model_cls=SequenceClassification, input_keys=[('input_ids', torch.LongTensor), ('attention_mask', torch.LongTensor), ('token_type_ids', torch.LongTensor)], @@ -59,4 +50,3 @@ class SequenceClassificationModel(Model): } """ return self.model.predict(input) - ... diff --git a/maas_lib/pipelines/audio/__file__.py b/maas_lib/pipelines/audio/__init__.py similarity index 100% rename from maas_lib/pipelines/audio/__file__.py rename to maas_lib/pipelines/audio/__init__.py diff --git a/maas_lib/pipelines/base.py b/maas_lib/pipelines/base.py index 1d804d1a..9bef4af2 100644 --- a/maas_lib/pipelines/base.py +++ b/maas_lib/pipelines/base.py @@ -1,10 +1,17 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp from abc import ABC, abstractmethod +from multiprocessing.sharedctypes import Value from typing import Any, Dict, List, Tuple, Union +from maas_hub.snapshot_download import snapshot_download + from maas_lib.models import Model from maas_lib.preprocessors import Preprocessor +from maas_lib.utils.config import Config +from maas_lib.utils.constant import CONFIGFILE +from .util import is_model_name Tensor = Union['torch.Tensor', 'tf.Tensor'] Input = Union[str, 'PIL.Image.Image', 'numpy.ndarray'] @@ -17,10 +24,38 @@ class Pipeline(ABC): def __init__(self, config_file: str = None, - model: Model = None, + model: Union[Model, str] = None, preprocessor: Preprocessor = None, **kwargs): - self.model = model + """ Base class for pipeline. + + If config_file is provided, model and preprocessor will be + instantiated from corresponding config. Otherwise model + and preprocessor will be constructed separately. + + Args: + config_file(str, optional): Filepath to configuration file. + model: Model name or model object + preprocessor: Preprocessor object + """ + if config_file is not None: + self.cfg = Config.from_file(config_file) + + if isinstance(model, str): + if not osp.exists(model): + model = snapshot_download(model) + + if is_model_name(model): + self.model = Model.from_pretrained(model) + else: + self.model = model + elif isinstance(model, Model): + self.model = model + else: + if model: + raise ValueError( + f'model type is either str or Model, but got type {type(model)}' + ) self.preprocessor = preprocessor def __call__(self, input: Union[Input, List[Input]], *args, diff --git a/maas_lib/pipelines/builder.py b/maas_lib/pipelines/builder.py index da47ba92..703dd33f 100644 --- a/maas_lib/pipelines/builder.py +++ b/maas_lib/pipelines/builder.py @@ -1,12 +1,17 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp from typing import Union +import json +from maas_hub.file_download import model_file_download + from maas_lib.models.base import Model -from maas_lib.utils.config import ConfigDict -from maas_lib.utils.constant import Tasks +from maas_lib.utils.config import Config, ConfigDict +from maas_lib.utils.constant import CONFIGFILE, Tasks from maas_lib.utils.registry import Registry, build_from_cfg from .base import Pipeline +from .util import is_model_name PIPELINES = Registry('pipelines') @@ -57,23 +62,26 @@ def pipeline(task: str = None, >>> resnet = Model.from_pretrained('Resnet') >>> p = pipeline('image-classification', model=resnet) """ - if task is not None and pipeline_name is None: - if model is None or isinstance(model, Model): - # get default pipeline for this task - assert task in PIPELINES.modules, f'No pipeline is registerd for Task {task}' - pipeline_name = list(PIPELINES.modules[task].keys())[0] - cfg = dict(type=pipeline_name, **kwargs) - if model is not None: - cfg['model'] = model - if preprocessor is not None: - cfg['preprocessor'] = preprocessor - else: - assert isinstance(model, str), \ - f'model should be either str or Model, but got {type(model)}' - # TODO @wenmeng.zwm determine pipeline_name according to task and model - elif pipeline_name is not None: - cfg = dict(type=pipeline_name) - else: + if task is None and pipeline_name is None: raise ValueError('task or pipeline_name is required') + if pipeline_name is None: + # get default pipeline for this task + assert task in PIPELINES.modules, f'No pipeline is registerd for Task {task}' + pipeline_name = get_default_pipeline(task) + + cfg = ConfigDict(type=pipeline_name) + + if model: + assert isinstance(model, (str, Model)), \ + f'model should be either str or Model, but got {type(model)}' + cfg.model = model + + if preprocessor is not None: + cfg.preprocessor = preprocessor + return build_pipeline(cfg, task_name=task) + + +def get_default_pipeline(task): + return list(PIPELINES.modules[task].keys())[0] diff --git a/maas_lib/pipelines/cv/image_matting.py b/maas_lib/pipelines/cv/image_matting.py index 1d0894bc..73796552 100644 --- a/maas_lib/pipelines/cv/image_matting.py +++ b/maas_lib/pipelines/cv/image_matting.py @@ -1,3 +1,4 @@ +import os.path as osp from typing import Any, Dict, List, Tuple, Union import cv2 @@ -23,8 +24,9 @@ logger = get_logger() Tasks.image_matting, module_name=Tasks.image_matting) class ImageMatting(Pipeline): - def __init__(self, model_path: str): - super().__init__() + def __init__(self, model: str): + super().__init__(model=model) + model_path = osp.join(self.model, 'matting_person.pb') config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True diff --git a/maas_lib/pipelines/util.py b/maas_lib/pipelines/util.py new file mode 100644 index 00000000..3e907359 --- /dev/null +++ b/maas_lib/pipelines/util.py @@ -0,0 +1,29 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp + +import json +from maas_hub.file_download import model_file_download + +from maas_lib.utils.constant import CONFIGFILE + + +def is_model_name(model): + if osp.exists(model): + if osp.exists(osp.join(model, CONFIGFILE)): + return True + else: + return False + else: + # try: + # cfg_file = model_file_download(model, CONFIGFILE) + # except Exception: + # cfg_file = None + # TODO @wenmeng.zwm use exception instead of + # following tricky logic + cfg_file = model_file_download(model, CONFIGFILE) + with open(cfg_file, 'r') as infile: + cfg = json.load(infile) + if 'Code' in cfg: + return False + else: + return True diff --git a/maas_lib/tools/eval.py b/maas_lib/tools/eval.py new file mode 100644 index 00000000..95bf7054 --- /dev/null +++ b/maas_lib/tools/eval.py @@ -0,0 +1,30 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import argparse + +from maas_lib.trainers import build_trainer + + +def parse_args(): + parser = argparse.ArgumentParser(description='evaluate a model') + parser.add_argument('config', help='config file path', type=str) + parser.add_argument( + '--trainer_name', help='name for trainer', type=str, default=None) + parser.add_argument( + '--checkpoint_path', + help='checkpoint to be evaluated', + type=str, + default=None) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + kwargs = dict(cfg_file=args.config) + trainer = build_trainer(args.trainer_name, kwargs) + trainer.evaluate(args.checkpoint_path) + + +if __name__ == '__main__': + main() diff --git a/maas_lib/tools/train.py b/maas_lib/tools/train.py new file mode 100644 index 00000000..f7c2b54b --- /dev/null +++ b/maas_lib/tools/train.py @@ -0,0 +1,25 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import argparse + +from maas_lib.trainers import build_trainer + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a model') + parser.add_argument('config', help='config file path', type=str) + parser.add_argument( + 'trainer_name', help='name for trainer', type=str, default=None) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + kwargs = dict(cfg_file=args.config) + trainer = build_trainer(args.trainer_name, kwargs) + trainer.train() + + +if __name__ == '__main__': + main() diff --git a/maas_lib/trainers/__init__.py b/maas_lib/trainers/__init__.py new file mode 100644 index 00000000..589f325d --- /dev/null +++ b/maas_lib/trainers/__init__.py @@ -0,0 +1,3 @@ +from .base import DummyTrainer +from .builder import build_trainer +from .nlp import SequenceClassificationTrainer diff --git a/maas_lib/trainers/base.py b/maas_lib/trainers/base.py new file mode 100644 index 00000000..2c11779e --- /dev/null +++ b/maas_lib/trainers/base.py @@ -0,0 +1,86 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from abc import ABC, abstractmethod +from typing import Callable, Dict, List, Optional, Tuple, Union + +from maas_lib.trainers.builder import TRAINERS +from maas_lib.utils.config import Config + + +class BaseTrainer(ABC): + """ Base class for trainer which can not be instantiated. + + BaseTrainer defines necessary interface + and provide default implementation for basic initialization + such as parsing config file and parsing commandline args. + """ + + def __init__(self, cfg_file: str, arg_parse_fn: Optional[Callable] = None): + """ Trainer basic init, should be called in derived class + + Args: + cfg_file: Path to configuration file. + arg_parse_fn: Same as ``parse_fn`` in :obj:`Config.to_args`. + """ + self.cfg = Config.from_file(cfg_file) + if arg_parse_fn: + self.args = self.cfg.to_args(arg_parse_fn) + else: + self.args = None + + @abstractmethod + def train(self, *args, **kwargs): + """ Train (and evaluate) process + + Train process should be implemented for specific task or + model, releated paramters have been intialized in + ``BaseTrainer.__init__`` and should be used in this function + """ + pass + + @abstractmethod + def evaluate(self, checkpoint_path: str, *args, + **kwargs) -> Dict[str, float]: + """ Evaluation process + + Evaluation process should be implemented for specific task or + model, releated paramters have been intialized in + ``BaseTrainer.__init__`` and should be used in this function + """ + pass + + +@TRAINERS.register_module(module_name='dummy') +class DummyTrainer(BaseTrainer): + + def __init__(self, cfg_file: str, *args, **kwargs): + """ Dummy Trainer. + + Args: + cfg_file: Path to configuration file. + """ + super().__init__(cfg_file) + + def train(self, *args, **kwargs): + """ Train (and evaluate) process + + Train process should be implemented for specific task or + model, releated paramters have been intialized in + ``BaseTrainer.__init__`` and should be used in this function + """ + cfg = self.cfg.train + print(f'train cfg {cfg}') + + def evaluate(self, + checkpoint_path: str = None, + *args, + **kwargs) -> Dict[str, float]: + """ Evaluation process + + Evaluation process should be implemented for specific task or + model, releated paramters have been intialized in + ``BaseTrainer.__init__`` and should be used in this function + """ + cfg = self.cfg.evaluation + print(f'eval cfg {cfg}') + print(f'checkpoint_path {checkpoint_path}') diff --git a/maas_lib/trainers/builder.py b/maas_lib/trainers/builder.py new file mode 100644 index 00000000..2165fe58 --- /dev/null +++ b/maas_lib/trainers/builder.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from maas_lib.utils.config import ConfigDict +from maas_lib.utils.constant import Tasks +from maas_lib.utils.registry import Registry, build_from_cfg + +TRAINERS = Registry('trainers') + + +def build_trainer(name: str = None, default_args: dict = None): + """ build trainer given a trainer name + + Args: + name (str, optional): Trainer name, if None, default trainer + will be used. + default_args (dict, optional): Default initialization arguments. + """ + if name is None: + name = 'Trainer' + cfg = dict(type=name) + return build_from_cfg(cfg, TRAINERS, default_args=default_args) diff --git a/maas_lib/trainers/nlp/__init__.py b/maas_lib/trainers/nlp/__init__.py new file mode 100644 index 00000000..6d61da43 --- /dev/null +++ b/maas_lib/trainers/nlp/__init__.py @@ -0,0 +1 @@ +from .sequence_classification_trainer import SequenceClassificationTrainer diff --git a/maas_lib/trainers/nlp/sequence_classification_trainer.py b/maas_lib/trainers/nlp/sequence_classification_trainer.py new file mode 100644 index 00000000..e88eb95e --- /dev/null +++ b/maas_lib/trainers/nlp/sequence_classification_trainer.py @@ -0,0 +1,226 @@ +import time +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np + +from maas_lib.utils.constant import Tasks +from maas_lib.utils.logger import get_logger +from ..base import BaseTrainer +from ..builder import TRAINERS + +# __all__ = ["SequenceClassificationTrainer"] + +PATH = None +logger = get_logger(PATH) + + +@TRAINERS.register_module( + Tasks.text_classification, module_name=r'bert-sentiment-analysis') +class SequenceClassificationTrainer(BaseTrainer): + + def __init__(self, cfg_file: str, *args, **kwargs): + """ A trainer is used for Sequence Classification + + Based on Config file (*.yaml or *.json), the trainer trains or evaluates on a dataset + + Args: + cfg_file (str): the path of config file + Raises: + ValueError: _description_ + """ + super().__init__(cfg_file) + + def train(self, *args, **kwargs): + logger.info('Train') + ... + + def __attr_is_exist(self, attr: str) -> Tuple[Union[str, bool]]: + """get attribute from config, if the attribute does exist, return false + + Example: + >>> self.__attr_is_exist("model path") + out: (model-path, "/workspace/bert-base-sst2") + >>> self.__attr_is_exist("model weights") + out: (model-weights, False) + + Args: + attr (str): attribute str, "model path" -> config["model"][path] + + Returns: + Tuple[Union[str, bool]]:[target attribute name, the target attribute or False] + """ + paths = attr.split(' ') + attr_str: str = '-'.join(paths) + target = self.cfg[paths[0]] if hasattr(self.cfg, paths[0]) else None + + for path_ in paths[1:]: + if not hasattr(target, path_): + return attr_str, False + target = target[path_] + + if target and target != '': + return attr_str, target + return attr_str, False + + def evaluate(self, + checkpoint_path: Optional[str] = None, + *args, + **kwargs) -> Dict[str, float]: + """evaluate a dataset + + evaluate a dataset via a specific model from the `checkpoint_path` path, if the `checkpoint_path` + does not exist, read from the config file. + + Args: + checkpoint_path (Optional[str], optional): the model path. Defaults to None. + + Returns: + Dict[str, float]: the results about the evaluation + Example: + {"accuracy": 0.5091743119266054, "f1": 0.673780487804878} + """ + import torch + from easynlp.appzoo import load_dataset + from easynlp.appzoo.dataset import GeneralDataset + from easynlp.appzoo.sequence_classification.model import SequenceClassification + from easynlp.utils import losses + from sklearn.metrics import f1_score + from torch.utils.data import DataLoader + + raise_str = 'Attribute {} is not given in config file!' + + metrics = self.__attr_is_exist('evaluation metrics') + eval_batch_size = self.__attr_is_exist('evaluation batch_size') + test_dataset_path = self.__attr_is_exist('dataset valid file') + + attrs = [metrics, eval_batch_size, test_dataset_path] + for attr_ in attrs: + if not attr_[-1]: + raise AttributeError(raise_str.format(attr_[0])) + + if not checkpoint_path: + checkpoint_path = self.__attr_is_exist('evaluation model_path')[-1] + if not checkpoint_path: + raise ValueError( + 'Argument checkout_path must be passed if the evaluation-model_path is not given in config file!' + ) + + max_sequence_length = kwargs.get( + 'max_sequence_length', + self.__attr_is_exist('evaluation max_sequence_length')[-1]) + if not max_sequence_length: + raise ValueError( + 'Argument max_sequence_length must be passed ' + 'if the evaluation-max_sequence_length does not exist in config file!' + ) + + # get the raw online dataset + raw_dataset = load_dataset(*test_dataset_path[-1].split('/')) + valid_dataset = raw_dataset['validation'] + + # generate a standard dataloader + pre_dataset = GeneralDataset(valid_dataset, checkpoint_path, + max_sequence_length) + valid_dataloader = DataLoader( + pre_dataset, + batch_size=eval_batch_size[-1], + shuffle=False, + collate_fn=pre_dataset.batch_fn) + + # generate a model + model = SequenceClassification(checkpoint_path) + + # copy from easynlp (start) + model.eval() + total_loss = 0 + total_steps = 0 + total_samples = 0 + hit_num = 0 + total_num = 0 + + logits_list = list() + y_trues = list() + + total_spent_time = 0.0 + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + model.to(device) + for _step, batch in enumerate(valid_dataloader): + try: + batch = { + # key: val.cuda() if isinstance(val, torch.Tensor) else val + # for key, val in batch.items() + key: + val.to(device) if isinstance(val, torch.Tensor) else val + for key, val in batch.items() + } + except RuntimeError: + batch = {key: val for key, val in batch.items()} + + infer_start_time = time.time() + with torch.no_grad(): + label_ids = batch.pop('label_ids') + outputs = model(batch) + infer_end_time = time.time() + total_spent_time += infer_end_time - infer_start_time + + assert 'logits' in outputs + logits = outputs['logits'] + + y_trues.extend(label_ids.tolist()) + logits_list.extend(logits.tolist()) + hit_num += torch.sum( + torch.argmax(logits, dim=-1) == label_ids).item() + total_num += label_ids.shape[0] + + if len(logits.shape) == 1 or logits.shape[-1] == 1: + tmp_loss = losses.mse_loss(logits, label_ids) + elif len(logits.shape) == 2: + tmp_loss = losses.cross_entropy(logits, label_ids) + else: + raise RuntimeError + + total_loss += tmp_loss.mean().item() + total_steps += 1 + total_samples += valid_dataloader.batch_size + if (_step + 1) % 100 == 0: + total_step = len( + valid_dataloader.dataset) // valid_dataloader.batch_size + logger.info('Eval: {}/{} steps finished'.format( + _step + 1, total_step)) + + logger.info('Inference time = {:.2f}s, [{:.4f} ms / sample] '.format( + total_spent_time, total_spent_time * 1000 / total_samples)) + + eval_loss = total_loss / total_steps + logger.info('Eval loss: {}'.format(eval_loss)) + + logits_list = np.array(logits_list) + eval_outputs = list() + for metric in metrics[-1]: + if metric.endswith('accuracy'): + acc = hit_num / total_num + logger.info('Accuracy: {}'.format(acc)) + eval_outputs.append(('accuracy', acc)) + elif metric == 'f1': + if model.config.num_labels == 2: + f1 = f1_score(y_trues, np.argmax(logits_list, axis=-1)) + logger.info('F1: {}'.format(f1)) + eval_outputs.append(('f1', f1)) + else: + f1 = f1_score( + y_trues, + np.argmax(logits_list, axis=-1), + average='macro') + logger.info('Macro F1: {}'.format(f1)) + eval_outputs.append(('macro-f1', f1)) + f1 = f1_score( + y_trues, + np.argmax(logits_list, axis=-1), + average='micro') + logger.info('Micro F1: {}'.format(f1)) + eval_outputs.append(('micro-f1', f1)) + else: + raise NotImplementedError('Metric %s not implemented' % metric) + # copy from easynlp (end) + + return dict(eval_outputs) diff --git a/maas_lib/utils/constant.py b/maas_lib/utils/constant.py index 0b1a4e75..8f808a6f 100644 --- a/maas_lib/utils/constant.py +++ b/maas_lib/utils/constant.py @@ -62,3 +62,9 @@ class InputFields(object): img = 'img' text = 'text' audio = 'audio' + + +# configuration filename +# in order to avoid conflict with huggingface +# config file we use maas_config instead +CONFIGFILE = 'maas_config.json' diff --git a/requirements/pipeline.txt b/requirements/pipeline.txt index 9e635431..259bbb1b 100644 --- a/requirements/pipeline.txt +++ b/requirements/pipeline.txt @@ -1,5 +1,6 @@ -http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/package/whl/easynlp-0.0.3-py2.py3-none-any.whl +#https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/package/whl/easynlp-0.0.4-py2.py3-none-any.whl tensorflow -torch==1.9.1 -torchaudio==0.9.1 -torchvision==0.10.1 +#--find-links https://download.pytorch.org/whl/torch_stable.html +torch<1.10,>=1.8.0 +torchaudio +torchvision diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 94be2c62..303a084c 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,4 +1,5 @@ addict +https://maashub.oss-cn-hangzhou.aliyuncs.com/releases/maas_hub-0.1.0.dev0-py2.py3-none-any.whl numpy opencv-python-headless Pillow diff --git a/setup.py b/setup.py index c9040815..b9044bff 100644 --- a/setup.py +++ b/setup.py @@ -113,26 +113,39 @@ def parse_requirements(fname='requirements.txt', with_version=True): if line.startswith('http'): print('skip http requirements %s' % line) continue - if line and not line.startswith('#'): + if line and not line.startswith('#') and not line.startswith( + '--'): for info in parse_line(line): yield info + elif line and line.startswith('--find-links'): + eles = line.split() + for e in eles: + e = e.strip() + if 'http' in e: + info = dict(dependency_links=e) + yield info def gen_packages_items(): + items = [] + deps_link = [] if exists(require_fpath): for info in parse_require_file(require_fpath): - parts = [info['package']] - if with_version and 'version' in info: - parts.extend(info['version']) - if not sys.version.startswith('3.4'): - # apparently package_deps are broken in 3.4 - platform_deps = info.get('platform_deps') - if platform_deps is not None: - parts.append(';' + platform_deps) - item = ''.join(parts) - yield item - - packages = list(gen_packages_items()) - return packages + if 'dependency_links' not in info: + parts = [info['package']] + if with_version and 'version' in info: + parts.extend(info['version']) + if not sys.version.startswith('3.4'): + # apparently package_deps are broken in 3.4 + platform_deps = info.get('platform_deps') + if platform_deps is not None: + parts.append(';' + platform_deps) + item = ''.join(parts) + items.append(item) + else: + deps_link.append(info['dependency_links']) + return items, deps_link + + return gen_packages_items() def pack_resource(): @@ -155,7 +168,7 @@ if __name__ == '__main__': # write_version_py() pack_resource() os.chdir('package') - install_requires = parse_requirements('requirements.txt') + install_requires, deps_link = parse_requirements('requirements.txt') setup( name='maas-lib', version=get_version(), @@ -180,4 +193,5 @@ if __name__ == '__main__': license='Apache License 2.0', tests_require=parse_requirements('requirements/tests.txt'), install_requires=install_requires, + dependency_links=deps_link, zip_safe=False) diff --git a/tests/pipelines/test_image_matting.py b/tests/pipelines/test_image_matting.py index 7da6c72f..88360994 100644 --- a/tests/pipelines/test_image_matting.py +++ b/tests/pipelines/test_image_matting.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp import tempfile import unittest from typing import Any, Dict, List, Tuple, Union @@ -18,15 +19,26 @@ class ImageMattingTest(unittest.TestCase): def test_run(self): model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ '.com/data/test/maas/image_matting/matting_person.pb' - with tempfile.NamedTemporaryFile('wb', suffix='.pb') as ofile: - ofile.write(File.read(model_path)) - img_matting = pipeline(Tasks.image_matting, model_path=ofile.name) + with tempfile.TemporaryDirectory() as tmp_dir: + model_file = osp.join(tmp_dir, 'matting_person.pb') + with open(model_file, 'wb') as ofile: + ofile.write(File.read(model_path)) + img_matting = pipeline(Tasks.image_matting, model=tmp_dir) result = img_matting( 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' ) cv2.imwrite('result.png', result['output_png']) + def test_run_modelhub(self): + img_matting = pipeline( + Tasks.image_matting, model='damo/image-matting-person') + + result = img_matting( + 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' + ) + cv2.imwrite('result.png', result['output_png']) + if __name__ == '__main__': unittest.main() diff --git a/tests/pipelines/test_text_classification.py b/tests/pipelines/test_text_classification.py index 0f7ba771..e49c480d 100644 --- a/tests/pipelines/test_text_classification.py +++ b/tests/pipelines/test_text_classification.py @@ -1,11 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import os -import os.path as osp import tempfile import unittest import zipfile +from pathlib import Path from maas_lib.fileio import File +from maas_lib.models import Model from maas_lib.models.nlp import SequenceClassificationModel from maas_lib.pipelines import SequenceClassificationPipeline, pipeline from maas_lib.preprocessors import SequenceClassificationPreprocessor @@ -13,15 +13,15 @@ from maas_lib.preprocessors import SequenceClassificationPreprocessor class SequenceClassificationTest(unittest.TestCase): - def predict(self, pipeline: SequenceClassificationPipeline): + def predict(self, pipeline_ins: SequenceClassificationPipeline): from easynlp.appzoo import load_dataset set = load_dataset('glue', 'sst2') data = set['test']['sentence'][:3] - results = pipeline(data[0]) + results = pipeline_ins(data[0]) print(results) - results = pipeline(data[1]) + results = pipeline_ins(data[1]) print(results) print(data) @@ -29,22 +29,34 @@ class SequenceClassificationTest(unittest.TestCase): def test_run(self): model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_file = osp.join(tmp_dir, 'bert-base-sst2.zip') - with open(tmp_file, 'wb') as ofile: + cache_path_str = r'.cache/easynlp/bert-base-sst2.zip' + cache_path = Path(cache_path_str) + + if not cache_path.exists(): + cache_path.parent.mkdir(parents=True, exist_ok=True) + cache_path.touch(exist_ok=True) + with cache_path.open('wb') as ofile: ofile.write(File.read(model_url)) - with zipfile.ZipFile(tmp_file, 'r') as zipf: - zipf.extractall(tmp_dir) - path = osp.join(tmp_dir, 'bert-base-sst2') - print(path) - model = SequenceClassificationModel(path) - preprocessor = SequenceClassificationPreprocessor( - path, first_sequence='sentence', second_sequence=None) - pipeline1 = SequenceClassificationPipeline(model, preprocessor) - self.predict(pipeline1) - pipeline2 = pipeline( - 'text-classification', model=model, preprocessor=preprocessor) - print(pipeline2('Hello world!')) + + with zipfile.ZipFile(cache_path_str, 'r') as zipf: + zipf.extractall(cache_path.parent) + path = r'.cache/easynlp/bert-base-sst2' + model = SequenceClassificationModel(path) + preprocessor = SequenceClassificationPreprocessor( + path, first_sequence='sentence', second_sequence=None) + pipeline1 = SequenceClassificationPipeline(model, preprocessor) + self.predict(pipeline1) + pipeline2 = pipeline( + 'text-classification', model=model, preprocessor=preprocessor) + print(pipeline2('Hello world!')) + + def test_run_modelhub(self): + model = Model.from_pretrained('damo/bert-base-sst2') + preprocessor = SequenceClassificationPreprocessor( + model.model_dir, first_sequence='sentence', second_sequence=None) + pipeline_ins = pipeline( + task='text-classification', model=model, preprocessor=preprocessor) + self.predict(pipeline_ins) if __name__ == '__main__': diff --git a/tests/trainers/__init__.py b/tests/trainers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/trainers/test_sequence_classification_trainer.py b/tests/trainers/test_sequence_classification_trainer.py new file mode 100644 index 00000000..9846db4f --- /dev/null +++ b/tests/trainers/test_sequence_classification_trainer.py @@ -0,0 +1,38 @@ +import unittest +import zipfile +from pathlib import Path + +from maas_lib.fileio import File +from maas_lib.trainers import build_trainer +from maas_lib.utils.logger import get_logger + +logger = get_logger() + + +class SequenceClassificationTrainerTest(unittest.TestCase): + + def test_sequence_classification(self): + model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ + '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' + cache_path_str = r'.cache/easynlp/bert-base-sst2.zip' + cache_path = Path(cache_path_str) + + if not cache_path.exists(): + cache_path.parent.mkdir(parents=True, exist_ok=True) + cache_path.touch(exist_ok=True) + with cache_path.open('wb') as ofile: + ofile.write(File.read(model_url)) + + with zipfile.ZipFile(cache_path_str, 'r') as zipf: + zipf.extractall(cache_path.parent) + + path: str = './configs/nlp/sequence_classification_trainer.yaml' + default_args = dict(cfg_file=path) + trainer = build_trainer('bert-sentiment-analysis', default_args) + trainer.train() + trainer.evaluate() + + +if __name__ == '__main__': + unittest.main() + ... diff --git a/tests/trainers/test_trainer_base.py b/tests/trainers/test_trainer_base.py new file mode 100644 index 00000000..e764d6c9 --- /dev/null +++ b/tests/trainers/test_trainer_base.py @@ -0,0 +1,19 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from maas_lib.trainers import build_trainer + + +class DummyTrainerTest(unittest.TestCase): + + def test_dummy(self): + default_args = dict(cfg_file='configs/examples/train.json') + trainer = build_trainer('dummy', default_args) + + trainer.train() + trainer.evaluate() + + +if __name__ == '__main__': + unittest.main()