diff --git a/Makefile.docker b/Makefile.docker index bbac840e..97400318 100644 --- a/Makefile.docker +++ b/Makefile.docker @@ -6,7 +6,8 @@ DOCKER_FULL_NAME = $(DOCKER_REGISTRY)/$(DOCKER_ORG)/$(DOCKER_IMAGE) # CUDA_VERSION = 11.3 # CUDNN_VERSION = 8 BASE_RUNTIME = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 -BASE_DEVEL = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 +# BASE_DEVEL = reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 +BASE_DEVEL = pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel MODELSCOPE_VERSION = $(shell git describe --tags --always) diff --git a/docker/pytorch.dockerfile b/docker/pytorch.dockerfile index 73c35af1..4862cab6 100644 --- a/docker/pytorch.dockerfile +++ b/docker/pytorch.dockerfile @@ -8,13 +8,29 @@ # For reference: # https://docs.docker.com/develop/develop-images/build_enhancements/ -#ARG BASE_IMAGE=reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 -#FROM ${BASE_IMAGE} as dev-base +# ARG BASE_IMAGE=reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 +# FROM ${BASE_IMAGE} as dev-base -FROM reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 as dev-base +# FROM reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 as dev-base +FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel +# FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-runtime # config pip source RUN mkdir /root/.pip COPY docker/rcfiles/pip.conf.tsinghua /root/.pip/pip.conf +COPY docker/rcfiles/sources.list.aliyun /etc/apt/sources.list + +# Install essential Ubuntu packages +RUN apt-get update &&\ + apt-get install -y software-properties-common \ + build-essential \ + git \ + wget \ + vim \ + curl \ + zip \ + zlib1g-dev \ + unzip \ + pkg-config # install modelscope and its python env WORKDIR /opt/modelscope diff --git a/docs/source/conf.py b/docs/source/conf.py index 2c2a0017..50ac2fa0 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -76,7 +76,7 @@ exclude_patterns = ['build', 'Thumbs.db', '.DS_Store'] # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = 'sphinx_book_theme' html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] html_theme_options = {} diff --git a/docs/source/develop.md b/docs/source/develop.md index f0c8b8b0..f96590b0 100644 --- a/docs/source/develop.md +++ b/docs/source/develop.md @@ -34,13 +34,62 @@ make linter ``` ## 2. Test -### 2.1 Unit test + +### 2.1 Test level + +There are mainly three test levels: + +* level 0: tests for basic interface and function of framework, such as `tests/trainers/test_trainer_base.py` +* level 1: important functional test which test end2end workflow, such as `tests/pipelines/test_image_matting.py` +* level 2: scenario tests for all the implemented modules such as model, pipeline in different algorithm filed. + +Default test level is 0, which will only run those cases of level 0, you can set test level +via environment variable `TEST_LEVEL`. For more details, you can refer to [test-doc](https://alidocs.dingtalk.com/i/nodes/mdvQnONayjBJKLXy1Bp38PY2MeXzp5o0?dontjump=true&nav=spaces&navQuery=spaceId%3Dnb9XJNlZxbgrOXyA) + + ```bash +# run all tests +TEST_LEVEL=2 make test + +# run important functional tests +TEST_LEVEL=1 make test + +# run core UT and basic functional tests make test ``` -### 2.2 Test data -TODO +When writing test cases, you should assign a test level for your test case using +following code. If left default, the test level will be 0, it will run in each +test stage. + +File test_module.py +```python +from modelscope.utils.test_utils import test_level + +class ImageCartoonTest(unittest.TestCase): + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_by_direct_model_download(self): + pass +``` + +### 2.2 Run tests + +1. Run your own single test case to test your self-implemented function. You can run your +test file directly, if it fails to run, pls check if variable `TEST_LEVEL` +exists in the environment and unset it. +```bash +python tests/path/to/your_test.py +``` + +2. Remember to run core tests in local environment before start a codereview, by default it will +only run test cases with level 0. +```bash +make tests +``` + +3. After you start a code review, ci tests will be triggered which will run test cases with level 1 + +4. Daily regression tests will run all cases at 0 am each day using master branch. ## Code Review diff --git a/modelscope/models/__init__.py b/modelscope/models/__init__.py index 170e525e..d9a89d35 100644 --- a/modelscope/models/__init__.py +++ b/modelscope/models/__init__.py @@ -2,4 +2,4 @@ from .base import Model from .builder import MODELS, build_model -from .nlp import BertForSequenceClassification +from .nlp import BertForSequenceClassification, SbertForSentenceSimilarity diff --git a/modelscope/models/base.py b/modelscope/models/base.py index e641236d..88b1e3b0 100644 --- a/modelscope/models/base.py +++ b/modelscope/models/base.py @@ -2,14 +2,13 @@ import os.path as osp from abc import ABC, abstractmethod -from typing import Dict, List, Tuple, Union +from typing import Dict, Union -from maas_hub.file_download import model_file_download from maas_hub.snapshot_download import snapshot_download from modelscope.models.builder import build_model from modelscope.utils.config import Config -from modelscope.utils.constant import CONFIGFILE +from modelscope.utils.constant import ModelFile from modelscope.utils.hub import get_model_cache_dir Tensor = Union['torch.Tensor', 'tf.Tensor'] @@ -21,16 +20,24 @@ class Model(ABC): self.model_dir = model_dir def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: - return self.post_process(self.forward(input)) + return self.postprocess(self.forward(input)) @abstractmethod def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: pass - def post_process(self, input: Dict[str, Tensor], - **kwargs) -> Dict[str, Tensor]: - # model specific postprocess, implementation is optional - # will be called in Pipeline and evaluation loop(in the future) + def postprocess(self, input: Dict[str, Tensor], + **kwargs) -> Dict[str, Tensor]: + """ Model specific postprocess and convert model output to + standard model outputs. + + Args: + inputs: input data + + Return: + dict of results: a dict containing outputs of model, each + output should have the standard output name. + """ return input @classmethod @@ -47,7 +54,8 @@ class Model(ABC): # raise ValueError( # 'Remote model repo {model_name_or_path} does not exists') - cfg = Config.from_file(osp.join(local_model_dir, CONFIGFILE)) + cfg = Config.from_file( + osp.join(local_model_dir, ModelFile.CONFIGURATION)) task_name = cfg.task model_cfg = cfg.model # TODO @wenmeng.zwm may should manually initialize model after model building diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index 37e6dd3c..6c3c17c0 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -1,3 +1,4 @@ +from .sentence_similarity_model import * # noqa F403 from .sequence_classification_model import * # noqa F403 from .text_generation_model import * # noqa F403 from .zero_shot_classification_model import * diff --git a/modelscope/models/nlp/sentence_similarity_model.py b/modelscope/models/nlp/sentence_similarity_model.py new file mode 100644 index 00000000..98daac92 --- /dev/null +++ b/modelscope/models/nlp/sentence_similarity_model.py @@ -0,0 +1,88 @@ +import os +from typing import Any, Dict + +import json +import numpy as np +import torch +from sofa import SbertModel +from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel +from torch import nn + +from modelscope.utils.constant import Tasks +from ..base import Model, Tensor +from ..builder import MODELS + +__all__ = ['SbertForSentenceSimilarity'] + + +class SbertTextClassifier(SbertPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.encoder = SbertModel(config, add_pooling_layer=True) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, input_ids=None, token_type_ids=None): + outputs = self.encoder( + input_ids, + token_type_ids=token_type_ids, + return_dict=None, + ) + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + return logits + + +@MODELS.register_module( + Tasks.sentence_similarity, + module_name=r'sbert-base-chinese-sentence-similarity') +class SbertForSentenceSimilarity(Model): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the sentence similarity model from the `model_dir` path. + + Args: + model_dir (str): the model path. + model_cls (Optional[Any], optional): model loader, if None, use the + default loader to load model weights, by default None. + """ + super().__init__(model_dir, *args, **kwargs) + self.model_dir = model_dir + + self.model = SbertTextClassifier.from_pretrained( + model_dir, num_labels=2) + self.model.eval() + self.label_path = os.path.join(self.model_dir, 'label_mapping.json') + with open(self.label_path) as f: + self.label_mapping = json.load(f) + self.id2label = {idx: name for name, idx in self.label_mapping.items()} + + def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: + """return the result by the model + + Args: + input (Dict[str, Any]): the preprocessed data + + Returns: + Dict[str, np.ndarray]: results + Example: + { + 'predictions': array([1]), # lable 0-negative 1-positive + 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), + 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value + } + """ + input_ids = torch.tensor(input['input_ids'], dtype=torch.long) + token_type_ids = torch.tensor( + input['token_type_ids'], dtype=torch.long) + with torch.no_grad(): + logits = self.model(input_ids, token_type_ids) + probs = logits.softmax(-1).numpy() + pred = logits.argmax(-1).numpy() + logits = logits.numpy() + res = {'predictions': pred, 'probabilities': probs, 'logits': logits} + return res diff --git a/modelscope/models/nlp/sequence_classification_model.py b/modelscope/models/nlp/sequence_classification_model.py index 6ced7a4e..a3cc4b68 100644 --- a/modelscope/models/nlp/sequence_classification_model.py +++ b/modelscope/models/nlp/sequence_classification_model.py @@ -1,5 +1,7 @@ +import os from typing import Any, Dict +import json import numpy as np from modelscope.utils.constant import Tasks @@ -34,6 +36,11 @@ class BertForSequenceClassification(Model): ('token_type_ids', torch.LongTensor)], output_keys=['predictions', 'probabilities', 'logits']) + self.label_path = os.path.join(self.model_dir, 'label_mapping.json') + with open(self.label_path) as f: + self.label_mapping = json.load(f) + self.id2label = {idx: name for name, idx in self.label_mapping.items()} + def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: """return the result by the model @@ -50,3 +57,13 @@ class BertForSequenceClassification(Model): } """ return self.model.predict(input) + + def postprocess(self, inputs: Dict[str, np.ndarray], + **kwargs) -> Dict[str, np.ndarray]: + # N x num_classes + probs = inputs['probabilities'] + result = { + 'probs': probs, + } + + return result diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index f4d4d1b7..1da65213 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -12,10 +12,11 @@ from modelscope.pydatasets import PyDataset from modelscope.utils.config import Config from modelscope.utils.hub import get_model_cache_dir from modelscope.utils.logger import get_logger +from .outputs import TASK_OUTPUTS from .util import is_model_name Tensor = Union['torch.Tensor', 'tf.Tensor'] -Input = Union[str, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] +Input = Union[str, tuple, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] InputModel = Union[str, Model] output_keys = [ @@ -106,8 +107,25 @@ class Pipeline(ABC): out = self.preprocess(input) out = self.forward(out) out = self.postprocess(out, **post_kwargs) + self._check_output(out) return out + def _check_output(self, input): + # this attribute is dynamically attached by registry + # when cls is registered in registry using task name + task_name = self.group_key + if task_name not in TASK_OUTPUTS: + logger.warning(f'task {task_name} output keys are missing') + return + output_keys = TASK_OUTPUTS[task_name] + missing_keys = [] + for k in output_keys: + if k not in input: + missing_keys.append(k) + if len(missing_keys) > 0: + raise ValueError(f'expected output keys are {output_keys}, ' + f'those {missing_keys} are missing') + def preprocess(self, inputs: Input) -> Dict[str, Any]: """ Provide default implementation based on preprocess_cfg and user can reimplement it """ @@ -125,4 +143,14 @@ class Pipeline(ABC): @abstractmethod def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """ If current pipeline support model reuse, common postprocess + code should be write here. + + Args: + inputs: input data + + Return: + dict of results: a dict containing outputs of model, each + output should have the standard output name. + """ raise NotImplementedError('postprocess') diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 5c5190c0..a4f15de2 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -3,21 +3,20 @@ import os.path as osp from typing import List, Union -import json -from maas_hub.file_download import model_file_download - from modelscope.models.base import Model from modelscope.utils.config import Config, ConfigDict -from modelscope.utils.constant import CONFIGFILE, Tasks +from modelscope.utils.constant import Tasks from modelscope.utils.registry import Registry, build_from_cfg from .base import Pipeline -from .util import is_model_name PIPELINES = Registry('pipelines') DEFAULT_MODEL_FOR_PIPELINE = { # TaskName: (pipeline_module_name, model_repo) - Tasks.image_matting: ('image-matting', 'damo/image-matting-person'), + Tasks.sentence_similarity: + ('sbert-base-chinese-sentence-similarity', + 'damo/nlp_structbert_sentence-similarity_chinese-base'), + Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting_damo'), Tasks.text_classification: ('bert-sentiment-analysis', 'damo/bert-base-sst2'), Tasks.zero_shot_classification: diff --git a/modelscope/pipelines/cv/image_matting_pipeline.py b/modelscope/pipelines/cv/image_matting_pipeline.py index 6f3ff5f5..0c60dfa7 100644 --- a/modelscope/pipelines/cv/image_matting_pipeline.py +++ b/modelscope/pipelines/cv/image_matting_pipeline.py @@ -1,5 +1,5 @@ import os.path as osp -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict import cv2 import numpy as np @@ -7,7 +7,7 @@ import PIL from modelscope.pipelines.base import Input from modelscope.preprocessors import load_image -from modelscope.utils.constant import Tasks +from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger from ..base import Pipeline from ..builder import PIPELINES @@ -24,7 +24,7 @@ class ImageMattingPipeline(Pipeline): import tensorflow as tf if tf.__version__ >= '2.0': tf = tf.compat.v1 - model_path = osp.join(self.model, 'matting_person.pb') + model_path = osp.join(self.model, ModelFile.TF_GRAPH_FILE) config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True diff --git a/modelscope/pipelines/multi_modal/image_captioning.py b/modelscope/pipelines/multi_modal/image_captioning.py index 91180e23..3e5f49d0 100644 --- a/modelscope/pipelines/multi_modal/image_captioning.py +++ b/modelscope/pipelines/multi_modal/image_captioning.py @@ -84,8 +84,11 @@ class ImageCaptionPipeline(Pipeline): s = torch.cat([s, self.eos_item]) return s - patch_image = self.patch_resize_transform( - load_image(input)).unsqueeze(0) + if isinstance(input, Image.Image): + patch_image = self.patch_resize_transform(input).unsqueeze(0) + else: + patch_image = self.patch_resize_transform( + load_image(input)).unsqueeze(0) patch_mask = torch.tensor([True]) text = 'what does the image describe?' src_text = encode_text( diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 02f4fbfa..b8a4614f 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -1,3 +1,4 @@ +from .sentence_similarity_pipeline import * # noqa F403 from .sequence_classification_pipeline import * # noqa F403 from .text_generation_pipeline import * # noqa F403 from .zero_shot_classification_pipeline import * diff --git a/modelscope/pipelines/nlp/sentence_similarity_pipeline.py b/modelscope/pipelines/nlp/sentence_similarity_pipeline.py new file mode 100644 index 00000000..44d91756 --- /dev/null +++ b/modelscope/pipelines/nlp/sentence_similarity_pipeline.py @@ -0,0 +1,65 @@ +import os +import uuid +from typing import Any, Dict, Union + +import json +import numpy as np + +from modelscope.models.nlp import SbertForSentenceSimilarity +from modelscope.preprocessors import SequenceClassificationPreprocessor +from modelscope.utils.constant import Tasks +from ...models import Model +from ..base import Input, Pipeline +from ..builder import PIPELINES + +__all__ = ['SentenceSimilarityPipeline'] + + +@PIPELINES.register_module( + Tasks.sentence_similarity, + module_name=r'sbert-base-chinese-sentence-similarity') +class SentenceSimilarityPipeline(Pipeline): + + def __init__(self, + model: Union[SbertForSentenceSimilarity, str], + preprocessor: SequenceClassificationPreprocessor = None, + **kwargs): + """use `model` and `preprocessor` to create a nlp sentence similarity pipeline for prediction + + Args: + model (SbertForSentenceSimilarity): a model instance + preprocessor (SequenceClassificationPreprocessor): a preprocessor instance + """ + assert isinstance(model, str) or isinstance(model, SbertForSentenceSimilarity), \ + 'model must be a single str or SbertForSentenceSimilarity' + sc_model = model if isinstance( + model, + SbertForSentenceSimilarity) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = SequenceClassificationPreprocessor( + sc_model.model_dir, + first_sequence='first_sequence', + second_sequence='second_sequence') + super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) + + assert hasattr(self.model, 'id2label'), \ + 'id2label map should be initalizaed in init function.' + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + + probs = inputs['probabilities'][0] + num_classes = probs.shape[0] + top_indices = np.argpartition(probs, -num_classes)[-num_classes:] + cls_ids = top_indices[np.argsort(-probs[top_indices], axis=-1)] + probs = probs[cls_ids].tolist() + cls_names = [self.model.id2label[cid] for cid in cls_ids] + b = 0 + return {'scores': probs[b], 'labels': cls_names[b]} diff --git a/modelscope/pipelines/nlp/sequence_classification_pipeline.py b/modelscope/pipelines/nlp/sequence_classification_pipeline.py index 5a14f136..9d2e4273 100644 --- a/modelscope/pipelines/nlp/sequence_classification_pipeline.py +++ b/modelscope/pipelines/nlp/sequence_classification_pipeline.py @@ -41,50 +41,29 @@ class SequenceClassificationPipeline(Pipeline): second_sequence=None) super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) - from easynlp.utils import io - self.label_path = os.path.join(sc_model.model_dir, - 'label_mapping.json') - with io.open(self.label_path) as f: - self.label_mapping = json.load(f) - self.label_id_to_name = { - idx: name - for name, idx in self.label_mapping.items() - } + assert hasattr(self.model, 'id2label'), \ + 'id2label map should be initalizaed in init function.' - def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: + def postprocess(self, + inputs: Dict[str, Any], + topk: int = 5) -> Dict[str, str]: """process the prediction results Args: - inputs (Dict[str, Any]): _description_ + inputs (Dict[str, Any]): input data dict + topk (int): return topk classification result. Returns: Dict[str, str]: the prediction results """ + # NxC np.ndarray + probs = inputs['probs'][0] + num_classes = probs.shape[0] + topk = min(topk, num_classes) + top_indices = np.argpartition(probs, -topk)[-topk:] + cls_ids = top_indices[np.argsort(probs[top_indices])] + probs = probs[cls_ids].tolist() - probs = inputs['probabilities'] - logits = inputs['logits'] - predictions = np.argsort(-probs, axis=-1) - preds = predictions[0] - b = 0 - new_result = list() - for pred in preds: - new_result.append({ - 'pred': self.label_id_to_name[pred], - 'prob': float(probs[b][pred]), - 'logit': float(logits[b][pred]) - }) - new_results = list() - new_results.append({ - 'id': - inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()), - 'output': - new_result, - 'predictions': - new_result[0]['pred'], - 'probabilities': - ','.join([str(t) for t in inputs['probabilities'][b]]), - 'logits': - ','.join([str(t) for t in inputs['logits'][b]]) - }) + cls_names = [self.model.id2label[cid] for cid in cls_ids] - return new_results[0] + return {'scores': probs, 'labels': cls_names} diff --git a/modelscope/pipelines/nlp/text_generation_pipeline.py b/modelscope/pipelines/nlp/text_generation_pipeline.py index 7ad2b67f..ea30a115 100644 --- a/modelscope/pipelines/nlp/text_generation_pipeline.py +++ b/modelscope/pipelines/nlp/text_generation_pipeline.py @@ -56,4 +56,4 @@ class TextGenerationPipeline(Pipeline): '').split('[SEP]')[0].replace('[CLS]', '').replace('[SEP]', '').replace('[UNK]', '') - return {'pred_string': pred_string} + return {'text': pred_string} diff --git a/modelscope/pipelines/outputs.py b/modelscope/pipelines/outputs.py new file mode 100644 index 00000000..1389abd3 --- /dev/null +++ b/modelscope/pipelines/outputs.py @@ -0,0 +1,98 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from modelscope.utils.constant import Tasks + +TASK_OUTPUTS = { + + # ============ vision tasks =================== + + # image classification result for single sample + # { + # "labels": ["dog", "horse", "cow", "cat"], + # "scores": [0.9, 0.1, 0.05, 0.05] + # } + Tasks.image_classification: ['scores', 'labels'], + Tasks.image_tagging: ['scores', 'labels'], + + # object detection result for single sample + # { + # "boxes": [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ], + # "labels": ["dog", "horse", "cow", "cat"], + # "scores": [0.9, 0.1, 0.05, 0.05] + # } + Tasks.object_detection: ['scores', 'labels', 'boxes'], + + # instance segmentation result for single sample + # { + # "masks": [ + # np.array in bgr channel order + # ], + # "labels": ["dog", "horse", "cow", "cat"], + # "scores": [0.9, 0.1, 0.05, 0.05] + # } + Tasks.image_segmentation: ['scores', 'labels', 'boxes'], + + # image generation/editing/matting result for single sample + # { + # "output_png": np.array with shape(h, w, 4) + # for matting or (h, w, 3) for general purpose + # } + Tasks.image_editing: ['output_png'], + Tasks.image_matting: ['output_png'], + Tasks.image_generation: ['output_png'], + + # pose estimation result for single sample + # { + # "poses": np.array with shape [num_pose, num_keypoint, 3], + # each keypoint is a array [x, y, score] + # "boxes": np.array with shape [num_pose, 4], each box is + # [x1, y1, x2, y2] + # } + Tasks.pose_estimation: ['poses', 'boxes'], + + # ============ nlp tasks =================== + + # text classification result for single sample + # { + # "labels": ["happy", "sad", "calm", "angry"], + # "scores": [0.9, 0.1, 0.05, 0.05] + # } + Tasks.text_classification: ['scores', 'labels'], + + # text generation result for single sample + # { + # "text": "this is text generated by a model." + # } + Tasks.text_generation: ['text'], + + # ============ audio tasks =================== + + # ============ multi-modal tasks =================== + + # image caption result for single sample + # { + # "caption": "this is an image caption text." + # } + Tasks.image_captioning: ['caption'], + + # visual grounding result for single sample + # { + # "boxes": [ + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # [x1, y1, x2, y2], + # ], + # "scores": [0.9, 0.1, 0.05, 0.05] + # } + Tasks.visual_grounding: ['boxes', 'scores'], + + # text_to_image result for a single sample + # { + # "image": np.ndarray with shape [height, width, 3] + # } + Tasks.text_to_image_synthesis: ['image'] +} diff --git a/modelscope/pipelines/util.py b/modelscope/pipelines/util.py index 43a7ac5a..37c9c929 100644 --- a/modelscope/pipelines/util.py +++ b/modelscope/pipelines/util.py @@ -1,14 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import os import os.path as osp from typing import List, Union -import json from maas_hub.file_download import model_file_download -from matplotlib.pyplot import get from modelscope.utils.config import Config -from modelscope.utils.constant import CONFIGFILE +from modelscope.utils.constant import ModelFile from modelscope.utils.logger import get_logger logger = get_logger() @@ -29,14 +26,14 @@ def is_model_name(model: Union[str, List]): def is_model_name_impl(model): if osp.exists(model): - cfg_file = osp.join(model, CONFIGFILE) + cfg_file = osp.join(model, ModelFile.CONFIGURATION) if osp.exists(cfg_file): return is_config_has_model(cfg_file) else: return False else: try: - cfg_file = model_file_download(model, CONFIGFILE) + cfg_file = model_file_download(model, ModelFile.CONFIGURATION) return is_config_has_model(cfg_file) except Exception: return False diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index b9a6901d..81ca1007 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -5,4 +5,3 @@ from .builder import PREPROCESSORS, build_preprocessor from .common import Compose from .image import LoadImage, load_image from .nlp import * # noqa F403 -from .nlp import TextGenerationPreprocessor, ZeroShotClassificationPreprocessor diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index 4ee3ee6a..0904fdcf 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -10,7 +10,10 @@ from modelscope.utils.type_assert import type_assert from .base import Preprocessor from .builder import PREPROCESSORS -__all__ = ['Tokenize', 'SequenceClassificationPreprocessor'] +__all__ = [ + 'Tokenize', 'SequenceClassificationPreprocessor', + 'TextGenerationPreprocessor' +] @PREPROCESSORS.register_module(Fields.nlp) @@ -28,7 +31,7 @@ class Tokenize(Preprocessor): @PREPROCESSORS.register_module( - Fields.nlp, module_name=r'bert-sentiment-analysis') + Fields.nlp, module_name=r'bert-sequence-classification') class SequenceClassificationPreprocessor(Preprocessor): def __init__(self, model_dir: str, *args, **kwargs): @@ -48,21 +51,42 @@ class SequenceClassificationPreprocessor(Preprocessor): self.sequence_length = kwargs.pop('sequence_length', 128) self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) + print(f'this is the tokenzier {self.tokenizer}') - @type_assert(object, str) - def __call__(self, data: str) -> Dict[str, Any]: + @type_assert(object, (str, tuple)) + def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]: """process the raw input data Args: - data (str): a sentence - Example: - 'you are so handsome.' + data (str or tuple): + sentence1 (str): a sentence + Example: + 'you are so handsome.' + or + (sentence1, sentence2) + sentence1 (str): a sentence + Example: + 'you are so handsome.' + sentence2 (str): a sentence + Example: + 'you are so beautiful.' Returns: Dict[str, Any]: the preprocessed data """ - new_data = {self.first_sequence: data} + if not isinstance(data, tuple): + data = ( + data, + None, + ) + + sentence1, sentence2 = data + new_data = { + self.first_sequence: sentence1, + self.second_sequence: sentence2 + } + # preprocess the data for the model input rst = { diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index f1eb1fbd..444348cd 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -32,6 +32,7 @@ class Tasks(object): # nlp tasks zero_shot_classification = 'zero-shot-classification' sentiment_analysis = 'sentiment-analysis' + sentence_similarity = 'sentence-similarity' text_classification = 'text-classification' relation_extraction = 'relation-extraction' zero_shot = 'zero-shot' @@ -51,7 +52,7 @@ class Tasks(object): text_to_speech = 'text-to-speech' speech_signal_process = 'speech-signal-process' - # multi-media + # multi-modal tasks image_captioning = 'image-captioning' visual_grounding = 'visual-grounding' text_to_image_synthesis = 'text-to-image-synthesis' @@ -72,5 +73,16 @@ class Hubs(object): huggingface = 'huggingface' -# configuration filename -CONFIGFILE = 'configuration.json' +class ModelFile(object): + CONFIGURATION = 'configuration.json' + README = 'README.md' + TF_SAVED_MODEL_FILE = 'saved_model.pb' + TF_GRAPH_FILE = 'tf_graph.pb' + TF_CHECKPOINT_FOLDER = 'tf_ckpts' + TF_CKPT_PREFIX = 'ckpt-' + TORCH_MODEL_FILE = 'pytorch_model.pt' + TORCH_MODEL_BIN_FILE = 'pytorch_model.bin' + + +TENSORFLOW = 'tensorflow' +PYTORCH = 'pytorch' diff --git a/modelscope/utils/registry.py b/modelscope/utils/registry.py index 73a938ea..319e54cb 100644 --- a/modelscope/utils/registry.py +++ b/modelscope/utils/registry.py @@ -1,7 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import inspect -from email.policy import default from modelscope.utils.logger import get_logger @@ -70,6 +69,7 @@ class Registry(object): f'{self._name}[{group_key}]') self._modules[group_key][module_name] = module_cls + module_cls.group_key = group_key if module_name in self._modules[default_group]: if id(self._modules[default_group][module_name]) == id(module_cls): diff --git a/modelscope/utils/test_utils.py b/modelscope/utils/test_utils.py new file mode 100644 index 00000000..c8ea0442 --- /dev/null +++ b/modelscope/utils/test_utils.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os + +TEST_LEVEL = 2 +TEST_LEVEL_STR = 'TEST_LEVEL' + + +def test_level(): + global TEST_LEVEL + if TEST_LEVEL_STR in os.environ: + TEST_LEVEL = int(os.environ[TEST_LEVEL_STR]) + + return TEST_LEVEL + + +def set_test_level(level: int): + global TEST_LEVEL + TEST_LEVEL = level diff --git a/requirements/docs.txt b/requirements/docs.txt index 25373976..2436f5af 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,6 +1,7 @@ docutils==0.16.0 recommonmark sphinx==4.0.2 +sphinx-book-theme sphinx-copybutton sphinx_markdown_tables sphinx_rtd_theme==0.5.2 diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 47a11cbc..43684a06 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,10 +1,10 @@ addict datasets easydict -https://maashub.oss-cn-hangzhou.aliyuncs.com/releases/maas_hub-0.1.0.dev0-py2.py3-none-any.whl +https://mindscope.oss-cn-hangzhou.aliyuncs.com/sdklib/maas_hub-0.2.2.dev0-py3-none-any.whl numpy opencv-python-headless -Pillow +Pillow>=6.2.0 pyyaml requests tokenizers<=0.10.3 diff --git a/tests/pipelines/test_base.py b/tests/pipelines/test_base.py index 14f646a9..73aebfdf 100644 --- a/tests/pipelines/test_base.py +++ b/tests/pipelines/test_base.py @@ -35,9 +35,10 @@ class CustomPipelineTest(unittest.TestCase): CustomPipeline1() def test_custom(self): + dummy_task = 'dummy-task' @PIPELINES.register_module( - group_key=Tasks.image_tagging, module_name='custom-image') + group_key=dummy_task, module_name='custom-image') class CustomImagePipeline(Pipeline): def __init__(self, @@ -67,32 +68,29 @@ class CustomPipelineTest(unittest.TestCase): outputs['filename'] = inputs['url'] img = inputs['img'] new_image = img.resize((img.width // 2, img.height // 2)) - outputs['resize_image'] = np.array(new_image) - outputs['dummy_result'] = 'dummy_result' + outputs['output_png'] = np.array(new_image) return outputs def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return inputs self.assertTrue('custom-image' in PIPELINES.modules[default_group]) - add_default_pipeline_info(Tasks.image_tagging, 'custom-image') + add_default_pipeline_info(dummy_task, 'custom-image', overwrite=True) pipe = pipeline(pipeline_name='custom-image') - pipe2 = pipeline(Tasks.image_tagging) + pipe2 = pipeline(dummy_task) self.assertTrue(type(pipe) is type(pipe2)) img_url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.' \ 'aliyuncs.com/data/test/images/image1.jpg' output = pipe(img_url) self.assertEqual(output['filename'], img_url) - self.assertEqual(output['resize_image'].shape, (318, 512, 3)) - self.assertEqual(output['dummy_result'], 'dummy_result') + self.assertEqual(output['output_png'].shape, (318, 512, 3)) outputs = pipe([img_url for i in range(4)]) self.assertEqual(len(outputs), 4) for out in outputs: self.assertEqual(out['filename'], img_url) - self.assertEqual(out['resize_image'].shape, (318, 512, 3)) - self.assertEqual(out['dummy_result'], 'dummy_result') + self.assertEqual(out['output_png'].shape, (318, 512, 3)) if __name__ == '__main__': diff --git a/tests/pipelines/test_image_captioning.py b/tests/pipelines/test_image_captioning.py index 5584d0e2..4fac4658 100644 --- a/tests/pipelines/test_image_captioning.py +++ b/tests/pipelines/test_image_captioning.py @@ -7,11 +7,12 @@ import unittest from modelscope.fileio import File from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level class ImageCaptionTest(unittest.TestCase): - @unittest.skip('skip long test') + @unittest.skip('skip before model is restored in model hub') def test_run(self): model = 'https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt' diff --git a/tests/pipelines/test_image_matting.py b/tests/pipelines/test_image_matting.py index 53006317..ba5d05ad 100644 --- a/tests/pipelines/test_image_matting.py +++ b/tests/pipelines/test_image_matting.py @@ -9,25 +9,27 @@ import cv2 from modelscope.fileio import File from modelscope.pipelines import pipeline from modelscope.pydatasets import PyDataset -from modelscope.utils.constant import Tasks +from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.hub import get_model_cache_dir +from modelscope.utils.test_utils import test_level class ImageMattingTest(unittest.TestCase): def setUp(self) -> None: - self.model_id = 'damo/image-matting-person' + self.model_id = 'damo/cv_unet_image-matting_damo' # switch to False if downloading everytime is not desired purge_cache = True if purge_cache: shutil.rmtree( get_model_cache_dir(self.model_id), ignore_errors=True) - def test_run(self): + @unittest.skip('deprecated, download model from model hub instead') + def test_run_with_direct_file_download(self): model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ '.com/data/test/maas/image_matting/matting_person.pb' with tempfile.TemporaryDirectory() as tmp_dir: - model_file = osp.join(tmp_dir, 'matting_person.pb') + model_file = osp.join(tmp_dir, ModelFile.TF_GRAPH_FILE) with open(model_file, 'wb') as ofile: ofile.write(File.read(model_path)) img_matting = pipeline(Tasks.image_matting, model=tmp_dir) @@ -37,6 +39,7 @@ class ImageMattingTest(unittest.TestCase): ) cv2.imwrite('result.png', result['output_png']) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_dataset(self): input_location = [ 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' @@ -51,6 +54,7 @@ class ImageMattingTest(unittest.TestCase): cv2.imwrite('result.png', next(result)['output_png']) print(f'Output written to {osp.abspath("result.png")}') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_modelhub(self): img_matting = pipeline(Tasks.image_matting, model=self.model_id) @@ -60,6 +64,7 @@ class ImageMattingTest(unittest.TestCase): cv2.imwrite('result.png', result['output_png']) print(f'Output written to {osp.abspath("result.png")}') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_modelhub_default_model(self): img_matting = pipeline(Tasks.image_matting) diff --git a/tests/pipelines/test_person_image_cartoon.py b/tests/pipelines/test_person_image_cartoon.py index 6f352e42..ed912b1c 100644 --- a/tests/pipelines/test_person_image_cartoon.py +++ b/tests/pipelines/test_person_image_cartoon.py @@ -8,6 +8,7 @@ import cv2 from modelscope.pipelines import pipeline from modelscope.pipelines.base import Pipeline from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level class ImageCartoonTest(unittest.TestCase): @@ -36,10 +37,12 @@ class ImageCartoonTest(unittest.TestCase): img_cartoon = pipeline(Tasks.image_generation, model=model_dir) self.pipeline_inference(img_cartoon, self.test_image) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_modelhub(self): img_cartoon = pipeline(Tasks.image_generation, model=self.model_id) self.pipeline_inference(img_cartoon, self.test_image) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_modelhub_default_model(self): img_cartoon = pipeline(Tasks.image_generation) self.pipeline_inference(img_cartoon, self.test_image) diff --git a/tests/pipelines/test_sentence_similarity.py b/tests/pipelines/test_sentence_similarity.py new file mode 100644 index 00000000..ac2ff4fb --- /dev/null +++ b/tests/pipelines/test_sentence_similarity.py @@ -0,0 +1,67 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import shutil +import unittest + +from maas_hub.snapshot_download import snapshot_download + +from modelscope.models import Model +from modelscope.models.nlp import SbertForSentenceSimilarity +from modelscope.pipelines import SentenceSimilarityPipeline, pipeline +from modelscope.preprocessors import SequenceClassificationPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.hub import get_model_cache_dir +from modelscope.utils.test_utils import test_level + + +class SentenceSimilarityTest(unittest.TestCase): + model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' + sentence1 = '今天气温比昨天高么?' + sentence2 = '今天湿度比昨天高么?' + + def setUp(self) -> None: + # switch to False if downloading everytime is not desired + purge_cache = True + if purge_cache: + shutil.rmtree( + get_model_cache_dir(self.model_id), ignore_errors=True) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run(self): + cache_path = snapshot_download(self.model_id) + tokenizer = SequenceClassificationPreprocessor(cache_path) + model = SbertForSentenceSimilarity(cache_path, tokenizer=tokenizer) + pipeline1 = SentenceSimilarityPipeline(model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.sentence_similarity, model=model, preprocessor=tokenizer) + print('test1') + print(f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' + f'pipeline1:{pipeline1(input=(self.sentence1, self.sentence2))}') + print() + print( + f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' + f'pipeline1: {pipeline2(input=(self.sentence1, self.sentence2))}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = SequenceClassificationPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.sentence_similarity, + model=model, + preprocessor=tokenizer) + print(pipeline_ins(input=(self.sentence1, self.sentence2))) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.sentence_similarity, model=self.model_id) + print(pipeline_ins(input=(self.sentence1, self.sentence2))) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.sentence_similarity) + print(pipeline_ins(input=(self.sentence1, self.sentence2))) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_text_classification.py b/tests/pipelines/test_text_classification.py index 7f6dc77c..01fdd29b 100644 --- a/tests/pipelines/test_text_classification.py +++ b/tests/pipelines/test_text_classification.py @@ -12,6 +12,7 @@ from modelscope.preprocessors import SequenceClassificationPreprocessor from modelscope.pydatasets import PyDataset from modelscope.utils.constant import Hubs, Tasks from modelscope.utils.hub import get_model_cache_dir +from modelscope.utils.test_utils import test_level class SequenceClassificationTest(unittest.TestCase): @@ -43,6 +44,7 @@ class SequenceClassificationTest(unittest.TestCase): break print(r) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run(self): model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' @@ -67,6 +69,7 @@ class SequenceClassificationTest(unittest.TestCase): Tasks.text_classification, model=model, preprocessor=preprocessor) print(pipeline2('Hello world!')) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_from_modelhub(self): model = Model.from_pretrained(self.model_id) preprocessor = SequenceClassificationPreprocessor( @@ -77,6 +80,7 @@ class SequenceClassificationTest(unittest.TestCase): preprocessor=preprocessor) self.predict(pipeline_ins) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_name(self): text_classification = pipeline( task=Tasks.text_classification, model=self.model_id) @@ -85,6 +89,7 @@ class SequenceClassificationTest(unittest.TestCase): 'glue', name='sst2', target='sentence', hub=Hubs.huggingface)) self.printDataset(result) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_default_model(self): text_classification = pipeline(task=Tasks.text_classification) result = text_classification( @@ -92,6 +97,7 @@ class SequenceClassificationTest(unittest.TestCase): 'glue', name='sst2', target='sentence', hub=Hubs.huggingface)) self.printDataset(result) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_dataset(self): model = Model.from_pretrained(self.model_id) preprocessor = SequenceClassificationPreprocessor( diff --git a/tests/pipelines/test_text_generation.py b/tests/pipelines/test_text_generation.py index d8f1b495..f98e135d 100644 --- a/tests/pipelines/test_text_generation.py +++ b/tests/pipelines/test_text_generation.py @@ -8,6 +8,7 @@ from modelscope.models.nlp import PalmForTextGenerationModel from modelscope.pipelines import TextGenerationPipeline, pipeline from modelscope.preprocessors import TextGenerationPreprocessor from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level class TextGenerationTest(unittest.TestCase): @@ -15,7 +16,7 @@ class TextGenerationTest(unittest.TestCase): input1 = "今日天气类型='晴'&温度变化趋势='大幅上升'&最低气温='28℃'&最高气温='31℃'&体感='湿热'" input2 = "今日天气类型='多云'&体感='舒适'&最低气温='26℃'&最高气温='30℃'" - @unittest.skip('skip temporarily to save test time') + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run(self): cache_path = snapshot_download(self.model_id) preprocessor = TextGenerationPreprocessor( @@ -29,6 +30,7 @@ class TextGenerationTest(unittest.TestCase): print() print(f'input: {self.input2}\npipeline2: {pipeline2(self.input2)}') + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_model_from_modelhub(self): model = Model.from_pretrained(self.model_id) preprocessor = TextGenerationPreprocessor( @@ -37,11 +39,13 @@ class TextGenerationTest(unittest.TestCase): task=Tasks.text_generation, model=model, preprocessor=preprocessor) print(pipeline_ins(self.input1)) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_model_name(self): pipeline_ins = pipeline( task=Tasks.text_generation, model=self.model_id) print(pipeline_ins(self.input2)) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_default_model(self): pipeline_ins = pipeline(task=Tasks.text_generation) print(pipeline_ins(self.input2)) diff --git a/tests/preprocessors/test_image.py b/tests/preprocessors/test_image.py new file mode 100644 index 00000000..cfa7b11d --- /dev/null +++ b/tests/preprocessors/test_image.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +import PIL + +from modelscope.preprocessors import load_image +from modelscope.utils.logger import get_logger + + +class ImagePreprocessorTest(unittest.TestCase): + + def test_load(self): + img = load_image( + 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' + ) + self.assertTrue(isinstance(img, PIL.Image.Image)) + self.assertEqual(img.size, (948, 533)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/run.py b/tests/run.py index 25404d7a..9f5d62a7 100644 --- a/tests/run.py +++ b/tests/run.py @@ -7,6 +7,11 @@ import sys import unittest from fnmatch import fnmatch +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import set_test_level, test_level + +logger = get_logger() + def gather_test_cases(test_dir, pattern, list_tests): case_list = [] @@ -49,5 +54,9 @@ if __name__ == '__main__': '--pattern', default='test_*.py', help='test file pattern') parser.add_argument( '--test_dir', default='tests', help='directory to be tested') + parser.add_argument( + '--level', default=0, help='2 -- all, 1 -- p1, 0 -- p0') args = parser.parse_args() + set_test_level(args.level) + logger.info(f'TEST LEVEL: {test_level()}') main(args) diff --git a/tests/utils/test_config.py b/tests/utils/test_config.py index fb7044e8..a3770f0d 100644 --- a/tests/utils/test_config.py +++ b/tests/utils/test_config.py @@ -1,11 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import argparse -import os.path as osp import tempfile import unittest -from pathlib import Path -from modelscope.fileio import dump, load from modelscope.utils.config import Config obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}}