diff --git a/.dev_scripts/run_docker.sh b/.dev_scripts/run_docker.sh new file mode 100644 index 00000000..8999458a --- /dev/null +++ b/.dev_scripts/run_docker.sh @@ -0,0 +1,7 @@ +#sudo docker run --name zwm_maas -v /home/wenmeng.zwm/workspace:/home/wenmeng.zwm/workspace --net host -ti reg.docker.alibaba-inc.com/pai-dlc/tensorflow-training:2.3-gpu-py36-cu101-ubuntu18.04 bash +#sudo docker run --name zwm_maas_pytorch -v /home/wenmeng.zwm/workspace:/home/wenmeng.zwm/workspace --net host -ti reg.docker.alibaba-inc.com/pai-dlc/pytorch-training:1.10PAI-gpu-py36-cu113-ubuntu18.04 bash +CONTAINER_NAME=modelscope-dev +IMAGE_NAME=registry.cn-shanghai.aliyuncs.com/modelscope/modelscope +IMAGE_VERSION=v0.1.1-16-g62856fa-devel +MOUNT_DIR=/home/wenmeng.zwm/workspace +sudo docker run --name $CONTAINER_NAME -v $MOUNT_DIR:$MOUNT_DIR --net host -ti ${IMAGE_NAME}:${IMAGE_VERSION} bash diff --git a/.gitattributes b/.gitattributes index 9c607acc..b2724f28 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,3 +1,4 @@ *.png filter=lfs diff=lfs merge=lfs -text *.jpg filter=lfs diff=lfs merge=lfs -text *.mp4 filter=lfs diff=lfs merge=lfs -text +*.wav filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index cc9ef477..05929ea9 100644 --- a/.gitignore +++ b/.gitignore @@ -124,7 +124,3 @@ replace.sh # Pytorch *.pth - - -# audio -*.wav diff --git a/data/test/audios/kws_bofangyinyue.wav b/data/test/audios/kws_bofangyinyue.wav new file mode 100644 index 00000000..c8bf69b7 --- /dev/null +++ b/data/test/audios/kws_bofangyinyue.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a72a7b8d1e8be6ebaa09aeee0d71472569bc62cc4872ecfdbd1651bb3d03eaba +size 69110 diff --git a/data/test/audios/kws_xiaoyunxiaoyun.wav b/data/test/audios/kws_xiaoyunxiaoyun.wav new file mode 100644 index 00000000..8afe6b7c --- /dev/null +++ b/data/test/audios/kws_xiaoyunxiaoyun.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c6b1671bcfa872278c99490cd1acb08297b8df4dc78f268e4b6a582b4364e4a1 +size 297684 diff --git a/docker/pytorch.dockerfile b/docker/pytorch.dockerfile index 4862cab6..a1fe5b15 100644 --- a/docker/pytorch.dockerfile +++ b/docker/pytorch.dockerfile @@ -30,7 +30,8 @@ RUN apt-get update &&\ zip \ zlib1g-dev \ unzip \ - pkg-config + pkg-config \ + libsndfile1 # install modelscope and its python env WORKDIR /opt/modelscope diff --git a/docs/source/index.rst b/docs/source/index.rst index 3b223531..e93c7aed 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -13,6 +13,7 @@ ModelScope doc quick_start.md develop.md + faq.md .. toctree:: :maxdepth: 2 @@ -20,6 +21,8 @@ ModelScope doc tutorials/index + + .. toctree:: :maxdepth: 2 :caption: Changelog diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index de416f08..54e04fc2 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -1,72 +1,55 @@ # 快速开始 - +ModelScope Library目前支持tensorflow,pytorch深度学习框架进行模型训练、推理, 在Python 3.7+, Pytorch 1.8+, Tensorflow1.15+,Tensorflow 2.6上测试可运行。 +注: 当前(630)版本仅支持python3.7 以及linux环境,其他环境(mac,windows等)支持预计730完成。 ## python环境配置 首先,参考[文档](https://docs.anaconda.com/anaconda/install/) 安装配置Anaconda环境 安装完成后,执行如下命令为modelscope library创建对应的python环境。 ```shell -conda create -n modelscope python=3.6 +conda create -n modelscope python=3.7 conda activate modelscope ``` -检查python和pip命令是否切换到conda环境下。 +## 安装深度学习框架 +* 安装pytorch[参考链接](https://pytorch.org/get-started/locally/) ```shell -which python -# ~/workspace/anaconda3/envs/modelscope/bin/python - -which pip -# ~/workspace/anaconda3/envs/modelscope/bin/pip +pip install torch torchvision ``` -注: 本项目只支持`python3`环境,请勿使用python2环境。 - -## 第三方依赖安装 - -ModelScope Library目前支持tensorflow,pytorch两大深度学习框架进行模型训练、推理, 在Python 3.6+, Pytorch 1.8+, Tensorflow 2.6上测试可运行,用户可以根据所选模型对应的计算框架进行安装,可以参考如下链接进行安装所需框架: - -* [Pytorch安装指导](https://pytorch.org/get-started/locally/) -* [Tensorflow安装指导](https://www.tensorflow.org/install/pip) - -部分第三方依赖库需要提前安装numpy -``` -pip install numpy +* 安装Tensorflow[参考链接](https://www.tensorflow.org/install/pip) +```shell +pip install --upgrade tensorflow ``` - ## ModelScope library 安装 注: 如果在安装过程中遇到错误,请前往[常见问题](faq.md)查找解决方案。 ### pip安装 +执行如下命令: ```shell -pip install -r http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/maas/modelscope.txt +pip install model_scope[all] -f https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/maas/repo.html ``` - -安装成功后,可以执行如下命令进行验证安装是否正确 -```shell -python -c "from modelscope.pipelines import pipeline;print(pipeline('image-matting',model='damo/image-matting-person')('http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png'))" -``` - - ### 使用源码安装 - 适合本地开发调试使用,修改源码后可以直接执行 +下载源码可以直接clone代码到本地 ```shell git clone git@gitlab.alibaba-inc.com:Ali-MaaS/MaaS-lib.git modelscope git fetch origin master git checkout master - cd modelscope - -#安装依赖 +``` +安装依赖并设置PYTHONPATH +```shell pip install -r requirements.txt - -# 设置PYTHONPATH export PYTHONPATH=`pwd` ``` - +### 安装验证 安装成功后,可以执行如下命令进行验证安装是否正确 ```shell -python -c "from modelscope.pipelines import pipeline;print(pipeline('image-matting',model='damo/image-matting-person')('http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png'))" +python -c "from modelscope.pipelines import pipeline;print(pipeline('word-segmentation')('今天天气不错,适合 出去游玩'))" +{'output': '今天 天气 不错 , 适合 出去 游玩'} ``` +## 推理 +pipeline函数提供了简洁的推理接口,相关介绍和示例请参考[pipeline使用教程](tutorials/pipeline.md) ## 训练 @@ -75,46 +58,3 @@ to be done ## 评估 to be done - -## 推理 - -pipeline函数提供了简洁的推理接口,示例如下, 更多pipeline介绍和示例请参考[pipeline使用教程](tutorials/pipeline.md) - -```python -import cv2 -import os.path as osp -from modelscope.pipelines import pipeline -from modelscope.utils.constant import Tasks - -# 根据任务名创建pipeline -img_matting = pipeline(Tasks.image_matting, model='damo/image-matting-person') - -# 直接提供图像文件的url作为pipeline推理的输入 -result = img_matting( - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' -) -cv2.imwrite('result.png', result['output_png']) -print(f'Output written to {osp.abspath("result.png")}') - -``` - -此外,pipeline接口也能接收Dataset作为输入,上面的代码同样可以实现为 - -```python -import cv2 -import os.path as osp -from modelscope.pipelines import pipeline -from modelscope.utils.constant import Tasks -from modelscope.msdatasets import MsDataset - -# 使用图像url构建MsDataset,此处也可通过 input_location = '/dir/to/images' 来使用本地文件夹 -input_location = [ - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' -] -dataset = MsDataset.load(input_location, target='image') -img_matting = pipeline(Tasks.image_matting, model='damo/image-matting-person') -# 输入为MsDataset时,输出的结果为迭代器 -result = img_matting(dataset) -cv2.imwrite('result.png', next(result)['output_png']) -print(f'Output written to {osp.abspath("result.png")}') -``` diff --git a/docs/source/tutorials/pipeline.md b/docs/source/tutorials/pipeline.md index cc851278..1134f417 100644 --- a/docs/source/tutorials/pipeline.md +++ b/docs/source/tutorials/pipeline.md @@ -1,84 +1,62 @@ # Pipeline使用教程 - -本文将简单介绍如何使用`pipeline`函数加载模型进行推理。`pipeline`函数支持按照任务类型、模型名称从模型仓库 -拉取模型进行进行推理,当前支持的任务有 - -* 人像抠图 (image-matting) -* 基于bert的语义情感分析 (bert-sentiment-analysis) - -本文将从如下方面进行讲解如何使用Pipeline模块: +本文简单介绍如何使用`pipeline`函数加载模型进行推理。`pipeline`函数支持按照任务类型、模型名称从模型仓库拉取模型进行进行推理,包含以下几个方面: * 使用pipeline()函数进行推理 * 指定特定预处理、特定模型进行推理 * 不同场景推理任务示例 - ## 环境准备 详细步骤可以参考 [快速开始](../quick_start.md) - ## Pipeline基本用法 +下面以中文分词任务为例,说明pipeline函数的基本用法 -1. pipeline函数支持指定特定任务名称,加载任务默认模型,创建对应Pipeline对象 +1. pipeline函数支持指定特定任务名称,加载任务默认模型,创建对应pipeline对象 执行如下python代码 ```python - >>> from modelscope.pipelines import pipeline - >>> img_matting = pipeline(task='image-matting', model='damo/image-matting-person') + from modelscope.pipelines import pipeline + word_segmentation = pipeline('word-segmentation') ``` -2. 传入单张图像url进行处理 +2. 输入文本 ``` python - >>> import cv2 - >>> result = img_matting('http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png') - >>> cv2.imwrite('result.png', result['output_png']) - >>> import os.path as osp - >>> print(f'result file path is {osp.abspath("result.png")}') + input = '今天天气不错,适合出去游玩' + print(word_segmentation(input)) + {'output': '今天 天气 不错 , 适合 出去 游玩'} ``` - pipeline对象也支持传入一个列表输入,返回对应输出列表,每个元素对应输入样本的返回结果 - ```python - >>> results = img_matting( - [ - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png', - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png', - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png', - ]) - ``` +3. 输入多条样本 + +pipeline对象也支持传入多个样本列表输入,返回对应输出列表,每个元素对应输入样本的返回结果 - 如果pipeline对应有一些后处理参数,也支持通过调用时候传入. ```python - >>> pipe = pipeline(task_name) - >>> result = pipe(input, post_process_args) + inputs = ['今天天气不错,适合出去游玩','这本书很好,建议你看看'] + print(word_segmentation(inputs)) + [{'output': '今天 天气 不错 , 适合 出去 游玩'}, {'output': '这 本 书 很 好 , 建议 你 看看'}] ``` - ## 指定预处理、模型进行推理 pipeline函数支持传入实例化的预处理对象、模型对象,从而支持用户在推理过程中定制化预处理、模型。 -下面以文本情感分类为例进行介绍。 -由于demo模型为EasyNLP提供的模型,首先,安装EasyNLP -```shell -pip install https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/package/whl/easynlp-0.0.4-py2.py3-none-any.whl -``` - - -下载模型文件 -```shell -wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip && unzip bert-base-sst2.zip -``` - -创建tokenizer和模型 +1. 首先,创建预处理方法和模型 ```python ->>> from modelscope.models import Model ->>> from modelscope.preprocessors import SequenceClassificationPreprocessor ->>> model = Model.from_pretrained('damo/bert-base-sst2') ->>> tokenizer = SequenceClassificationPreprocessor( - model.model_dir, first_sequence='sentence', second_sequence=None) +from modelscope.models import Model +from modelscope.preprocessors import TokenClassifcationPreprocessor +model = Model.from_pretrained('damo/nlp_structbert_word-segmentation_chinese-base') +tokenizer = TokenClassifcationPreprocessor(model.model_dir) ``` -使用tokenizer和模型对象创建pipeline +2. 使用tokenizer和模型对象创建pipeline ```python ->>> from modelscope.pipelines import pipeline ->>> semantic_cls = pipeline('text-classification', model=model, preprocessor=tokenizer) ->>> semantic_cls("Hello world!") +from modelscope.pipelines import pipeline +word_seg = pipeline('word-segmentation', model=model, preprocessor=tokenizer) +input = '今天天气不错,适合出去游玩' +print(word_seg(input)) +{'output': '今天 天气 不错 , 适合 出去 游玩'} ``` - ## 不同场景任务推理示例 - -人像抠图、语义分类建上述两个例子。 其他例子未来添加。 +下面以一个图像任务:人像抠图('image-matting')为例,进一步说明pipeline的用法 +```python +import cv2 +import os.path as osp +from modelscope.pipelines import pipeline +img_matting = pipeline('image-matting') +result = img_matting('http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png') +cv2.imwrite('result.png', result['output_png']) +``` diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index 6cfad54d..45e39133 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -1,6 +1,8 @@ import os import pickle +import shutil import subprocess +from collections import defaultdict from http.cookiejar import CookieJar from os.path import expanduser from typing import List, Optional, Tuple, Union @@ -8,8 +10,11 @@ from typing import List, Optional, Tuple, Union import requests from modelscope.utils.logger import get_logger +from ..msdatasets.config import DOWNLOADED_DATASETS_PATH, HUB_DATASET_ENDPOINT +from ..utils.constant import DownloadMode from .constants import MODELSCOPE_URL_SCHEME -from .errors import InvalidParameter, NotExistError, is_ok, raise_on_error +from .errors import (InvalidParameter, NotExistError, datahub_raise_on_error, + is_ok, raise_on_error) from .utils.utils import (get_endpoint, get_gitlab_domain, model_id_to_group_owner_name) @@ -18,8 +23,9 @@ logger = get_logger() class HubApi: - def __init__(self, endpoint=None): + def __init__(self, endpoint=None, dataset_endpoint=None): self.endpoint = endpoint if endpoint is not None else get_endpoint() + self.dataset_endpoint = dataset_endpoint if dataset_endpoint is not None else HUB_DATASET_ENDPOINT def login( self, @@ -241,6 +247,70 @@ class HubApi: files.append(file) return files + def list_datasets(self): + path = f'{self.dataset_endpoint}/api/v1/datasets' + headers = None + params = {} + r = requests.get(path, params=params, headers=headers) + r.raise_for_status() + dataset_list = r.json()['Data'] + return [x['Name'] for x in dataset_list] + + def fetch_dataset_scripts(self, + dataset_name: str, + namespace: str, + download_mode: Optional[DownloadMode], + version: Optional[str] = 'master'): + if namespace is None: + raise ValueError( + f'Dataset from Hubs.modelscope should have a valid "namespace", but get {namespace}' + ) + version = version or 'master' + cache_dir = os.path.join(DOWNLOADED_DATASETS_PATH, dataset_name, + namespace, version) + download_mode = DownloadMode(download_mode + or DownloadMode.REUSE_DATASET_IF_EXISTS) + if download_mode == DownloadMode.FORCE_REDOWNLOAD and os.path.exists( + cache_dir): + shutil.rmtree(cache_dir) + os.makedirs(cache_dir, exist_ok=True) + datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{namespace}/{dataset_name}' + r = requests.get(datahub_url) + resp = r.json() + datahub_raise_on_error(datahub_url, resp) + dataset_id = resp['Data']['Id'] + datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={version}' + r = requests.get(datahub_url) + resp = r.json() + datahub_raise_on_error(datahub_url, resp) + file_list = resp['Data'] + if file_list is None: + raise NotExistError( + f'The modelscope dataset [dataset_name = {dataset_name}, namespace = {namespace}, ' + f'version = {version}] dose not exist') + + file_list = file_list['Files'] + local_paths = defaultdict(list) + for file_info in file_list: + file_path = file_info['Path'] + if file_path.endswith('.py'): + datahub_url = f'{self.dataset_endpoint}/api/v1/datasets/{dataset_id}/repo/files?' \ + f'Revision={version}&Path={file_path}' + r = requests.get(datahub_url) + r.raise_for_status() + content = r.json()['Data']['Content'] + local_path = os.path.join(cache_dir, file_path) + if os.path.exists(local_path): + logger.warning( + f"Reusing dataset {dataset_name}'s python file ({local_path})" + ) + local_paths['py'].append(local_path) + continue + with open(local_path, 'w') as f: + f.writelines(content) + local_paths['py'].append(local_path) + return local_paths + class ModelScopeConfig: path_credential = expanduser('~/.modelscope/credentials') diff --git a/modelscope/hub/constants.py b/modelscope/hub/constants.py index 0ee451c2..91c08786 100644 --- a/modelscope/hub/constants.py +++ b/modelscope/hub/constants.py @@ -1,6 +1,8 @@ MODELSCOPE_URL_SCHEME = 'http://' -DEFAULT_MODELSCOPE_DOMAIN = '47.94.223.21:31090' +DEFAULT_MODELSCOPE_IP = '47.94.223.21' +DEFAULT_MODELSCOPE_DOMAIN = DEFAULT_MODELSCOPE_IP + ':31090' DEFAULT_MODELSCOPE_GITLAB_DOMAIN = '101.201.119.157:31102' +DEFAULT_MODELSCOPE_DATA_ENDPOINT = MODELSCOPE_URL_SCHEME + DEFAULT_MODELSCOPE_IP + ':31752' DEFAULT_MODELSCOPE_GROUP = 'damo' MODEL_ID_SEPARATOR = '/' diff --git a/modelscope/hub/utils/caching.py b/modelscope/hub/utils/caching.py index 7675e49b..fc30fa27 100644 --- a/modelscope/hub/utils/caching.py +++ b/modelscope/hub/utils/caching.py @@ -1,9 +1,7 @@ import hashlib -import logging import os import pickle import tempfile -import time from shutil import move, rmtree from modelscope.utils.logger import get_logger diff --git a/modelscope/models/__init__.py b/modelscope/models/__init__.py index d7a4bfed..90234b67 100644 --- a/modelscope/models/__init__.py +++ b/modelscope/models/__init__.py @@ -1,15 +1,30 @@ # Copyright (c) Alibaba, Inc. and its affiliates. - -from .audio.ans.frcrn import FRCRNModel -from .audio.kws import GenericKeyWordSpotting -from .audio.tts.am import SambertNetHifi16k -from .audio.tts.vocoder import Hifigan16k from .base import Model from .builder import MODELS, build_model -from .multi_modal import OfaForImageCaptioning -from .nlp import (BertForMaskedLM, BertForSequenceClassification, SbertForNLI, - SbertForSentenceSimilarity, SbertForSentimentClassification, - SbertForTokenClassification, SbertForZeroShotClassification, - SpaceForDialogIntent, SpaceForDialogModeling, - SpaceForDialogStateTracking, StructBertForMaskedLM, - VecoForMaskedLM) + +try: + from .audio.tts.am import SambertNetHifi16k + from .audio.tts.vocoder import Hifigan16k + +except ModuleNotFoundError as e: + if str(e) == "No module named 'tensorflow'": + pass + else: + raise ModuleNotFoundError(e) + +try: + from .audio.kws import GenericKeyWordSpotting + from .multi_modal import OfaForImageCaptioning + from .nlp import (BertForMaskedLM, BertForSequenceClassification, + SbertForNLI, SbertForSentenceSimilarity, + SbertForSentimentClassification, + SbertForTokenClassification, + SbertForZeroShotClassification, SpaceForDialogIntent, + SpaceForDialogModeling, SpaceForDialogStateTracking, + StructBertForMaskedLM, VecoForMaskedLM) + from .audio.ans.frcrn import FRCRNModel +except ModuleNotFoundError as e: + if str(e) == "No module named 'pytorch'": + pass + else: + raise ModuleNotFoundError(e) diff --git a/modelscope/models/multi_modal/clip/clip_model.py b/modelscope/models/multi_modal/clip/clip_model.py index 4283886f..839b8d0e 100644 --- a/modelscope/models/multi_modal/clip/clip_model.py +++ b/modelscope/models/multi_modal/clip/clip_model.py @@ -108,7 +108,11 @@ class CLIPForMultiModalEmbedding(Model): return text_ids_tensor, text_mask_tensor def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: - output = {'img_embedding': None, 'text_embedding': None} + from modelscope.pipelines.outputs import OutputKeys + output = { + OutputKeys.IMG_EMBEDDING: None, + OutputKeys.TEXT_EMBEDDING: None + } if 'img' in input and input['img'] is not None: input_img = input['img'] if isinstance(input_img, Image.Image): @@ -130,7 +134,8 @@ class CLIPForMultiModalEmbedding(Model): img_embedding = self.clip_model( input_data=img_tensor, input_type='img') - output['img_embedding'] = img_embedding.data.cpu().numpy() + from modelscope.pipelines.outputs import OutputKeys + output[OutputKeys.IMG_EMBEDDING] = img_embedding.data.cpu().numpy() if 'text' in input and input['text'] is not None: text_str = input['text'] diff --git a/modelscope/models/multi_modal/image_captioning_model.py b/modelscope/models/multi_modal/image_captioning_model.py index 0154ac29..5c0a3ddf 100644 --- a/modelscope/models/multi_modal/image_captioning_model.py +++ b/modelscope/models/multi_modal/image_captioning_model.py @@ -76,9 +76,10 @@ class OfaForImageCaptioning(Model): input = fairseq.utils.move_to_cuda(input, device=self._device) results, _ = self.eval_caption(self.task, self.generator, self.models, input) + from ...pipelines.outputs import OutputKeys return { 'image_id': results[0]['image_id'], - 'caption': results[0]['caption'] + OutputKeys.CAPTION: results[0][OutputKeys.CAPTION] } def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: diff --git a/modelscope/msdatasets/config.py b/modelscope/msdatasets/config.py index 22390ed7..0357e823 100644 --- a/modelscope/msdatasets/config.py +++ b/modelscope/msdatasets/config.py @@ -2,6 +2,8 @@ import os from pathlib import Path # Cache location +from modelscope.hub.constants import DEFAULT_MODELSCOPE_DATA_ENDPOINT + DEFAULT_CACHE_HOME = '~/.cache' CACHE_HOME = os.getenv('CACHE_HOME', DEFAULT_CACHE_HOME) DEFAULT_MS_CACHE_HOME = os.path.join(CACHE_HOME, 'modelscope/hub') @@ -18,5 +20,5 @@ DEFAULT_DOWNLOADED_DATASETS_PATH = os.path.join(MS_DATASETS_CACHE, DOWNLOADED_DATASETS_PATH = Path( os.getenv('DOWNLOADED_DATASETS_PATH', DEFAULT_DOWNLOADED_DATASETS_PATH)) -MS_HUB_ENDPOINT = os.environ.get('MS_HUB_ENDPOINT', - 'http://47.94.223.21:31752') +HUB_DATASET_ENDPOINT = os.environ.get('HUB_DATASET_ENDPOINT', + DEFAULT_MODELSCOPE_DATA_ENDPOINT) diff --git a/modelscope/msdatasets/ms_dataset.py b/modelscope/msdatasets/ms_dataset.py index 90964b36..fa7d1bf2 100644 --- a/modelscope/msdatasets/ms_dataset.py +++ b/modelscope/msdatasets/ms_dataset.py @@ -11,7 +11,6 @@ from datasets.utils.file_utils import (is_relative_path, relative_to_absolute_path) from modelscope.msdatasets.config import MS_DATASETS_CACHE -from modelscope.msdatasets.utils.ms_api import MsApi from modelscope.utils.constant import DownloadMode, Hubs from modelscope.utils.logger import get_logger @@ -146,8 +145,9 @@ class MsDataset: use_hf = True elif is_relative_path(dataset_name) and dataset_name.count( '/') == 0: - ms_api = MsApi() - dataset_scripts = ms_api.fetch_dataset_scripts( + from modelscope.hub.api import HubApi + api = HubApi() + dataset_scripts = api.fetch_dataset_scripts( dataset_name, namespace, download_mode, version) if 'py' in dataset_scripts: # dataset copied from hf datasets dataset_name = dataset_scripts['py'][0] diff --git a/modelscope/msdatasets/utils/__init__.py b/modelscope/msdatasets/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/modelscope/msdatasets/utils/ms_api.py b/modelscope/msdatasets/utils/ms_api.py deleted file mode 100644 index c9b49ca1..00000000 --- a/modelscope/msdatasets/utils/ms_api.py +++ /dev/null @@ -1,84 +0,0 @@ -import os -import shutil -from collections import defaultdict -from typing import Optional - -import requests - -from modelscope.hub.errors import NotExistError, datahub_raise_on_error -from modelscope.msdatasets.config import (DOWNLOADED_DATASETS_PATH, - MS_HUB_ENDPOINT) -from modelscope.utils.constant import DownloadMode -from modelscope.utils.logger import get_logger - -logger = get_logger() - - -class MsApi: - - def __init__(self, endpoint=MS_HUB_ENDPOINT): - self.endpoint = endpoint - - def list_datasets(self): - path = f'{self.endpoint}/api/v1/datasets' - headers = None - params = {} - r = requests.get(path, params=params, headers=headers) - r.raise_for_status() - dataset_list = r.json()['Data'] - return [x['Name'] for x in dataset_list] - - def fetch_dataset_scripts(self, - dataset_name: str, - namespace: str, - download_mode: Optional[DownloadMode], - version: Optional[str] = 'master'): - if namespace is None: - raise ValueError( - f'Dataset from Hubs.modelscope should have a valid "namespace", but get {namespace}' - ) - version = version or 'master' - cache_dir = os.path.join(DOWNLOADED_DATASETS_PATH, dataset_name, - namespace, version) - download_mode = DownloadMode(download_mode - or DownloadMode.REUSE_DATASET_IF_EXISTS) - if download_mode == DownloadMode.FORCE_REDOWNLOAD and os.path.exists( - cache_dir): - shutil.rmtree(cache_dir) - os.makedirs(cache_dir, exist_ok=True) - datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}' - r = requests.get(datahub_url) - resp = r.json() - datahub_raise_on_error(datahub_url, resp) - dataset_id = resp['Data']['Id'] - datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={version}' - r = requests.get(datahub_url) - resp = r.json() - datahub_raise_on_error(datahub_url, resp) - file_list = resp['Data'] - if file_list is None: - raise NotExistError( - f'The modelscope dataset [dataset_name = {dataset_name}, namespace = {namespace}, ' - f'version = {version}] dose not exist') - - file_list = file_list['Files'] - local_paths = defaultdict(list) - for file_info in file_list: - file_path = file_info['Path'] - if file_path.endswith('.py'): - datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/files?' \ - f'Revision={version}&Path={file_path}' - r = requests.get(datahub_url) - r.raise_for_status() - content = r.json()['Data']['Content'] - local_path = os.path.join(cache_dir, file_path) - if os.path.exists(local_path): - logger.warning( - f"Reusing dataset {dataset_name}'s python file ({local_path})" - ) - local_paths['py'].append(local_path) - continue - with open(local_path, 'w') as f: - f.writelines(content) - local_paths['py'].append(local_path) - return local_paths diff --git a/modelscope/pipelines/audio/__init__.py b/modelscope/pipelines/audio/__init__.py index 87ccd49a..c4dc0100 100644 --- a/modelscope/pipelines/audio/__init__.py +++ b/modelscope/pipelines/audio/__init__.py @@ -1,3 +1,16 @@ -from .kws_kwsbp_pipeline import * # noqa F403 -from .linear_aec_pipeline import LinearAECPipeline -from .text_to_speech_pipeline import * # noqa F403 +try: + from .kws_kwsbp_pipeline import * # noqa F403 + from .linear_aec_pipeline import LinearAECPipeline +except ModuleNotFoundError as e: + if str(e) == "No module named 'torch'": + pass + else: + raise ModuleNotFoundError(e) + +try: + from .text_to_speech_pipeline import * # noqa F403 +except ModuleNotFoundError as e: + if str(e) == "No module named 'tensorflow'": + pass + else: + raise ModuleNotFoundError(e) diff --git a/modelscope/pipelines/audio/ans_pipeline.py b/modelscope/pipelines/audio/ans_pipeline.py index d9a04a29..536a536a 100644 --- a/modelscope/pipelines/audio/ans_pipeline.py +++ b/modelscope/pipelines/audio/ans_pipeline.py @@ -10,6 +10,7 @@ from modelscope.metainfo import Pipelines from modelscope.utils.constant import Tasks from ..base import Input, Pipeline from ..builder import PIPELINES +from ..outputs import OutputKeys def audio_norm(x): @@ -108,10 +109,10 @@ class ANSPipeline(Pipeline): current_idx += stride else: outputs = self.model(ndarray)['wav_l2'][0].cpu().numpy() - return {'output_pcm': outputs[:nsamples]} + return {OutputKeys.OUTPUT_PCM: outputs[:nsamples]} def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: if 'output_path' in kwargs.keys(): - sf.write(kwargs['output_path'], inputs['output_pcm'], + sf.write(kwargs['output_path'], inputs[OutputKeys.OUTPUT_PCM], self.SAMPLE_RATE) return inputs diff --git a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py index 4a69976a..45184ad7 100644 --- a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py +++ b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py @@ -5,6 +5,8 @@ import stat import subprocess from typing import Any, Dict, List +import json + from modelscope.metainfo import Pipelines from modelscope.models import Model from modelscope.pipelines.base import Pipeline @@ -39,6 +41,12 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): self._preprocessor = preprocessor self._model = model + self._keywords = None + + if 'keywords' in kwargs.keys(): + self._keywords = kwargs['keywords'] + print('self._keywords len: ', len(self._keywords)) + print('self._keywords: ', self._keywords) def __call__(self, kws_type: str, wav_path: List[str]) -> Dict[str, Any]: assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', @@ -197,6 +205,16 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): return rst_dict def _run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + opts: str = '' + + # setting customized keywords + keywords_json = self._set_customized_keywords() + if len(keywords_json) > 0: + keywords_json_file = os.path.join(inputs['workspace'], + 'keyword_custom.json') + with open(keywords_json_file, 'w') as f: + json.dump(keywords_json, f) + opts = '--keyword-custom ' + keywords_json_file if inputs['kws_set'] == 'roc': inputs['keyword_grammar_path'] = os.path.join( @@ -211,7 +229,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): ' --sample-rate=' + inputs['sample_rate'] + \ ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ ' --wave-scp=' + os.path.join(inputs['pos_data_path'], 'wave.list') + \ - ' --num-thread=1 > ' + dump_log_path + ' 2>&1' + ' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1' os.system(kws_cmd) if inputs['kws_set'] in ['pos_testsets', 'roc']: @@ -236,7 +254,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): ' --sample-rate=' + inputs['sample_rate'] + \ ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ ' --wave-scp=' + wav_list_path + \ - ' --num-thread=1 > ' + dump_log_path + ' 2>&1' + ' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1' p = subprocess.Popen(kws_cmd, shell=True) process.append(p) j += 1 @@ -268,7 +286,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): ' --sample-rate=' + inputs['sample_rate'] + \ ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ ' --wave-scp=' + wav_list_path + \ - ' --num-thread=1 > ' + dump_log_path + ' 2>&1' + ' --num-thread=1 ' + opts + ' > ' + dump_log_path + ' 2>&1' p = subprocess.Popen(kws_cmd, shell=True) process.append(p) j += 1 @@ -447,3 +465,29 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): threshold_cur += step return output + + def _set_customized_keywords(self) -> Dict[str, Any]: + if self._keywords is not None: + word_list_inputs = self._keywords + word_list = [] + for i in range(len(word_list_inputs)): + key = word_list_inputs[i] + new_item = {} + if key.__contains__('keyword'): + name = key['keyword'] + new_name: str = '' + for n in range(0, len(name), 1): + new_name += name[n] + new_name += ' ' + new_name = new_name.strip() + new_item['name'] = new_name + + if key.__contains__('threshold'): + threshold1: float = key['threshold'] + new_item['threshold1'] = threshold1 + + word_list.append(new_item) + out = {'word_list': word_list} + return out + else: + return '' diff --git a/modelscope/pipelines/audio/linear_aec_pipeline.py b/modelscope/pipelines/audio/linear_aec_pipeline.py index 70562b19..5ceb499f 100644 --- a/modelscope/pipelines/audio/linear_aec_pipeline.py +++ b/modelscope/pipelines/audio/linear_aec_pipeline.py @@ -12,6 +12,7 @@ from modelscope.preprocessors.audio import LinearAECAndFbank from modelscope.utils.constant import ModelFile, Tasks from ..base import Pipeline from ..builder import PIPELINES +from ..outputs import OutputKeys FEATURE_MVN = 'feature.DEY.mvn.txt' @@ -120,7 +121,7 @@ class LinearAECPipeline(Pipeline): } """ output_data = self._process(inputs['feature'], inputs['base']) - return {'output_pcm': output_data} + return {OutputKeys.OUTPUT_PCM: output_data} def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: r"""The post process. Will save audio to file, if the output_path is given. @@ -140,8 +141,8 @@ class LinearAECPipeline(Pipeline): """ if 'output_path' in kwargs.keys(): wav.write(kwargs['output_path'], self.preprocessor.SAMPLE_RATE, - inputs['output_pcm'].astype(np.int16)) - inputs['output_pcm'] = inputs['output_pcm'] / 32768.0 + inputs[OutputKeys.OUTPUT_PCM].astype(np.int16)) + inputs[OutputKeys.OUTPUT_PCM] = inputs[OutputKeys.OUTPUT_PCM] / 32768.0 return inputs def _process(self, fbanks, mixture): diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index b046e076..aa393ec5 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -1,5 +1,18 @@ -from .action_recognition_pipeline import ActionRecognitionPipeline -from .animal_recog_pipeline import AnimalRecogPipeline -from .image_cartoon_pipeline import ImageCartoonPipeline -from .image_matting_pipeline import ImageMattingPipeline -from .ocr_detection_pipeline import OCRDetectionPipeline +try: + from .action_recognition_pipeline import ActionRecognitionPipeline + from .animal_recog_pipeline import AnimalRecogPipeline +except ModuleNotFoundError as e: + if str(e) == "No module named 'torch'": + pass + else: + raise ModuleNotFoundError(e) + +try: + from .image_cartoon_pipeline import ImageCartoonPipeline + from .image_matting_pipeline import ImageMattingPipeline + from .ocr_detection_pipeline import OCRDetectionPipeline +except ModuleNotFoundError as e: + if str(e) == "No module named 'tensorflow'": + pass + else: + raise ModuleNotFoundError(e) diff --git a/modelscope/pipelines/cv/action_recognition_pipeline.py b/modelscope/pipelines/cv/action_recognition_pipeline.py index 845f8f9a..fce037d8 100644 --- a/modelscope/pipelines/cv/action_recognition_pipeline.py +++ b/modelscope/pipelines/cv/action_recognition_pipeline.py @@ -16,6 +16,7 @@ from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger from ..base import Pipeline from ..builder import PIPELINES +from ..outputs import OutputKeys logger = get_logger() @@ -49,7 +50,7 @@ class ActionRecognitionPipeline(Pipeline): def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: pred = self.perform_inference(input['video_data']) output_label = self.label_mapping[str(pred)] - return {'output_label': output_label} + return {OutputKeys.LABELS: output_label} @torch.no_grad() def perform_inference(self, data, max_bsz=4): diff --git a/modelscope/pipelines/cv/animal_recog_pipeline.py b/modelscope/pipelines/cv/animal_recog_pipeline.py index eee9e844..dd68dab6 100644 --- a/modelscope/pipelines/cv/animal_recog_pipeline.py +++ b/modelscope/pipelines/cv/animal_recog_pipeline.py @@ -18,6 +18,7 @@ from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger from ..base import Pipeline from ..builder import PIPELINES +from ..outputs import OutputKeys logger = get_logger() @@ -121,7 +122,9 @@ class AnimalRecogPipeline(Pipeline): label_mapping = f.readlines() score = torch.max(inputs['outputs']) inputs = { - 'scores': score.item(), - 'labels': label_mapping[inputs['outputs'].argmax()].split('\t')[1] + OutputKeys.SCORES: + score.item(), + OutputKeys.LABELS: + label_mapping[inputs['outputs'].argmax()].split('\t')[1] } return inputs diff --git a/modelscope/pipelines/cv/image_cartoon_pipeline.py b/modelscope/pipelines/cv/image_cartoon_pipeline.py index 717336e9..f6fd3ee2 100644 --- a/modelscope/pipelines/cv/image_cartoon_pipeline.py +++ b/modelscope/pipelines/cv/image_cartoon_pipeline.py @@ -17,6 +17,7 @@ from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger from ..base import Pipeline from ..builder import PIPELINES +from ..outputs import OutputKeys if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -94,7 +95,7 @@ class ImageCartoonPipeline(Pipeline): landmarks = self.detect_face(img) if landmarks is None: print('No face detected!') - return {'output_png': None} + return {OutputKeys.OUTPUT_IMG: None} # background process pad_bg, pad_h, pad_w = padTo16x(img_brg) @@ -143,7 +144,7 @@ class ImageCartoonPipeline(Pipeline): res = cv2.resize(res, (ori_w, ori_h), interpolation=cv2.INTER_AREA) - return {'output_png': res} + return {OutputKeys.OUTPUT_IMG: res} def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return inputs diff --git a/modelscope/pipelines/cv/image_matting_pipeline.py b/modelscope/pipelines/cv/image_matting_pipeline.py index b3e27e4b..140d28d7 100644 --- a/modelscope/pipelines/cv/image_matting_pipeline.py +++ b/modelscope/pipelines/cv/image_matting_pipeline.py @@ -12,6 +12,7 @@ from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger from ..base import Pipeline from ..builder import PIPELINES +from ..outputs import OutputKeys logger = get_logger() @@ -60,9 +61,9 @@ class ImageMattingPipeline(Pipeline): def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: with self._session.as_default(): feed_dict = {self.input_name: input['img']} - output_png = self._session.run(self.output, feed_dict=feed_dict) - output_png = cv2.cvtColor(output_png, cv2.COLOR_RGBA2BGRA) - return {'output_png': output_png} + output_img = self._session.run(self.output, feed_dict=feed_dict) + output_img = cv2.cvtColor(output_img, cv2.COLOR_RGBA2BGRA) + return {OutputKeys.OUTPUT_IMG: output_img} def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return inputs diff --git a/modelscope/pipelines/cv/ocr_detection_pipeline.py b/modelscope/pipelines/cv/ocr_detection_pipeline.py index 4856b06b..6b259eaf 100644 --- a/modelscope/pipelines/cv/ocr_detection_pipeline.py +++ b/modelscope/pipelines/cv/ocr_detection_pipeline.py @@ -16,6 +16,7 @@ from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger from ..base import Pipeline from ..builder import PIPELINES +from ..outputs import OutputKeys from .ocr_utils import model_resnet_mutex_v4_linewithchar, ops, utils if tf.__version__ >= '2.0': @@ -174,5 +175,5 @@ class OCRDetectionPipeline(Pipeline): dt_nms = utils.nms_python(dt_n9) dt_polygons = np.array([o[:8] for o in dt_nms]) - result = {'det_polygons': dt_polygons} + result = {OutputKeys.POLYGONS: dt_polygons} return result diff --git a/modelscope/pipelines/multi_modal/__init__.py b/modelscope/pipelines/multi_modal/__init__.py index fdcada89..49b07cce 100644 --- a/modelscope/pipelines/multi_modal/__init__.py +++ b/modelscope/pipelines/multi_modal/__init__.py @@ -1,3 +1,9 @@ -from .image_captioning_pipeline import ImageCaptionPipeline -from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline -from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline +try: + from .image_captioning_pipeline import ImageCaptionPipeline + from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline + from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline +except ModuleNotFoundError as e: + if str(e) == "No module named 'torch'": + pass + else: + raise ModuleNotFoundError(e) diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index f600dec0..eeed9b38 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -1,11 +1,17 @@ -from .dialog_intent_prediction_pipeline import * # noqa F403 -from .dialog_modeling_pipeline import * # noqa F403 -from .dialog_state_tracking_pipeline import * # noqa F403 -from .fill_mask_pipeline import * # noqa F403 -from .nli_pipeline import * # noqa F403 -from .sentence_similarity_pipeline import * # noqa F403 -from .sentiment_classification_pipeline import * # noqa F403 -from .sequence_classification_pipeline import * # noqa F403 -from .text_generation_pipeline import * # noqa F403 -from .word_segmentation_pipeline import * # noqa F403 -from .zero_shot_classification_pipeline import * # noqa F403 +try: + from .dialog_intent_prediction_pipeline import * # noqa F403 + from .dialog_modeling_pipeline import * # noqa F403 + from .dialog_state_tracking_pipeline import * # noqa F403 + from .fill_mask_pipeline import * # noqa F403 + from .nli_pipeline import * # noqa F403 + from .sentence_similarity_pipeline import * # noqa F403 + from .sentiment_classification_pipeline import * # noqa F403 + from .sequence_classification_pipeline import * # noqa F403 + from .text_generation_pipeline import * # noqa F403 + from .word_segmentation_pipeline import * # noqa F403 + from .zero_shot_classification_pipeline import * # noqa F403 +except ModuleNotFoundError as e: + if str(e) == "No module named 'torch'": + pass + else: + raise ModuleNotFoundError(e) diff --git a/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py b/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py index 4323dec6..45844b30 100644 --- a/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py +++ b/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py @@ -8,6 +8,7 @@ from ...preprocessors import DialogIntentPredictionPreprocessor from ...utils.constant import Tasks from ..base import Pipeline from ..builder import PIPELINES +from ..outputs import OutputKeys __all__ = ['DialogIntentPredictionPipeline'] @@ -44,9 +45,9 @@ class DialogIntentPredictionPipeline(Pipeline): pos = np.where(pred == np.max(pred)) result = { - 'prediction': pred, - 'label_pos': pos[0], - 'label': self.categories[pos[0][0]] + OutputKeys.PREDICTION: pred, + OutputKeys.LABEL_POS: pos[0], + OutputKeys.LABEL: self.categories[pos[0][0]] } return result diff --git a/modelscope/pipelines/nlp/dialog_modeling_pipeline.py b/modelscope/pipelines/nlp/dialog_modeling_pipeline.py index 0261b2e4..bdc1e092 100644 --- a/modelscope/pipelines/nlp/dialog_modeling_pipeline.py +++ b/modelscope/pipelines/nlp/dialog_modeling_pipeline.py @@ -8,6 +8,7 @@ from ...preprocessors import DialogModelingPreprocessor from ...utils.constant import Tasks from ..base import Pipeline, Tensor from ..builder import PIPELINES +from ..outputs import OutputKeys __all__ = ['DialogModelingPipeline'] @@ -42,7 +43,6 @@ class DialogModelingPipeline(Pipeline): inputs['resp']) assert len(sys_rsp) > 2 sys_rsp = sys_rsp[1:len(sys_rsp) - 1] - - inputs['response'] = sys_rsp + inputs[OutputKeys.RESPONSE] = sys_rsp return inputs diff --git a/modelscope/pipelines/nlp/fill_mask_pipeline.py b/modelscope/pipelines/nlp/fill_mask_pipeline.py index 256f867a..bd4118cc 100644 --- a/modelscope/pipelines/nlp/fill_mask_pipeline.py +++ b/modelscope/pipelines/nlp/fill_mask_pipeline.py @@ -11,6 +11,7 @@ from ...utils.config import Config from ...utils.constant import ModelFile, Tasks from ..base import Pipeline, Tensor from ..builder import PIPELINES +from ..outputs import OutputKeys __all__ = ['FillMaskPipeline'] _type_map = {'veco': 'roberta', 'sbert': 'bert'} @@ -108,4 +109,4 @@ class FillMaskPipeline(Pipeline): pred_string = rep_tokens(pred_string, self.rep_map[process_type]) pred_strings.append(pred_string) - return {'text': pred_strings} + return {OutputKeys.TEXT: pred_strings} diff --git a/modelscope/pipelines/nlp/nli_pipeline.py b/modelscope/pipelines/nlp/nli_pipeline.py index 49dc330f..7ed050be 100644 --- a/modelscope/pipelines/nlp/nli_pipeline.py +++ b/modelscope/pipelines/nlp/nli_pipeline.py @@ -11,6 +11,7 @@ from ...preprocessors import NLIPreprocessor from ...utils.constant import Tasks from ..base import Pipeline from ..builder import PIPELINES +from ..outputs import OutputKeys __all__ = ['NLIPipeline'] @@ -69,4 +70,4 @@ class NLIPipeline(Pipeline): cls_names = [self.model.id2label[cid] for cid in cls_ids] - return {'scores': probs, 'labels': cls_names} + return {OutputKeys.SCORES: probs, OutputKeys.LABELS: cls_names} diff --git a/modelscope/pipelines/nlp/sentence_similarity_pipeline.py b/modelscope/pipelines/nlp/sentence_similarity_pipeline.py index f6bcd72e..4cccd996 100644 --- a/modelscope/pipelines/nlp/sentence_similarity_pipeline.py +++ b/modelscope/pipelines/nlp/sentence_similarity_pipeline.py @@ -10,6 +10,7 @@ from ...preprocessors import SequenceClassificationPreprocessor from ...utils.constant import Tasks from ..base import Input, Pipeline from ..builder import PIPELINES +from ..outputs import OutputKeys __all__ = ['SentenceSimilarityPipeline'] @@ -69,4 +70,4 @@ class SentenceSimilarityPipeline(Pipeline): probs = probs[cls_ids].tolist() cls_names = [self.model.id2label[cid] for cid in cls_ids] b = 0 - return {'scores': probs[b], 'labels': cls_names[b]} + return {OutputKeys.SCORES: probs[b], OutputKeys.LABELS: cls_names[b]} diff --git a/modelscope/pipelines/nlp/sentiment_classification_pipeline.py b/modelscope/pipelines/nlp/sentiment_classification_pipeline.py index 9291ed44..2afe64d9 100644 --- a/modelscope/pipelines/nlp/sentiment_classification_pipeline.py +++ b/modelscope/pipelines/nlp/sentiment_classification_pipeline.py @@ -13,6 +13,7 @@ from ...preprocessors import SentimentClassificationPreprocessor from ...utils.constant import Tasks from ..base import Input, Pipeline from ..builder import PIPELINES +from ..outputs import OutputKeys __all__ = ['SentimentClassificationPipeline'] @@ -73,5 +74,4 @@ class SentimentClassificationPipeline(Pipeline): probs = probs[cls_ids].tolist() cls_names = [self.model.id2label[cid] for cid in cls_ids] - - return {'scores': probs, 'labels': cls_names} + return {OutputKeys.SCORES: probs, OutputKeys.LABELS: cls_names} diff --git a/modelscope/pipelines/nlp/sequence_classification_pipeline.py b/modelscope/pipelines/nlp/sequence_classification_pipeline.py index 43c81d60..ec765f55 100644 --- a/modelscope/pipelines/nlp/sequence_classification_pipeline.py +++ b/modelscope/pipelines/nlp/sequence_classification_pipeline.py @@ -9,6 +9,7 @@ from modelscope.utils.constant import Tasks from ...models import Model from ..base import Input, Pipeline from ..builder import PIPELINES +from ..outputs import OutputKeys __all__ = ['SequenceClassificationPipeline'] @@ -64,4 +65,4 @@ class SequenceClassificationPipeline(Pipeline): cls_names = [self.model.id2label[cid] for cid in cls_ids] - return {'scores': probs, 'labels': cls_names} + return {OutputKeys.SCORES: probs, OutputKeys.LABELS: cls_names} diff --git a/modelscope/pipelines/nlp/text_generation_pipeline.py b/modelscope/pipelines/nlp/text_generation_pipeline.py index 8f55cce0..d5e9e58b 100644 --- a/modelscope/pipelines/nlp/text_generation_pipeline.py +++ b/modelscope/pipelines/nlp/text_generation_pipeline.py @@ -9,6 +9,7 @@ from ...preprocessors import TextGenerationPreprocessor from ...utils.constant import Tasks from ..base import Pipeline, Tensor from ..builder import PIPELINES +from ..outputs import OutputKeys __all__ = ['TextGenerationPipeline'] @@ -70,4 +71,4 @@ class TextGenerationPipeline(Pipeline): for _old, _new in replace_tokens_roberta: pred_string = pred_string.replace(_old, _new) pred_string.strip() - return {'text': pred_string} + return {OutputKeys.TEXT: pred_string} diff --git a/modelscope/pipelines/nlp/word_segmentation_pipeline.py b/modelscope/pipelines/nlp/word_segmentation_pipeline.py index 9501efb7..66b333cb 100644 --- a/modelscope/pipelines/nlp/word_segmentation_pipeline.py +++ b/modelscope/pipelines/nlp/word_segmentation_pipeline.py @@ -9,6 +9,7 @@ from ...preprocessors import TokenClassifcationPreprocessor from ...utils.constant import Tasks from ..base import Pipeline, Tensor from ..builder import PIPELINES +from ..outputs import OutputKeys __all__ = ['WordSegmentationPipeline'] @@ -73,7 +74,4 @@ class WordSegmentationPipeline(Pipeline): if chunk: chunks.append(chunk) seg_result = ' '.join(chunks) - rst = { - 'output': seg_result, - } - return rst + return {OutputKeys.OUTPUT: seg_result} diff --git a/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py index 375e9093..a7ea1e9a 100644 --- a/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py +++ b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py @@ -14,6 +14,7 @@ from ...preprocessors import ZeroShotClassificationPreprocessor from ...utils.constant import Tasks from ..base import Input, Pipeline from ..builder import PIPELINES +from ..outputs import OutputKeys __all__ = ['ZeroShotClassificationPipeline'] @@ -82,7 +83,7 @@ class ZeroShotClassificationPipeline(Pipeline): scores = softmax(logits, axis=-1) reversed_index = list(reversed(scores.argsort())) result = { - 'labels': [candidate_labels[i] for i in reversed_index], - 'scores': [scores[i].item() for i in reversed_index] + OutputKeys.LABELS: [candidate_labels[i] for i in reversed_index], + OutputKeys.SCORES: [scores[i].item() for i in reversed_index], } return result diff --git a/modelscope/pipelines/outputs.py b/modelscope/pipelines/outputs.py index 5af43ab7..8fcf498b 100644 --- a/modelscope/pipelines/outputs.py +++ b/modelscope/pipelines/outputs.py @@ -2,54 +2,76 @@ from modelscope.utils.constant import Tasks + +class OutputKeys(object): + SCORES = 'scores' + LABEL = 'label' + LABELS = 'labels' + LABEL_POS = 'label_pos' + POSES = 'poses' + CAPTION = 'caption' + BOXES = 'boxes' + TEXT = 'text' + POLYGONS = 'polygons' + OUTPUT = 'output' + OUTPUT_IMG = 'output_img' + OUTPUT_PCM = 'output_pcm' + IMG_EMBEDDING = 'img_embedding' + TEXT_EMBEDDING = 'text_embedding' + RESPONSE = 'response' + PREDICTION = 'prediction' + + TASK_OUTPUTS = { # ============ vision tasks =================== # image classification result for single sample # { - # "labels": ["dog", "horse", "cow", "cat"], # "scores": [0.9, 0.1, 0.05, 0.05] + # "labels": ["dog", "horse", "cow", "cat"], # } - Tasks.image_classification: ['scores', 'labels'], - Tasks.image_tagging: ['scores', 'labels'], + Tasks.image_classification: [OutputKeys.SCORES, OutputKeys.LABELS], + Tasks.image_tagging: [OutputKeys.SCORES, OutputKeys.LABELS], # object detection result for single sample # { + # "scores": [0.9, 0.1, 0.05, 0.05] + # "labels": ["dog", "horse", "cow", "cat"], # "boxes": [ # [x1, y1, x2, y2], # [x1, y1, x2, y2], # [x1, y1, x2, y2], # ], - # "labels": ["dog", "horse", "cow", "cat"], - # "scores": [0.9, 0.1, 0.05, 0.05] # } - Tasks.object_detection: ['scores', 'labels', 'boxes'], + Tasks.object_detection: + [OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.BOXES], # instance segmentation result for single sample # { - # "masks": [ - # np.array in bgr channel order - # ], + # "scores": [0.9, 0.1, 0.05, 0.05], # "labels": ["dog", "horse", "cow", "cat"], - # "scores": [0.9, 0.1, 0.05, 0.05] + # "boxes": [ + # np.array in bgr channel order + # ] # } - Tasks.image_segmentation: ['scores', 'labels', 'boxes'], + Tasks.image_segmentation: + [OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.BOXES], # image generation/editing/matting result for single sample # { - # "output_png": np.array with shape(h, w, 4) + # "output_img": np.array with shape(h, w, 4) # for matting or (h, w, 3) for general purpose # } - Tasks.image_editing: ['output_png'], - Tasks.image_matting: ['output_png'], - Tasks.image_generation: ['output_png'], + Tasks.image_editing: [OutputKeys.OUTPUT_IMG], + Tasks.image_matting: [OutputKeys.OUTPUT_IMG], + Tasks.image_generation: [OutputKeys.OUTPUT_IMG], # action recognition result for single video # { # "output_label": "abseiling" # } - Tasks.action_recognition: ['output_label'], + Tasks.action_recognition: [OutputKeys.LABELS], # pose estimation result for single sample # { @@ -58,48 +80,55 @@ TASK_OUTPUTS = { # "boxes": np.array with shape [num_pose, 4], each box is # [x1, y1, x2, y2] # } - Tasks.pose_estimation: ['poses', 'boxes'], + Tasks.pose_estimation: [OutputKeys.POSES, OutputKeys.BOXES], # ocr detection result for single sample # { - # "det_polygons": np.array with shape [num_text, 8], each box is + # "polygons": np.array with shape [num_text, 8], each polygon is # [x1, y1, x2, y2, x3, y3, x4, y4] # } - Tasks.ocr_detection: ['det_polygons'], + Tasks.ocr_detection: [OutputKeys.POLYGONS], # ============ nlp tasks =================== # text classification result for single sample # { - # "labels": ["happy", "sad", "calm", "angry"], # "scores": [0.9, 0.1, 0.05, 0.05] + # "labels": ["happy", "sad", "calm", "angry"], # } - Tasks.text_classification: ['scores', 'labels'], + Tasks.text_classification: [OutputKeys.SCORES, OutputKeys.LABELS], # text generation result for single sample # { - # "text": "this is text generated by a model." + # "text": "this is the text generated by a model." # } - Tasks.text_generation: ['text'], + Tasks.text_generation: [OutputKeys.TEXT], # fill mask result for single sample # { # "text": "this is the text which masks filled by model." # } - Tasks.fill_mask: ['text'], + Tasks.fill_mask: [OutputKeys.TEXT], # word segmentation result for single sample # { # "output": "今天 天气 不错 , 适合 出去 游玩" # } - Tasks.word_segmentation: ['output'], + Tasks.word_segmentation: [OutputKeys.OUTPUT], # sentence similarity result for single sample # { - # "labels": "1", # "scores": 0.9 + # "labels": "1", # } - Tasks.sentence_similarity: ['scores', 'labels'], + Tasks.sentence_similarity: [OutputKeys.SCORES, OutputKeys.LABELS], + + # sentiment classification result for single sample + # { + # "labels": ["happy", "sad", "calm", "angry"], + # "scores": [0.9, 0.1, 0.05, 0.05] + # } + Tasks.sentiment_classification: [OutputKeys.SCORES, OutputKeys.LABELS], # sentiment classification result for single sample # { @@ -110,10 +139,43 @@ TASK_OUTPUTS = { # zero-shot classification result for single sample # { + # "scores": [0.9, 0.1, 0.05, 0.05] + # "labels": ["happy", "sad", "calm", "angry"], + # } + Tasks.zero_shot_classification: [OutputKeys.SCORES, OutputKeys.LABELS], + + # nli result for single sample + # { # "labels": ["happy", "sad", "calm", "angry"], # "scores": [0.9, 0.1, 0.05, 0.05] # } - Tasks.zero_shot_classification: ['scores', 'labels'], + Tasks.nli: [OutputKeys.SCORES, OutputKeys.LABELS], + + # {'pred': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05, + # 1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04, + # 6.45841064e-05, 1.15997791e-04, 5.11605394e-05, 9.87020373e-01, + # 2.66957268e-05, 4.72324500e-05, 9.74208378e-05, 4.18022355e-05, + # 2.97343540e-05, 5.81317654e-05, 5.44203431e-05, 6.28319322e-05, + # 7.34537680e-05, 6.61411541e-05, 3.62534920e-05, 8.58885178e-05, + # 8.24327726e-05, 4.66077945e-05, 5.32869453e-05, 4.16190960e-05, + # 5.97518992e-05, 3.92273068e-05, 3.44069012e-05, 9.92335918e-05, + # 9.25978165e-05, 6.26462061e-05, 3.32317031e-05, 1.32061413e-03, + # 2.01607945e-05, 3.36636294e-05, 3.99156743e-05, 5.84108493e-05, + # 2.53432900e-05, 4.95731190e-04, 2.64443643e-05, 4.46992999e-05, + # 2.42672231e-05, 4.75615161e-05, 2.66230145e-05, 4.00083954e-05, + # 2.90536875e-04, 4.23891543e-05, 8.63691166e-05, 4.98188965e-05, + # 3.47019341e-05, 4.52718523e-05, 4.20905781e-05, 5.50173208e-05, + # 4.92360487e-05, 3.56021264e-05, 2.13957210e-05, 6.17428886e-05, + # 1.43893281e-04, 7.32152112e-05, 2.91354867e-04, 2.46623786e-05, + # 3.61441926e-05, 3.38475402e-05, 3.44323053e-05, 5.70138109e-05, + # 4.31488479e-05, 4.94503947e-05, 4.30105974e-05, 1.00963116e-04, + # 2.82062047e-05, 1.15582036e-04, 4.48261271e-05, 3.99339879e-05, + # 7.27692823e-05], dtype=float32), 'label_pos': array([11]), 'label': 'lost_or_stolen_card'} + Tasks.dialog_intent_prediction: + [OutputKeys.PREDICTION, OutputKeys.LABEL_POS, OutputKeys.LABEL], + + # sys : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!'] + Tasks.dialog_modeling: [OutputKeys.RESPONSE], # nli result for single sample # { @@ -189,7 +251,7 @@ TASK_OUTPUTS = { # { # "output_pcm": np.array with shape(samples,) and dtype float32 # } - Tasks.speech_signal_process: ['output_pcm'], + Tasks.speech_signal_process: [OutputKeys.OUTPUT_PCM], # ============ multi-modal tasks =================== @@ -197,14 +259,15 @@ TASK_OUTPUTS = { # { # "caption": "this is an image caption text." # } - Tasks.image_captioning: ['caption'], + Tasks.image_captioning: [OutputKeys.CAPTION], # multi-modal embedding result for single sample # { # "img_embedding": np.array with shape [1, D], # "text_embedding": np.array with shape [1, D] # } - Tasks.multi_modal_embedding: ['img_embedding', 'text_embedding'], + Tasks.multi_modal_embedding: + [OutputKeys.IMG_EMBEDDING, OutputKeys.TEXT_EMBEDDING], # visual grounding result for single sample # { @@ -215,11 +278,11 @@ TASK_OUTPUTS = { # ], # "scores": [0.9, 0.1, 0.05, 0.05] # } - Tasks.visual_grounding: ['boxes', 'scores'], + Tasks.visual_grounding: [OutputKeys.BOXES, OutputKeys.SCORES], # text_to_image result for a single sample # { - # "image": np.ndarray with shape [height, width, 3] + # "output_img": np.ndarray with shape [height, width, 3] # } - Tasks.text_to_image_synthesis: ['image'] + Tasks.text_to_image_synthesis: [OutputKeys.OUTPUT_IMG] } diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 742a6152..962e9f6e 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -1,14 +1,21 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from .audio import LinearAECAndFbank from .base import Preprocessor from .builder import PREPROCESSORS, build_preprocessor from .common import Compose from .image import LoadImage, load_image from .kws import WavToLists -from .multi_modal import * # noqa F403 -from .nlp import * # noqa F403 -from .space.dialog_intent_prediction_preprocessor import * # noqa F403 -from .space.dialog_modeling_preprocessor import * # noqa F403 -from .space.dialog_state_tracking_preprocessor import * # noqa F403 from .text_to_speech import * # noqa F403 + +try: + from .audio import LinearAECAndFbank + from .multi_modal import * # noqa F403 + from .nlp import * # noqa F403 + from .space.dialog_intent_prediction_preprocessor import * # noqa F403 + from .space.dialog_modeling_preprocessor import * # noqa F403 + from .space.dialog_state_tracking_preprocessor import * # noqa F403 +except ModuleNotFoundError as e: + if str(e) == "No module named 'tensorflow'": + pass + else: + raise ModuleNotFoundError(e) diff --git a/modelscope/preprocessors/space/fields/dst_processors.py b/modelscope/preprocessors/space/fields/dst_processors.py new file mode 100644 index 00000000..22e06eec --- /dev/null +++ b/modelscope/preprocessors/space/fields/dst_processors.py @@ -0,0 +1,1523 @@ +# +# Copyright 2020 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# Part of this code is based on the source code of Transformers +# (arXiv:1910.03771) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +import json +import numpy as np +import six +from tqdm import tqdm + +logger = logging.getLogger(__name__) +USER_NAME = 'User' +SYSTEM_NAME = 'System' +DIALOG_ACT = 'Dialog_Act' + +utter1 = { + 'User-1': + "I'd really like to take my client out to a nice restaurant that serves indian food." +} +history_states1 = [ + {}, +] +utter2 = { + 'User-1': + "I'd really like to take my client out to a nice restaurant that serves indian food.", + 'System-1': + 'I show many restaurants that serve Indian food in that price range. What area would you like to travel to?', + 'Dialog_Act-1': { + 'Restaurant-Inform': [['choice', 'many'], ['food', 'Indian'], + ['pricerange', 'that price range']] + }, + 'User-2': + 'I am looking for an expensive indian restaurant in the area of centre.', +} + +history_states2 = [{}, { + 'attraction': { + 'book': { + 'booked': [] + }, + 'semi': { + 'area': '', + 'name': '', + 'type': '' + } + }, + 'hospital': { + 'book': { + 'booked': [] + }, + 'semi': { + 'department': '' + } + }, + 'hotel': { + 'book': { + 'booked': [{ + 'name': 'alexander bed and breakfast', + 'reference': 'JXVKZ7KV' + }], + 'day': + 'sunday', + 'people': + '6', + 'stay': + '4' + }, + 'semi': { + 'area': '', + 'internet': 'yes', + 'name': 'alexander bed and breakfast', + 'parking': 'yes', + 'pricerange': 'cheap', + 'stars': '', + 'type': 'guesthouse' + } + }, + 'police': { + 'book': { + 'booked': [] + }, + 'semi': {} + }, + 'restaurant': { + 'book': { + 'booked': [{ + 'name': 'ask', + 'reference': 'Y2Y8QYBY' + }], + 'day': 'sunday', + 'people': '6', + 'time': '18:45' + }, + 'semi': { + 'area': 'centre', + 'food': 'italian', + 'name': 'ask', + 'pricerange': 'cheap' + } + }, + 'taxi': { + 'book': { + 'booked': [] + }, + 'semi': { + 'arriveBy': '', + 'departure': '', + 'destination': '', + 'leaveAt': '' + } + }, + 'train': { + 'book': { + 'booked': [], + 'people': '' + }, + 'semi': { + 'arriveBy': '', + 'day': '', + 'departure': '', + 'destination': '', + 'leaveAt': '' + } + } +}, {}] + +utter3 = { + 'User-1': + "I'd really like to take my client out to a nice restaurant that serves indian food.", + 'System-1': + 'I show many restaurants that serve Indian food in that price range. What area would you like to travel to?', + 'Dialog_Act-1': { + 'Restaurant-Inform': [['choice', 'many'], ['food', 'Indian'], + ['pricerange', 'that price range']] + }, + 'User-2': + 'I am looking for an expensive indian restaurant in the area of centre.', + 'System-2': + 'Might I recommend Saffron Brasserie? That is an expensive Indian restaurant ' + 'in the center of town. I can book a table for you, if you like.', + 'Dialog_Act-2': { + 'Restaurant-Recommend': [['area', 'center of town'], + ['food', 'Indian'], + ['name', 'Saffron Brasserie'], + ['pricerange', 'expensive']] + }, + 'User-3': + 'Sure thing, please book for 6 people at 19:30 on Saturday.' +} + +history_states3 = [{}, { + 'attraction': { + 'book': { + 'booked': [] + }, + 'semi': { + 'area': '', + 'name': '', + 'type': '' + } + }, + 'hospital': { + 'book': { + 'booked': [] + }, + 'semi': { + 'department': '' + } + }, + 'hotel': { + 'book': { + 'booked': [{ + 'name': 'alexander bed and breakfast', + 'reference': 'JXVKZ7KV' + }], + 'day': + 'sunday', + 'people': + '6', + 'stay': + '4' + }, + 'semi': { + 'area': '', + 'internet': 'yes', + 'name': 'alexander bed and breakfast', + 'parking': 'yes', + 'pricerange': 'cheap', + 'stars': '', + 'type': 'guesthouse' + } + }, + 'police': { + 'book': { + 'booked': [] + }, + 'semi': {} + }, + 'restaurant': { + 'book': { + 'booked': [{ + 'name': 'ask', + 'reference': 'Y2Y8QYBY' + }], + 'day': 'sunday', + 'people': '6', + 'time': '18:45' + }, + 'semi': { + 'area': 'centre', + 'food': 'italian', + 'name': 'ask', + 'pricerange': 'cheap' + } + }, + 'taxi': { + 'book': { + 'booked': [] + }, + 'semi': { + 'arriveBy': '', + 'departure': '', + 'destination': '', + 'leaveAt': '' + } + }, + 'train': { + 'book': { + 'booked': [], + 'people': '' + }, + 'semi': { + 'arriveBy': '', + 'day': '', + 'departure': '', + 'destination': '', + 'leaveAt': '' + } + } +}, {}, { + 'attraction': { + 'book': { + 'booked': [] + }, + 'semi': { + 'area': '', + 'name': '', + 'type': '' + } + }, + 'hospital': { + 'book': { + 'booked': [] + }, + 'semi': { + 'department': '' + } + }, + 'hotel': { + 'book': { + 'booked': [{ + 'name': 'alexander bed and breakfast', + 'reference': 'JXVKZ7KV' + }], + 'day': + 'sunday', + 'people': + '6', + 'stay': + '4' + }, + 'semi': { + 'area': '', + 'internet': 'yes', + 'name': 'alexander bed and breakfast', + 'parking': 'yes', + 'pricerange': 'cheap', + 'stars': '', + 'type': 'guesthouse' + } + }, + 'police': { + 'book': { + 'booked': [] + }, + 'semi': {} + }, + 'restaurant': { + 'book': { + 'booked': [{ + 'name': 'ask', + 'reference': 'Y2Y8QYBY' + }], + 'day': 'sunday', + 'people': '6', + 'time': '18:45' + }, + 'semi': { + 'area': 'centre', + 'food': 'italian', + 'name': 'ask', + 'pricerange': 'cheap' + } + }, + 'taxi': { + 'book': { + 'booked': [] + }, + 'semi': { + 'arriveBy': '', + 'departure': '', + 'destination': '', + 'leaveAt': '' + } + }, + 'train': { + 'book': { + 'booked': [], + 'people': '' + }, + 'semi': { + 'arriveBy': '', + 'day': '', + 'departure': '', + 'destination': '', + 'leaveAt': '' + } + } +}, {}] + + +class DSTProcessor(object): + ACTS_DICT = { + 'taxi-depart': 'taxi-departure', + 'taxi-dest': 'taxi-destination', + 'taxi-leaveat': 'taxi-leaveAt', + 'taxi-arriveby': 'taxi-arriveBy', + 'train-depart': 'train-departure', + 'train-dest': 'train-destination', + 'train-leaveat': 'train-leaveAt', + 'train-arriveby': 'train-arriveBy', + 'train-bookpeople': 'train-book_people', + 'restaurant-price': 'restaurant-pricerange', + 'restaurant-bookpeople': 'restaurant-book_people', + 'restaurant-bookday': 'restaurant-book_day', + 'restaurant-booktime': 'restaurant-book_time', + 'hotel-price': 'hotel-pricerange', + 'hotel-bookpeople': 'hotel-book_people', + 'hotel-bookday': 'hotel-book_day', + 'hotel-bookstay': 'hotel-book_stay', + 'booking-bookpeople': 'booking-book_people', + 'booking-bookday': 'booking-book_day', + 'booking-bookstay': 'booking-book_stay', + 'booking-booktime': 'booking-book_time', + } + + LABEL_MAPS = {} # Loaded from file + + def __init__(self): + # Required for mapping slot names in dialogue_acts.json file + # to proper designations. + pass + + def _convert_inputs_to_utterances(self, inputs: dict, + history_states: list): + """This method is to generate the utterances with user, sys, dialog_acts and metadata, + while metadata is from the history_states or the output from the inference pipline""" + + utterances = [] + user_inputs = [] + sys_gen_inputs = [] + dialog_acts_inputs = [] + for i, item in enumerate(inputs): + name, turn = item.split('-') + if name == USER_NAME: + user_inputs.insert(int(turn) - 1, inputs[item]) + elif name == SYSTEM_NAME: + sys_gen_inputs.insert(int(turn) - 1, inputs[item]) + else: + dialog_acts_inputs.insert(int(turn) - 1, inputs[item]) + + # user is leading the topic should aways larger than sys and dialog acts + assert len(user_inputs) - 1 == len(sys_gen_inputs) + assert len(user_inputs) - 1 == len(dialog_acts_inputs) + # the history states record both user and sys states + assert len(history_states) == len(user_inputs) + len(sys_gen_inputs) + + # the dialog_act at user turn is useless + for i, item in enumerate(history_states): + utterance = {} + # the dialog_act at user turn is useless + utterance['dialog_act'] = dialog_acts_inputs[ + i // 2] if i % 2 == 1 else {} + utterance['text'] = sys_gen_inputs[ + i // 2] if i % 2 == 1 else user_inputs[i // 2] + utterance['metadata'] = item + utterance['span_info'] = [] + utterances.append(utterance) + + return utterances + + def _load_acts(self, inputs: dict, dialog_id='example.json'): + dialog_acts_inputs = [] + for i, item in enumerate(inputs): + name, turn = item.split('-') + if name == DIALOG_ACT: + dialog_acts_inputs.insert(int(turn) - 1, inputs[item]) + s_dict = {} + + for j, item in enumerate(dialog_acts_inputs): + if isinstance(item, dict): + for a in item: + aa = a.lower().split('-') + if aa[1] == 'inform' or aa[1] == 'recommend' or \ + aa[1] == 'select' or aa[1] == 'book': + for i in item[a]: + s = i[0].lower() + v = i[1].lower().strip() + if s == 'none' or v == '?' or v == 'none': + continue + slot = aa[0] + '-' + s + if slot in self.ACTS_DICT: + slot = self.ACTS_DICT[slot] + key = dialog_id, str(int(j) + 1), slot + # In case of multiple mentioned values... + # ... Option 1: Keep first informed value + if key not in s_dict: + s_dict[key] = list([v]) + # ... Option 2: Keep last informed value + # s_dict[key] = list([v]) + + return s_dict + + +class multiwoz22Processor(DSTProcessor): + + def __init__(self): + super().__init__() + + def normalize_time(self, text): + text = re.sub(r'(\d{1})(a\.?m\.?|p\.?m\.?)', r'\1 \2', + text) # am/pm without space + text = re.sub(r'(^| )(\d{1,2}) (a\.?m\.?|p\.?m\.?)', r'\1\2:00 \3', + text) # am/pm short to long form + text = re.sub( + r'(^| )(at|from|by|until|after) ?(\d{1,2}) ?(\d{2})([^0-9]|$)', + r'\1\2 \3:\4\5', text) # Missing separator + text = re.sub(r'(^| )(\d{2})[;.,](\d{2})', r'\1\2:\3', + text) # Wrong separator + text = re.sub(r'(^| )(at|from|by|until|after) ?(\d{1,2})([;., ]|$)', + r'\1\2 \3:00\4', text) # normalize simple full hour time + text = re.sub(r'(^| )(\d{1}:\d{2})', r'\g<1>0\2', + text) # Add missing leading 0 + # Map 12 hour times to 24 hour times + text = \ + re.sub( + r'(\d{2})(:\d{2}) ?p\.?m\.?', + lambda x: str(int(x.groups()[0]) + 12 + if int(x.groups()[0]) < 12 else int(x.groups()[0])) + x.groups()[1], text) + text = re.sub(r'(^| )24:(\d{2})', r'\g<1>00:\2', + text) # Correct times that use 24 as hour + return text + + def normalize_text(self, text): + text = self.normalize_time(text) + text = re.sub("n't", ' not', text) + text = re.sub('(^| )zero(-| )star([s.,? ]|$)', r'\g<1>0 star\3', text) + text = re.sub('(^| )one(-| )star([s.,? ]|$)', r'\g<1>1 star\3', text) + text = re.sub('(^| )two(-| )star([s.,? ]|$)', r'\g<1>2 star\3', text) + text = re.sub('(^| )three(-| )star([s.,? ]|$)', r'\g<1>3 star\3', text) + text = re.sub('(^| )four(-| )star([s.,? ]|$)', r'\g<1>4 star\3', text) + text = re.sub('(^| )five(-| )star([s.,? ]|$)', r'\g<1>5 star\3', text) + text = re.sub('archaelogy', 'archaeology', text) # Systematic typo + text = re.sub('guesthouse', 'guest house', text) # Normalization + text = re.sub('(^| )b ?& ?b([.,? ]|$)', r'\1bed and breakfast\2', + text) # Normalization + text = re.sub('bed & breakfast', 'bed and breakfast', + text) # Normalization + return text + + # Loads the dialogue_acts.json and returns a list + # of slot-value pairs. + def load_acts(self, input_file): + with open(input_file) as f: + acts = json.load(f) + s_dict = {} + for d in acts: + for t in acts[d]: + if int(t) % 2 == 0: + continue + # Only process, if turn has annotation + if isinstance(acts[d][t]['dialog_act'], dict): + for a in acts[d][t]['dialog_act']: + aa = a.lower().split('-') + if aa[1] == 'inform' or aa[1] == 'recommend' \ + or aa[1] == 'select' or aa[1] == 'book': + for i in acts[d][t]['dialog_act'][a]: + s = i[0].lower() + v = i[1].lower().strip() + if s == 'none' or v == '?' or v == 'none': + continue + slot = aa[0] + '-' + s + if slot in self.ACTS_DICT: + slot = self.ACTS_DICT[slot] + key = d, str(int(t) // 2 + 1), slot + # In case of multiple mentioned values... + # ... Option 1: Keep first informed value + if key not in s_dict: + s_dict[key] = list([v]) + # ... Option 2: Keep last informed value + # s_dict[key] = list([v]) + return s_dict + + # This should only contain label normalizations. All other mappings should + # be defined in LABEL_MAPS. + def normalize_label(self, slot, value_label): + # Normalization of empty slots + if value_label == '' or value_label == 'not mentioned': + return 'none' + + # Normalization of time slots + if 'leaveAt' in slot or 'arriveBy' in slot or slot == 'restaurant-book_time': + return self.normalize_time(value_label) + + # Normalization + if 'type' in slot or 'name' in slot or 'destination' in slot or 'departure' in slot: + value_label = re.sub('guesthouse', 'guest house', value_label) + + # Map to boolean slots + if slot == 'hotel-parking' or slot == 'hotel-internet': + if value_label == 'yes' or value_label == 'free': + return 'true' + if value_label == 'no': + return 'false' + if slot == 'hotel-type': + if value_label == 'hotel': + return 'true' + if value_label == 'guest house': + return 'false' + + return value_label + + def tokenize(self, utt): + utt_lower = convert_to_unicode(utt).lower() + utt_lower = self.normalize_text(utt_lower) + utt_tok = [ + tok for tok in map(str.strip, re.split(r'(\W+)', utt_lower)) + if len(tok) > 0 + ] + return utt_tok + + def delex_utt(self, utt, values, unk_token='[UNK]'): + utt_norm = self.tokenize(utt) + for s, vals in values.items(): + for v in vals: + if v != 'none': + v_norm = self.tokenize(v) + v_len = len(v_norm) + for i in range(len(utt_norm) + 1 - v_len): + if utt_norm[i:i + v_len] == v_norm: + utt_norm[i:i + v_len] = [unk_token] * v_len + return utt_norm + + def get_token_pos(self, tok_list, value_label): + find_pos = [] + found = False + label_list = [ + item for item in map(str.strip, re.split(r'(\W+)', value_label)) + if len(item) > 0 + ] + len_label = len(label_list) + for i in range(len(tok_list) + 1 - len_label): + if tok_list[i:i + len_label] == label_list: + find_pos.append((i, i + len_label)) # start, exclusive_end + found = True + return found, find_pos + + def check_label_existence(self, value_label, usr_utt_tok): + in_usr, usr_pos = self.get_token_pos(usr_utt_tok, value_label) + # If no hit even though there should be one, check for value label variants + if not in_usr and value_label in self.LABEL_MAPS: + for value_label_variant in self.LABEL_MAPS[value_label]: + in_usr, usr_pos = self.get_token_pos(usr_utt_tok, + value_label_variant) + if in_usr: + break + return in_usr, usr_pos + + def check_slot_referral(self, value_label, slot, seen_slots): + referred_slot = 'none' + if slot == 'hotel-stars' or slot == 'hotel-internet' or slot == 'hotel-parking': + return referred_slot + for s in seen_slots: + # Avoid matches for slots that share values with different meaning. + # hotel-internet and -parking are handled separately as Boolean slots. + if s == 'hotel-stars' or s == 'hotel-internet' or s == 'hotel-parking': + continue + if re.match('(hotel|restaurant)-book_people', + s) and slot == 'hotel-book_stay': + continue + if re.match('(hotel|restaurant)-book_people', + slot) and s == 'hotel-book_stay': + continue + if slot != s and (slot not in seen_slots + or seen_slots[slot] != value_label): + if seen_slots[s] == value_label: + referred_slot = s + break + elif value_label in self.LABEL_MAPS: + for value_label_variant in self.LABEL_MAPS[value_label]: + if seen_slots[s] == value_label_variant: + referred_slot = s + break + return referred_slot + + def is_in_list(self, tok, value): + found = False + tok_list = [ + item for item in map(str.strip, re.split(r'(\W+)', tok)) + if len(item) > 0 + ] + value_list = [ + item for item in map(str.strip, re.split(r'(\W+)', value)) + if len(item) > 0 + ] + tok_len = len(tok_list) + value_len = len(value_list) + for i in range(tok_len + 1 - value_len): + if tok_list[i:i + value_len] == value_list: + found = True + break + return found + + # Fuzzy matching to label informed slot values + def check_slot_inform(self, value_label, inform_label): + result = False + informed_value = 'none' + vl = ' '.join(self.tokenize(value_label)) + for il in inform_label: + if vl == il: + result = True + elif self.is_in_list(il, vl): + result = True + elif self.is_in_list(vl, il): + result = True + elif il in self.LABEL_MAPS: + for il_variant in self.LABEL_MAPS[il]: + if vl == il_variant: + result = True + break + elif self.is_in_list(il_variant, vl): + result = True + break + elif self.is_in_list(vl, il_variant): + result = True + break + elif vl in self.LABEL_MAPS: + for value_label_variant in self.LABEL_MAPS[vl]: + if value_label_variant == il: + result = True + break + elif self.is_in_list(il, value_label_variant): + result = True + break + elif self.is_in_list(value_label_variant, il): + result = True + break + if result: + informed_value = il + break + return result, informed_value + + def get_turn_label(self, value_label, inform_label, sys_utt_tok, + usr_utt_tok, slot, seen_slots, slot_last_occurrence): + usr_utt_tok_label = [0 for _ in usr_utt_tok] + informed_value = 'none' + referred_slot = 'none' + if value_label == 'none' or value_label == 'dontcare' or value_label == 'true' or value_label == 'false': + class_type = value_label + else: + in_usr, usr_pos = self.check_label_existence( + value_label, usr_utt_tok) + is_informed, informed_value = self.check_slot_inform( + value_label, inform_label) + if in_usr: + class_type = 'copy_value' + if slot_last_occurrence: + (s, e) = usr_pos[-1] + for i in range(s, e): + usr_utt_tok_label[i] = 1 + else: + for (s, e) in usr_pos: + for i in range(s, e): + usr_utt_tok_label[i] = 1 + elif is_informed: + class_type = 'inform' + else: + referred_slot = self.check_slot_referral( + value_label, slot, seen_slots) + if referred_slot != 'none': + class_type = 'refer' + else: + class_type = 'unpointable' + return informed_value, referred_slot, usr_utt_tok_label, class_type + + def _create_example(self, + utterances, + sys_inform_dict, + set_type, + slot_list, + label_maps={}, + append_history=False, + use_history_labels=False, + swap_utterances=False, + label_value_repetitions=False, + delexicalize_sys_utts=False, + unk_token='[UNK]', + analyze=False, + dialog_id='example.json'): + + # Collects all slot changes throughout the dialog + cumulative_labels = {slot: 'none' for slot in slot_list} + + # First system utterance is empty, since multiwoz starts with user input + utt_tok_list = [[]] + mod_slots_list = [] + + # Collect all utterances and their metadata + usr_sys_switch = True + turn_itr = 0 + + for utt in utterances: + # Assert that system and user utterances alternate + is_sys_utt = utt['metadata'] != {} + if usr_sys_switch == is_sys_utt: + print( + 'WARN: Wrong order of system and user utterances. Skipping rest of the dialog %s' + % (dialog_id)) + break + usr_sys_switch = is_sys_utt + + if is_sys_utt: + turn_itr += 1 + + # Delexicalize sys utterance + if delexicalize_sys_utts and is_sys_utt: + inform_dict = {slot: 'none' for slot in slot_list} + for slot in slot_list: + if (str(dialog_id), str(turn_itr), + slot) in sys_inform_dict: + inform_dict[slot] = sys_inform_dict[(str(dialog_id), + str(turn_itr), + slot)] + utt_tok_list.append( + self.delex_utt(utt['text'], inform_dict, + unk_token)) # normalize utterances + else: + utt_tok_list.append(self.tokenize( + utt['text'])) # normalize utterances + + modified_slots = {} + + # If sys utt, extract metadata (identify and collect modified slots) + if is_sys_utt: + for d in utt['metadata']: + booked = utt['metadata'][d]['book']['booked'] + booked_slots = {} + # Check the booked section + if booked != []: + for s in booked[0]: + booked_slots[s] = self.normalize_label( + '%s-%s' % (d, s), + booked[0][s]) # normalize labels + # Check the semi and the inform slots + for category in ['book', 'semi']: + for s in utt['metadata'][d][category]: + cs = '%s-book_%s' % ( + d, s) if category == 'book' else '%s-%s' % (d, + s) + value_label = self.normalize_label( + cs, utt['metadata'][d][category] + [s]) # normalize labels + # Prefer the slot value as stored in the booked section + if s in booked_slots: + value_label = booked_slots[s] + # Remember modified slots and entire dialog state + if cs in slot_list and cumulative_labels[ + cs] != value_label: + modified_slots[cs] = value_label + cumulative_labels[cs] = value_label + + mod_slots_list.append(modified_slots.copy()) + + # Form proper (usr, sys) turns + turn_itr = 0 + diag_seen_slots_dict = {} + diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list} + diag_state = {slot: 'none' for slot in slot_list} + sys_utt_tok = [] + usr_utt_tok = [] + hst_utt_tok = [] + hst_utt_tok_label_dict = {slot: [] for slot in slot_list} + new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy() + new_diag_state = diag_state.copy() + + for i in range(0, len(utt_tok_list) - 1, 2): + sys_utt_tok_label_dict = {} + usr_utt_tok_label_dict = {} + value_dict = {} + inform_dict = {} + inform_slot_dict = {} + referral_dict = {} + class_type_dict = {} + + # Collect turn data + if append_history: + if swap_utterances: + hst_utt_tok = usr_utt_tok + sys_utt_tok + hst_utt_tok + else: + hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok + sys_utt_tok = utt_tok_list[i] + usr_utt_tok = utt_tok_list[i + 1] + turn_slots = mod_slots_list[ + i + 1] if len(mod_slots_list) > 1 else {} + + guid = '%s-%s-%s' % (set_type, str(dialog_id), str(turn_itr)) + + if analyze: + print('%15s %2s %s ||| %s' % + (dialog_id, turn_itr, ' '.join(sys_utt_tok), + ' '.join(usr_utt_tok))) + print('%15s %2s [' % (dialog_id, turn_itr), end='') + + new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy() + new_diag_state = diag_state.copy() + for slot in slot_list: + value_label = 'none' + if slot in turn_slots: + value_label = turn_slots[slot] + # We keep the original labels so as to not + # overlook unpointable values, as well as to not + # modify any of the original labels for test sets, + # since this would make comparison difficult. + value_dict[slot] = value_label + elif label_value_repetitions and slot in diag_seen_slots_dict: + value_label = diag_seen_slots_value_dict[slot] + + # Get dialog act annotations + inform_label = list(['none']) + inform_slot_dict[slot] = 0 + if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict: + inform_label = list([ + self.normalize_label(slot, i) + for i in sys_inform_dict[(str(dialog_id), + str(turn_itr), slot)] + ]) + inform_slot_dict[slot] = 1 + elif (str(dialog_id), str(turn_itr), + 'booking-' + slot.split('-')[1]) in sys_inform_dict: + inform_label = list([ + self.normalize_label(slot, i) + for i in sys_inform_dict[(str(dialog_id), + str(turn_itr), 'booking-' + + slot.split('-')[1])] + ]) + inform_slot_dict[slot] = 1 + + (informed_value, referred_slot, usr_utt_tok_label, + class_type) = self.get_turn_label( + value_label, + inform_label, + sys_utt_tok, + usr_utt_tok, + slot, + diag_seen_slots_value_dict, + slot_last_occurrence=True) + + inform_dict[slot] = informed_value + + # Generally don't use span prediction on sys utterance (but inform prediction instead). + sys_utt_tok_label = [0 for _ in sys_utt_tok] + + # Determine what to do with value repetitions. + # If value is unique in seen slots, then tag it, otherwise not, + # since correct slot assignment can not be guaranteed anymore. + if label_value_repetitions and slot in diag_seen_slots_dict: + if class_type == 'copy_value' and list( + diag_seen_slots_value_dict.values()).count( + value_label) > 1: + class_type = 'none' + usr_utt_tok_label = [0 for _ in usr_utt_tok_label] + + sys_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label_dict[slot] = usr_utt_tok_label + + if append_history: + if use_history_labels: + if swap_utterances: + new_hst_utt_tok_label_dict[ + slot] = usr_utt_tok_label + sys_utt_tok_label + new_hst_utt_tok_label_dict[ + slot] + else: + new_hst_utt_tok_label_dict[ + slot] = sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[ + slot] + else: + new_hst_utt_tok_label_dict[slot] = [ + 0 for _ in sys_utt_tok_label + usr_utt_tok_label + + new_hst_utt_tok_label_dict[slot] + ] + + # For now, we map all occurences of unpointable slot values + # to none. However, since the labels will still suggest + # a presence of unpointable slot values, the task of the + # DST is still to find those values. It is just not + # possible to do that via span prediction on the current input. + if class_type == 'unpointable': + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + if analyze: + if slot not in diag_seen_slots_dict or value_label != diag_seen_slots_value_dict[ + slot]: + print('(%s): %s, ' % (slot, value_label), end='') + elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] \ + and class_type != 'copy_value' and class_type != 'inform': + # If slot has seen before and its class type did not change, label this slot a not present, + # assuming that the slot has not actually been mentioned in this turn. + # Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform, + # this must mean there is evidence in the original labels, therefore consider + # them as mentioned again. + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + else: + class_type_dict[slot] = class_type + referral_dict[slot] = referred_slot + # Remember that this slot was mentioned during this dialog already. + if class_type != 'none': + diag_seen_slots_dict[slot] = class_type + diag_seen_slots_value_dict[slot] = value_label + new_diag_state[slot] = class_type + # Unpointable is not a valid class, therefore replace with + # some valid class for now... + if class_type == 'unpointable': + new_diag_state[slot] = 'copy_value' + + if analyze: + print(']') + + if swap_utterances: + txt_a = usr_utt_tok + txt_b = sys_utt_tok + txt_a_lbl = usr_utt_tok_label_dict + txt_b_lbl = sys_utt_tok_label_dict + else: + txt_a = sys_utt_tok + txt_b = usr_utt_tok + txt_a_lbl = sys_utt_tok_label_dict + txt_b_lbl = usr_utt_tok_label_dict + + example = DSTExample( + guid=guid, + text_a=txt_a, + text_b=txt_b, + history=hst_utt_tok, + text_a_label=txt_a_lbl, + text_b_label=txt_b_lbl, + history_label=hst_utt_tok_label_dict, + values=diag_seen_slots_value_dict.copy(), + inform_label=inform_dict, + inform_slot_label=inform_slot_dict, + refer_label=referral_dict, + diag_state=diag_state, + class_label=class_type_dict) + # Update some variables. + hst_utt_tok_label_dict = new_hst_utt_tok_label_dict.copy() + diag_state = new_diag_state.copy() + + turn_itr += 1 + return example + + def create_example(self, + inputs, + history_states, + set_type, + slot_list, + label_maps={}, + append_history=False, + use_history_labels=False, + swap_utterances=False, + label_value_repetitions=False, + delexicalize_sys_utts=False, + unk_token='[UNK]', + analyze=False, + dialog_id='0'): + utterances = self._convert_inputs_to_utterances(inputs, history_states) + sys_inform_dict = self._load_acts(inputs) + self.LABEL_MAPS = label_maps + example = self._create_example(utterances, sys_inform_dict, set_type, + slot_list, label_maps, append_history, + use_history_labels, swap_utterances, + label_value_repetitions, + delexicalize_sys_utts, unk_token, + analyze) + + return example + + def create_examples(self, + input_file, + acts_file, + set_type, + slot_list, + label_maps={}, + append_history=False, + use_history_labels=False, + swap_utterances=False, + label_value_repetitions=False, + delexicalize_sys_utts=False, + unk_token='[UNK]', + analyze=False): + """Read a DST json file into a list of DSTExample.""" + + sys_inform_dict = self.load_acts(acts_file) + + with open(input_file, 'r', encoding='utf-8') as reader: + input_data = json.load(reader) + + self.LABEL_MAPS = label_maps + + examples = [] + for dialog_id in tqdm(input_data): + entry = input_data[dialog_id] + utterances = entry['log'] + + example = self._create_example( + utterances, sys_inform_dict, set_type, slot_list, label_maps, + append_history, use_history_labels, swap_utterances, + label_value_repetitions, delexicalize_sys_utts, unk_token, + analyze) + examples.append(example) + + return examples + + +class DSTExample(object): + """ + A single training/test example for the DST dataset. + """ + + def __init__(self, + guid, + text_a, + text_b, + history, + text_a_label=None, + text_b_label=None, + history_label=None, + values=None, + inform_label=None, + inform_slot_label=None, + refer_label=None, + diag_state=None, + class_label=None): + self.guid = guid + self.text_a = text_a + self.text_b = text_b + self.history = history + self.text_a_label = text_a_label + self.text_b_label = text_b_label + self.history_label = history_label + self.values = values + self.inform_label = inform_label + self.inform_slot_label = inform_slot_label + self.refer_label = refer_label + self.diag_state = diag_state + self.class_label = class_label + + def __str__(self): + return self.__repr__() + + def __repr__(self): + s = '' + s += 'guid: %s' % (self.guid) + s += ', text_a: %s' % (self.text_a) + s += ', text_b: %s' % (self.text_b) + s += ', history: %s' % (self.history) + if self.text_a_label: + s += ', text_a_label: %d' % (self.text_a_label) + if self.text_b_label: + s += ', text_b_label: %d' % (self.text_b_label) + if self.history_label: + s += ', history_label: %d' % (self.history_label) + if self.values: + s += ', values: %d' % (self.values) + if self.inform_label: + s += ', inform_label: %d' % (self.inform_label) + if self.inform_slot_label: + s += ', inform_slot_label: %d' % (self.inform_slot_label) + if self.refer_label: + s += ', refer_label: %d' % (self.refer_label) + if self.diag_state: + s += ', diag_state: %d' % (self.diag_state) + if self.class_label: + s += ', class_label: %d' % (self.class_label) + return s + + +class InputFeatures(object): + """A single set of features of data.""" + + def __init__(self, + input_ids, + input_ids_unmasked, + input_mask, + segment_ids, + start_pos=None, + end_pos=None, + values=None, + inform=None, + inform_slot=None, + refer_id=None, + diag_state=None, + class_label_id=None, + guid='NONE'): + self.guid = guid + self.input_ids = input_ids + self.input_ids_unmasked = input_ids_unmasked + self.input_mask = input_mask + self.segment_ids = segment_ids + self.start_pos = start_pos + self.end_pos = end_pos + self.values = values + self.inform = inform + self.inform_slot = inform_slot + self.refer_id = refer_id + self.diag_state = diag_state + self.class_label_id = class_label_id + + +def convert_examples_to_features(examples, + slot_list, + class_types, + model_type, + tokenizer, + max_seq_length, + slot_value_dropout=0.0): + """Loads a data file into a list of `InputBatch`s.""" + + if model_type == 'bert': + model_specs = { + 'MODEL_TYPE': 'bert', + 'CLS_TOKEN': '[CLS]', + 'UNK_TOKEN': '[UNK]', + 'SEP_TOKEN': '[SEP]', + 'TOKEN_CORRECTION': 4 + } + else: + logger.error('Unknown model type (%s). Aborting.' % (model_type)) + exit(1) + + def _tokenize_text_and_label(text, text_label_dict, slot, tokenizer, + model_specs, slot_value_dropout): + joint_text_label = [0 for _ in text_label_dict[slot] + ] # joint all slots' label + for slot_text_label in text_label_dict.values(): + for idx, label in enumerate(slot_text_label): + if label == 1: + joint_text_label[idx] = 1 + + text_label = text_label_dict[slot] + tokens = [] + tokens_unmasked = [] + token_labels = [] + for token, token_label, joint_label in zip(text, text_label, + joint_text_label): + token = convert_to_unicode(token) + sub_tokens = tokenizer.tokenize(token) # Most time intensive step + tokens_unmasked.extend(sub_tokens) + if slot_value_dropout == 0.0 or joint_label == 0: + tokens.extend(sub_tokens) + else: + rn_list = np.random.random_sample((len(sub_tokens), )) + for rn, sub_token in zip(rn_list, sub_tokens): + if rn > slot_value_dropout: + tokens.append(sub_token) + else: + tokens.append(model_specs['UNK_TOKEN']) + token_labels.extend([token_label for _ in sub_tokens]) + assert len(tokens) == len(token_labels) + assert len(tokens_unmasked) == len(token_labels) + return tokens, tokens_unmasked, token_labels + + def _truncate_seq_pair(tokens_a, tokens_b, history, max_length): + """Truncates a sequence pair in place to the maximum length. + Copied from bert/run_classifier.py + """ + # This is a simple heuristic which will always truncate the longer sequence + # one token at a time. This makes more sense than truncating an equal percent + # of tokens from each, since if one sequence is very short then each token + # that's truncated likely contains more information than a longer sequence. + while True: + total_length = len(tokens_a) + len(tokens_b) + len(history) + if total_length <= max_length: + break + if len(history) > 0: + history.pop() + elif len(tokens_a) > len(tokens_b): + tokens_a.pop() + else: + tokens_b.pop() + + def _truncate_length_and_warn(tokens_a, tokens_b, history, max_seq_length, + model_specs, guid): + # Modifies `tokens_a` and `tokens_b` in place so that the total + # length is less than the specified length. + # Account for [CLS], [SEP], [SEP], [SEP] with "- 4" (BERT) + if len(tokens_a) + len(tokens_b) + len( + history) > max_seq_length - model_specs['TOKEN_CORRECTION']: + logger.info('Truncate Example %s. Total len=%d.' % + (guid, len(tokens_a) + len(tokens_b) + len(history))) + input_text_too_long = True + else: + input_text_too_long = False + _truncate_seq_pair(tokens_a, tokens_b, history, + max_seq_length - model_specs['TOKEN_CORRECTION']) + return input_text_too_long + + def _get_token_label_ids(token_labels_a, token_labels_b, + token_labels_history, max_seq_length, + model_specs): + token_label_ids = [] + token_label_ids.append(0) # [CLS] + for token_label in token_labels_a: + token_label_ids.append(token_label) + token_label_ids.append(0) # [SEP] + for token_label in token_labels_b: + token_label_ids.append(token_label) + token_label_ids.append(0) # [SEP] + for token_label in token_labels_history: + token_label_ids.append(token_label) + token_label_ids.append(0) # [SEP] + while len(token_label_ids) < max_seq_length: + token_label_ids.append(0) # padding + assert len(token_label_ids) == max_seq_length + return token_label_ids + + def _get_start_end_pos(class_type, token_label_ids, max_seq_length): + if class_type == 'copy_value' and 1 not in token_label_ids: + # logger.warn("copy_value label, but token_label not detected. Setting label to 'none'.") + class_type = 'none' + start_pos = 0 + end_pos = 0 + if 1 in token_label_ids: + start_pos = token_label_ids.index(1) + # Parsing is supposed to find only first location of wanted value + if 0 not in token_label_ids[start_pos:]: + end_pos = len(token_label_ids[start_pos:]) + start_pos - 1 + else: + end_pos = token_label_ids[start_pos:].index(0) + start_pos - 1 + for i in range(max_seq_length): + if i >= start_pos and i <= end_pos: + assert token_label_ids[i] == 1 + return class_type, start_pos, end_pos + + def _get_transformer_input(tokens_a, tokens_b, history, max_seq_length, + tokenizer, model_specs): + # The convention in BERT is: + # (a) For sequence pairs: + # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] + # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 + # (b) For single sequences: + # tokens: [CLS] the dog is hairy . [SEP] + # type_ids: 0 0 0 0 0 0 0 + # + # Where "type_ids" are used to indicate whether this is the first + # sequence or the second sequence. The embedding vectors for `type=0` and + # `type=1` were learned during pre-training and are added to the wordpiece + # embedding vector (and position vector). This is not *strictly* necessary + # since the [SEP] token unambiguously separates the sequences, but it makes + # it easier for the model to learn the concept of sequences. + # + # For classification tasks, the first vector (corresponding to [CLS]) is + # used as the "sentence vector". Note that this only makes sense because + # the entire model is fine-tuned. + tokens = [] + segment_ids = [] + tokens.append(model_specs['CLS_TOKEN']) + segment_ids.append(0) + for token in tokens_a: + tokens.append(token) + segment_ids.append(0) + tokens.append(model_specs['SEP_TOKEN']) + segment_ids.append(0) + for token in tokens_b: + tokens.append(token) + segment_ids.append(1) + tokens.append(model_specs['SEP_TOKEN']) + segment_ids.append(1) + for token in history: + tokens.append(token) + segment_ids.append(1) + tokens.append(model_specs['SEP_TOKEN']) + segment_ids.append(1) + input_ids = tokenizer.convert_tokens_to_ids(tokens) + # The mask has 1 for real tokens and 0 for padding tokens. Only real + # tokens are attended to. + input_mask = [1] * len(input_ids) + # Zero-pad up to the sequence length. + while len(input_ids) < max_seq_length: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + assert len(input_ids) == max_seq_length + assert len(input_mask) == max_seq_length + assert len(segment_ids) == max_seq_length + return tokens, input_ids, input_mask, segment_ids + + total_cnt = 0 + too_long_cnt = 0 + + refer_list = ['none'] + slot_list + + features = [] + # Convert single example + for (example_index, example) in enumerate(examples): + if example_index % 1000 == 0: + logger.info('Writing example %d of %d' % + (example_index, len(examples))) + + total_cnt += 1 + + value_dict = {} + inform_dict = {} + inform_slot_dict = {} + refer_id_dict = {} + diag_state_dict = {} + class_label_id_dict = {} + start_pos_dict = {} + end_pos_dict = {} + for slot in slot_list: + tokens_a, tokens_a_unmasked, token_labels_a = _tokenize_text_and_label( + example.text_a, example.text_a_label, slot, tokenizer, + model_specs, slot_value_dropout) + tokens_b, tokens_b_unmasked, token_labels_b = _tokenize_text_and_label( + example.text_b, example.text_b_label, slot, tokenizer, + model_specs, slot_value_dropout) + tokens_history, tokens_history_unmasked, token_labels_history = _tokenize_text_and_label( + example.history, example.history_label, slot, tokenizer, + model_specs, slot_value_dropout) + + input_text_too_long = _truncate_length_and_warn( + tokens_a, tokens_b, tokens_history, max_seq_length, + model_specs, example.guid) + + if input_text_too_long: + if example_index < 10: + if len(token_labels_a) > len(tokens_a): + logger.info(' tokens_a truncated labels: %s' + % str(token_labels_a[len(tokens_a):])) + if len(token_labels_b) > len(tokens_b): + logger.info(' tokens_b truncated labels: %s' + % str(token_labels_b[len(tokens_b):])) + if len(token_labels_history) > len(tokens_history): + logger.info( + ' tokens_history truncated labels: %s' + % str(token_labels_history[len(tokens_history):])) + + token_labels_a = token_labels_a[:len(tokens_a)] + token_labels_b = token_labels_b[:len(tokens_b)] + token_labels_history = token_labels_history[:len(tokens_history + )] + tokens_a_unmasked = tokens_a_unmasked[:len(tokens_a)] + tokens_b_unmasked = tokens_b_unmasked[:len(tokens_b)] + tokens_history_unmasked = tokens_history_unmasked[:len( + tokens_history)] + + assert len(token_labels_a) == len(tokens_a) + assert len(token_labels_b) == len(tokens_b) + assert len(token_labels_history) == len(tokens_history) + assert len(token_labels_a) == len(tokens_a_unmasked) + assert len(token_labels_b) == len(tokens_b_unmasked) + assert len(token_labels_history) == len(tokens_history_unmasked) + token_label_ids = _get_token_label_ids(token_labels_a, + token_labels_b, + token_labels_history, + max_seq_length, model_specs) + + value_dict[slot] = example.values[slot] + inform_dict[slot] = example.inform_label[slot] + + class_label_mod, start_pos_dict[slot], end_pos_dict[ + slot] = _get_start_end_pos(example.class_label[slot], + token_label_ids, max_seq_length) + if class_label_mod != example.class_label[slot]: + example.class_label[slot] = class_label_mod + inform_slot_dict[slot] = example.inform_slot_label[slot] + refer_id_dict[slot] = refer_list.index(example.refer_label[slot]) + diag_state_dict[slot] = class_types.index(example.diag_state[slot]) + class_label_id_dict[slot] = class_types.index( + example.class_label[slot]) + + if input_text_too_long: + too_long_cnt += 1 + + tokens, input_ids, input_mask, segment_ids = _get_transformer_input( + tokens_a, tokens_b, tokens_history, max_seq_length, tokenizer, + model_specs) + if slot_value_dropout > 0.0: + _, input_ids_unmasked, _, _ = _get_transformer_input( + tokens_a_unmasked, tokens_b_unmasked, tokens_history_unmasked, + max_seq_length, tokenizer, model_specs) + else: + input_ids_unmasked = input_ids + + assert (len(input_ids) == len(input_ids_unmasked)) + + if example_index < 10: + logger.info('*** Example ***') + logger.info('guid: %s' % (example.guid)) + logger.info('tokens: %s' % ' '.join(tokens)) + logger.info('input_ids: %s' % ' '.join([str(x) + for x in input_ids])) + logger.info('input_mask: %s' + % ' '.join([str(x) for x in input_mask])) + logger.info('segment_ids: %s' + % ' '.join([str(x) for x in segment_ids])) + logger.info('start_pos: %s' % str(start_pos_dict)) + logger.info('end_pos: %s' % str(end_pos_dict)) + logger.info('values: %s' % str(value_dict)) + logger.info('inform: %s' % str(inform_dict)) + logger.info('inform_slot: %s' % str(inform_slot_dict)) + logger.info('refer_id: %s' % str(refer_id_dict)) + logger.info('diag_state: %s' % str(diag_state_dict)) + logger.info('class_label_id: %s' % str(class_label_id_dict)) + + features.append( + InputFeatures( + guid=example.guid, + input_ids=input_ids, + input_ids_unmasked=input_ids_unmasked, + input_mask=input_mask, + segment_ids=segment_ids, + start_pos=start_pos_dict, + end_pos=end_pos_dict, + values=value_dict, + inform=inform_dict, + inform_slot=inform_slot_dict, + refer_id=refer_id_dict, + diag_state=diag_state_dict, + class_label_id=class_label_id_dict)) + + logger.info('========== %d out of %d examples have text too long' % + (too_long_cnt, total_cnt)) + + return features + + +# From bert.tokenization (TF code) +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode('utf-8', 'ignore') + else: + raise ValueError('Unsupported string type: %s' % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode('utf-8', 'ignore') + elif isinstance(text, unicode): + return text + else: + raise ValueError('Unsupported string type: %s' % (type(text))) + else: + raise ValueError('Not running on Python2 or Python 3?') + + +if __name__ == '__main__': + processor = multiwoz22Processor() + set_type = 'test' + slot_list = [ + 'taxi-leaveAt', 'taxi-destination', 'taxi-departure', 'taxi-arriveBy', + 'restaurant-book_people', 'restaurant-book_day', + 'restaurant-book_time', 'restaurant-food', 'restaurant-pricerange', + 'restaurant-name', 'restaurant-area', 'hotel-book_people', + 'hotel-book_day', 'hotel-book_stay', 'hotel-name', 'hotel-area', + 'hotel-parking', 'hotel-pricerange', 'hotel-stars', 'hotel-internet', + 'hotel-type', 'attraction-type', 'attraction-name', 'attraction-area', + 'train-book_people', 'train-leaveAt', 'train-destination', 'train-day', + 'train-arriveBy', 'train-departure' + ] + append_history = True + use_history_labels = True + swap_utterances = True + label_value_repetitions = True + delexicalize_sys_utts = True, + unk_token = '[UNK]' + analyze = False + example = processor.create_example(utter1, history_states1, set_type, + slot_list, {}, append_history, + use_history_labels, swap_utterances, + label_value_repetitions, + delexicalize_sys_utts, unk_token, + analyze) + print(f'utterances is {example}') diff --git a/modelscope/trainers/nlp/sequence_classification_trainer.py b/modelscope/trainers/nlp/sequence_classification_trainer.py index b2b759fa..7ae5576f 100644 --- a/modelscope/trainers/nlp/sequence_classification_trainer.py +++ b/modelscope/trainers/nlp/sequence_classification_trainer.py @@ -14,8 +14,7 @@ PATH = None logger = get_logger(PATH) -@TRAINERS.register_module( - Tasks.text_classification, module_name=r'bert-sentiment-analysis') +@TRAINERS.register_module(module_name=r'bert-sentiment-analysis') class SequenceClassificationTrainer(BaseTrainer): def __init__(self, cfg_file: str, *args, **kwargs): diff --git a/modelscope/utils/check_requirements.py b/modelscope/utils/check_requirements.py new file mode 100644 index 00000000..7aad8e4e --- /dev/null +++ b/modelscope/utils/check_requirements.py @@ -0,0 +1,79 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from modelscope.utils.constant import Fields, Requirements +from modelscope.utils.import_utils import requires + + +def get_msg(field): + msg = f'\n{field} requirements not installed, please execute ' \ + f'`pip install requirements/{field}.txt` or ' \ + f'`pip install modelscope[{field}]`' + return msg + + +class NLPModuleNotFoundError(ModuleNotFoundError): + + def __init__(self, e: ModuleNotFoundError) -> None: + e.msg += get_msg(Fields.nlp) + super().__init__(e) + + +class CVModuleNotFoundError(ModuleNotFoundError): + + def __init__(self, e: ModuleNotFoundError) -> None: + e.msg += get_msg(Fields.cv) + super().__init__(e) + + +class AudioModuleNotFoundError(ModuleNotFoundError): + + def __init__(self, e: ModuleNotFoundError) -> None: + e.msg += get_msg(Fields.audio) + super().__init__(e) + + +class MultiModalModuleNotFoundError(ModuleNotFoundError): + + def __init__(self, e: ModuleNotFoundError) -> None: + e.msg += get_msg(Fields.multi_modal) + super().__init__(e) + + +def check_nlp(): + try: + requires('nlp models', ( + Requirements.torch, + Requirements.tokenizers, + )) + except ImportError as e: + raise NLPModuleNotFoundError(e) + + +def check_cv(): + try: + requires('cv models', ( + Requirements.torch, + Requirements.tokenizers, + )) + except ImportError as e: + raise CVModuleNotFoundError(e) + + +def check_audio(): + try: + requires('audio models', ( + Requirements.torch, + Requirements.tf, + )) + except ImportError as e: + raise AudioModuleNotFoundError(e) + + +def check_multi_modal(): + try: + requires('multi-modal models', ( + Requirements.torch, + Requirements.tokenizers, + )) + except ImportError as e: + raise MultiModalModuleNotFoundError(e) diff --git a/modelscope/utils/config.py b/modelscope/utils/config.py index df9e38fd..79307f17 100644 --- a/modelscope/utils/config.py +++ b/modelscope/utils/config.py @@ -17,9 +17,10 @@ from typing import Dict import addict from yapf.yapflib.yapf_api import FormatCode +from modelscope.utils.import_utils import (import_modules, + import_modules_from_file, + validate_py_syntax) from modelscope.utils.logger import get_logger -from modelscope.utils.pymod import (import_modules, import_modules_from_file, - validate_py_syntax) if platform.system() == 'Windows': import regex as re # type: ignore diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index ce3b5718..150e9904 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -102,5 +102,18 @@ class ModelFile(object): TORCH_MODEL_BIN_FILE = 'pytorch_model.bin' +class Requirements(object): + """Requirement names for each module + """ + protobuf = 'protobuf' + sentencepiece = 'sentencepiece' + sklearn = 'sklearn' + scipy = 'scipy' + timm = 'timm' + tokenizers = 'tokenizers' + tf = 'tf' + torch = 'torch' + + TENSORFLOW = 'tensorflow' PYTORCH = 'pytorch' diff --git a/modelscope/utils/import_utils.py b/modelscope/utils/import_utils.py new file mode 100644 index 00000000..e4192082 --- /dev/null +++ b/modelscope/utils/import_utils.py @@ -0,0 +1,324 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from huggingface/transformers. +import ast +import functools +import importlib.util +import os +import os.path as osp +import sys +import types +from collections import OrderedDict +from functools import wraps +from importlib import import_module +from itertools import chain +from types import ModuleType +from typing import Any + +import json +from packaging import version + +from modelscope.utils.constant import Fields +from modelscope.utils.logger import get_logger + +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + +logger = get_logger() + + +def import_modules_from_file(py_file: str): + """ Import module from a certrain file + + Args: + py_file: path to a python file to be imported + + Return: + + """ + dirname, basefile = os.path.split(py_file) + if dirname == '': + dirname == './' + module_name = osp.splitext(basefile)[0] + sys.path.insert(0, dirname) + validate_py_syntax(py_file) + mod = import_module(module_name) + sys.path.pop(0) + return module_name, mod + + +def import_modules(imports, allow_failed_imports=False): + """Import modules from the given list of strings. + + Args: + imports (list | str | None): The given module names to be imported. + allow_failed_imports (bool): If True, the failed imports will return + None. Otherwise, an ImportError is raise. Default: False. + + Returns: + list[module] | module | None: The imported modules. + + Examples: + >>> osp, sys = import_modules( + ... ['os.path', 'sys']) + >>> import os.path as osp_ + >>> import sys as sys_ + >>> assert osp == osp_ + >>> assert sys == sys_ + """ + if not imports: + return + single_import = False + if isinstance(imports, str): + single_import = True + imports = [imports] + if not isinstance(imports, list): + raise TypeError( + f'custom_imports must be a list but got type {type(imports)}') + imported = [] + for imp in imports: + if not isinstance(imp, str): + raise TypeError( + f'{imp} is of type {type(imp)} and cannot be imported.') + try: + imported_tmp = import_module(imp) + except ImportError: + if allow_failed_imports: + logger.warning(f'{imp} failed to import and is ignored.') + imported_tmp = None + else: + raise ImportError + imported.append(imported_tmp) + if single_import: + imported = imported[0] + return imported + + +def validate_py_syntax(filename): + with open(filename, 'r', encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + content = f.read() + try: + ast.parse(content) + except SyntaxError as e: + raise SyntaxError('There are syntax errors in config ' + f'file {filename}: {e}') + + +# following code borrows implementation from huggingface/transformers +ENV_VARS_TRUE_VALUES = {'1', 'ON', 'YES', 'TRUE'} +ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({'AUTO'}) +USE_TF = os.environ.get('USE_TF', 'AUTO').upper() +USE_TORCH = os.environ.get('USE_TORCH', 'AUTO').upper() +_torch_version = 'N/A' +if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: + _torch_available = importlib.util.find_spec('torch') is not None + if _torch_available: + try: + _torch_version = importlib_metadata.version('torch') + logger.info(f'PyTorch version {_torch_version} available.') + except importlib_metadata.PackageNotFoundError: + _torch_available = False +else: + logger.info('Disabling PyTorch because USE_TF is set') + _torch_available = False + +_tf_version = 'N/A' +if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: + _tf_available = importlib.util.find_spec('tensorflow') is not None + if _tf_available: + candidates = ( + 'tensorflow', + 'tensorflow-cpu', + 'tensorflow-gpu', + 'tf-nightly', + 'tf-nightly-cpu', + 'tf-nightly-gpu', + 'intel-tensorflow', + 'intel-tensorflow-avx512', + 'tensorflow-rocm', + 'tensorflow-macos', + ) + _tf_version = None + # For the metadata, we have to look for both tensorflow and tensorflow-cpu + for pkg in candidates: + try: + _tf_version = importlib_metadata.version(pkg) + break + except importlib_metadata.PackageNotFoundError: + pass + _tf_available = _tf_version is not None + if _tf_available: + if version.parse(_tf_version) < version.parse('2'): + pass + else: + logger.info(f'TensorFlow version {_tf_version} available.') +else: + logger.info('Disabling Tensorflow because USE_TORCH is set') + _tf_available = False + +_timm_available = importlib.util.find_spec('timm') is not None +try: + _timm_version = importlib_metadata.version('timm') + logger.debug(f'Successfully imported timm version {_timm_version}') +except importlib_metadata.PackageNotFoundError: + _timm_available = False + + +def is_scipy_available(): + return importlib.util.find_spec('scipy') is not None + + +def is_sklearn_available(): + if importlib.util.find_spec('sklearn') is None: + return False + return is_scipy_available() and importlib.util.find_spec('sklearn.metrics') + + +def is_sentencepiece_available(): + return importlib.util.find_spec('sentencepiece') is not None + + +def is_protobuf_available(): + if importlib.util.find_spec('google') is None: + return False + return importlib.util.find_spec('google.protobuf') is not None + + +def is_tokenizers_available(): + return importlib.util.find_spec('tokenizers') is not None + + +def is_timm_available(): + return _timm_available + + +def is_torch_available(): + return _torch_available + + +def is_torch_cuda_available(): + if is_torch_available(): + import torch + + return torch.cuda.is_available() + else: + return False + + +def is_tf_available(): + return _tf_available + + +# docstyle-ignore +PROTOBUF_IMPORT_ERROR = """ +{0} requires the protobuf library but it was not found in your environment. Checkout the instructions on the +installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and +follow the ones that match your environment. +""" + +# docstyle-ignore +SENTENCEPIECE_IMPORT_ERROR = """ +{0} requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the +installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones +that match your environment. +""" + +# docstyle-ignore +SKLEARN_IMPORT_ERROR = """ +{0} requires the scikit-learn library but it was not found in your environment. You can install it with: +``` +pip install -U scikit-learn +``` +In a notebook or a colab, you can install it by executing a cell with +``` +!pip install -U scikit-learn +``` +""" + +# docstyle-ignore +TENSORFLOW_IMPORT_ERROR = """ +{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the +installation page: https://www.tensorflow.org/install and follow the ones that match your environment. +""" + +# docstyle-ignore +TIMM_IMPORT_ERROR = """ +{0} requires the timm library but it was not found in your environment. You can install it with pip: +`pip install timm` +""" + +# docstyle-ignore +TOKENIZERS_IMPORT_ERROR = """ +{0} requires the 🤗 Tokenizers library but it was not found in your environment. You can install it with: +``` +pip install tokenizers +``` +In a notebook or a colab, you can install it by executing a cell with +``` +!pip install tokenizers +``` +""" + +# docstyle-ignore +PYTORCH_IMPORT_ERROR = """ +{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the +installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. +""" + +# docstyle-ignore +SCIPY_IMPORT_ERROR = """ +{0} requires the scipy library but it was not found in your environment. You can install it with pip: +`pip install scipy` +""" + +REQUIREMENTS_MAAPING = OrderedDict([ + ('protobuf', (is_protobuf_available, PROTOBUF_IMPORT_ERROR)), + ('sentencepiece', (is_sentencepiece_available, + SENTENCEPIECE_IMPORT_ERROR)), + ('sklearn', (is_sklearn_available, SKLEARN_IMPORT_ERROR)), + ('tf', (is_tf_available, TENSORFLOW_IMPORT_ERROR)), + ('timm', (is_timm_available, TIMM_IMPORT_ERROR)), + ('tokenizers', (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)), + ('torch', (is_torch_available, PYTORCH_IMPORT_ERROR)), + ('scipy', (is_scipy_available, SCIPY_IMPORT_ERROR)), +]) + + +def requires(obj, requirements): + if not isinstance(requirements, (list, tuple)): + requirements = [requirements] + if isinstance(obj, str): + name = obj + else: + name = obj.__name__ if hasattr(obj, + '__name__') else obj.__class__.__name__ + checks = (REQUIREMENTS_MAAPING[req] for req in requirements) + failed = [msg.format(name) for available, msg in checks if not available()] + if failed: + raise ImportError(''.join(failed)) + + +def torch_required(func): + # Chose a different decorator name than in tests so it's clear they are not the same. + @functools.wraps(func) + def wrapper(*args, **kwargs): + if is_torch_available(): + return func(*args, **kwargs) + else: + raise ImportError(f'Method `{func.__name__}` requires PyTorch.') + + return wrapper + + +def tf_required(func): + # Chose a different decorator name than in tests so it's clear they are not the same. + @functools.wraps(func) + def wrapper(*args, **kwargs): + if is_tf_available(): + return func(*args, **kwargs) + else: + raise ImportError(f'Method `{func.__name__}` requires TF.') + + return wrapper diff --git a/modelscope/utils/pymod.py b/modelscope/utils/pymod.py deleted file mode 100644 index 6db6798d..00000000 --- a/modelscope/utils/pymod.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import ast -import os -import os.path as osp -import sys -import types -from importlib import import_module - -from modelscope.utils.logger import get_logger - -logger = get_logger() - - -def import_modules_from_file(py_file: str): - """ Import module from a certrain file - - Args: - py_file: path to a python file to be imported - - Return: - - """ - dirname, basefile = os.path.split(py_file) - if dirname == '': - dirname == './' - module_name = osp.splitext(basefile)[0] - sys.path.insert(0, dirname) - validate_py_syntax(py_file) - mod = import_module(module_name) - sys.path.pop(0) - return module_name, mod - - -def import_modules(imports, allow_failed_imports=False): - """Import modules from the given list of strings. - - Args: - imports (list | str | None): The given module names to be imported. - allow_failed_imports (bool): If True, the failed imports will return - None. Otherwise, an ImportError is raise. Default: False. - - Returns: - list[module] | module | None: The imported modules. - - Examples: - >>> osp, sys = import_modules( - ... ['os.path', 'sys']) - >>> import os.path as osp_ - >>> import sys as sys_ - >>> assert osp == osp_ - >>> assert sys == sys_ - """ - if not imports: - return - single_import = False - if isinstance(imports, str): - single_import = True - imports = [imports] - if not isinstance(imports, list): - raise TypeError( - f'custom_imports must be a list but got type {type(imports)}') - imported = [] - for imp in imports: - if not isinstance(imp, str): - raise TypeError( - f'{imp} is of type {type(imp)} and cannot be imported.') - try: - imported_tmp = import_module(imp) - except ImportError: - if allow_failed_imports: - logger.warning(f'{imp} failed to import and is ignored.') - imported_tmp = None - else: - raise ImportError - imported.append(imported_tmp) - if single_import: - imported = imported[0] - return imported - - -def validate_py_syntax(filename): - with open(filename, 'r', encoding='utf-8') as f: - # Setting encoding explicitly to resolve coding issue on windows - content = f.read() - try: - ast.parse(content) - except SyntaxError as e: - raise SyntaxError('There are syntax errors in config ' - f'file {filename}: {e}') diff --git a/modelscope/utils/registry.py b/modelscope/utils/registry.py index 8009b084..2e1f8672 100644 --- a/modelscope/utils/registry.py +++ b/modelscope/utils/registry.py @@ -1,7 +1,9 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import inspect +from typing import List, Tuple, Union +from modelscope.utils.import_utils import requires from modelscope.utils.logger import get_logger default_group = 'default' @@ -52,9 +54,14 @@ class Registry(object): def _register_module(self, group_key=default_group, module_name=None, - module_cls=None): + module_cls=None, + requirements=None): assert isinstance(group_key, str), 'group_key is required and must be str' + + if requirements is not None: + requires(module_cls, requirements) + if group_key not in self._modules: self._modules[group_key] = dict() @@ -70,23 +77,11 @@ class Registry(object): self._modules[group_key][module_name] = module_cls module_cls.group_key = group_key - if module_name in self._modules[default_group]: - if id(self._modules[default_group][module_name]) == id(module_cls): - return - else: - logger.warning(f'{module_name} is already registered in ' - f'{self._name}[{default_group}] and will ' - 'be overwritten') - logger.warning(f'{self._modules[default_group][module_name]}' - f'to {module_cls}') - # also register module in the default group for faster access - # only by module name - self._modules[default_group][module_name] = module_cls - def register_module(self, group_key: str = default_group, module_name: str = None, - module_cls: type = None): + module_cls: type = None, + requirements: Union[List, Tuple] = None): """ Register module Example: @@ -110,17 +105,18 @@ class Registry(object): default group name is 'default' module_name: Module name module_cls: Module class object + requirements: Module necessary requirements """ if not (module_name is None or isinstance(module_name, str)): raise TypeError(f'module_name must be either of None, str,' f'got {type(module_name)}') - if module_cls is not None: self._register_module( group_key=group_key, module_name=module_name, - module_cls=module_cls) + module_cls=module_cls, + requirements=requirements) return module_cls # if module_cls is None, should return a decorator function @@ -128,7 +124,8 @@ class Registry(object): self._register_module( group_key=group_key, module_name=module_name, - module_cls=module_cls) + module_cls=module_cls, + requirements=requirements) return module_cls return _register diff --git a/requirements.txt b/requirements.txt index b9b4a1c4..c6e294ba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1 @@ -r requirements/runtime.txt --r requirements/pipeline.txt --r requirements/multi-modal.txt --r requirements/nlp.txt --r requirements/audio.txt --r requirements/cv.txt diff --git a/requirements/audio.txt b/requirements/audio.txt index 1f5984ca..4c009d27 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -1,10 +1,5 @@ #tts h5py -https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/pytorch_wavelets-1.3.0-py3-none-any.whl -https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.2-cp36-cp36m-linux_x86_64.whl; python_version=='3.6' -https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.2-cp37-cp37m-linux_x86_64.whl; python_version=='3.7' -https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.2-cp38-cp38-linux_x86_64.whl; python_version=='3.8' -https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.2-cp39-cp39-linux_x86_64.whl; python_version=='3.9' inflect keras librosa @@ -14,6 +9,7 @@ nara_wpe numpy protobuf>3,<=3.20 ptflops +pytorch_wavelets==1.3.0 PyWavelets>=1.0.0 scikit-learn SoundFile>0.10 @@ -24,4 +20,5 @@ torch torchaudio torchvision tqdm +ttsfrd==0.0.2 unidecode diff --git a/requirements/multi-modal.txt b/requirements/multi-modal.txt index ad641b63..b96bdd01 100644 --- a/requirements/multi-modal.txt +++ b/requirements/multi-modal.txt @@ -1,8 +1,6 @@ -datasets -einops +fairseq==maas ftfy>=6.0.3 -https://jirenmr.oss-cn-zhangjiakou.aliyuncs.com/ofa/fairseq-maas-py3-none-any.whl -https://jirenmr.oss-cn-zhangjiakou.aliyuncs.com/ofa/ofa-0.0.2-py3-none-any.whl +ofa==0.0.2 pycocoevalcap>=1.2 pycocotools>=2.0.4 rouge_score diff --git a/requirements/nlp.txt b/requirements/nlp.txt index beb5f016..ec8f9513 100644 --- a/requirements/nlp.txt +++ b/requirements/nlp.txt @@ -1,3 +1,4 @@ http://ait-public.oss-cn-hangzhou-zmf.aliyuncs.com/jizhu/en_core_web_sm-2.3.1.tar.gz https://alinlp.alibaba-inc.com/pypi/sofa-1.0.5-py3-none-any.whl +sofa==1.0.5 spacy>=2.3.5 diff --git a/requirements/pipeline.txt b/requirements/pipeline.txt deleted file mode 100644 index 64500a6b..00000000 --- a/requirements/pipeline.txt +++ /dev/null @@ -1,6 +0,0 @@ -#https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/package/whl/easynlp-0.0.4-py2.py3-none-any.whl -# tensorflow -#--find-links https://download.pytorch.org/whl/torch_stable.html -# torch<1.10,>=1.8.0 -# torchaudio -# torchvision diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 6580de53..1fcce7ff 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,16 +1,18 @@ addict datasets easydict +einops filelock>=3.3.0 numpy -opencv-python-headless +opencv-python Pillow>=6.2.0 +protobuf>3,<=3.20 pyyaml requests -requests==2.27.1 scipy -setuptools==58.0.4 +setuptools tokenizers<=0.10.3 +torch tqdm>=4.64.0 -transformers<=4.16.2 +transformers<=4.16.2,>=4.10.3 yapf diff --git a/setup.py b/setup.py index b027c4cb..3b40ac8b 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,8 @@ import shutil import subprocess from setuptools import find_packages, setup +from modelscope.utils.constant import Fields + def readme(): with open('README.md', encoding='utf-8') as f: @@ -169,6 +171,16 @@ if __name__ == '__main__': pack_resource() os.chdir('package') install_requires, deps_link = parse_requirements('requirements.txt') + extra_requires = {} + all_requires = [] + for field in dir(Fields): + if field.startswith('_'): + continue + extra_requires[field], _ = parse_requirements( + f'requirements/{field}.txt') + all_requires.append(extra_requires[field]) + extra_requires['all'] = all_requires + setup( name='model-scope', version=get_version(), @@ -193,5 +205,6 @@ if __name__ == '__main__': license='Apache License 2.0', tests_require=parse_requirements('requirements/tests.txt'), install_requires=install_requires, + extras_require=extra_requires, dependency_links=deps_link, zip_safe=False) diff --git a/tests/pipelines/test_base.py b/tests/pipelines/test_base.py index c642ed4b..93ebf08f 100644 --- a/tests/pipelines/test_base.py +++ b/tests/pipelines/test_base.py @@ -8,6 +8,7 @@ import PIL from modelscope.pipelines import Pipeline, pipeline from modelscope.pipelines.builder import PIPELINES, add_default_pipeline_info +from modelscope.pipelines.outputs import OutputKeys from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger from modelscope.utils.registry import default_group @@ -68,28 +69,28 @@ class CustomPipelineTest(unittest.TestCase): outputs['filename'] = inputs['url'] img = inputs['img'] new_image = img.resize((img.width // 2, img.height // 2)) - outputs['output_png'] = np.array(new_image) + outputs[OutputKeys.OUTPUT_IMG] = np.array(new_image) return outputs def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return inputs - self.assertTrue('custom-image' in PIPELINES.modules[default_group]) + self.assertTrue('custom-image' in PIPELINES.modules[dummy_task]) add_default_pipeline_info(dummy_task, 'custom-image', overwrite=True) - pipe = pipeline(pipeline_name='custom-image') + pipe = pipeline(task=dummy_task, pipeline_name='custom-image') pipe2 = pipeline(dummy_task) self.assertTrue(type(pipe) is type(pipe2)) img_url = 'data/test/images/image1.jpg' output = pipe(img_url) self.assertEqual(output['filename'], img_url) - self.assertEqual(output['output_png'].shape, (318, 512, 3)) + self.assertEqual(output[OutputKeys.OUTPUT_IMG].shape, (318, 512, 3)) outputs = pipe([img_url for i in range(4)]) self.assertEqual(len(outputs), 4) for out in outputs: self.assertEqual(out['filename'], img_url) - self.assertEqual(out['output_png'].shape, (318, 512, 3)) + self.assertEqual(out[OutputKeys.OUTPUT_IMG].shape, (318, 512, 3)) if __name__ == '__main__': diff --git a/tests/pipelines/test_image_captioning.py b/tests/pipelines/test_image_captioning.py index 5fa6ff49..c185d774 100644 --- a/tests/pipelines/test_image_captioning.py +++ b/tests/pipelines/test_image_captioning.py @@ -3,6 +3,7 @@ import unittest from modelscope.pipelines import pipeline +from modelscope.pipelines.outputs import OutputKeys from modelscope.utils.constant import Tasks from modelscope.utils.test_utils import test_level @@ -15,7 +16,7 @@ class ImageCaptionTest(unittest.TestCase): Tasks.image_captioning, model='damo/ofa_image-caption_coco_large_en') result = img_captioning('data/test/images/image_captioning.png') - print(result['caption']) + print(result[OutputKeys.CAPTION]) if __name__ == '__main__': diff --git a/tests/pipelines/test_image_matting.py b/tests/pipelines/test_image_matting.py index 48a715f1..22fb127b 100644 --- a/tests/pipelines/test_image_matting.py +++ b/tests/pipelines/test_image_matting.py @@ -9,6 +9,7 @@ import cv2 from modelscope.fileio import File from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline +from modelscope.pipelines.outputs import OutputKeys from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.test_utils import test_level @@ -29,7 +30,7 @@ class ImageMattingTest(unittest.TestCase): img_matting = pipeline(Tasks.image_matting, model=tmp_dir) result = img_matting('data/test/images/image_matting.png') - cv2.imwrite('result.png', result['output_png']) + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_dataset(self): @@ -41,7 +42,7 @@ class ImageMattingTest(unittest.TestCase): img_matting = pipeline(Tasks.image_matting, model=self.model_id) # note that for dataset output, the inference-output is a Generator that can be iterated. result = img_matting(dataset) - cv2.imwrite('result.png', next(result)['output_png']) + cv2.imwrite('result.png', next(result)[OutputKeys.OUTPUT_IMG]) print(f'Output written to {osp.abspath("result.png")}') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -49,7 +50,7 @@ class ImageMattingTest(unittest.TestCase): img_matting = pipeline(Tasks.image_matting, model=self.model_id) result = img_matting('data/test/images/image_matting.png') - cv2.imwrite('result.png', result['output_png']) + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) print(f'Output written to {osp.abspath("result.png")}') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @@ -57,7 +58,7 @@ class ImageMattingTest(unittest.TestCase): img_matting = pipeline(Tasks.image_matting) result = img_matting('data/test/images/image_matting.png') - cv2.imwrite('result.png', result['output_png']) + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) print(f'Output written to {osp.abspath("result.png")}') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @@ -67,7 +68,7 @@ class ImageMattingTest(unittest.TestCase): img_matting = pipeline(Tasks.image_matting, model=self.model_id) result = img_matting(dataset) for i in range(10): - cv2.imwrite(f'result_{i}.png', next(result)['output_png']) + cv2.imwrite(f'result_{i}.png', next(result)[OutputKeys.OUTPUT_IMG]) print( f'Output written to dir: {osp.dirname(osp.abspath("result_0.png"))}' ) diff --git a/tests/pipelines/test_key_word_spotting.py b/tests/pipelines/test_key_word_spotting.py index e82a4211..b01e3f21 100644 --- a/tests/pipelines/test_key_word_spotting.py +++ b/tests/pipelines/test_key_word_spotting.py @@ -15,8 +15,8 @@ from modelscope.utils.test_utils import test_level KWSBP_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/tools/kwsbp' -POS_WAV_FILE = '20200707_spk57db_storenoise52db_40cm_xiaoyun_sox_6.wav' -POS_WAV_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/' + POS_WAV_FILE +POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav' +BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav' POS_TESTSETS_FILE = 'pos_testsets.tar.gz' POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz' @@ -47,12 +47,8 @@ class KeyWordSpottingTest(unittest.TestCase): # wav, neg_testsets, pos_testsets, roc kws_set = 'wav' - # downloading wav file - wav_file_path = os.path.join(self.workspace, POS_WAV_FILE) - if not os.path.exists(wav_file_path): - r = requests.get(POS_WAV_URL) - with open(wav_file_path, 'wb') as f: - f.write(r.content) + # get wav file + wav_file_path = POS_WAV_FILE # downloading kwsbp kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') @@ -70,6 +66,7 @@ class KeyWordSpottingTest(unittest.TestCase): self.assertTrue(preprocessor is not None) kwsbp_16k_pipline = pipeline( + task=Tasks.key_word_spotting, pipeline_name=Pipelines.kws_kwsbp, model=model, preprocessor=preprocessor) @@ -91,9 +88,73 @@ class KeyWordSpottingTest(unittest.TestCase): """ if kws_result.__contains__('keywords'): print('test_run_with_wav keywords: ', kws_result['keywords']) + print('test_run_with_wav confidence: ', kws_result['confidence']) print('test_run_with_wav detected result: ', kws_result['detected']) print('test_run_with_wav wave time(seconds): ', kws_result['wav_time']) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_wav_by_customized_keywords(self): + # wav, neg_testsets, pos_testsets, roc + kws_set = 'wav' + + # get wav file + wav_file_path = BOFANGYINYUE_WAV_FILE + + # downloading kwsbp + kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') + if not os.path.exists(kwsbp_file_path): + r = requests.get(KWSBP_URL) + with open(kwsbp_file_path, 'wb') as f: + f.write(r.content) + + model = Model.from_pretrained(self.model_id) + self.assertTrue(model is not None) + + cfg_preprocessor = dict( + type=Preprocessors.wav_to_lists, workspace=self.workspace) + preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) + self.assertTrue(preprocessor is not None) + + # customized keyword if you need. + # full settings eg. + # keywords = [ + # {'keyword':'你好电视', 'threshold': 0.008}, + # {'keyword':'播放音乐', 'threshold': 0.008} + # ] + keywords = [{'keyword': '播放音乐'}] + + kwsbp_16k_pipline = pipeline( + task=Tasks.key_word_spotting, + pipeline_name=Pipelines.kws_kwsbp, + model=model, + preprocessor=preprocessor, + keywords=keywords) + self.assertTrue(kwsbp_16k_pipline is not None) + + kws_result = kwsbp_16k_pipline( + kws_type=kws_set, wav_path=[wav_file_path, None]) + self.assertTrue(kws_result.__contains__('detected')) + """ + kws result json format example: + { + 'wav_count': 1, + 'kws_set': 'wav', + 'wav_time': 9.132938, + 'keywords': ['播放音乐'], + 'detected': True, + 'confidence': 0.660368 + } + """ + if kws_result.__contains__('keywords'): + print('test_run_with_wav_by_customized_keywords keywords: ', + kws_result['keywords']) + print('test_run_with_wav_by_customized_keywords confidence: ', + kws_result['confidence']) + print('test_run_with_wav_by_customized_keywords detected result: ', + kws_result['detected']) + print('test_run_with_wav_by_customized_keywords wave time(seconds): ', + kws_result['wav_time']) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_pos_testsets(self): # wav, neg_testsets, pos_testsets, roc @@ -133,6 +194,7 @@ class KeyWordSpottingTest(unittest.TestCase): self.assertTrue(preprocessor is not None) kwsbp_16k_pipline = pipeline( + task=Tasks.key_word_spotting, pipeline_name=Pipelines.kws_kwsbp, model=model, preprocessor=preprocessor) @@ -204,6 +266,7 @@ class KeyWordSpottingTest(unittest.TestCase): self.assertTrue(preprocessor is not None) kwsbp_16k_pipline = pipeline( + task=Tasks.key_word_spotting, pipeline_name=Pipelines.kws_kwsbp, model=model, preprocessor=preprocessor) @@ -298,6 +361,7 @@ class KeyWordSpottingTest(unittest.TestCase): self.assertTrue(preprocessor is not None) kwsbp_16k_pipline = pipeline( + task=Tasks.key_word_spotting, pipeline_name=Pipelines.kws_kwsbp, model=model, preprocessor=preprocessor) diff --git a/tests/pipelines/test_person_image_cartoon.py b/tests/pipelines/test_person_image_cartoon.py index f47ca008..505e02cc 100644 --- a/tests/pipelines/test_person_image_cartoon.py +++ b/tests/pipelines/test_person_image_cartoon.py @@ -7,6 +7,7 @@ import cv2 from modelscope.pipelines import pipeline from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.outputs import OutputKeys from modelscope.utils.constant import Tasks from modelscope.utils.test_utils import test_level @@ -22,7 +23,7 @@ class ImageCartoonTest(unittest.TestCase): def pipeline_inference(self, pipeline: Pipeline, input_location: str): result = pipeline(input_location) if result is not None: - cv2.imwrite('result.png', result['output_png']) + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) print(f'Output written to {osp.abspath("result.png")}') @unittest.skip('deprecated, download model from model hub instead') diff --git a/tests/pipelines/test_text_to_speech.py b/tests/pipelines/test_text_to_speech.py index e92047d6..c371d80a 100644 --- a/tests/pipelines/test_text_to_speech.py +++ b/tests/pipelines/test_text_to_speech.py @@ -12,7 +12,7 @@ from modelscope.metainfo import Pipelines, Preprocessors from modelscope.models import Model from modelscope.pipelines import pipeline from modelscope.preprocessors import build_preprocessor -from modelscope.utils.constant import Fields +from modelscope.utils.constant import Fields, Tasks from modelscope.utils.logger import get_logger from modelscope.utils.test_utils import test_level @@ -43,6 +43,7 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase): self.assertTrue(voc is not None) sambert_tts = pipeline( + task=Tasks.text_to_speech, pipeline_name=Pipelines.sambert_hifigan_16k_tts, config_file='', model=[am, voc], diff --git a/tests/utils/test_check_requirements.py b/tests/utils/test_check_requirements.py new file mode 100644 index 00000000..2ad19e82 --- /dev/null +++ b/tests/utils/test_check_requirements.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest +from typing import List, Union + +from modelscope.utils.check_requirements import NLPModuleNotFoundError, get_msg +from modelscope.utils.constant import Fields + + +class ImportUtilsTest(unittest.TestCase): + + def test_type_module_not_found(self): + with self.assertRaises(NLPModuleNotFoundError) as ctx: + try: + import not_found + except ModuleNotFoundError as e: + raise NLPModuleNotFoundError(e) + self.assertTrue(get_msg(Fields.nlp) in ctx.exception.msg.msg) + + +if __name__ == '__main__': + unittest.main()