# Conflicts: # modelscope/preprocessors/__init__.pymaster
| @@ -6,7 +6,8 @@ DOCKER_FULL_NAME = $(DOCKER_REGISTRY)/$(DOCKER_ORG)/$(DOCKER_IMAGE) | |||
| # CUDA_VERSION = 11.3 | |||
| # CUDNN_VERSION = 8 | |||
| BASE_RUNTIME = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 | |||
| BASE_DEVEL = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 | |||
| # BASE_DEVEL = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 | |||
| BASE_DEVEL = pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel | |||
| MODELSCOPE_VERSION = $(shell git describe --tags --always) | |||
| @@ -8,13 +8,29 @@ | |||
| # For reference: | |||
| # https://docs.docker.com/develop/develop-images/build_enhancements/ | |||
| #ARG BASE_IMAGE=reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 | |||
| #FROM ${BASE_IMAGE} as dev-base | |||
| # ARG BASE_IMAGE=reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 | |||
| # FROM ${BASE_IMAGE} as dev-base | |||
| FROM reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 as dev-base | |||
| # FROM reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 as dev-base | |||
| FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel | |||
| # FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-runtime | |||
| # config pip source | |||
| RUN mkdir /root/.pip | |||
| COPY docker/rcfiles/pip.conf.tsinghua /root/.pip/pip.conf | |||
| COPY docker/rcfiles/sources.list.aliyun /etc/apt/sources.list | |||
| # Install essential Ubuntu packages | |||
| RUN apt-get update &&\ | |||
| apt-get install -y software-properties-common \ | |||
| build-essential \ | |||
| git \ | |||
| wget \ | |||
| vim \ | |||
| curl \ | |||
| zip \ | |||
| zlib1g-dev \ | |||
| unzip \ | |||
| pkg-config | |||
| # install modelscope and its python env | |||
| WORKDIR /opt/modelscope | |||
| @@ -76,7 +76,7 @@ exclude_patterns = ['build', 'Thumbs.db', '.DS_Store'] | |||
| # The theme to use for HTML and HTML Help pages. See the documentation for | |||
| # a list of builtin themes. | |||
| # | |||
| html_theme = 'sphinx_rtd_theme' | |||
| html_theme = 'sphinx_book_theme' | |||
| html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] | |||
| html_theme_options = {} | |||
| @@ -34,13 +34,62 @@ make linter | |||
| ``` | |||
| ## 2. Test | |||
| ### 2.1 Unit test | |||
| ### 2.1 Test level | |||
| There are mainly three test levels: | |||
| * level 0: tests for basic interface and function of framework, such as `tests/trainers/test_trainer_base.py` | |||
| * level 1: important functional test which test end2end workflow, such as `tests/pipelines/test_image_matting.py` | |||
| * level 2: scenario tests for all the implemented modules such as model, pipeline in different algorithm filed. | |||
| Default test level is 0, which will only run those cases of level 0, you can set test level | |||
| via environment variable `TEST_LEVEL`. For more details, you can refer to [test-doc](https://alidocs.dingtalk.com/i/nodes/mdvQnONayjBJKLXy1Bp38PY2MeXzp5o0?dontjump=true&nav=spaces&navQuery=spaceId%3Dnb9XJNlZxbgrOXyA) | |||
| ```bash | |||
| # run all tests | |||
| TEST_LEVEL=2 make test | |||
| # run important functional tests | |||
| TEST_LEVEL=1 make test | |||
| # run core UT and basic functional tests | |||
| make test | |||
| ``` | |||
| ### 2.2 Test data | |||
| TODO | |||
| When writing test cases, you should assign a test level for your test case using | |||
| following code. If left default, the test level will be 0, it will run in each | |||
| test stage. | |||
| File test_module.py | |||
| ```python | |||
| from modelscope.utils.test_utils import test_level | |||
| class ImageCartoonTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_by_direct_model_download(self): | |||
| pass | |||
| ``` | |||
| ### 2.2 Run tests | |||
| 1. Run your own single test case to test your self-implemented function. You can run your | |||
| test file directly, if it fails to run, pls check if variable `TEST_LEVEL` | |||
| exists in the environment and unset it. | |||
| ```bash | |||
| python tests/path/to/your_test.py | |||
| ``` | |||
| 2. Remember to run core tests in local environment before start a codereview, by default it will | |||
| only run test cases with level 0. | |||
| ```bash | |||
| make tests | |||
| ``` | |||
| 3. After you start a code review, ci tests will be triggered which will run test cases with level 1 | |||
| 4. Daily regression tests will run all cases at 0 am each day using master branch. | |||
| ## Code Review | |||
| @@ -2,4 +2,4 @@ | |||
| from .base import Model | |||
| from .builder import MODELS, build_model | |||
| from .nlp import BertForSequenceClassification | |||
| from .nlp import BertForSequenceClassification, SbertForSentenceSimilarity | |||
| @@ -2,14 +2,13 @@ | |||
| import os.path as osp | |||
| from abc import ABC, abstractmethod | |||
| from typing import Dict, List, Tuple, Union | |||
| from typing import Dict, Union | |||
| from maas_hub.file_download import model_file_download | |||
| from maas_hub.snapshot_download import snapshot_download | |||
| from modelscope.models.builder import build_model | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import CONFIGFILE | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.hub import get_model_cache_dir | |||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
| @@ -21,16 +20,24 @@ class Model(ABC): | |||
| self.model_dir = model_dir | |||
| def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| return self.post_process(self.forward(input)) | |||
| return self.postprocess(self.forward(input)) | |||
| @abstractmethod | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| pass | |||
| def post_process(self, input: Dict[str, Tensor], | |||
| **kwargs) -> Dict[str, Tensor]: | |||
| # model specific postprocess, implementation is optional | |||
| # will be called in Pipeline and evaluation loop(in the future) | |||
| def postprocess(self, input: Dict[str, Tensor], | |||
| **kwargs) -> Dict[str, Tensor]: | |||
| """ Model specific postprocess and convert model output to | |||
| standard model outputs. | |||
| Args: | |||
| inputs: input data | |||
| Return: | |||
| dict of results: a dict containing outputs of model, each | |||
| output should have the standard output name. | |||
| """ | |||
| return input | |||
| @classmethod | |||
| @@ -47,7 +54,8 @@ class Model(ABC): | |||
| # raise ValueError( | |||
| # 'Remote model repo {model_name_or_path} does not exists') | |||
| cfg = Config.from_file(osp.join(local_model_dir, CONFIGFILE)) | |||
| cfg = Config.from_file( | |||
| osp.join(local_model_dir, ModelFile.CONFIGURATION)) | |||
| task_name = cfg.task | |||
| model_cfg = cfg.model | |||
| # TODO @wenmeng.zwm may should manually initialize model after model building | |||
| @@ -1,3 +1,4 @@ | |||
| from .sentence_similarity_model import * # noqa F403 | |||
| from .sequence_classification_model import * # noqa F403 | |||
| from .text_generation_model import * # noqa F403 | |||
| from .zero_shot_classification_model import * | |||
| @@ -0,0 +1,88 @@ | |||
| import os | |||
| from typing import Any, Dict | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from sofa import SbertModel | |||
| from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel | |||
| from torch import nn | |||
| from modelscope.utils.constant import Tasks | |||
| from ..base import Model, Tensor | |||
| from ..builder import MODELS | |||
| __all__ = ['SbertForSentenceSimilarity'] | |||
| class SbertTextClassifier(SbertPreTrainedModel): | |||
| def __init__(self, config): | |||
| super().__init__(config) | |||
| self.num_labels = config.num_labels | |||
| self.config = config | |||
| self.encoder = SbertModel(config, add_pooling_layer=True) | |||
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |||
| def forward(self, input_ids=None, token_type_ids=None): | |||
| outputs = self.encoder( | |||
| input_ids, | |||
| token_type_ids=token_type_ids, | |||
| return_dict=None, | |||
| ) | |||
| pooled_output = outputs[1] | |||
| pooled_output = self.dropout(pooled_output) | |||
| logits = self.classifier(pooled_output) | |||
| return logits | |||
| @MODELS.register_module( | |||
| Tasks.sentence_similarity, | |||
| module_name=r'sbert-base-chinese-sentence-similarity') | |||
| class SbertForSentenceSimilarity(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the sentence similarity model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| model_cls (Optional[Any], optional): model loader, if None, use the | |||
| default loader to load model weights, by default None. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.model_dir = model_dir | |||
| self.model = SbertTextClassifier.from_pretrained( | |||
| model_dir, num_labels=2) | |||
| self.model.eval() | |||
| self.label_path = os.path.join(self.model_dir, 'label_mapping.json') | |||
| with open(self.label_path) as f: | |||
| self.label_mapping = json.load(f) | |||
| self.id2label = {idx: name for name, idx in self.label_mapping.items()} | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Any]): the preprocessed data | |||
| Returns: | |||
| Dict[str, np.ndarray]: results | |||
| Example: | |||
| { | |||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
| } | |||
| """ | |||
| input_ids = torch.tensor(input['input_ids'], dtype=torch.long) | |||
| token_type_ids = torch.tensor( | |||
| input['token_type_ids'], dtype=torch.long) | |||
| with torch.no_grad(): | |||
| logits = self.model(input_ids, token_type_ids) | |||
| probs = logits.softmax(-1).numpy() | |||
| pred = logits.argmax(-1).numpy() | |||
| logits = logits.numpy() | |||
| res = {'predictions': pred, 'probabilities': probs, 'logits': logits} | |||
| return res | |||
| @@ -1,5 +1,7 @@ | |||
| import os | |||
| from typing import Any, Dict | |||
| import json | |||
| import numpy as np | |||
| from modelscope.utils.constant import Tasks | |||
| @@ -34,6 +36,11 @@ class BertForSequenceClassification(Model): | |||
| ('token_type_ids', torch.LongTensor)], | |||
| output_keys=['predictions', 'probabilities', 'logits']) | |||
| self.label_path = os.path.join(self.model_dir, 'label_mapping.json') | |||
| with open(self.label_path) as f: | |||
| self.label_mapping = json.load(f) | |||
| self.id2label = {idx: name for name, idx in self.label_mapping.items()} | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
| """return the result by the model | |||
| @@ -50,3 +57,13 @@ class BertForSequenceClassification(Model): | |||
| } | |||
| """ | |||
| return self.model.predict(input) | |||
| def postprocess(self, inputs: Dict[str, np.ndarray], | |||
| **kwargs) -> Dict[str, np.ndarray]: | |||
| # N x num_classes | |||
| probs = inputs['probabilities'] | |||
| result = { | |||
| 'probs': probs, | |||
| } | |||
| return result | |||
| @@ -12,10 +12,11 @@ from modelscope.pydatasets import PyDataset | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.hub import get_model_cache_dir | |||
| from modelscope.utils.logger import get_logger | |||
| from .outputs import TASK_OUTPUTS | |||
| from .util import is_model_name | |||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
| Input = Union[str, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] | |||
| Input = Union[str, tuple, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] | |||
| InputModel = Union[str, Model] | |||
| output_keys = [ | |||
| @@ -106,8 +107,25 @@ class Pipeline(ABC): | |||
| out = self.preprocess(input) | |||
| out = self.forward(out) | |||
| out = self.postprocess(out, **post_kwargs) | |||
| self._check_output(out) | |||
| return out | |||
| def _check_output(self, input): | |||
| # this attribute is dynamically attached by registry | |||
| # when cls is registered in registry using task name | |||
| task_name = self.group_key | |||
| if task_name not in TASK_OUTPUTS: | |||
| logger.warning(f'task {task_name} output keys are missing') | |||
| return | |||
| output_keys = TASK_OUTPUTS[task_name] | |||
| missing_keys = [] | |||
| for k in output_keys: | |||
| if k not in input: | |||
| missing_keys.append(k) | |||
| if len(missing_keys) > 0: | |||
| raise ValueError(f'expected output keys are {output_keys}, ' | |||
| f'those {missing_keys} are missing') | |||
| def preprocess(self, inputs: Input) -> Dict[str, Any]: | |||
| """ Provide default implementation based on preprocess_cfg and user can reimplement it | |||
| """ | |||
| @@ -125,4 +143,14 @@ class Pipeline(ABC): | |||
| @abstractmethod | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| """ If current pipeline support model reuse, common postprocess | |||
| code should be write here. | |||
| Args: | |||
| inputs: input data | |||
| Return: | |||
| dict of results: a dict containing outputs of model, each | |||
| output should have the standard output name. | |||
| """ | |||
| raise NotImplementedError('postprocess') | |||
| @@ -3,21 +3,20 @@ | |||
| import os.path as osp | |||
| from typing import List, Union | |||
| import json | |||
| from maas_hub.file_download import model_file_download | |||
| from modelscope.models.base import Model | |||
| from modelscope.utils.config import Config, ConfigDict | |||
| from modelscope.utils.constant import CONFIGFILE, Tasks | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.registry import Registry, build_from_cfg | |||
| from .base import Pipeline | |||
| from .util import is_model_name | |||
| PIPELINES = Registry('pipelines') | |||
| DEFAULT_MODEL_FOR_PIPELINE = { | |||
| # TaskName: (pipeline_module_name, model_repo) | |||
| Tasks.image_matting: ('image-matting', 'damo/image-matting-person'), | |||
| Tasks.sentence_similarity: | |||
| ('sbert-base-chinese-sentence-similarity', | |||
| 'damo/nlp_structbert_sentence-similarity_chinese-base'), | |||
| Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting_damo'), | |||
| Tasks.text_classification: | |||
| ('bert-sentiment-analysis', 'damo/bert-base-sst2'), | |||
| Tasks.zero_shot_classification: | |||
| @@ -1,5 +1,5 @@ | |||
| import os.path as osp | |||
| from typing import Any, Dict, List, Tuple, Union | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| @@ -7,7 +7,7 @@ import PIL | |||
| from modelscope.pipelines.base import Input | |||
| from modelscope.preprocessors import load_image | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from ..base import Pipeline | |||
| from ..builder import PIPELINES | |||
| @@ -24,7 +24,7 @@ class ImageMattingPipeline(Pipeline): | |||
| import tensorflow as tf | |||
| if tf.__version__ >= '2.0': | |||
| tf = tf.compat.v1 | |||
| model_path = osp.join(self.model, 'matting_person.pb') | |||
| model_path = osp.join(self.model, ModelFile.TF_GRAPH_FILE) | |||
| config = tf.ConfigProto(allow_soft_placement=True) | |||
| config.gpu_options.allow_growth = True | |||
| @@ -84,8 +84,11 @@ class ImageCaptionPipeline(Pipeline): | |||
| s = torch.cat([s, self.eos_item]) | |||
| return s | |||
| patch_image = self.patch_resize_transform( | |||
| load_image(input)).unsqueeze(0) | |||
| if isinstance(input, Image.Image): | |||
| patch_image = self.patch_resize_transform(input).unsqueeze(0) | |||
| else: | |||
| patch_image = self.patch_resize_transform( | |||
| load_image(input)).unsqueeze(0) | |||
| patch_mask = torch.tensor([True]) | |||
| text = 'what does the image describe?' | |||
| src_text = encode_text( | |||
| @@ -1,3 +1,4 @@ | |||
| from .sentence_similarity_pipeline import * # noqa F403 | |||
| from .sequence_classification_pipeline import * # noqa F403 | |||
| from .text_generation_pipeline import * # noqa F403 | |||
| from .zero_shot_classification_pipeline import * | |||
| @@ -0,0 +1,65 @@ | |||
| import os | |||
| import uuid | |||
| from typing import Any, Dict, Union | |||
| import json | |||
| import numpy as np | |||
| from modelscope.models.nlp import SbertForSentenceSimilarity | |||
| from modelscope.preprocessors import SequenceClassificationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from ...models import Model | |||
| from ..base import Input, Pipeline | |||
| from ..builder import PIPELINES | |||
| __all__ = ['SentenceSimilarityPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.sentence_similarity, | |||
| module_name=r'sbert-base-chinese-sentence-similarity') | |||
| class SentenceSimilarityPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[SbertForSentenceSimilarity, str], | |||
| preprocessor: SequenceClassificationPreprocessor = None, | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp sentence similarity pipeline for prediction | |||
| Args: | |||
| model (SbertForSentenceSimilarity): a model instance | |||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||
| """ | |||
| assert isinstance(model, str) or isinstance(model, SbertForSentenceSimilarity), \ | |||
| 'model must be a single str or SbertForSentenceSimilarity' | |||
| sc_model = model if isinstance( | |||
| model, | |||
| SbertForSentenceSimilarity) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = SequenceClassificationPreprocessor( | |||
| sc_model.model_dir, | |||
| first_sequence='first_sequence', | |||
| second_sequence='second_sequence') | |||
| super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | |||
| assert hasattr(self.model, 'id2label'), \ | |||
| 'id2label map should be initalizaed in init function.' | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| probs = inputs['probabilities'][0] | |||
| num_classes = probs.shape[0] | |||
| top_indices = np.argpartition(probs, -num_classes)[-num_classes:] | |||
| cls_ids = top_indices[np.argsort(-probs[top_indices], axis=-1)] | |||
| probs = probs[cls_ids].tolist() | |||
| cls_names = [self.model.id2label[cid] for cid in cls_ids] | |||
| b = 0 | |||
| return {'scores': probs[b], 'labels': cls_names[b]} | |||
| @@ -41,50 +41,29 @@ class SequenceClassificationPipeline(Pipeline): | |||
| second_sequence=None) | |||
| super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | |||
| from easynlp.utils import io | |||
| self.label_path = os.path.join(sc_model.model_dir, | |||
| 'label_mapping.json') | |||
| with io.open(self.label_path) as f: | |||
| self.label_mapping = json.load(f) | |||
| self.label_id_to_name = { | |||
| idx: name | |||
| for name, idx in self.label_mapping.items() | |||
| } | |||
| assert hasattr(self.model, 'id2label'), \ | |||
| 'id2label map should be initalizaed in init function.' | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||
| def postprocess(self, | |||
| inputs: Dict[str, Any], | |||
| topk: int = 5) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| inputs (Dict[str, Any]): input data dict | |||
| topk (int): return topk classification result. | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| # NxC np.ndarray | |||
| probs = inputs['probs'][0] | |||
| num_classes = probs.shape[0] | |||
| topk = min(topk, num_classes) | |||
| top_indices = np.argpartition(probs, -topk)[-topk:] | |||
| cls_ids = top_indices[np.argsort(probs[top_indices])] | |||
| probs = probs[cls_ids].tolist() | |||
| probs = inputs['probabilities'] | |||
| logits = inputs['logits'] | |||
| predictions = np.argsort(-probs, axis=-1) | |||
| preds = predictions[0] | |||
| b = 0 | |||
| new_result = list() | |||
| for pred in preds: | |||
| new_result.append({ | |||
| 'pred': self.label_id_to_name[pred], | |||
| 'prob': float(probs[b][pred]), | |||
| 'logit': float(logits[b][pred]) | |||
| }) | |||
| new_results = list() | |||
| new_results.append({ | |||
| 'id': | |||
| inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()), | |||
| 'output': | |||
| new_result, | |||
| 'predictions': | |||
| new_result[0]['pred'], | |||
| 'probabilities': | |||
| ','.join([str(t) for t in inputs['probabilities'][b]]), | |||
| 'logits': | |||
| ','.join([str(t) for t in inputs['logits'][b]]) | |||
| }) | |||
| cls_names = [self.model.id2label[cid] for cid in cls_ids] | |||
| return new_results[0] | |||
| return {'scores': probs, 'labels': cls_names} | |||
| @@ -56,4 +56,4 @@ class TextGenerationPipeline(Pipeline): | |||
| '').split('[SEP]')[0].replace('[CLS]', | |||
| '').replace('[SEP]', | |||
| '').replace('[UNK]', '') | |||
| return {'pred_string': pred_string} | |||
| return {'text': pred_string} | |||
| @@ -0,0 +1,98 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from modelscope.utils.constant import Tasks | |||
| TASK_OUTPUTS = { | |||
| # ============ vision tasks =================== | |||
| # image classification result for single sample | |||
| # { | |||
| # "labels": ["dog", "horse", "cow", "cat"], | |||
| # "scores": [0.9, 0.1, 0.05, 0.05] | |||
| # } | |||
| Tasks.image_classification: ['scores', 'labels'], | |||
| Tasks.image_tagging: ['scores', 'labels'], | |||
| # object detection result for single sample | |||
| # { | |||
| # "boxes": [ | |||
| # [x1, y1, x2, y2], | |||
| # [x1, y1, x2, y2], | |||
| # [x1, y1, x2, y2], | |||
| # ], | |||
| # "labels": ["dog", "horse", "cow", "cat"], | |||
| # "scores": [0.9, 0.1, 0.05, 0.05] | |||
| # } | |||
| Tasks.object_detection: ['scores', 'labels', 'boxes'], | |||
| # instance segmentation result for single sample | |||
| # { | |||
| # "masks": [ | |||
| # np.array in bgr channel order | |||
| # ], | |||
| # "labels": ["dog", "horse", "cow", "cat"], | |||
| # "scores": [0.9, 0.1, 0.05, 0.05] | |||
| # } | |||
| Tasks.image_segmentation: ['scores', 'labels', 'boxes'], | |||
| # image generation/editing/matting result for single sample | |||
| # { | |||
| # "output_png": np.array with shape(h, w, 4) | |||
| # for matting or (h, w, 3) for general purpose | |||
| # } | |||
| Tasks.image_editing: ['output_png'], | |||
| Tasks.image_matting: ['output_png'], | |||
| Tasks.image_generation: ['output_png'], | |||
| # pose estimation result for single sample | |||
| # { | |||
| # "poses": np.array with shape [num_pose, num_keypoint, 3], | |||
| # each keypoint is a array [x, y, score] | |||
| # "boxes": np.array with shape [num_pose, 4], each box is | |||
| # [x1, y1, x2, y2] | |||
| # } | |||
| Tasks.pose_estimation: ['poses', 'boxes'], | |||
| # ============ nlp tasks =================== | |||
| # text classification result for single sample | |||
| # { | |||
| # "labels": ["happy", "sad", "calm", "angry"], | |||
| # "scores": [0.9, 0.1, 0.05, 0.05] | |||
| # } | |||
| Tasks.text_classification: ['scores', 'labels'], | |||
| # text generation result for single sample | |||
| # { | |||
| # "text": "this is text generated by a model." | |||
| # } | |||
| Tasks.text_generation: ['text'], | |||
| # ============ audio tasks =================== | |||
| # ============ multi-modal tasks =================== | |||
| # image caption result for single sample | |||
| # { | |||
| # "caption": "this is an image caption text." | |||
| # } | |||
| Tasks.image_captioning: ['caption'], | |||
| # visual grounding result for single sample | |||
| # { | |||
| # "boxes": [ | |||
| # [x1, y1, x2, y2], | |||
| # [x1, y1, x2, y2], | |||
| # [x1, y1, x2, y2], | |||
| # ], | |||
| # "scores": [0.9, 0.1, 0.05, 0.05] | |||
| # } | |||
| Tasks.visual_grounding: ['boxes', 'scores'], | |||
| # text_to_image result for a single sample | |||
| # { | |||
| # "image": np.ndarray with shape [height, width, 3] | |||
| # } | |||
| Tasks.text_to_image_synthesis: ['image'] | |||
| } | |||
| @@ -1,14 +1,11 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| from typing import List, Union | |||
| import json | |||
| from maas_hub.file_download import model_file_download | |||
| from matplotlib.pyplot import get | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import CONFIGFILE | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @@ -29,14 +26,14 @@ def is_model_name(model: Union[str, List]): | |||
| def is_model_name_impl(model): | |||
| if osp.exists(model): | |||
| cfg_file = osp.join(model, CONFIGFILE) | |||
| cfg_file = osp.join(model, ModelFile.CONFIGURATION) | |||
| if osp.exists(cfg_file): | |||
| return is_config_has_model(cfg_file) | |||
| else: | |||
| return False | |||
| else: | |||
| try: | |||
| cfg_file = model_file_download(model, CONFIGFILE) | |||
| cfg_file = model_file_download(model, ModelFile.CONFIGURATION) | |||
| return is_config_has_model(cfg_file) | |||
| except Exception: | |||
| return False | |||
| @@ -5,4 +5,3 @@ from .builder import PREPROCESSORS, build_preprocessor | |||
| from .common import Compose | |||
| from .image import LoadImage, load_image | |||
| from .nlp import * # noqa F403 | |||
| from .nlp import TextGenerationPreprocessor, ZeroShotClassificationPreprocessor | |||
| @@ -10,7 +10,10 @@ from modelscope.utils.type_assert import type_assert | |||
| from .base import Preprocessor | |||
| from .builder import PREPROCESSORS | |||
| __all__ = ['Tokenize', 'SequenceClassificationPreprocessor'] | |||
| __all__ = [ | |||
| 'Tokenize', 'SequenceClassificationPreprocessor', | |||
| 'TextGenerationPreprocessor' | |||
| ] | |||
| @PREPROCESSORS.register_module(Fields.nlp) | |||
| @@ -28,7 +31,7 @@ class Tokenize(Preprocessor): | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=r'bert-sentiment-analysis') | |||
| Fields.nlp, module_name=r'bert-sequence-classification') | |||
| class SequenceClassificationPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| @@ -48,21 +51,42 @@ class SequenceClassificationPreprocessor(Preprocessor): | |||
| self.sequence_length = kwargs.pop('sequence_length', 128) | |||
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) | |||
| print(f'this is the tokenzier {self.tokenizer}') | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| @type_assert(object, (str, tuple)) | |||
| def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| data (str or tuple): | |||
| sentence1 (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| or | |||
| (sentence1, sentence2) | |||
| sentence1 (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| sentence2 (str): a sentence | |||
| Example: | |||
| 'you are so beautiful.' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| new_data = {self.first_sequence: data} | |||
| if not isinstance(data, tuple): | |||
| data = ( | |||
| data, | |||
| None, | |||
| ) | |||
| sentence1, sentence2 = data | |||
| new_data = { | |||
| self.first_sequence: sentence1, | |||
| self.second_sequence: sentence2 | |||
| } | |||
| # preprocess the data for the model input | |||
| rst = { | |||
| @@ -32,6 +32,7 @@ class Tasks(object): | |||
| # nlp tasks | |||
| zero_shot_classification = 'zero-shot-classification' | |||
| sentiment_analysis = 'sentiment-analysis' | |||
| sentence_similarity = 'sentence-similarity' | |||
| text_classification = 'text-classification' | |||
| relation_extraction = 'relation-extraction' | |||
| zero_shot = 'zero-shot' | |||
| @@ -51,7 +52,7 @@ class Tasks(object): | |||
| text_to_speech = 'text-to-speech' | |||
| speech_signal_process = 'speech-signal-process' | |||
| # multi-media | |||
| # multi-modal tasks | |||
| image_captioning = 'image-captioning' | |||
| visual_grounding = 'visual-grounding' | |||
| text_to_image_synthesis = 'text-to-image-synthesis' | |||
| @@ -72,5 +73,16 @@ class Hubs(object): | |||
| huggingface = 'huggingface' | |||
| # configuration filename | |||
| CONFIGFILE = 'configuration.json' | |||
| class ModelFile(object): | |||
| CONFIGURATION = 'configuration.json' | |||
| README = 'README.md' | |||
| TF_SAVED_MODEL_FILE = 'saved_model.pb' | |||
| TF_GRAPH_FILE = 'tf_graph.pb' | |||
| TF_CHECKPOINT_FOLDER = 'tf_ckpts' | |||
| TF_CKPT_PREFIX = 'ckpt-' | |||
| TORCH_MODEL_FILE = 'pytorch_model.pt' | |||
| TORCH_MODEL_BIN_FILE = 'pytorch_model.bin' | |||
| TENSORFLOW = 'tensorflow' | |||
| PYTORCH = 'pytorch' | |||
| @@ -1,7 +1,6 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import inspect | |||
| from email.policy import default | |||
| from modelscope.utils.logger import get_logger | |||
| @@ -70,6 +69,7 @@ class Registry(object): | |||
| f'{self._name}[{group_key}]') | |||
| self._modules[group_key][module_name] = module_cls | |||
| module_cls.group_key = group_key | |||
| if module_name in self._modules[default_group]: | |||
| if id(self._modules[default_group][module_name]) == id(module_cls): | |||
| @@ -0,0 +1,20 @@ | |||
| #!/usr/bin/env python | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| TEST_LEVEL = 2 | |||
| TEST_LEVEL_STR = 'TEST_LEVEL' | |||
| def test_level(): | |||
| global TEST_LEVEL | |||
| if TEST_LEVEL_STR in os.environ: | |||
| TEST_LEVEL = int(os.environ[TEST_LEVEL_STR]) | |||
| return TEST_LEVEL | |||
| def set_test_level(level: int): | |||
| global TEST_LEVEL | |||
| TEST_LEVEL = level | |||
| @@ -1,6 +1,7 @@ | |||
| docutils==0.16.0 | |||
| recommonmark | |||
| sphinx==4.0.2 | |||
| sphinx-book-theme | |||
| sphinx-copybutton | |||
| sphinx_markdown_tables | |||
| sphinx_rtd_theme==0.5.2 | |||
| @@ -1,10 +1,10 @@ | |||
| addict | |||
| datasets | |||
| easydict | |||
| https://maashub.oss-cn-hangzhou.aliyuncs.com/releases/maas_hub-0.1.0.dev0-py2.py3-none-any.whl | |||
| https://mindscope.oss-cn-hangzhou.aliyuncs.com/sdklib/maas_hub-0.2.2.dev0-py3-none-any.whl | |||
| numpy | |||
| opencv-python-headless | |||
| Pillow | |||
| Pillow>=6.2.0 | |||
| pyyaml | |||
| requests | |||
| tokenizers<=0.10.3 | |||
| @@ -35,9 +35,10 @@ class CustomPipelineTest(unittest.TestCase): | |||
| CustomPipeline1() | |||
| def test_custom(self): | |||
| dummy_task = 'dummy-task' | |||
| @PIPELINES.register_module( | |||
| group_key=Tasks.image_tagging, module_name='custom-image') | |||
| group_key=dummy_task, module_name='custom-image') | |||
| class CustomImagePipeline(Pipeline): | |||
| def __init__(self, | |||
| @@ -67,32 +68,29 @@ class CustomPipelineTest(unittest.TestCase): | |||
| outputs['filename'] = inputs['url'] | |||
| img = inputs['img'] | |||
| new_image = img.resize((img.width // 2, img.height // 2)) | |||
| outputs['resize_image'] = np.array(new_image) | |||
| outputs['dummy_result'] = 'dummy_result' | |||
| outputs['output_png'] = np.array(new_image) | |||
| return outputs | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| self.assertTrue('custom-image' in PIPELINES.modules[default_group]) | |||
| add_default_pipeline_info(Tasks.image_tagging, 'custom-image') | |||
| add_default_pipeline_info(dummy_task, 'custom-image', overwrite=True) | |||
| pipe = pipeline(pipeline_name='custom-image') | |||
| pipe2 = pipeline(Tasks.image_tagging) | |||
| pipe2 = pipeline(dummy_task) | |||
| self.assertTrue(type(pipe) is type(pipe2)) | |||
| img_url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.' \ | |||
| 'aliyuncs.com/data/test/images/image1.jpg' | |||
| output = pipe(img_url) | |||
| self.assertEqual(output['filename'], img_url) | |||
| self.assertEqual(output['resize_image'].shape, (318, 512, 3)) | |||
| self.assertEqual(output['dummy_result'], 'dummy_result') | |||
| self.assertEqual(output['output_png'].shape, (318, 512, 3)) | |||
| outputs = pipe([img_url for i in range(4)]) | |||
| self.assertEqual(len(outputs), 4) | |||
| for out in outputs: | |||
| self.assertEqual(out['filename'], img_url) | |||
| self.assertEqual(out['resize_image'].shape, (318, 512, 3)) | |||
| self.assertEqual(out['dummy_result'], 'dummy_result') | |||
| self.assertEqual(out['output_png'].shape, (318, 512, 3)) | |||
| if __name__ == '__main__': | |||
| @@ -7,11 +7,12 @@ import unittest | |||
| from modelscope.fileio import File | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class ImageCaptionTest(unittest.TestCase): | |||
| @unittest.skip('skip long test') | |||
| @unittest.skip('skip before model is restored in model hub') | |||
| def test_run(self): | |||
| model = 'https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt' | |||
| @@ -9,25 +9,27 @@ import cv2 | |||
| from modelscope.fileio import File | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pydatasets import PyDataset | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.hub import get_model_cache_dir | |||
| from modelscope.utils.test_utils import test_level | |||
| class ImageMattingTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/image-matting-person' | |||
| self.model_id = 'damo/cv_unet_image-matting_damo' | |||
| # switch to False if downloading everytime is not desired | |||
| purge_cache = True | |||
| if purge_cache: | |||
| shutil.rmtree( | |||
| get_model_cache_dir(self.model_id), ignore_errors=True) | |||
| def test_run(self): | |||
| @unittest.skip('deprecated, download model from model hub instead') | |||
| def test_run_with_direct_file_download(self): | |||
| model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | |||
| '.com/data/test/maas/image_matting/matting_person.pb' | |||
| with tempfile.TemporaryDirectory() as tmp_dir: | |||
| model_file = osp.join(tmp_dir, 'matting_person.pb') | |||
| model_file = osp.join(tmp_dir, ModelFile.TF_GRAPH_FILE) | |||
| with open(model_file, 'wb') as ofile: | |||
| ofile.write(File.read(model_path)) | |||
| img_matting = pipeline(Tasks.image_matting, model=tmp_dir) | |||
| @@ -37,6 +39,7 @@ class ImageMattingTest(unittest.TestCase): | |||
| ) | |||
| cv2.imwrite('result.png', result['output_png']) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_dataset(self): | |||
| input_location = [ | |||
| 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' | |||
| @@ -51,6 +54,7 @@ class ImageMattingTest(unittest.TestCase): | |||
| cv2.imwrite('result.png', next(result)['output_png']) | |||
| print(f'Output written to {osp.abspath("result.png")}') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_modelhub(self): | |||
| img_matting = pipeline(Tasks.image_matting, model=self.model_id) | |||
| @@ -60,6 +64,7 @@ class ImageMattingTest(unittest.TestCase): | |||
| cv2.imwrite('result.png', result['output_png']) | |||
| print(f'Output written to {osp.abspath("result.png")}') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_modelhub_default_model(self): | |||
| img_matting = pipeline(Tasks.image_matting) | |||
| @@ -8,6 +8,7 @@ import cv2 | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class ImageCartoonTest(unittest.TestCase): | |||
| @@ -36,10 +37,12 @@ class ImageCartoonTest(unittest.TestCase): | |||
| img_cartoon = pipeline(Tasks.image_generation, model=model_dir) | |||
| self.pipeline_inference(img_cartoon, self.test_image) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_modelhub(self): | |||
| img_cartoon = pipeline(Tasks.image_generation, model=self.model_id) | |||
| self.pipeline_inference(img_cartoon, self.test_image) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_modelhub_default_model(self): | |||
| img_cartoon = pipeline(Tasks.image_generation) | |||
| self.pipeline_inference(img_cartoon, self.test_image) | |||
| @@ -0,0 +1,67 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import shutil | |||
| import unittest | |||
| from maas_hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import SbertForSentenceSimilarity | |||
| from modelscope.pipelines import SentenceSimilarityPipeline, pipeline | |||
| from modelscope.preprocessors import SequenceClassificationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.hub import get_model_cache_dir | |||
| from modelscope.utils.test_utils import test_level | |||
| class SentenceSimilarityTest(unittest.TestCase): | |||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||
| sentence1 = '今天气温比昨天高么?' | |||
| sentence2 = '今天湿度比昨天高么?' | |||
| def setUp(self) -> None: | |||
| # switch to False if downloading everytime is not desired | |||
| purge_cache = True | |||
| if purge_cache: | |||
| shutil.rmtree( | |||
| get_model_cache_dir(self.model_id), ignore_errors=True) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| tokenizer = SequenceClassificationPreprocessor(cache_path) | |||
| model = SbertForSentenceSimilarity(cache_path, tokenizer=tokenizer) | |||
| pipeline1 = SentenceSimilarityPipeline(model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline( | |||
| Tasks.sentence_similarity, model=model, preprocessor=tokenizer) | |||
| print('test1') | |||
| print(f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' | |||
| f'pipeline1:{pipeline1(input=(self.sentence1, self.sentence2))}') | |||
| print() | |||
| print( | |||
| f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' | |||
| f'pipeline1: {pipeline2(input=(self.sentence1, self.sentence2))}') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| tokenizer = SequenceClassificationPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.sentence_similarity, | |||
| model=model, | |||
| preprocessor=tokenizer) | |||
| print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.sentence_similarity, model=self.model_id) | |||
| print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.sentence_similarity) | |||
| print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -12,6 +12,7 @@ from modelscope.preprocessors import SequenceClassificationPreprocessor | |||
| from modelscope.pydatasets import PyDataset | |||
| from modelscope.utils.constant import Hubs, Tasks | |||
| from modelscope.utils.hub import get_model_cache_dir | |||
| from modelscope.utils.test_utils import test_level | |||
| class SequenceClassificationTest(unittest.TestCase): | |||
| @@ -43,6 +44,7 @@ class SequenceClassificationTest(unittest.TestCase): | |||
| break | |||
| print(r) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run(self): | |||
| model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ | |||
| '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' | |||
| @@ -67,6 +69,7 @@ class SequenceClassificationTest(unittest.TestCase): | |||
| Tasks.text_classification, model=model, preprocessor=preprocessor) | |||
| print(pipeline2('Hello world!')) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| preprocessor = SequenceClassificationPreprocessor( | |||
| @@ -77,6 +80,7 @@ class SequenceClassificationTest(unittest.TestCase): | |||
| preprocessor=preprocessor) | |||
| self.predict(pipeline_ins) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| text_classification = pipeline( | |||
| task=Tasks.text_classification, model=self.model_id) | |||
| @@ -85,6 +89,7 @@ class SequenceClassificationTest(unittest.TestCase): | |||
| 'glue', name='sst2', target='sentence', hub=Hubs.huggingface)) | |||
| self.printDataset(result) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| text_classification = pipeline(task=Tasks.text_classification) | |||
| result = text_classification( | |||
| @@ -92,6 +97,7 @@ class SequenceClassificationTest(unittest.TestCase): | |||
| 'glue', name='sst2', target='sentence', hub=Hubs.huggingface)) | |||
| self.printDataset(result) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_dataset(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| preprocessor = SequenceClassificationPreprocessor( | |||
| @@ -8,6 +8,7 @@ from modelscope.models.nlp import PalmForTextGenerationModel | |||
| from modelscope.pipelines import TextGenerationPipeline, pipeline | |||
| from modelscope.preprocessors import TextGenerationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class TextGenerationTest(unittest.TestCase): | |||
| @@ -15,7 +16,7 @@ class TextGenerationTest(unittest.TestCase): | |||
| input1 = "今日天气类型='晴'&温度变化趋势='大幅上升'&最低气温='28℃'&最高气温='31℃'&体感='湿热'" | |||
| input2 = "今日天气类型='多云'&体感='舒适'&最低气温='26℃'&最高气温='30℃'" | |||
| @unittest.skip('skip temporarily to save test time') | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| preprocessor = TextGenerationPreprocessor( | |||
| @@ -29,6 +30,7 @@ class TextGenerationTest(unittest.TestCase): | |||
| print() | |||
| print(f'input: {self.input2}\npipeline2: {pipeline2(self.input2)}') | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| preprocessor = TextGenerationPreprocessor( | |||
| @@ -37,11 +39,13 @@ class TextGenerationTest(unittest.TestCase): | |||
| task=Tasks.text_generation, model=model, preprocessor=preprocessor) | |||
| print(pipeline_ins(self.input1)) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.text_generation, model=self.model_id) | |||
| print(pipeline_ins(self.input2)) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.text_generation) | |||
| print(pipeline_ins(self.input2)) | |||
| @@ -0,0 +1,22 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| import PIL | |||
| from modelscope.preprocessors import load_image | |||
| from modelscope.utils.logger import get_logger | |||
| class ImagePreprocessorTest(unittest.TestCase): | |||
| def test_load(self): | |||
| img = load_image( | |||
| 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' | |||
| ) | |||
| self.assertTrue(isinstance(img, PIL.Image.Image)) | |||
| self.assertEqual(img.size, (948, 533)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -7,6 +7,11 @@ import sys | |||
| import unittest | |||
| from fnmatch import fnmatch | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.test_utils import set_test_level, test_level | |||
| logger = get_logger() | |||
| def gather_test_cases(test_dir, pattern, list_tests): | |||
| case_list = [] | |||
| @@ -49,5 +54,9 @@ if __name__ == '__main__': | |||
| '--pattern', default='test_*.py', help='test file pattern') | |||
| parser.add_argument( | |||
| '--test_dir', default='tests', help='directory to be tested') | |||
| parser.add_argument( | |||
| '--level', default=0, help='2 -- all, 1 -- p1, 0 -- p0') | |||
| args = parser.parse_args() | |||
| set_test_level(args.level) | |||
| logger.info(f'TEST LEVEL: {test_level()}') | |||
| main(args) | |||
| @@ -1,11 +1,8 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import argparse | |||
| import os.path as osp | |||
| import tempfile | |||
| import unittest | |||
| from pathlib import Path | |||
| from modelscope.fileio import dump, load | |||
| from modelscope.utils.config import Config | |||
| obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}} | |||