From 6702b29e21e9ad10256ba8adebb796530751c10d Mon Sep 17 00:00:00 2001 From: "yingda.chen" Date: Mon, 27 Jun 2022 11:09:38 +0800 Subject: [PATCH 1/7] [to #42794773]rename pydataset to msdataset Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9165402 --- docs/source/api/modelscope.pydatasets.rst | 8 +++---- docs/source/api/modelscope.rst | 2 +- docs/source/quick_start.md | 10 ++++---- modelscope/hub/file_download.py | 2 +- modelscope/msdatasets/__init__.py | 1 + .../{pydatasets => msdatasets}/config.py | 0 .../ms_dataset.py} | 24 +++++++++---------- .../utils/__init__.py | 0 .../utils/ms_api.py | 2 +- modelscope/pipelines/base.py | 6 ++--- modelscope/pydatasets/__init__.py | 1 - tests/{pydatasets => msdatasets}/__init__.py | 0 .../test_ms_dataset.py} | 19 +++++++-------- tests/pipelines/test_action_recognition.py | 2 +- tests/pipelines/test_image_matting.py | 6 ++--- tests/pipelines/test_text_classification.py | 12 +++++----- 16 files changed, 47 insertions(+), 48 deletions(-) create mode 100644 modelscope/msdatasets/__init__.py rename modelscope/{pydatasets => msdatasets}/config.py (100%) rename modelscope/{pydatasets/py_dataset.py => msdatasets/ms_dataset.py} (96%) rename modelscope/{pydatasets => msdatasets}/utils/__init__.py (100%) rename modelscope/{pydatasets => msdatasets}/utils/ms_api.py (97%) delete mode 100644 modelscope/pydatasets/__init__.py rename tests/{pydatasets => msdatasets}/__init__.py (100%) rename tests/{pydatasets/test_py_dataset.py => msdatasets/test_ms_dataset.py} (88%) diff --git a/docs/source/api/modelscope.pydatasets.rst b/docs/source/api/modelscope.pydatasets.rst index 2508a91f..53b858a8 100644 --- a/docs/source/api/modelscope.pydatasets.rst +++ b/docs/source/api/modelscope.pydatasets.rst @@ -1,7 +1,7 @@ -modelscope.pydatasets package +modelscope.msdatasets package ============================= -.. automodule:: modelscope.pydatasets +.. automodule:: modelscope.msdatasets :members: :undoc-members: :show-inheritance: @@ -9,10 +9,10 @@ modelscope.pydatasets package Submodules ---------- -modelscope.pydatasets.py\_dataset module +modelscope.msdatasets.ms\_dataset module ---------------------------------------- -.. automodule:: modelscope.pydatasets.py_dataset +.. automodule:: modelscope.msdatasets.ms_dataset :members: :undoc-members: :show-inheritance: diff --git a/docs/source/api/modelscope.rst b/docs/source/api/modelscope.rst index efab568b..eacdf33d 100644 --- a/docs/source/api/modelscope.rst +++ b/docs/source/api/modelscope.rst @@ -16,7 +16,7 @@ Subpackages modelscope.models modelscope.pipelines modelscope.preprocessors - modelscope.pydatasets + modelscope.msdatasets modelscope.trainers modelscope.utils diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index 7148f27f..de416f08 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -3,7 +3,7 @@ ## python环境配置 首先,参考[文档](https://docs.anaconda.com/anaconda/install/) 安装配置Anaconda环境 -安装完成后,执行如下命令为maas library创建对应的python环境。 +安装完成后,执行如下命令为modelscope library创建对应的python环境。 ```shell conda create -n modelscope python=3.6 conda activate modelscope @@ -105,15 +105,15 @@ import cv2 import os.path as osp from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks -from modelscope.pydatasets import PyDataset +from modelscope.msdatasets import MsDataset -# 使用图像url构建PyDataset,此处也可通过 input_location = '/dir/to/images' 来使用本地文件夹 +# 使用图像url构建MsDataset,此处也可通过 input_location = '/dir/to/images' 来使用本地文件夹 input_location = [ 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' ] -dataset = PyDataset.load(input_location, target='image') +dataset = MsDataset.load(input_location, target='image') img_matting = pipeline(Tasks.image_matting, model='damo/image-matting-person') -# 输入为PyDataset时,输出的结果为迭代器 +# 输入为MsDataset时,输出的结果为迭代器 result = img_matting(dataset) cv2.imwrite('result.png', next(result)['output_png']) print(f'Output written to {osp.abspath("result.png")}') diff --git a/modelscope/hub/file_download.py b/modelscope/hub/file_download.py index e5c64f1c..b92bf89c 100644 --- a/modelscope/hub/file_download.py +++ b/modelscope/hub/file_download.py @@ -187,7 +187,7 @@ def get_file_download_url(model_id: str, file_path: str, revision: str): """ Format file download url according to `model_id`, `revision` and `file_path`. e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`, - the resulted download url is: https://maas.co/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md + the resulted download url is: https://modelscope.co/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md """ download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}' return download_url_template.format( diff --git a/modelscope/msdatasets/__init__.py b/modelscope/msdatasets/__init__.py new file mode 100644 index 00000000..8e0647bb --- /dev/null +++ b/modelscope/msdatasets/__init__.py @@ -0,0 +1 @@ +from .ms_dataset import MsDataset diff --git a/modelscope/pydatasets/config.py b/modelscope/msdatasets/config.py similarity index 100% rename from modelscope/pydatasets/config.py rename to modelscope/msdatasets/config.py diff --git a/modelscope/pydatasets/py_dataset.py b/modelscope/msdatasets/ms_dataset.py similarity index 96% rename from modelscope/pydatasets/py_dataset.py rename to modelscope/msdatasets/ms_dataset.py index 49137253..0466894c 100644 --- a/modelscope/pydatasets/py_dataset.py +++ b/modelscope/msdatasets/ms_dataset.py @@ -10,8 +10,8 @@ from datasets.packaged_modules import _PACKAGED_DATASETS_MODULES from datasets.utils.file_utils import (is_relative_path, relative_to_absolute_path) -from modelscope.pydatasets.config import MS_DATASETS_CACHE -from modelscope.pydatasets.utils.ms_api import MsApi +from modelscope.msdatasets.config import MS_DATASETS_CACHE +from modelscope.msdatasets.utils.ms_api import MsApi from modelscope.utils.constant import Hubs from modelscope.utils.logger import get_logger @@ -28,9 +28,9 @@ def format_list(para) -> List: return para -class PyDataset: +class MsDataset: _hf_ds = None # holds the underlying HuggingFace Dataset - """A PyDataset backed by hugging face Dataset.""" + """A MsDataset backed by hugging face Dataset.""" def __init__(self, hf_ds: Dataset, target: Optional[str] = None): self._hf_ds = hf_ds @@ -49,7 +49,7 @@ class PyDataset: @classmethod def from_hf_dataset(cls, hf_ds: Dataset, - target: str = None) -> Union[dict, 'PyDataset']: + target: str = None) -> Union[dict, 'MsDataset']: if isinstance(hf_ds, Dataset): return cls(hf_ds, target) if len(hf_ds.keys()) == 1: @@ -68,8 +68,8 @@ class PyDataset: data_files: Optional[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]] = None - ) -> Union[dict, 'PyDataset']: - """Load a PyDataset from the ModelScope Hub, Hugging Face Hub, urls, or a local dataset. + ) -> Union[dict, 'MsDataset']: + """Load a MsDataset from the ModelScope Hub, Hugging Face Hub, urls, or a local dataset. Args: dataset_name (str): Path or name of the dataset. @@ -82,7 +82,7 @@ class PyDataset: hub (Hubs, optional): When loading from a remote hub, where it is from Returns: - PyDataset (obj:`PyDataset`): PyDataset object for a certain dataset. + MsDataset (obj:`MsDataset`): MsDataset object for a certain dataset. """ if hub == Hubs.huggingface: dataset = hf_load_dataset( @@ -92,9 +92,9 @@ class PyDataset: split=split, data_dir=data_dir, data_files=data_files) - return PyDataset.from_hf_dataset(dataset, target=target) + return MsDataset.from_hf_dataset(dataset, target=target) else: - return PyDataset._load_ms_dataset( + return MsDataset._load_ms_dataset( dataset_name, target=target, subset_name=subset_name, @@ -114,7 +114,7 @@ class PyDataset: data_files: Optional[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]] = None - ) -> Union[dict, 'PyDataset']: + ) -> Union[dict, 'MsDataset']: if isinstance(dataset_name, str): use_hf = False if dataset_name in _PACKAGED_DATASETS_MODULES or os.path.isdir(dataset_name) or \ @@ -153,7 +153,7 @@ class PyDataset: else: raise TypeError('path must be a str or a list, but got' f' {type(dataset_name)}') - return PyDataset.from_hf_dataset(dataset, target=target) + return MsDataset.from_hf_dataset(dataset, target=target) def to_torch_dataset_with_processors( self, diff --git a/modelscope/pydatasets/utils/__init__.py b/modelscope/msdatasets/utils/__init__.py similarity index 100% rename from modelscope/pydatasets/utils/__init__.py rename to modelscope/msdatasets/utils/__init__.py diff --git a/modelscope/pydatasets/utils/ms_api.py b/modelscope/msdatasets/utils/ms_api.py similarity index 97% rename from modelscope/pydatasets/utils/ms_api.py rename to modelscope/msdatasets/utils/ms_api.py index 04052cc4..fc3bcca2 100644 --- a/modelscope/pydatasets/utils/ms_api.py +++ b/modelscope/msdatasets/utils/ms_api.py @@ -4,7 +4,7 @@ from typing import Optional import requests -from modelscope.pydatasets.config import (DOWNLOADED_DATASETS_PATH, +from modelscope.msdatasets.config import (DOWNLOADED_DATASETS_PATH, MS_HUB_ENDPOINT) from modelscope.utils.logger import get_logger diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index 7e32f543..2f5d5dcc 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -6,15 +6,15 @@ from typing import Any, Dict, Generator, List, Union from modelscope.hub.snapshot_download import snapshot_download from modelscope.models.base import Model +from modelscope.msdatasets import MsDataset from modelscope.preprocessors import Preprocessor -from modelscope.pydatasets import PyDataset from modelscope.utils.config import Config from modelscope.utils.logger import get_logger from .outputs import TASK_OUTPUTS from .util import is_model, is_official_hub_path Tensor = Union['torch.Tensor', 'tf.Tensor'] -Input = Union[str, tuple, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] +Input = Union[str, tuple, MsDataset, 'PIL.Image.Image', 'numpy.ndarray'] InputModel = Union[str, Model] output_keys = [ @@ -85,7 +85,7 @@ class Pipeline(ABC): for ele in input: output.append(self._process_single(ele, *args, **post_kwargs)) - elif isinstance(input, PyDataset): + elif isinstance(input, MsDataset): return self._process_iterator(input, *args, **post_kwargs) else: diff --git a/modelscope/pydatasets/__init__.py b/modelscope/pydatasets/__init__.py deleted file mode 100644 index a1ed1d93..00000000 --- a/modelscope/pydatasets/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .py_dataset import PyDataset diff --git a/tests/pydatasets/__init__.py b/tests/msdatasets/__init__.py similarity index 100% rename from tests/pydatasets/__init__.py rename to tests/msdatasets/__init__.py diff --git a/tests/pydatasets/test_py_dataset.py b/tests/msdatasets/test_ms_dataset.py similarity index 88% rename from tests/pydatasets/test_py_dataset.py rename to tests/msdatasets/test_ms_dataset.py index e84f240a..de413d5f 100644 --- a/tests/pydatasets/test_py_dataset.py +++ b/tests/msdatasets/test_ms_dataset.py @@ -3,10 +3,9 @@ import unittest import datasets as hfdata from modelscope.models import Model +from modelscope.msdatasets import MsDataset from modelscope.preprocessors import SequenceClassificationPreprocessor from modelscope.preprocessors.base import Preprocessor -from modelscope.pydatasets import PyDataset -from modelscope.utils.constant import Hubs from modelscope.utils.test_utils import require_tf, require_torch, test_level @@ -31,15 +30,15 @@ class ImgPreprocessor(Preprocessor): } -class PyDatasetTest(unittest.TestCase): +class MsDatasetTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_ds_basic(self): - ms_ds_full = PyDataset.load('squad') + ms_ds_full = MsDataset.load('squad') ms_ds_full_hf = hfdata.load_dataset('squad') - ms_ds_train = PyDataset.load('squad', split='train') + ms_ds_train = MsDataset.load('squad', split='train') ms_ds_train_hf = hfdata.load_dataset('squad', split='train') - ms_image_train = PyDataset.from_hf_dataset( + ms_image_train = MsDataset.from_hf_dataset( hfdata.load_dataset('beans', split='train')) self.assertEqual(ms_ds_full['train'][0], ms_ds_full_hf['train'][0]) self.assertEqual(ms_ds_full['validation'][0], @@ -58,7 +57,7 @@ class PyDatasetTest(unittest.TestCase): nlp_model.model_dir, first_sequence='context', second_sequence=None) - ms_ds_train = PyDataset.load('squad', split='train') + ms_ds_train = MsDataset.load('squad', split='train') pt_dataset = ms_ds_train.to_torch_dataset(preprocessors=preprocessor) import torch dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5) @@ -75,7 +74,7 @@ class PyDatasetTest(unittest.TestCase): nlp_model.model_dir, first_sequence='context', second_sequence=None) - ms_ds_train = PyDataset.load('squad', split='train') + ms_ds_train = MsDataset.load('squad', split='train') tf_dataset = ms_ds_train.to_tf_dataset( batch_size=5, shuffle=True, @@ -86,7 +85,7 @@ class PyDatasetTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @require_torch def test_to_torch_dataset_img(self): - ms_image_train = PyDataset.from_hf_dataset( + ms_image_train = MsDataset.from_hf_dataset( hfdata.load_dataset('beans', split='train')) pt_dataset = ms_image_train.to_torch_dataset( preprocessors=ImgPreprocessor( @@ -100,7 +99,7 @@ class PyDatasetTest(unittest.TestCase): def test_to_tf_dataset_img(self): import tensorflow as tf tf.compat.v1.enable_eager_execution() - ms_image_train = PyDataset.load('beans', split='train') + ms_image_train = MsDataset.load('beans', split='train') tf_dataset = ms_image_train.to_tf_dataset( batch_size=5, shuffle=True, diff --git a/tests/pipelines/test_action_recognition.py b/tests/pipelines/test_action_recognition.py index b524ca18..6f608041 100644 --- a/tests/pipelines/test_action_recognition.py +++ b/tests/pipelines/test_action_recognition.py @@ -8,8 +8,8 @@ import unittest import cv2 from modelscope.fileio import File +from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline -from modelscope.pydatasets import PyDataset from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.test_utils import test_level diff --git a/tests/pipelines/test_image_matting.py b/tests/pipelines/test_image_matting.py index 1b547e14..de60ff0b 100644 --- a/tests/pipelines/test_image_matting.py +++ b/tests/pipelines/test_image_matting.py @@ -7,8 +7,8 @@ import unittest import cv2 from modelscope.fileio import File +from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline -from modelscope.pydatasets import PyDataset from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.test_utils import test_level @@ -37,7 +37,7 @@ class ImageMattingTest(unittest.TestCase): # alternatively: # input_location = '/dir/to/images' - dataset = PyDataset.load(input_location, target='image') + dataset = MsDataset.load(input_location, target='image') img_matting = pipeline(Tasks.image_matting, model=self.model_id) # note that for dataset output, the inference-output is a Generator that can be iterated. result = img_matting(dataset) @@ -62,7 +62,7 @@ class ImageMattingTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_modelscope_dataset(self): - dataset = PyDataset.load('beans', split='train', target='image') + dataset = MsDataset.load('beans', split='train', target='image') img_matting = pipeline(Tasks.image_matting, model=self.model_id) result = img_matting(dataset) for i in range(10): diff --git a/tests/pipelines/test_text_classification.py b/tests/pipelines/test_text_classification.py index 9e5f15b9..f913490c 100644 --- a/tests/pipelines/test_text_classification.py +++ b/tests/pipelines/test_text_classification.py @@ -3,9 +3,9 @@ import shutil import unittest from modelscope.models import Model +from modelscope.msdatasets import MsDataset from modelscope.pipelines import SequenceClassificationPipeline, pipeline from modelscope.preprocessors import SequenceClassificationPreprocessor -from modelscope.pydatasets import PyDataset from modelscope.utils.constant import Hubs, Tasks from modelscope.utils.test_utils import test_level @@ -28,7 +28,7 @@ class SequenceClassificationTest(unittest.TestCase): print(data) - def printDataset(self, dataset: PyDataset): + def printDataset(self, dataset: MsDataset): for i, r in enumerate(dataset): if i > 10: break @@ -50,7 +50,7 @@ class SequenceClassificationTest(unittest.TestCase): text_classification = pipeline( task=Tasks.text_classification, model=self.model_id) result = text_classification( - PyDataset.load( + MsDataset.load( 'glue', subset_name='sst2', split='train', @@ -62,7 +62,7 @@ class SequenceClassificationTest(unittest.TestCase): def test_run_with_default_model(self): text_classification = pipeline(task=Tasks.text_classification) result = text_classification( - PyDataset.load( + MsDataset.load( 'glue', subset_name='sst2', split='train', @@ -78,7 +78,7 @@ class SequenceClassificationTest(unittest.TestCase): text_classification = pipeline( Tasks.text_classification, model=model, preprocessor=preprocessor) # loaded from huggingface dataset - dataset = PyDataset.load( + dataset = MsDataset.load( 'glue', subset_name='sst2', split='train', @@ -91,7 +91,7 @@ class SequenceClassificationTest(unittest.TestCase): def test_run_with_modelscope_dataset(self): text_classification = pipeline(task=Tasks.text_classification) # loaded from modelscope dataset - dataset = PyDataset.load( + dataset = MsDataset.load( 'squad', split='train', target='context', hub=Hubs.modelscope) result = text_classification(dataset) self.printDataset(result) From 9ff6b704b0cfe8fede2e524ab0bf788a3a8c7b38 Mon Sep 17 00:00:00 2001 From: "eniac.xcw" Date: Mon, 27 Jun 2022 11:57:22 +0800 Subject: [PATCH 2/7] [to #42322933]add multi-modal-feature MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 增加中文图文特征模型 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9157786 * add multi-modal-feature * 修改code review中的问题 --- modelscope/metainfo.py | 2 + modelscope/models/__init__.py | 2 +- .../{multi_model => multi_modal}/__init__.py | 1 + .../models/multi_modal/clip/__init__.py | 0 .../models/multi_modal/clip/clip_bert.py | 26 +++ .../models/multi_modal/clip/clip_model.py | 158 ++++++++++++++++++ .../models/multi_modal/clip/clip_vit.py | 121 ++++++++++++++ .../image_captioning_model.py | 0 modelscope/pipelines/builder.py | 7 +- modelscope/pipelines/multi_modal/__init__.py | 1 + .../multi_modal_embedding_pipeline.py | 34 ++++ modelscope/pipelines/outputs.py | 7 + modelscope/utils/constant.py | 1 + tests/pipelines/test_multi_modal_embedding.py | 52 ++++++ 14 files changed, 409 insertions(+), 3 deletions(-) rename modelscope/models/{multi_model => multi_modal}/__init__.py (50%) create mode 100644 modelscope/models/multi_modal/clip/__init__.py create mode 100644 modelscope/models/multi_modal/clip/clip_bert.py create mode 100644 modelscope/models/multi_modal/clip/clip_model.py create mode 100644 modelscope/models/multi_modal/clip/clip_vit.py rename modelscope/models/{multi_model => multi_modal}/image_captioning_model.py (100%) create mode 100644 modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py create mode 100644 tests/pipelines/test_multi_modal_embedding.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index af39f3f4..af89cf33 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -24,6 +24,7 @@ class Models(object): # multi-modal models ofa = 'ofa' + clip = 'clip-multi-modal-embedding' class Pipelines(object): @@ -55,6 +56,7 @@ class Pipelines(object): # multi-modal tasks image_caption = 'image-caption' + multi_modal_embedding = 'multi-modal-embedding' class Trainers(object): diff --git a/modelscope/models/__init__.py b/modelscope/models/__init__.py index f873dcca..06380035 100644 --- a/modelscope/models/__init__.py +++ b/modelscope/models/__init__.py @@ -4,5 +4,5 @@ from .audio.tts.am import SambertNetHifi16k from .audio.tts.vocoder import Hifigan16k from .base import Model from .builder import MODELS, build_model -from .multi_model import OfaForImageCaptioning +from .multi_modal import OfaForImageCaptioning from .nlp import BertForSequenceClassification, SbertForSentenceSimilarity diff --git a/modelscope/models/multi_model/__init__.py b/modelscope/models/multi_modal/__init__.py similarity index 50% rename from modelscope/models/multi_model/__init__.py rename to modelscope/models/multi_modal/__init__.py index 02e8d6ab..2e6cc3bf 100644 --- a/modelscope/models/multi_model/__init__.py +++ b/modelscope/models/multi_modal/__init__.py @@ -1 +1,2 @@ +from .clip.clip_model import CLIPForMultiModalEmbedding from .image_captioning_model import OfaForImageCaptioning diff --git a/modelscope/models/multi_modal/clip/__init__.py b/modelscope/models/multi_modal/clip/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/multi_modal/clip/clip_bert.py b/modelscope/models/multi_modal/clip/clip_bert.py new file mode 100644 index 00000000..50ddba99 --- /dev/null +++ b/modelscope/models/multi_modal/clip/clip_bert.py @@ -0,0 +1,26 @@ +import torch.nn as nn +from transformers import BertConfig, BertForMaskedLM + + +class TextTransformer(nn.Module): + + def __init__(self, config_dict, feat_dim=768): + super(TextTransformer, self).__init__() + bert_config = BertConfig.from_dict(config_dict) + self.bert = BertForMaskedLM(bert_config).bert + + self.projector = nn.Linear( + bert_config.hidden_size, feat_dim, bias=False) + + def forward(self, input_ids, attention_mask): + trans_features = { + 'input_ids': input_ids, + 'attention_mask': attention_mask + } + + output_states = self.bert(**trans_features, return_dict=False) + output_tokens = output_states[0] + + cls_tokens = output_tokens[:, 0, :] + + return self.projector(cls_tokens) diff --git a/modelscope/models/multi_modal/clip/clip_model.py b/modelscope/models/multi_modal/clip/clip_model.py new file mode 100644 index 00000000..4283886f --- /dev/null +++ b/modelscope/models/multi_modal/clip/clip_model.py @@ -0,0 +1,158 @@ +import os.path as osp +from typing import Any, Dict + +import json +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from tokenizers import BertWordPieceTokenizer +from torchvision.transforms import Compose, Normalize, Resize, ToTensor + +from modelscope.metainfo import Models +from modelscope.models.base import Model +from modelscope.models.builder import MODELS +from modelscope.models.multi_modal.clip.clip_bert import TextTransformer +from modelscope.models.multi_modal.clip.clip_vit import VisionTransformer +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['CLIPForMultiModalEmbedding'] + + +class CLIPModel(nn.Module): + + def __init__(self, model_dir): + super(CLIPModel, self).__init__() + # including vision config and text config + model_config = json.load( + open('{}/encoder_config.json'.format(model_dir))) + + # vision encoder + vision_config = model_config['vision_config'] + self.img_size = vision_config['input_resolution'] + self.vision_encoder = VisionTransformer( + input_resolution=self.img_size, + patch_size=vision_config['patch_size'], + width=vision_config['width'], + layers=vision_config['layers'], + heads=vision_config['heads'], + output_dim=vision_config['feat_dim']) + + # text encoder + text_config = model_config['text_config'] + self.text_encoder = TextTransformer( + text_config['bert_config'], feat_dim=text_config['feat_dim']) + + def forward(self, input_data, input_type): + if input_type == 'img': + img_embedding = self.vision_encoder(input_data) + img_embedding = F.normalize(img_embedding, p=2.0, dim=1) + return img_embedding + elif input_type == 'text': + text_ids_tensor, text_mask_tensor = input_data + text_embedding = self.text_encoder(text_ids_tensor, + text_mask_tensor) + text_embedding = F.normalize(text_embedding, p=2.0, dim=1) + return text_embedding + else: + raise ValueError('Unknown input type') + + +@MODELS.register_module(Tasks.multi_modal_embedding, module_name=Models.clip) +class CLIPForMultiModalEmbedding(Model): + + def __init__(self, model_dir, device_id=-1): + super().__init__(model_dir=model_dir, device_id=device_id) + self.clip_model = CLIPModel(model_dir=model_dir) + pretrained_params = torch.load( + '{}/pytorch_model.bin'.format(model_dir), 'cpu') + self.clip_model.load_state_dict(pretrained_params) + self.clip_model.eval() + + self.device_id = device_id + if self.device_id >= 0: + self.clip_model.to('cuda:{}'.format(self.device_id)) + logger.info('Use GPU: {}'.format(self.device_id)) + else: + logger.info('Use CPU for inference') + + # image preprocessor + norm_op = Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)) + self.img_preprocessor = Compose([ + Resize((self.clip_model.img_size, self.clip_model.img_size), + interpolation=Image.BICUBIC), + ToTensor(), norm_op + ]) + + # text tokenizer + vocab_path = '{}/vocab.txt'.format(model_dir) + self.text_tokenizer = BertWordPieceTokenizer( + vocab_path, lowercase=False) + self.text_tokenizer.enable_truncation(max_length=30) + + def tokenize_text(self, text_str): + tokens = self.text_tokenizer.encode(text_str) + max_tokens = 30 + text_ids_tensor = torch.zeros((1, max_tokens)).long() + text_mask_tensor = torch.zeros((1, max_tokens)) + + text_ids, text_mask = tokens.ids, tokens.attention_mask + text_ids_tensor[0, 0:len(text_ids)] = torch.tensor(text_ids) + text_mask_tensor[0, 0:len(text_mask)] = torch.tensor(text_mask) + + return text_ids_tensor, text_mask_tensor + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + output = {'img_embedding': None, 'text_embedding': None} + if 'img' in input and input['img'] is not None: + input_img = input['img'] + if isinstance(input_img, Image.Image): + img_tensor = self.img_preprocessor(input_img)[None, ...] + elif isinstance(input_img, np.ndarray): + if len(input_img.shape) == 2: + input_img = cv2.cvtColor(input_img, cv2.COLOR_GRAY2BGR) + input_img = input_img[:, :, ::-1] # in rgb order + input_img = Image.fromarray( + input_img.astype('uint8')).convert('RGB') + img_tensor = self.img_preprocessor(input_img)[None, ...] + else: + raise TypeError( + f'img should be either PIL.Image or np.array, but got {type(input_img)}' + ) + + if self.device_id >= 0: + img_tensor = img_tensor.to('cuda:{}'.format(self.device_id)) + + img_embedding = self.clip_model( + input_data=img_tensor, input_type='img') + output['img_embedding'] = img_embedding.data.cpu().numpy() + + if 'text' in input and input['text'] is not None: + text_str = input['text'] + if isinstance(text_str, str): + text_ids_tensor, text_mask_tensor = self.tokenize_text( + text_str) + else: + raise TypeError( + f'text should be str, but got {type(text_str)}') + + if self.device_id >= 0: + text_ids_tensor = text_ids_tensor.to('cuda:{}'.format( + self.device_id)) + text_mask_tensor = text_mask_tensor.to('cuda:{}'.format( + self.device_id)) + + text_embedding = self.clip_model( + input_data=(text_ids_tensor, text_mask_tensor), + input_type='text') + output['text_embedding'] = text_embedding.data.cpu().numpy() + + return output + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/models/multi_modal/clip/clip_vit.py b/modelscope/models/multi_modal/clip/clip_vit.py new file mode 100644 index 00000000..95bb1adc --- /dev/null +++ b/modelscope/models/multi_modal/clip/clip_vit.py @@ -0,0 +1,121 @@ +# Copyright 2021 The OpenAI CLIP Authors. All rights reserved. + +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to( + dtype=x.dtype, + device=x.device) if self.attn_mask is not None else None + return self.attn( + x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ + ResidualAttentionBlock(width, heads, attn_mask) + for _ in range(layers) + ]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + + def __init__(self, input_resolution: int, patch_size: int, width: int, + layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + class_embeddings = self.class_embedding.to(x.dtype) + \ + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) + x = torch.cat([class_embeddings, x], dim=1) + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x diff --git a/modelscope/models/multi_model/image_captioning_model.py b/modelscope/models/multi_modal/image_captioning_model.py similarity index 100% rename from modelscope/models/multi_model/image_captioning_model.py rename to modelscope/models/multi_modal/image_captioning_model.py diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index d3be06bc..41cd73da 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -21,8 +21,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.sentence_similarity: (Pipelines.sentence_similarity, 'damo/nlp_structbert_sentence-similarity_chinese-base'), - Tasks.image_matting: - (Pipelines.image_matting, 'damo/cv_unet_image-matting'), + Tasks.image_matting: (Pipelines.image_matting, + 'damo/cv_unet_image-matting'), Tasks.text_classification: (Pipelines.sentiment_analysis, 'damo/bert-base-sst2'), Tasks.text_generation: (Pipelines.text_generation, @@ -37,6 +37,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), Tasks.action_recognition: (Pipelines.action_recognition, 'damo/cv_TAdaConv_action-recognition'), + Tasks.multi_modal_embedding: + (Pipelines.multi_modal_embedding, + 'damo/multi-modal_clip-vit-large-patch14-chinese_multi-modal-embedding') } diff --git a/modelscope/pipelines/multi_modal/__init__.py b/modelscope/pipelines/multi_modal/__init__.py index b7402b93..6c96d843 100644 --- a/modelscope/pipelines/multi_modal/__init__.py +++ b/modelscope/pipelines/multi_modal/__init__.py @@ -1 +1,2 @@ from .image_captioning_pipeline import ImageCaptionPipeline +from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline diff --git a/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py b/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py new file mode 100644 index 00000000..a21ecc79 --- /dev/null +++ b/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py @@ -0,0 +1,34 @@ +from typing import Any, Dict, Union + +from modelscope.metainfo import Pipelines +from modelscope.pipelines.base import Input +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from ..base import Model, Pipeline +from ..builder import PIPELINES + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.multi_modal_embedding, module_name=Pipelines.multi_modal_embedding) +class MultiModalEmbeddingPipeline(Pipeline): + + def __init__(self, model: str, device_id: int = -1): + if isinstance(model, str): + pipe_model = Model.from_pretrained(model) + elif isinstance(model, Model): + pipe_model = model + else: + raise NotImplementedError('model must be a single str') + + super().__init__(model=pipe_model) + + def preprocess(self, input: Input) -> Dict[str, Any]: + return input + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + return self.model(input) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/outputs.py b/modelscope/pipelines/outputs.py index 3b1c67de..52b7eeae 100644 --- a/modelscope/pipelines/outputs.py +++ b/modelscope/pipelines/outputs.py @@ -117,6 +117,13 @@ TASK_OUTPUTS = { # } Tasks.image_captioning: ['caption'], + # multi-modal embedding result for single sample + # { + # "img_embedding": np.array with shape [1, D], + # "text_embedding": np.array with shape [1, D] + # } + Tasks.multi_modal_embedding: ['img_embedding', 'text_embedding'], + # visual grounding result for single sample # { # "boxes": [ diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index e824db9a..2045efb6 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -57,6 +57,7 @@ class Tasks(object): image_captioning = 'image-captioning' visual_grounding = 'visual-grounding' text_to_image_synthesis = 'text-to-image-synthesis' + multi_modal_embedding = 'multi-modal-embedding' class InputFields(object): diff --git a/tests/pipelines/test_multi_modal_embedding.py b/tests/pipelines/test_multi_modal_embedding.py new file mode 100644 index 00000000..001bf951 --- /dev/null +++ b/tests/pipelines/test_multi_modal_embedding.py @@ -0,0 +1,52 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +import numpy as np + +from modelscope.models import Model +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class MultiModalEmbeddingTest(unittest.TestCase): + model_id = 'damo/multi-modal_clip-vit-large-patch14-chinese_multi-modal-embedding' + test_text = {'text': '一张风景图'} + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run(self): + pipe_line_multi_modal_embedding = pipeline( + Tasks.multi_modal_embedding, model=self.model_id) + test_str_embedding = pipe_line_multi_modal_embedding( + self.test_text)['text_embedding'] + print(np.sum(np.abs(test_str_embedding))) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + pipe_line_multi_modal_embedding = pipeline( + task=Tasks.multi_modal_embedding, model=model) + test_str_embedding = pipe_line_multi_modal_embedding( + self.test_text)['text_embedding'] + print(np.sum(np.abs(test_str_embedding))) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_model_name(self): + pipe_line_multi_modal_embedding = pipeline( + task=Tasks.multi_modal_embedding, model=self.model_id) + test_str_embedding = pipe_line_multi_modal_embedding( + self.test_text)['text_embedding'] + print(np.sum(np.abs(test_str_embedding))) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipe_line_multi_modal_embedding = pipeline( + task=Tasks.multi_modal_embedding) + test_str_embedding = pipe_line_multi_modal_embedding( + self.test_text)['text_embedding'] + print(np.sum(np.abs(test_str_embedding))) + + +if __name__ == '__main__': + unittest.main() From 5386748bc4d765e54357be50d0835aec3974bae8 Mon Sep 17 00:00:00 2001 From: "shichen.fsc" Date: Mon, 27 Jun 2022 11:59:44 +0800 Subject: [PATCH 3/7] [to #42322933] add some code check Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9122842 * [Add] add KWS code * [Update] check code linters and formatter * [Update] update kws code * Merge branch 'master' into dev/kws * [Fix] fix kws warning * [Add] add ROC for KWS * [Update] add some code check * feat: Fix confilct, auto commit by WebIDE * feat: Fix confilct, auto commit by WebIDE * Merge branch 'master' into dev/kws * [Update] refactor kws code * [Update] refactor kws code * [Update] refactor kws code, bug fix * [Update] refactor kws code, bug fix --- modelscope/metainfo.py | 3 + modelscope/models/__init__.py | 1 + modelscope/models/audio/kws/__init__.py | 1 + .../audio/kws/generic_key_word_spotting.py | 30 ++ modelscope/pipelines/audio/__init__.py | 1 + .../pipelines/audio/kws_kwsbp_pipeline.py | 449 ++++++++++++++++++ modelscope/preprocessors/__init__.py | 1 + modelscope/preprocessors/kws.py | 253 ++++++++++ modelscope/utils/constant.py | 1 + tests/pipelines/test_key_word_spotting.py | 334 +++++++++++++ 10 files changed, 1074 insertions(+) create mode 100644 modelscope/models/audio/kws/__init__.py create mode 100644 modelscope/models/audio/kws/generic_key_word_spotting.py create mode 100644 modelscope/pipelines/audio/kws_kwsbp_pipeline.py create mode 100644 modelscope/preprocessors/kws.py create mode 100644 tests/pipelines/test_key_word_spotting.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index af89cf33..680fe2e8 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -21,6 +21,7 @@ class Models(object): sambert_hifi_16k = 'sambert-hifi-16k' generic_tts_frontend = 'generic-tts-frontend' hifigan16k = 'hifigan16k' + kws_kwsbp = 'kws-kwsbp' # multi-modal models ofa = 'ofa' @@ -53,6 +54,7 @@ class Pipelines(object): # audio tasks sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts' speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' + kws_kwsbp = 'kws-kwsbp' # multi-modal tasks image_caption = 'image-caption' @@ -94,6 +96,7 @@ class Preprocessors(object): # audio preprocessor linear_aec_fbank = 'linear-aec-fbank' text_to_tacotron_symbols = 'text-to-tacotron-symbols' + wav_to_lists = 'wav-to-lists' # multi-modal ofa_image_caption = 'ofa-image-caption' diff --git a/modelscope/models/__init__.py b/modelscope/models/__init__.py index 06380035..ebf81c32 100644 --- a/modelscope/models/__init__.py +++ b/modelscope/models/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from .audio.kws import GenericKeyWordSpotting from .audio.tts.am import SambertNetHifi16k from .audio.tts.vocoder import Hifigan16k from .base import Model diff --git a/modelscope/models/audio/kws/__init__.py b/modelscope/models/audio/kws/__init__.py new file mode 100644 index 00000000..d7e163a9 --- /dev/null +++ b/modelscope/models/audio/kws/__init__.py @@ -0,0 +1 @@ +from .generic_key_word_spotting import * # noqa F403 diff --git a/modelscope/models/audio/kws/generic_key_word_spotting.py b/modelscope/models/audio/kws/generic_key_word_spotting.py new file mode 100644 index 00000000..7a738d5b --- /dev/null +++ b/modelscope/models/audio/kws/generic_key_word_spotting.py @@ -0,0 +1,30 @@ +import os +from typing import Any, Dict + +from modelscope.metainfo import Models +from modelscope.models.base import Model +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks + +__all__ = ['GenericKeyWordSpotting'] + + +@MODELS.register_module(Tasks.key_word_spotting, module_name=Models.kws_kwsbp) +class GenericKeyWordSpotting(Model): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the info of model. + + Args: + model_dir (str): the model path. + """ + + self.model_cfg = { + 'model_workspace': model_dir, + 'config_path': os.path.join(model_dir, 'config.yaml') + } + + def forward(self) -> Dict[str, Any]: + """return the info of the model + """ + return self.model_cfg diff --git a/modelscope/pipelines/audio/__init__.py b/modelscope/pipelines/audio/__init__.py index 20c7710a..87ccd49a 100644 --- a/modelscope/pipelines/audio/__init__.py +++ b/modelscope/pipelines/audio/__init__.py @@ -1,2 +1,3 @@ +from .kws_kwsbp_pipeline import * # noqa F403 from .linear_aec_pipeline import LinearAECPipeline from .text_to_speech_pipeline import * # noqa F403 diff --git a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py new file mode 100644 index 00000000..4a69976a --- /dev/null +++ b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py @@ -0,0 +1,449 @@ +import io +import os +import shutil +import stat +import subprocess +from typing import Any, Dict, List + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import WavToLists +from modelscope.utils.constant import Tasks + +__all__ = ['KeyWordSpottingKwsbpPipeline'] + + +@PIPELINES.register_module( + Tasks.key_word_spotting, module_name=Pipelines.kws_kwsbp) +class KeyWordSpottingKwsbpPipeline(Pipeline): + """KWS Pipeline - key word spotting decoding + """ + + def __init__(self, + config_file: str = None, + model: Model = None, + preprocessor: WavToLists = None, + **kwargs): + """use `model` and `preprocessor` to create a kws pipeline for prediction + """ + + super().__init__( + config_file=config_file, + model=model, + preprocessor=preprocessor, + **kwargs) + assert model is not None, 'kws model should be provided' + assert preprocessor is not None, 'preprocessor is none' + + self._preprocessor = preprocessor + self._model = model + + def __call__(self, kws_type: str, wav_path: List[str]) -> Dict[str, Any]: + assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', + 'roc'], f'kws_type {kws_type} is invalid' + output = self._preprocessor.forward(self._model.forward(), kws_type, + wav_path) + output = self.forward(output) + rst = self.postprocess(output) + return rst + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Decoding + """ + + # will generate kws result into dump/dump.JOB.log + out = self._run_with_kwsbp(inputs) + + return out + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """process the kws results + """ + + pos_result_json = {} + neg_result_json = {} + + if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']: + self._parse_dump_log(pos_result_json, inputs['pos_dump_path']) + if inputs['kws_set'] in ['neg_testsets', 'roc']: + self._parse_dump_log(neg_result_json, inputs['neg_dump_path']) + """ + result_json format example: + { + "wav_count": 450, + "keywords": ["小云小云"], + "wav_time": 3560.999999, + "detected": [ + { + "xxx.wav": { + "confidence": "0.990368", + "keyword": "小云小云" + } + }, + { + "yyy.wav": { + "confidence": "0.990368", + "keyword": "小云小云" + } + }, + ...... + ], + "detected_count": 429, + "rejected_count": 21, + "rejected": [ + "yyy.wav", + "zzz.wav", + ...... + ] + } + """ + + rst_dict = {'kws_set': inputs['kws_set']} + + # parsing the result of wav + if inputs['kws_set'] == 'wav': + rst_dict['wav_count'] = pos_result_json['wav_count'] = inputs[ + 'pos_wav_count'] + rst_dict['wav_time'] = round(pos_result_json['wav_time'], 6) + if pos_result_json['detected_count'] == 1: + rst_dict['keywords'] = pos_result_json['keywords'] + rst_dict['detected'] = True + wav_file_name = os.path.basename(inputs['pos_wav_path']) + rst_dict['confidence'] = float(pos_result_json['detected'][0] + [wav_file_name]['confidence']) + else: + rst_dict['detected'] = False + + # parsing the result of pos_tests + elif inputs['kws_set'] == 'pos_testsets': + rst_dict['wav_count'] = pos_result_json['wav_count'] = inputs[ + 'pos_wav_count'] + rst_dict['wav_time'] = round(pos_result_json['wav_time'], 6) + if pos_result_json.__contains__('keywords'): + rst_dict['keywords'] = pos_result_json['keywords'] + + rst_dict['recall'] = round( + pos_result_json['detected_count'] / rst_dict['wav_count'], 6) + + if pos_result_json.__contains__('detected_count'): + rst_dict['detected_count'] = pos_result_json['detected_count'] + if pos_result_json.__contains__('rejected_count'): + rst_dict['rejected_count'] = pos_result_json['rejected_count'] + if pos_result_json.__contains__('rejected'): + rst_dict['rejected'] = pos_result_json['rejected'] + + # parsing the result of neg_tests + elif inputs['kws_set'] == 'neg_testsets': + rst_dict['wav_count'] = neg_result_json['wav_count'] = inputs[ + 'neg_wav_count'] + rst_dict['wav_time'] = round(neg_result_json['wav_time'], 6) + if neg_result_json.__contains__('keywords'): + rst_dict['keywords'] = neg_result_json['keywords'] + + rst_dict['fa_rate'] = 0.0 + rst_dict['fa_per_hour'] = 0.0 + + if neg_result_json.__contains__('detected_count'): + rst_dict['detected_count'] = neg_result_json['detected_count'] + rst_dict['fa_rate'] = round( + neg_result_json['detected_count'] / rst_dict['wav_count'], + 6) + if neg_result_json.__contains__('wav_time'): + rst_dict['fa_per_hour'] = round( + neg_result_json['detected_count'] + / float(neg_result_json['wav_time'] / 3600), 6) + + if neg_result_json.__contains__('rejected_count'): + rst_dict['rejected_count'] = neg_result_json['rejected_count'] + + if neg_result_json.__contains__('detected'): + rst_dict['detected'] = neg_result_json['detected'] + + # parsing the result of roc + elif inputs['kws_set'] == 'roc': + threshold_start = 0.000 + threshold_step = 0.001 + threshold_end = 1.000 + + pos_keywords_list = [] + neg_keywords_list = [] + if pos_result_json.__contains__('keywords'): + pos_keywords_list = pos_result_json['keywords'] + if neg_result_json.__contains__('keywords'): + neg_keywords_list = neg_result_json['keywords'] + + keywords_list = list(set(pos_keywords_list + neg_keywords_list)) + + pos_result_json['wav_count'] = inputs['pos_wav_count'] + neg_result_json['wav_count'] = inputs['neg_wav_count'] + + if len(keywords_list) > 0: + rst_dict['keywords'] = keywords_list + + for index in range(len(rst_dict['keywords'])): + cur_keyword = rst_dict['keywords'][index] + output_list = self._generate_roc_list( + start=threshold_start, + step=threshold_step, + end=threshold_end, + keyword=cur_keyword, + pos_inputs=pos_result_json, + neg_inputs=neg_result_json) + + rst_dict[cur_keyword] = output_list + + return rst_dict + + def _run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + + if inputs['kws_set'] == 'roc': + inputs['keyword_grammar_path'] = os.path.join( + inputs['model_workspace'], 'keywords_roc.json') + + if inputs['kws_set'] == 'wav': + dump_log_path: str = os.path.join(inputs['pos_dump_path'], + 'dump.log') + kws_cmd: str = inputs['kws_tool_path'] + \ + ' --sys-dir=' + inputs['model_workspace'] + \ + ' --cfg-file=' + inputs['cfg_file_path'] + \ + ' --sample-rate=' + inputs['sample_rate'] + \ + ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ + ' --wave-scp=' + os.path.join(inputs['pos_data_path'], 'wave.list') + \ + ' --num-thread=1 > ' + dump_log_path + ' 2>&1' + os.system(kws_cmd) + + if inputs['kws_set'] in ['pos_testsets', 'roc']: + data_dir: str = os.listdir(inputs['pos_data_path']) + wav_list = [] + for i in data_dir: + suffix = os.path.splitext(os.path.basename(i))[1] + if suffix == '.list': + wav_list.append(os.path.join(inputs['pos_data_path'], i)) + + j: int = 0 + process = [] + while j < inputs['pos_num_thread']: + wav_list_path: str = inputs['pos_data_path'] + '/wave.' + str( + j) + '.list' + dump_log_path: str = inputs['pos_dump_path'] + '/dump.' + str( + j) + '.log' + + kws_cmd: str = inputs['kws_tool_path'] + \ + ' --sys-dir=' + inputs['model_workspace'] + \ + ' --cfg-file=' + inputs['cfg_file_path'] + \ + ' --sample-rate=' + inputs['sample_rate'] + \ + ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ + ' --wave-scp=' + wav_list_path + \ + ' --num-thread=1 > ' + dump_log_path + ' 2>&1' + p = subprocess.Popen(kws_cmd, shell=True) + process.append(p) + j += 1 + + k: int = 0 + while k < len(process): + process[k].wait() + k += 1 + + if inputs['kws_set'] in ['neg_testsets', 'roc']: + data_dir: str = os.listdir(inputs['neg_data_path']) + wav_list = [] + for i in data_dir: + suffix = os.path.splitext(os.path.basename(i))[1] + if suffix == '.list': + wav_list.append(os.path.join(inputs['neg_data_path'], i)) + + j: int = 0 + process = [] + while j < inputs['neg_num_thread']: + wav_list_path: str = inputs['neg_data_path'] + '/wave.' + str( + j) + '.list' + dump_log_path: str = inputs['neg_dump_path'] + '/dump.' + str( + j) + '.log' + + kws_cmd: str = inputs['kws_tool_path'] + \ + ' --sys-dir=' + inputs['model_workspace'] + \ + ' --cfg-file=' + inputs['cfg_file_path'] + \ + ' --sample-rate=' + inputs['sample_rate'] + \ + ' --keyword-grammar=' + inputs['keyword_grammar_path'] + \ + ' --wave-scp=' + wav_list_path + \ + ' --num-thread=1 > ' + dump_log_path + ' 2>&1' + p = subprocess.Popen(kws_cmd, shell=True) + process.append(p) + j += 1 + + k: int = 0 + while k < len(process): + process[k].wait() + k += 1 + + return inputs + + def _parse_dump_log(self, result_json: Dict[str, Any], + dump_path: str) -> Dict[str, Any]: + dump_dir = os.listdir(dump_path) + for i in dump_dir: + basename = os.path.splitext(os.path.basename(i))[0] + # find dump.JOB.log + if 'dump' in basename: + with open( + os.path.join(dump_path, i), mode='r', + encoding='utf-8') as file: + while 1: + line = file.readline() + if not line: + break + else: + result_json = self._parse_result_log( + line, result_json) + + def _parse_result_log(self, line: str, + result_json: Dict[str, Any]) -> Dict[str, Any]: + # valid info + if '[rejected]' in line or '[detected]' in line: + detected_count = 0 + rejected_count = 0 + + if result_json.__contains__('detected_count'): + detected_count = result_json['detected_count'] + if result_json.__contains__('rejected_count'): + rejected_count = result_json['rejected_count'] + + if '[detected]' in line: + # [detected], fname:/xxx/.tmp_pos_testsets/pos_testsets/33.wav, + # kw:小云小云, confidence:0.965155, time:[4.62-5.10], threshold:0.00, + detected_count += 1 + content_list = line.split(', ') + file_name = os.path.basename(content_list[1].split(':')[1]) + keyword = content_list[2].split(':')[1] + confidence = content_list[3].split(':')[1] + + keywords_list = [] + if result_json.__contains__('keywords'): + keywords_list = result_json['keywords'] + + if keyword not in keywords_list: + keywords_list.append(keyword) + result_json['keywords'] = keywords_list + + keyword_item = {} + keyword_item['confidence'] = confidence + keyword_item['keyword'] = keyword + item = {} + item[file_name] = keyword_item + + detected_list = [] + if result_json.__contains__('detected'): + detected_list = result_json['detected'] + + detected_list.append(item) + result_json['detected'] = detected_list + + elif '[rejected]' in line: + # [rejected], fname:/xxx/.tmp_pos_testsets/pos_testsets/28.wav + rejected_count += 1 + content_list = line.split(', ') + file_name = os.path.basename(content_list[1].split(':')[1]) + file_name = file_name.strip().replace('\n', + '').replace('\r', '') + + rejected_list = [] + if result_json.__contains__('rejected'): + rejected_list = result_json['rejected'] + + rejected_list.append(file_name) + result_json['rejected'] = rejected_list + + result_json['detected_count'] = detected_count + result_json['rejected_count'] = rejected_count + + elif 'total_proc_time=' in line and 'wav_time=' in line: + # eg: total_proc_time=0.289000(s), wav_time=20.944125(s), kwsbp_rtf=0.013799 + wav_total_time = 0 + content_list = line.split('), ') + if result_json.__contains__('wav_time'): + wav_total_time = result_json['wav_time'] + + wav_time_str = content_list[1].split('=')[1] + wav_time_str = wav_time_str.split('(')[0] + wav_time = float(wav_time_str) + wav_time = round(wav_time, 6) + + if isinstance(wav_time, float): + wav_total_time += wav_time + + result_json['wav_time'] = wav_total_time + + return result_json + + def _generate_roc_list(self, start: float, step: float, end: float, + keyword: str, pos_inputs: Dict[str, Any], + neg_inputs: Dict[str, Any]) -> Dict[str, Any]: + pos_wav_count = pos_inputs['wav_count'] + neg_wav_time = neg_inputs['wav_time'] + det_lists = pos_inputs['detected'] + fa_lists = neg_inputs['detected'] + threshold_cur = start + """ + input det_lists dict + [ + { + "xxx.wav": { + "confidence": "0.990368", + "keyword": "小云小云" + } + }, + { + "yyy.wav": { + "confidence": "0.990368", + "keyword": "小云小云" + } + }, + ] + + output dict + [ + { + "threshold": 0.000, + "recall": 0.999888, + "fa_per_hour": 1.999999 + }, + { + "threshold": 0.001, + "recall": 0.999888, + "fa_per_hour": 1.999999 + }, + ] + """ + + output = [] + while threshold_cur <= end: + det_count = 0 + fa_count = 0 + for index in range(len(det_lists)): + det_item = det_lists[index] + det_wav_item = det_item.get(next(iter(det_item))) + if det_wav_item['keyword'] == keyword: + confidence = float(det_wav_item['confidence']) + if confidence >= threshold_cur: + det_count += 1 + + for index in range(len(fa_lists)): + fa_item = fa_lists[index] + fa_wav_item = fa_item.get(next(iter(fa_item))) + if fa_wav_item['keyword'] == keyword: + confidence = float(fa_wav_item['confidence']) + if confidence >= threshold_cur: + fa_count += 1 + + output_item = { + 'threshold': round(threshold_cur, 3), + 'recall': round(float(det_count / pos_wav_count), 6), + 'fa_per_hour': round(fa_count / float(neg_wav_time / 3600), 6) + } + output.append(output_item) + + threshold_cur += step + + return output diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 942d17c3..1bc06ce3 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -5,6 +5,7 @@ from .base import Preprocessor from .builder import PREPROCESSORS, build_preprocessor from .common import Compose from .image import LoadImage, load_image +from .kws import WavToLists from .multi_modal import OfaImageCaptionPreprocessor from .nlp import * # noqa F403 from .text_to_speech import * # noqa F403 diff --git a/modelscope/preprocessors/kws.py b/modelscope/preprocessors/kws.py new file mode 100644 index 00000000..d69e8283 --- /dev/null +++ b/modelscope/preprocessors/kws.py @@ -0,0 +1,253 @@ +import os +import shutil +import stat +from pathlib import Path +from typing import Any, Dict, List + +import yaml + +from modelscope.metainfo import Preprocessors +from modelscope.models.base import Model +from modelscope.utils.constant import Fields +from .base import Preprocessor +from .builder import PREPROCESSORS + +__all__ = ['WavToLists'] + + +@PREPROCESSORS.register_module( + Fields.audio, module_name=Preprocessors.wav_to_lists) +class WavToLists(Preprocessor): + """generate audio lists file from wav + + Args: + workspace (str): store temporarily kws intermedium and result + """ + + def __init__(self, workspace: str = None): + # the workspace path + if len(workspace) == 0: + self._workspace = os.path.join(os.getcwd(), '.tmp') + else: + self._workspace = workspace + + if not os.path.exists(self._workspace): + os.mkdir(self._workspace) + + def __call__(self, + model: Model = None, + kws_type: str = None, + wav_path: List[str] = None) -> Dict[str, Any]: + """Call functions to load model and wav. + + Args: + model (Model): model should be provided + kws_type (str): kws work type: wav, neg_testsets, pos_testsets, roc + wav_path (List[str]): wav_path[0] is positive wav path, wav_path[1] is negative wav path + Returns: + Dict[str, Any]: the kws result + """ + + assert model is not None, 'preprocess kws model should be provided' + assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', 'roc' + ], f'preprocess kws_type {kws_type} is invalid' + assert wav_path[0] is not None or wav_path[ + 1] is not None, 'preprocess wav_path is invalid' + + self._model = model + out = self.forward(self._model.forward(), kws_type, wav_path) + return out + + def forward(self, model: Dict[str, Any], kws_type: str, + wav_path: List[str]) -> Dict[str, Any]: + assert len(kws_type) > 0, 'preprocess kws_type is empty' + assert len( + model['config_path']) > 0, 'preprocess model[config_path] is empty' + assert os.path.exists( + model['config_path']), 'model config.yaml is absent' + + inputs = model.copy() + + inputs['kws_set'] = kws_type + inputs['workspace'] = self._workspace + if wav_path[0] is not None: + inputs['pos_wav_path'] = wav_path[0] + if wav_path[1] is not None: + inputs['neg_wav_path'] = wav_path[1] + + out = self._read_config(inputs) + out = self._generate_wav_lists(out) + + return out + + def _read_config(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """read and parse config.yaml to get all model files + """ + + assert os.path.exists( + inputs['config_path']), 'model config yaml file does not exist' + + config_file = open(inputs['config_path']) + root = yaml.full_load(config_file) + config_file.close() + + inputs['cfg_file'] = root['cfg_file'] + inputs['cfg_file_path'] = os.path.join(inputs['model_workspace'], + root['cfg_file']) + inputs['keyword_grammar'] = root['keyword_grammar'] + inputs['keyword_grammar_path'] = os.path.join( + inputs['model_workspace'], root['keyword_grammar']) + inputs['sample_rate'] = str(root['sample_rate']) + inputs['kws_tool'] = root['kws_tool'] + + if os.path.exists( + os.path.join(inputs['workspace'], inputs['kws_tool'])): + inputs['kws_tool_path'] = os.path.join(inputs['workspace'], + inputs['kws_tool']) + elif os.path.exists(os.path.join('/usr/bin', inputs['kws_tool'])): + inputs['kws_tool_path'] = os.path.join('/usr/bin', + inputs['kws_tool']) + elif os.path.exists(os.path.join('/bin', inputs['kws_tool'])): + inputs['kws_tool_path'] = os.path.join('/bin', inputs['kws_tool']) + + assert os.path.exists(inputs['kws_tool_path']), 'cannot find kwsbp' + os.chmod(inputs['kws_tool_path'], + stat.S_IXUSR + stat.S_IXGRP + stat.S_IXOTH) + + self._config_checking(inputs) + return inputs + + def _generate_wav_lists(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """assemble wav lists + """ + + if inputs['kws_set'] == 'wav': + inputs['pos_num_thread'] = 1 + wave_scp_content: str = inputs['pos_wav_path'] + '\n' + + with open(os.path.join(inputs['pos_data_path'], 'wave.list'), + 'a') as f: + f.write(wave_scp_content) + + inputs['pos_wav_count'] = 1 + + if inputs['kws_set'] in ['pos_testsets', 'roc']: + # find all positive wave + wav_list = [] + wav_dir = inputs['pos_wav_path'] + wav_list = self._recursion_dir_all_wave(wav_list, wav_dir) + + list_count: int = len(wav_list) + inputs['pos_wav_count'] = list_count + + if list_count <= 128: + inputs['pos_num_thread'] = list_count + j: int = 0 + while j < list_count: + wave_scp_content: str = wav_list[j] + '\n' + wav_list_path = inputs['pos_data_path'] + '/wave.' + str( + j) + '.list' + with open(wav_list_path, 'a') as f: + f.write(wave_scp_content) + j += 1 + + else: + inputs['pos_num_thread'] = 128 + j: int = 0 + k: int = 0 + while j < list_count: + wave_scp_content: str = wav_list[j] + '\n' + wav_list_path = inputs['pos_data_path'] + '/wave.' + str( + k) + '.list' + with open(wav_list_path, 'a') as f: + f.write(wave_scp_content) + j += 1 + k += 1 + if k >= 128: + k = 0 + + if inputs['kws_set'] in ['neg_testsets', 'roc']: + # find all negative wave + wav_list = [] + wav_dir = inputs['neg_wav_path'] + wav_list = self._recursion_dir_all_wave(wav_list, wav_dir) + + list_count: int = len(wav_list) + inputs['neg_wav_count'] = list_count + + if list_count <= 128: + inputs['neg_num_thread'] = list_count + j: int = 0 + while j < list_count: + wave_scp_content: str = wav_list[j] + '\n' + wav_list_path = inputs['neg_data_path'] + '/wave.' + str( + j) + '.list' + with open(wav_list_path, 'a') as f: + f.write(wave_scp_content) + j += 1 + + else: + inputs['neg_num_thread'] = 128 + j: int = 0 + k: int = 0 + while j < list_count: + wave_scp_content: str = wav_list[j] + '\n' + wav_list_path = inputs['neg_data_path'] + '/wave.' + str( + k) + '.list' + with open(wav_list_path, 'a') as f: + f.write(wave_scp_content) + j += 1 + k += 1 + if k >= 128: + k = 0 + + return inputs + + def _recursion_dir_all_wave(self, wav_list, + dir_path: str) -> Dict[str, Any]: + dir_files = os.listdir(dir_path) + for file in dir_files: + file_path = os.path.join(dir_path, file) + if os.path.isfile(file_path): + if file_path.endswith('.wav') or file_path.endswith('.WAV'): + wav_list.append(file_path) + elif os.path.isdir(file_path): + self._recursion_dir_all_wave(wav_list, file_path) + + return wav_list + + def _config_checking(self, inputs: Dict[str, Any]): + + if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']: + inputs['pos_data_path'] = os.path.join(inputs['workspace'], + 'pos_data') + if not os.path.exists(inputs['pos_data_path']): + os.mkdir(inputs['pos_data_path']) + else: + shutil.rmtree(inputs['pos_data_path']) + os.mkdir(inputs['pos_data_path']) + + inputs['pos_dump_path'] = os.path.join(inputs['workspace'], + 'pos_dump') + if not os.path.exists(inputs['pos_dump_path']): + os.mkdir(inputs['pos_dump_path']) + else: + shutil.rmtree(inputs['pos_dump_path']) + os.mkdir(inputs['pos_dump_path']) + + if inputs['kws_set'] in ['neg_testsets', 'roc']: + inputs['neg_data_path'] = os.path.join(inputs['workspace'], + 'neg_data') + if not os.path.exists(inputs['neg_data_path']): + os.mkdir(inputs['neg_data_path']) + else: + shutil.rmtree(inputs['neg_data_path']) + os.mkdir(inputs['neg_data_path']) + + inputs['neg_dump_path'] = os.path.join(inputs['workspace'], + 'neg_dump') + if not os.path.exists(inputs['neg_dump_path']): + os.mkdir(inputs['neg_dump_path']) + else: + shutil.rmtree(inputs['neg_dump_path']) + os.mkdir(inputs['neg_dump_path']) diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 2045efb6..f2215359 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -52,6 +52,7 @@ class Tasks(object): auto_speech_recognition = 'auto-speech-recognition' text_to_speech = 'text-to-speech' speech_signal_process = 'speech-signal-process' + key_word_spotting = 'key-word-spotting' # multi-modal tasks image_captioning = 'image-captioning' diff --git a/tests/pipelines/test_key_word_spotting.py b/tests/pipelines/test_key_word_spotting.py new file mode 100644 index 00000000..e82a4211 --- /dev/null +++ b/tests/pipelines/test_key_word_spotting.py @@ -0,0 +1,334 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tarfile +import unittest + +import requests + +from modelscope.metainfo import Pipelines, Preprocessors +from modelscope.models import Model +from modelscope.pipelines import pipeline +from modelscope.preprocessors import build_preprocessor +from modelscope.utils.constant import Fields, InputFields, Tasks +from modelscope.utils.test_utils import test_level + +KWSBP_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/tools/kwsbp' + +POS_WAV_FILE = '20200707_spk57db_storenoise52db_40cm_xiaoyun_sox_6.wav' +POS_WAV_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/' + POS_WAV_FILE + +POS_TESTSETS_FILE = 'pos_testsets.tar.gz' +POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz' + +NEG_TESTSETS_FILE = 'neg_testsets.tar.gz' +NEG_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/neg_testsets.tar.gz' + + +def un_tar_gz(fname, dirs): + t = tarfile.open(fname) + t.extractall(path=dirs) + + +class KeyWordSpottingTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/speech_charctc_kws_phone-xiaoyunxiaoyun' + self.workspace = os.path.join(os.getcwd(), '.tmp') + if not os.path.exists(self.workspace): + os.mkdir(self.workspace) + + def tearDown(self) -> None: + if os.path.exists(self.workspace): + shutil.rmtree(self.workspace) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_wav(self): + # wav, neg_testsets, pos_testsets, roc + kws_set = 'wav' + + # downloading wav file + wav_file_path = os.path.join(self.workspace, POS_WAV_FILE) + if not os.path.exists(wav_file_path): + r = requests.get(POS_WAV_URL) + with open(wav_file_path, 'wb') as f: + f.write(r.content) + + # downloading kwsbp + kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') + if not os.path.exists(kwsbp_file_path): + r = requests.get(KWSBP_URL) + with open(kwsbp_file_path, 'wb') as f: + f.write(r.content) + + model = Model.from_pretrained(self.model_id) + self.assertTrue(model is not None) + + cfg_preprocessor = dict( + type=Preprocessors.wav_to_lists, workspace=self.workspace) + preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) + self.assertTrue(preprocessor is not None) + + kwsbp_16k_pipline = pipeline( + pipeline_name=Pipelines.kws_kwsbp, + model=model, + preprocessor=preprocessor) + self.assertTrue(kwsbp_16k_pipline is not None) + + kws_result = kwsbp_16k_pipline( + kws_type=kws_set, wav_path=[wav_file_path, None]) + self.assertTrue(kws_result.__contains__('detected')) + """ + kws result json format example: + { + 'wav_count': 1, + 'kws_set': 'wav', + 'wav_time': 9.132938, + 'keywords': ['小云小云'], + 'detected': True, + 'confidence': 0.990368 + } + """ + if kws_result.__contains__('keywords'): + print('test_run_with_wav keywords: ', kws_result['keywords']) + print('test_run_with_wav detected result: ', kws_result['detected']) + print('test_run_with_wav wave time(seconds): ', kws_result['wav_time']) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_pos_testsets(self): + # wav, neg_testsets, pos_testsets, roc + kws_set = 'pos_testsets' + + # downloading pos_testsets file + testsets_file_path = os.path.join(self.workspace, POS_TESTSETS_FILE) + if not os.path.exists(testsets_file_path): + r = requests.get(POS_TESTSETS_URL) + with open(testsets_file_path, 'wb') as f: + f.write(r.content) + + testsets_dir_name = os.path.splitext( + os.path.basename(POS_TESTSETS_FILE))[0] + testsets_dir_name = os.path.splitext( + os.path.basename(testsets_dir_name))[0] + # wav_file_path = /.tmp_pos_testsets/pos_testsets/ + wav_file_path = os.path.join(self.workspace, testsets_dir_name) + + # untar the pos_testsets file + if not os.path.exists(wav_file_path): + un_tar_gz(testsets_file_path, self.workspace) + + # downloading kwsbp -- a kws batch processing tool + kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') + if not os.path.exists(kwsbp_file_path): + r = requests.get(KWSBP_URL) + with open(kwsbp_file_path, 'wb') as f: + f.write(r.content) + + model = Model.from_pretrained(self.model_id) + self.assertTrue(model is not None) + + cfg_preprocessor = dict( + type=Preprocessors.wav_to_lists, workspace=self.workspace) + preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) + self.assertTrue(preprocessor is not None) + + kwsbp_16k_pipline = pipeline( + pipeline_name=Pipelines.kws_kwsbp, + model=model, + preprocessor=preprocessor) + self.assertTrue(kwsbp_16k_pipline is not None) + + kws_result = kwsbp_16k_pipline( + kws_type=kws_set, wav_path=[wav_file_path, None]) + self.assertTrue(kws_result.__contains__('recall')) + """ + kws result json format example: + { + 'wav_count': 450, + 'kws_set': 'pos_testsets', + 'wav_time': 3013.759254, + 'keywords': ["小云小云"], + 'recall': 0.953333, + 'detected_count': 429, + 'rejected_count': 21, + 'rejected': [ + 'yyy.wav', + 'zzz.wav', + ...... + ] + } + """ + if kws_result.__contains__('keywords'): + print('test_run_with_pos_testsets keywords: ', + kws_result['keywords']) + print('test_run_with_pos_testsets recall: ', kws_result['recall']) + print('test_run_with_pos_testsets wave time(seconds): ', + kws_result['wav_time']) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_neg_testsets(self): + # wav, neg_testsets, pos_testsets, roc + kws_set = 'neg_testsets' + + # downloading neg_testsets file + testsets_file_path = os.path.join(self.workspace, NEG_TESTSETS_FILE) + if not os.path.exists(testsets_file_path): + r = requests.get(NEG_TESTSETS_URL) + with open(testsets_file_path, 'wb') as f: + f.write(r.content) + + testsets_dir_name = os.path.splitext( + os.path.basename(NEG_TESTSETS_FILE))[0] + testsets_dir_name = os.path.splitext( + os.path.basename(testsets_dir_name))[0] + # wav_file_path = /.tmp_neg_testsets/neg_testsets/ + wav_file_path = os.path.join(self.workspace, testsets_dir_name) + + # untar the neg_testsets file + if not os.path.exists(wav_file_path): + un_tar_gz(testsets_file_path, self.workspace) + + # downloading kwsbp -- a kws batch processing tool + kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') + if not os.path.exists(kwsbp_file_path): + r = requests.get(KWSBP_URL) + with open(kwsbp_file_path, 'wb') as f: + f.write(r.content) + + model = Model.from_pretrained(self.model_id) + self.assertTrue(model is not None) + + cfg_preprocessor = dict( + type=Preprocessors.wav_to_lists, workspace=self.workspace) + preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) + self.assertTrue(preprocessor is not None) + + kwsbp_16k_pipline = pipeline( + pipeline_name=Pipelines.kws_kwsbp, + model=model, + preprocessor=preprocessor) + self.assertTrue(kwsbp_16k_pipline is not None) + + kws_result = kwsbp_16k_pipline( + kws_type=kws_set, wav_path=[None, wav_file_path]) + self.assertTrue(kws_result.__contains__('fa_rate')) + """ + kws result json format example: + { + 'wav_count': 751, + 'kws_set': 'neg_testsets', + 'wav_time': 3572.180812, + 'keywords': ['小云小云'], + 'fa_rate': 0.001332, + 'fa_per_hour': 1.007788, + 'detected_count': 1, + 'rejected_count': 750, + 'detected': [ + { + '6.wav': { + 'confidence': '0.321170' + } + } + ] + } + """ + if kws_result.__contains__('keywords'): + print('test_run_with_neg_testsets keywords: ', + kws_result['keywords']) + print('test_run_with_neg_testsets fa rate: ', kws_result['fa_rate']) + print('test_run_with_neg_testsets fa per hour: ', + kws_result['fa_per_hour']) + print('test_run_with_neg_testsets wave time(seconds): ', + kws_result['wav_time']) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_roc(self): + # wav, neg_testsets, pos_testsets, roc + kws_set = 'roc' + + # downloading neg_testsets file + testsets_file_path = os.path.join(self.workspace, NEG_TESTSETS_FILE) + if not os.path.exists(testsets_file_path): + r = requests.get(NEG_TESTSETS_URL) + with open(testsets_file_path, 'wb') as f: + f.write(r.content) + + testsets_dir_name = os.path.splitext( + os.path.basename(NEG_TESTSETS_FILE))[0] + testsets_dir_name = os.path.splitext( + os.path.basename(testsets_dir_name))[0] + # neg_file_path = /.tmp_roc/neg_testsets/ + neg_file_path = os.path.join(self.workspace, testsets_dir_name) + + # untar the neg_testsets file + if not os.path.exists(neg_file_path): + un_tar_gz(testsets_file_path, self.workspace) + + # downloading pos_testsets file + testsets_file_path = os.path.join(self.workspace, POS_TESTSETS_FILE) + if not os.path.exists(testsets_file_path): + r = requests.get(POS_TESTSETS_URL) + with open(testsets_file_path, 'wb') as f: + f.write(r.content) + + testsets_dir_name = os.path.splitext( + os.path.basename(POS_TESTSETS_FILE))[0] + testsets_dir_name = os.path.splitext( + os.path.basename(testsets_dir_name))[0] + # pos_file_path = /.tmp_roc/pos_testsets/ + pos_file_path = os.path.join(self.workspace, testsets_dir_name) + + # untar the pos_testsets file + if not os.path.exists(pos_file_path): + un_tar_gz(testsets_file_path, self.workspace) + + # downloading kwsbp -- a kws batch processing tool + kwsbp_file_path = os.path.join(self.workspace, 'kwsbp') + if not os.path.exists(kwsbp_file_path): + r = requests.get(KWSBP_URL) + with open(kwsbp_file_path, 'wb') as f: + f.write(r.content) + + model = Model.from_pretrained(self.model_id) + self.assertTrue(model is not None) + + cfg_preprocessor = dict( + type=Preprocessors.wav_to_lists, workspace=self.workspace) + preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) + self.assertTrue(preprocessor is not None) + + kwsbp_16k_pipline = pipeline( + pipeline_name=Pipelines.kws_kwsbp, + model=model, + preprocessor=preprocessor) + self.assertTrue(kwsbp_16k_pipline is not None) + + kws_result = kwsbp_16k_pipline( + kws_type=kws_set, wav_path=[pos_file_path, neg_file_path]) + """ + kws result json format example: + { + 'kws_set': 'roc', + 'keywords': ['小云小云'], + '小云小云': [ + {'threshold': 0.0, 'recall': 0.953333, 'fa_per_hour': 1.007788}, + {'threshold': 0.001, 'recall': 0.953333, 'fa_per_hour': 1.007788}, + ...... + {'threshold': 0.999, 'recall': 0.004444, 'fa_per_hour': 0.0} + ] + } + """ + if kws_result.__contains__('keywords'): + find_keyword = kws_result['keywords'][0] + print('test_run_with_roc keywords: ', find_keyword) + keyword_list = kws_result[find_keyword] + for item in iter(keyword_list): + threshold: float = item['threshold'] + recall: float = item['recall'] + fa_per_hour: float = item['fa_per_hour'] + print(' threshold:', threshold, ' recall:', recall, + ' fa_per_hour:', fa_per_hour) + + +if __name__ == '__main__': + unittest.main() From 1cc1b3c63711d8542f2735064c886010eeb78ab9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=BA=E4=B8=9E?= Date: Mon, 27 Jun 2022 14:28:05 +0800 Subject: [PATCH 4/7] revise based on comment except chinease comment --- .../nlp/bert_for_sequence_classification.py | 4 ++-- .../models/nlp/sbert_for_sentence_similarity.py | 4 ++-- .../nlp/sbert_for_sentiment_classification.py | 4 ++-- .../models/nlp/sbert_for_token_classification.py | 4 ++-- .../nlp/space/dialog_intent_prediction_model.py | 14 +++++++------- .../models/nlp/space/dialog_modeling_model.py | 14 +++++++------- modelscope/models/nlp/space/model/model_base.py | 15 +++++++++------ .../models/nlp/space/model/unified_transformer.py | 4 ++-- .../nlp/dialog_intent_prediction_pipeline.py | 4 ++-- .../pipelines/nlp/dialog_modeling_pipeline.py | 8 ++++---- .../dialog_intent_prediction_preprocessor.py | 4 ++-- .../space/dialog_modeling_preprocessor.py | 4 ++-- .../preprocessors/space/fields/gen_field.py | 1 + .../nlp/space}/__init__.py | 0 .../nlp/space/metrics/__init__.py | 0 .../nlp/space/metrics/metrics_tracker.py | 0 modelscope/trainers/nlp/space/trainer/__init__.py | 0 .../nlp/space/trainer/gen_trainer.py} | 0 .../nlp/space/trainer/intent_trainer.py} | 0 tests/pipelines/test_dialog_intent_prediction.py | 4 ++-- tests/pipelines/test_dialog_modeling.py | 4 ++-- 21 files changed, 48 insertions(+), 44 deletions(-) rename modelscope/{models/nlp/space/application => trainers/nlp/space}/__init__.py (100%) rename modelscope/{models => trainers}/nlp/space/metrics/__init__.py (100%) rename modelscope/{models => trainers}/nlp/space/metrics/metrics_tracker.py (100%) create mode 100644 modelscope/trainers/nlp/space/trainer/__init__.py rename modelscope/{models/nlp/space/application/gen_app.py => trainers/nlp/space/trainer/gen_trainer.py} (100%) rename modelscope/{models/nlp/space/application/intent_app.py => trainers/nlp/space/trainer/intent_trainer.py} (100%) diff --git a/modelscope/models/nlp/bert_for_sequence_classification.py b/modelscope/models/nlp/bert_for_sequence_classification.py index 7d85fa28..6eb27f03 100644 --- a/modelscope/models/nlp/bert_for_sequence_classification.py +++ b/modelscope/models/nlp/bert_for_sequence_classification.py @@ -4,8 +4,8 @@ from typing import Any, Dict import json import numpy as np -from modelscope.metainfo import Models -from modelscope.utils.constant import Tasks +from ...metainfo import Models +from ...utils.constant import Tasks from ..base import Model from ..builder import MODELS diff --git a/modelscope/models/nlp/sbert_for_sentence_similarity.py b/modelscope/models/nlp/sbert_for_sentence_similarity.py index 25c38a2e..5fa487e5 100644 --- a/modelscope/models/nlp/sbert_for_sentence_similarity.py +++ b/modelscope/models/nlp/sbert_for_sentence_similarity.py @@ -1,5 +1,5 @@ -from modelscope.metainfo import Models -from modelscope.utils.constant import Tasks +from ...metainfo import Models +from ...utils.constant import Tasks from ..builder import MODELS from .sbert_for_sequence_classification import \ SbertForSequenceClassificationBase diff --git a/modelscope/models/nlp/sbert_for_sentiment_classification.py b/modelscope/models/nlp/sbert_for_sentiment_classification.py index 72fb92f0..00ec8b73 100644 --- a/modelscope/models/nlp/sbert_for_sentiment_classification.py +++ b/modelscope/models/nlp/sbert_for_sentiment_classification.py @@ -1,5 +1,5 @@ -from modelscope.metainfo import Models -from modelscope.utils.constant import Tasks +from ...metainfo import Models +from ...utils.constant import Tasks from ..builder import MODELS from .sbert_for_sequence_classification import \ SbertForSequenceClassificationBase diff --git a/modelscope/models/nlp/sbert_for_token_classification.py b/modelscope/models/nlp/sbert_for_token_classification.py index fd175033..a23002ee 100644 --- a/modelscope/models/nlp/sbert_for_token_classification.py +++ b/modelscope/models/nlp/sbert_for_token_classification.py @@ -3,8 +3,8 @@ from typing import Any, Dict, Union import numpy as np import torch -from modelscope.metainfo import Models -from modelscope.utils.constant import Tasks +from ...metainfo import Models +from ...utils.constant import Tasks from ..base import Model, Tensor from ..builder import MODELS diff --git a/modelscope/models/nlp/space/dialog_intent_prediction_model.py b/modelscope/models/nlp/space/dialog_intent_prediction_model.py index a6bd1d27..74e4e9e7 100644 --- a/modelscope/models/nlp/space/dialog_intent_prediction_model.py +++ b/modelscope/models/nlp/space/dialog_intent_prediction_model.py @@ -2,19 +2,19 @@ import os from typing import Any, Dict from ....preprocessors.space.fields.intent_field import IntentBPETextField +from ....trainers.nlp.space.trainer.intent_trainer import IntentTrainer from ....utils.config import Config -from ....utils.constant import Tasks +from ....utils.constant import ModelFile, Tasks from ...base import Model, Tensor from ...builder import MODELS -from .application.intent_app import IntentTrainer from .model.generator import Generator -from .model.model_base import ModelBase +from .model.model_base import SpaceModelBase -__all__ = ['DialogIntentModel'] +__all__ = ['SpaceForDialogIntentModel'] @MODELS.register_module(Tasks.dialog_intent_prediction, module_name=r'space') -class DialogIntentModel(Model): +class SpaceForDialogIntentModel(Model): def __init__(self, model_dir: str, *args, **kwargs): """initialize the test generation model from the `model_dir` path. @@ -30,13 +30,13 @@ class DialogIntentModel(Model): self.config = kwargs.pop( 'config', Config.from_file( - os.path.join(self.model_dir, 'configuration.json'))) + os.path.join(self.model_dir, ModelFile.CONFIGURATION))) self.text_field = kwargs.pop( 'text_field', IntentBPETextField(self.model_dir, config=self.config)) self.generator = Generator.create(self.config, reader=self.text_field) - self.model = ModelBase.create( + self.model = SpaceModelBase.create( model_dir=model_dir, config=self.config, reader=self.text_field, diff --git a/modelscope/models/nlp/space/dialog_modeling_model.py b/modelscope/models/nlp/space/dialog_modeling_model.py index ad8212c0..e11ef9fd 100644 --- a/modelscope/models/nlp/space/dialog_modeling_model.py +++ b/modelscope/models/nlp/space/dialog_modeling_model.py @@ -2,19 +2,19 @@ import os from typing import Any, Dict, Optional from ....preprocessors.space.fields.gen_field import MultiWOZBPETextField +from ....trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer from ....utils.config import Config -from ....utils.constant import Tasks +from ....utils.constant import ModelFile, Tasks from ...base import Model, Tensor from ...builder import MODELS -from .application.gen_app import MultiWOZTrainer from .model.generator import Generator -from .model.model_base import ModelBase +from .model.model_base import SpaceModelBase -__all__ = ['DialogModelingModel'] +__all__ = ['SpaceForDialogModelingModel'] @MODELS.register_module(Tasks.dialog_modeling, module_name=r'space') -class DialogModelingModel(Model): +class SpaceForDialogModelingModel(Model): def __init__(self, model_dir: str, *args, **kwargs): """initialize the test generation model from the `model_dir` path. @@ -30,12 +30,12 @@ class DialogModelingModel(Model): self.config = kwargs.pop( 'config', Config.from_file( - os.path.join(self.model_dir, 'configuration.json'))) + os.path.join(self.model_dir, ModelFile.CONFIGURATION))) self.text_field = kwargs.pop( 'text_field', MultiWOZBPETextField(self.model_dir, config=self.config)) self.generator = Generator.create(self.config, reader=self.text_field) - self.model = ModelBase.create( + self.model = SpaceModelBase.create( model_dir=model_dir, config=self.config, reader=self.text_field, diff --git a/modelscope/models/nlp/space/model/model_base.py b/modelscope/models/nlp/space/model/model_base.py index cdd355a5..42496e76 100644 --- a/modelscope/models/nlp/space/model/model_base.py +++ b/modelscope/models/nlp/space/model/model_base.py @@ -5,8 +5,10 @@ import os import torch.nn as nn +from .....utils.constant import ModelFile -class ModelBase(nn.Module): + +class SpaceModelBase(nn.Module): """ Basic model wrapper for static graph and dygrpah. """ @@ -14,21 +16,22 @@ class ModelBase(nn.Module): @classmethod def register(cls, name): - ModelBase._registry[name] = cls + SpaceModelBase._registry[name] = cls return @staticmethod def by_name(name): - return ModelBase._registry[name] + return SpaceModelBase._registry[name] @staticmethod def create(model_dir, config, *args, **kwargs): - model_cls = ModelBase.by_name(config.Model.model) + model_cls = SpaceModelBase.by_name(config.Model.model) return model_cls(model_dir, config, *args, **kwargs) def __init__(self, model_dir, config): - super(ModelBase, self).__init__() - self.init_checkpoint = os.path.join(model_dir, 'pytorch_model.bin') + super(SpaceModelBase, self).__init__() + self.init_checkpoint = os.path.join(model_dir, + ModelFile.TORCH_MODEL_BIN_FILE) self.abandon_label = config.Dataset.abandon_label self.use_gpu = config.use_gpu self.gpu = config.Trainer.gpu diff --git a/modelscope/models/nlp/space/model/unified_transformer.py b/modelscope/models/nlp/space/model/unified_transformer.py index 2636553d..15a18056 100644 --- a/modelscope/models/nlp/space/model/unified_transformer.py +++ b/modelscope/models/nlp/space/model/unified_transformer.py @@ -9,10 +9,10 @@ import torch.nn.functional as F from ..modules.embedder import Embedder from ..modules.transformer_block import TransformerBlock -from .model_base import ModelBase +from .model_base import SpaceModelBase -class UnifiedTransformer(ModelBase): +class UnifiedTransformer(SpaceModelBase): """ Implement unified transformer. """ diff --git a/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py b/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py index 3fd38641..4677b62e 100644 --- a/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py +++ b/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py @@ -1,7 +1,7 @@ from typing import Any, Dict from ...metainfo import Pipelines -from ...models.nlp import DialogIntentModel +from ...models.nlp import SpaceForDialogIntentModel from ...preprocessors import DialogIntentPredictionPreprocessor from ...utils.constant import Tasks from ..base import Pipeline @@ -15,7 +15,7 @@ __all__ = ['DialogIntentPredictionPipeline'] module_name=Pipelines.dialog_intent_prediction) class DialogIntentPredictionPipeline(Pipeline): - def __init__(self, model: DialogIntentModel, + def __init__(self, model: SpaceForDialogIntentModel, preprocessor: DialogIntentPredictionPreprocessor, **kwargs): """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction diff --git a/modelscope/pipelines/nlp/dialog_modeling_pipeline.py b/modelscope/pipelines/nlp/dialog_modeling_pipeline.py index 778284de..29303d4b 100644 --- a/modelscope/pipelines/nlp/dialog_modeling_pipeline.py +++ b/modelscope/pipelines/nlp/dialog_modeling_pipeline.py @@ -1,9 +1,9 @@ from typing import Any, Dict, Optional -from modelscope.models.nlp import DialogModelingModel -from modelscope.preprocessors import DialogModelingPreprocessor -from modelscope.utils.constant import Tasks from ...metainfo import Pipelines +from ...models.nlp import SpaceForDialogModelingModel +from ...preprocessors import DialogModelingPreprocessor +from ...utils.constant import Tasks from ..base import Pipeline, Tensor from ..builder import PIPELINES @@ -14,7 +14,7 @@ __all__ = ['DialogModelingPipeline'] Tasks.dialog_modeling, module_name=Pipelines.dialog_modeling) class DialogModelingPipeline(Pipeline): - def __init__(self, model: DialogModelingModel, + def __init__(self, model: SpaceForDialogModelingModel, preprocessor: DialogModelingPreprocessor, **kwargs): """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction diff --git a/modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py b/modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py index 733abf24..4b46b044 100644 --- a/modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py +++ b/modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py @@ -4,7 +4,7 @@ import os from typing import Any, Dict from ...utils.config import Config -from ...utils.constant import Fields +from ...utils.constant import Fields, ModelFile from ...utils.type_assert import type_assert from ..base import Preprocessor from ..builder import PREPROCESSORS @@ -26,7 +26,7 @@ class DialogIntentPredictionPreprocessor(Preprocessor): self.model_dir: str = model_dir self.config = Config.from_file( - os.path.join(self.model_dir, 'configuration.json')) + os.path.join(self.model_dir, ModelFile.CONFIGURATION)) self.text_field = IntentBPETextField( self.model_dir, config=self.config) diff --git a/modelscope/preprocessors/space/dialog_modeling_preprocessor.py b/modelscope/preprocessors/space/dialog_modeling_preprocessor.py index b0758b40..d5e02c4a 100644 --- a/modelscope/preprocessors/space/dialog_modeling_preprocessor.py +++ b/modelscope/preprocessors/space/dialog_modeling_preprocessor.py @@ -4,7 +4,7 @@ import os from typing import Any, Dict from ...utils.config import Config -from ...utils.constant import Fields +from ...utils.constant import Fields, ModelFile from ...utils.type_assert import type_assert from ..base import Preprocessor from ..builder import PREPROCESSORS @@ -26,7 +26,7 @@ class DialogModelingPreprocessor(Preprocessor): self.model_dir: str = model_dir self.config = Config.from_file( - os.path.join(self.model_dir, 'configuration.json')) + os.path.join(self.model_dir, ModelFile.CONFIGURATION)) self.text_field = MultiWOZBPETextField( self.model_dir, config=self.config) diff --git a/modelscope/preprocessors/space/fields/gen_field.py b/modelscope/preprocessors/space/fields/gen_field.py index 49a30e8f..410aea7e 100644 --- a/modelscope/preprocessors/space/fields/gen_field.py +++ b/modelscope/preprocessors/space/fields/gen_field.py @@ -8,6 +8,7 @@ from itertools import chain import numpy as np +from ....utils.constant import ModelFile from ....utils.nlp.space import ontology, utils from ....utils.nlp.space.db_ops import MultiWozDB from ....utils.nlp.space.utils import list2np diff --git a/modelscope/models/nlp/space/application/__init__.py b/modelscope/trainers/nlp/space/__init__.py similarity index 100% rename from modelscope/models/nlp/space/application/__init__.py rename to modelscope/trainers/nlp/space/__init__.py diff --git a/modelscope/models/nlp/space/metrics/__init__.py b/modelscope/trainers/nlp/space/metrics/__init__.py similarity index 100% rename from modelscope/models/nlp/space/metrics/__init__.py rename to modelscope/trainers/nlp/space/metrics/__init__.py diff --git a/modelscope/models/nlp/space/metrics/metrics_tracker.py b/modelscope/trainers/nlp/space/metrics/metrics_tracker.py similarity index 100% rename from modelscope/models/nlp/space/metrics/metrics_tracker.py rename to modelscope/trainers/nlp/space/metrics/metrics_tracker.py diff --git a/modelscope/trainers/nlp/space/trainer/__init__.py b/modelscope/trainers/nlp/space/trainer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/nlp/space/application/gen_app.py b/modelscope/trainers/nlp/space/trainer/gen_trainer.py similarity index 100% rename from modelscope/models/nlp/space/application/gen_app.py rename to modelscope/trainers/nlp/space/trainer/gen_trainer.py diff --git a/modelscope/models/nlp/space/application/intent_app.py b/modelscope/trainers/nlp/space/trainer/intent_trainer.py similarity index 100% rename from modelscope/models/nlp/space/application/intent_app.py rename to modelscope/trainers/nlp/space/trainer/intent_trainer.py diff --git a/tests/pipelines/test_dialog_intent_prediction.py b/tests/pipelines/test_dialog_intent_prediction.py index 97cdbb3d..ae3a9bf1 100644 --- a/tests/pipelines/test_dialog_intent_prediction.py +++ b/tests/pipelines/test_dialog_intent_prediction.py @@ -3,7 +3,7 @@ import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model -from modelscope.models.nlp import DialogIntentModel +from modelscope.models.nlp import SpaceForDialogIntentModel from modelscope.pipelines import DialogIntentPredictionPipeline, pipeline from modelscope.preprocessors import DialogIntentPredictionPreprocessor from modelscope.utils.constant import Tasks @@ -20,7 +20,7 @@ class DialogIntentPredictionTest(unittest.TestCase): def test_run(self): cache_path = snapshot_download(self.model_id) preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path) - model = DialogIntentModel( + model = SpaceForDialogIntentModel( model_dir=cache_path, text_field=preprocessor.text_field, config=preprocessor.config) diff --git a/tests/pipelines/test_dialog_modeling.py b/tests/pipelines/test_dialog_modeling.py index f606ba49..79644bc5 100644 --- a/tests/pipelines/test_dialog_modeling.py +++ b/tests/pipelines/test_dialog_modeling.py @@ -6,7 +6,7 @@ import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model -from modelscope.models.nlp import DialogModelingModel +from modelscope.models.nlp import SpaceForDialogModelingModel from modelscope.pipelines import DialogModelingPipeline, pipeline from modelscope.preprocessors import DialogModelingPreprocessor from modelscope.utils.constant import Tasks @@ -97,7 +97,7 @@ class DialogModelingTest(unittest.TestCase): cache_path = snapshot_download(self.model_id) preprocessor = DialogModelingPreprocessor(model_dir=cache_path) - model = DialogModelingModel( + model = SpaceForDialogModelingModel( model_dir=cache_path, text_field=preprocessor.text_field, config=preprocessor.config) From cfeac7afd893d215225ad738aceec28148902015 Mon Sep 17 00:00:00 2001 From: Yingda Chen Date: Mon, 27 Jun 2022 15:07:46 +0800 Subject: [PATCH 5/7] [to #42322933] skip aec test --- tests/pipelines/test_speech_signal_process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/test_speech_signal_process.py b/tests/pipelines/test_speech_signal_process.py index 1b070fda..bc3a542e 100644 --- a/tests/pipelines/test_speech_signal_process.py +++ b/tests/pipelines/test_speech_signal_process.py @@ -34,7 +34,7 @@ class SpeechSignalProcessTest(unittest.TestCase): # A temporary hack to provide c++ lib. Download it first. download(AEC_LIB_URL, AEC_LIB_FILE) - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run(self): download(NEAREND_MIC_URL, NEAREND_MIC_FILE) download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE) From 5a2865c273820b6db8647e6c15ea770d10bab7e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=AB?= Date: Tue, 28 Jun 2022 11:18:40 +0800 Subject: [PATCH 6/7] change Chinese notes of space3.0 into English --- modelscope/models/nlp/space/model/generator.py | 10 ++++------ .../models/nlp/space/model/unified_transformer.py | 10 ++-------- .../preprocessors/space/fields/dst_processors.py | 1 - .../preprocessors/space/fields/gen_field.py | 15 ++------------- .../preprocessors/space/fields/intent_field.py | 10 ---------- modelscope/preprocessors/space/tokenizer.py | 4 ---- .../trainers/nlp/space/metrics/metrics_tracker.py | 4 ++-- .../trainers/nlp/space/trainer/gen_trainer.py | 2 +- .../trainers/nlp/space/trainer/intent_trainer.py | 15 +++++++-------- modelscope/utils/nlp/space/db_ops.py | 4 ++-- 10 files changed, 20 insertions(+), 55 deletions(-) diff --git a/modelscope/models/nlp/space/model/generator.py b/modelscope/models/nlp/space/model/generator.py index bdf6b135..fab70fd6 100644 --- a/modelscope/models/nlp/space/model/generator.py +++ b/modelscope/models/nlp/space/model/generator.py @@ -183,7 +183,7 @@ class BeamSearch(Generator): scores_after_end = np.full(self.vocab_size, -1e10, dtype='float32') scores_after_end[ - self.pad_id] = 0 # 希望之后只生成,故使词表中log(p())最高(0) + self.pad_id] = 0 # we want is generated after ,so maximum log(p()) is (0) scores_after_end = torch.from_numpy(scores_after_end) if self.use_gpu: @@ -245,10 +245,8 @@ class BeamSearch(Generator): scores = scores.reshape(batch_size, beam_size * self.vocab_size) topk_scores, topk_indices = torch.topk(scores, beam_size) - # topk_indices: [batch_size, beam_size * self.vocab_size] (已reshape) - # 判断当前时间步产生词的前一个词在哪个beam中,对vocab_size取商 + # topk_indices: [batch_size, beam_size * self.vocab_size] (already reshaped) parent_idx = topk_indices.floor_divide(self.vocab_size) - # 对vocab_size取余 preds = topk_indices % self.vocab_size # Gather state / sequence_scores @@ -262,14 +260,14 @@ class BeamSearch(Generator): predictions = predictions.reshape(batch_size, beam_size, step) predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2) - # 希望生成的整个句子已完结,所以要求最后一个token为或者(跟在之后),否则惩罚 + # The last token should be or pre_ids = predictions[:, :, -1] pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \ (1 - torch.not_equal(pre_ids, self.pad_id).float()) sequence_scores = sequence_scores * pre_eos_mask + ( 1 - pre_eos_mask) * (-1e10) - # 先获得ascending排序的index,便于之后对predictions和sequence_scores排序(针对beam size轴) + # first get ascending ordered index,then sort "predictions" and "sequence_scores" indices = torch.argsort(sequence_scores, dim=1) indices = indices + pos_index indices = indices.reshape(-1) diff --git a/modelscope/models/nlp/space/model/unified_transformer.py b/modelscope/models/nlp/space/model/unified_transformer.py index 15a18056..8060879d 100644 --- a/modelscope/models/nlp/space/model/unified_transformer.py +++ b/modelscope/models/nlp/space/model/unified_transformer.py @@ -122,11 +122,7 @@ class UnifiedTransformer(SpaceModelBase): auto_regressive=False): """ Create attention mask. - 创建从序列形式到矩阵形式的mask:[batch_size, max_seq_len, 1] -> [batch_size, max_seq_len, max_seq_len] - mask除了要考虑attention mask(自回归),还需要考虑pad的mask(自回归和双向) - 注: - 1. 一个句子中的非词看整个句子,该句中只有词才被mask - 2. 一个句子中的词看整个句子,该句的所有词都应该被mask + from sequence to matrix:[batch_size, max_seq_len, 1] -> [batch_size, max_seq_len, max_seq_len] @param : input_mask @type : Variable(shape: [batch_size, max_seq_len]) @@ -142,13 +138,11 @@ class UnifiedTransformer(SpaceModelBase): mask = mask1 * mask2 if append_head: - # 拼接上句首位置([M]/z)的mask mask = torch.cat([mask[:, :1, :], mask], dim=1) mask = torch.cat([mask[:, :, :1], mask], dim=2) seq_len += 1 if auto_regressive: - # 将tgt端的 mask和自回归attention mask融合 seq_mask = self.sequence_mask[:seq_len, :seq_len] seq_mask = seq_mask.to(mask.device) mask = mask * seq_mask @@ -159,7 +153,7 @@ class UnifiedTransformer(SpaceModelBase): def _join_mask(self, mask1, mask2): """ Merge source attention mask and target attention mask. - 合并后的整个mask矩阵可以分为四个部分:左上lu/右上ru/左下lb/右下rb + There are four parts:left upper (lu) / right upper (ru) / left below (lb) / right below (rb) @param : mask1 : source attention mask @type : Variable(shape: [batch_size, max_src_len, max_src_len]) diff --git a/modelscope/preprocessors/space/fields/dst_processors.py b/modelscope/preprocessors/space/fields/dst_processors.py index c5c81f66..22e06eec 100644 --- a/modelscope/preprocessors/space/fields/dst_processors.py +++ b/modelscope/preprocessors/space/fields/dst_processors.py @@ -570,7 +570,6 @@ class multiwoz22Processor(DSTProcessor): def delex_utt(self, utt, values, unk_token='[UNK]'): utt_norm = self.tokenize(utt) for s, vals in values.items(): - # TODO vals可能不是数组形式,而是初始化的字符串"none" for v in vals: if v != 'none': v_norm = self.tokenize(v) diff --git a/modelscope/preprocessors/space/fields/gen_field.py b/modelscope/preprocessors/space/fields/gen_field.py index 410aea7e..fa037145 100644 --- a/modelscope/preprocessors/space/fields/gen_field.py +++ b/modelscope/preprocessors/space/fields/gen_field.py @@ -36,18 +36,10 @@ class BPETextField(object): @property def bot_id(self): - """ - 用于区分user和bot两个角色 - 1和0不是词表中的index,而是专门针对role的index,大小就为2,对应超参数'num_type_embeddings' - """ return 0 @property def user_id(self): - """ - 用于区分user和bot两个角色 - 1和0不是词表中的index,而是专门针对role的index,大小就为2,对应超参数'num_type_embeddings' - """ return 1 @property @@ -186,7 +178,7 @@ class BPETextField(object): ] src_role.append(list(chain(*role))[-self.max_len:]) - # src端序列和tgt端序列需要分开pad,以保证解码时第一个词对齐 + # src sequence and tgt sequence should be padded separately,to make sure the first word is aligned src_token = list2np(src_token, padding=self.pad_id) src_pos = list2np(src_pos, padding=self.pad_id) src_turn = list2np(src_turn, padding=self.pad_id) @@ -439,7 +431,7 @@ class MultiWOZBPETextField(BPETextField): # logging.info(log_str) # cfg.num_training_steps = num_training_steps * cfg.epoch_num self.set_stats[set_name][ - 'num_training_steps_per_epoch'] = num_training_steps # turn-level的steps + 'num_training_steps_per_epoch'] = num_training_steps # turn-level steps self.set_stats[set_name]['num_turns'] = num_turns self.set_stats[set_name]['num_dials'] = num_dials @@ -548,9 +540,6 @@ class MultiWOZBPETextField(BPETextField): def convert_batch_turn(self, turn_batch, pv_batch, first_turn=False): """ - URURU:这里的含义是指轮级别的训练(数据整理),区别于session级别的训练方式(convert_batch_session); - 但不同于eval时的含义,eval时二者都是逐轮依次生成的,那时URURU的含义请见相关的函数注释; - convert the current and the last turn concat [U_0,R_0,...,U_{t-1}, R_{t-1}, U_t, B_t, A_t, R_t] firts turn: [U_t, B_t, A_t, R_t] diff --git a/modelscope/preprocessors/space/fields/intent_field.py b/modelscope/preprocessors/space/fields/intent_field.py index 35e1693c..15bd20b6 100644 --- a/modelscope/preprocessors/space/fields/intent_field.py +++ b/modelscope/preprocessors/space/fields/intent_field.py @@ -154,18 +154,10 @@ class BPETextField(object): @property def bot_id(self): - """ - 用于区分user和bot两个角色 - 1和0不是词表中的index,而是专门针对role的index,大小就为2,对应超参数'num_type_embeddings' - """ return 0 @property def user_id(self): - """ - 用于区分user和bot两个角色 - 1和0不是词表中的index,而是专门针对role的index,大小就为2,对应超参数'num_type_embeddings' - """ return 1 def add_sepcial_tokens(self): @@ -862,7 +854,6 @@ class BPETextField(object): ] src_role.append(list(chain(*role))[-self.max_len:]) - # src端序列和tgt端序列需要分开pad,以保证解码时第一个词对齐 src_token = list2np(src_token, padding=self.pad_id) src_pos = list2np(src_pos, padding=self.pad_id) src_turn = list2np(src_turn, padding=self.pad_id) @@ -1038,7 +1029,6 @@ class IntentBPETextField(BPETextField): ] * l for i, l in enumerate(utt_lens)] src_role.append(list(chain(*role))[-self.max_len:]) - # src端序列和tgt端序列需要分开pad,以保证解码时第一个词对齐 src_token = list2np(src_token, padding=self.pad_id) src_pos = list2np(src_pos, padding=self.pad_id) src_turn = list2np(src_turn, padding=self.pad_id) diff --git a/modelscope/preprocessors/space/tokenizer.py b/modelscope/preprocessors/space/tokenizer.py index 764552cd..87f7e8c3 100644 --- a/modelscope/preprocessors/space/tokenizer.py +++ b/modelscope/preprocessors/space/tokenizer.py @@ -56,10 +56,6 @@ class Tokenizer(object): self._tokenizer = BertTokenizer( vocab_path, never_split=self.special_tokens) for tok in self.special_tokens: - ''' - 需要先保证special_tokens在词表中,这里设置special_tokens的目的是为了这些词能够完整占位,不再切分为子词; - 若不在词表中,可以使用词表中的[unused]符号进行转换:spec_convert_dict; - ''' assert tok in self._tokenizer.vocab, f"special token '{tok}' is not in the vocabulary" self.vocab_size = len(self._tokenizer.vocab) elif tokenizer_type == 'GPT2': diff --git a/modelscope/trainers/nlp/space/metrics/metrics_tracker.py b/modelscope/trainers/nlp/space/metrics/metrics_tracker.py index c08eba68..865600d3 100644 --- a/modelscope/trainers/nlp/space/metrics/metrics_tracker.py +++ b/modelscope/trainers/nlp/space/metrics/metrics_tracker.py @@ -10,8 +10,8 @@ class MetricsTracker(object): """ Tracking metrics. """ def __init__(self): - self.metrics_val = defaultdict(float) # 记录最新一个batch返回的指标 - self.metrics_avg = defaultdict(float) # 维护一个epoch内已训练batches的平均指标 + self.metrics_val = defaultdict(float) # for one batch + self.metrics_avg = defaultdict(float) # avg batches self.num_samples = 0 def update(self, metrics, num_samples): diff --git a/modelscope/trainers/nlp/space/trainer/gen_trainer.py b/modelscope/trainers/nlp/space/trainer/gen_trainer.py index e09e2100..41e5f81e 100644 --- a/modelscope/trainers/nlp/space/trainer/gen_trainer.py +++ b/modelscope/trainers/nlp/space/trainer/gen_trainer.py @@ -563,7 +563,7 @@ class MultiWOZTrainer(Trainer): generated_bs = outputs[0].cpu().numpy().tolist() bspn_gen = self.decode_generated_bspn(generated_bs) # check DB result - if self.reader.use_true_db_pointer: # 控制当前轮的db是否为ground truth + if self.reader.use_true_db_pointer: # To control whether current db is ground truth db = turn['db'] else: db_result = self.reader.bspan_to_DBpointer( diff --git a/modelscope/trainers/nlp/space/trainer/intent_trainer.py b/modelscope/trainers/nlp/space/trainer/intent_trainer.py index 2c5081d7..f5ae6e31 100644 --- a/modelscope/trainers/nlp/space/trainer/intent_trainer.py +++ b/modelscope/trainers/nlp/space/trainer/intent_trainer.py @@ -314,18 +314,18 @@ class IntentTrainer(Trainer): self.can_norm = config.Trainer.can_norm def can_normalization(self, y_pred, y_true, ex_data_iter): - # 预测结果,计算修正前准确率 + # compute ACC acc_original = np.mean([y_pred.argmax(1) == y_true]) message = 'original acc: %s' % acc_original - # 评价每个预测结果的不确定性 + # compute uncertainty k = 3 y_pred_topk = np.sort(y_pred, axis=1)[:, -k:] y_pred_topk /= y_pred_topk.sum(axis=1, keepdims=True) y_pred_uncertainty =\ -(y_pred_topk * np.log(y_pred_topk)).sum(1) / np.log(k) - # 选择阈值,划分高、低置信度两部分 + # choose threshold # print(np.sort(y_pred_uncertainty)[-100:].tolist()) threshold = 0.7 y_pred_confident = y_pred[y_pred_uncertainty < threshold] @@ -333,8 +333,7 @@ class IntentTrainer(Trainer): y_true_confident = y_true[y_pred_uncertainty < threshold] y_true_unconfident = y_true[y_pred_uncertainty >= threshold] - # 显示两部分各自的准确率 - # 一般而言,高置信度集准确率会远高于低置信度的 + # compute ACC again for high and low confidence sets acc_confident = (y_pred_confident.argmax(1) == y_true_confident).mean() \ if len(y_true_confident) else 0. acc_unconfident = (y_pred_unconfident.argmax(1) == y_true_unconfident).mean() \ @@ -344,7 +343,7 @@ class IntentTrainer(Trainer): message += ' (%s) unconfident acc: %s' % (len(y_true_unconfident), acc_unconfident) - # 从训练集统计先验分布 + # get prior distribution from training set prior = np.zeros(self.func_model.num_intent) for _, (batch, batch_size) in ex_data_iter: for intent_label in batch['intent_label']: @@ -352,7 +351,7 @@ class IntentTrainer(Trainer): prior /= prior.sum() - # 逐个修改低置信度样本,并重新评价准确率 + # revise each sample from the low confidence set, and compute new ACC right, alpha, iters = 0, 1, 1 for i, y in enumerate(y_pred_unconfident): Y = np.concatenate([y_pred_confident, y[None]], axis=0) @@ -365,7 +364,7 @@ class IntentTrainer(Trainer): if y.argmax() == y_true_unconfident[i]: right += 1 - # 输出修正后的准确率 + # get final ACC acc_final = \ (acc_confident * len(y_pred_confident) + right) / \ len(y_pred) diff --git a/modelscope/utils/nlp/space/db_ops.py b/modelscope/utils/nlp/space/db_ops.py index 2168c079..10c3aab7 100644 --- a/modelscope/utils/nlp/space/db_ops.py +++ b/modelscope/utils/nlp/space/db_ops.py @@ -172,8 +172,8 @@ class MultiWozDB(object): continue if s in ['people', 'stay'] or (domain == 'hotel' and s == 'day') or \ (domain == 'restaurant' and s in ['day', 'time']): - # 因为这些inform slot属于book info,而数据库中没有这些slot; - # 能否book是根据user goal中的信息判断,而非通过数据库查询; + # These inform slots belong to "book info",which do not exist in DB + # "book" is according to the user goal,not DB continue skip_case = { From 1e361a4ad19dd0a0626c6e3518a61ac34a671e57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=BA=E4=B8=9E?= Date: Tue, 28 Jun 2022 11:36:19 +0800 Subject: [PATCH 7/7] translate chinese comment to english --- modelscope/models/nlp/space/model/generator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modelscope/models/nlp/space/model/generator.py b/modelscope/models/nlp/space/model/generator.py index fab70fd6..08e1c765 100644 --- a/modelscope/models/nlp/space/model/generator.py +++ b/modelscope/models/nlp/space/model/generator.py @@ -183,7 +183,8 @@ class BeamSearch(Generator): scores_after_end = np.full(self.vocab_size, -1e10, dtype='float32') scores_after_end[ - self.pad_id] = 0 # we want is generated after ,so maximum log(p()) is (0) + self. + pad_id] = 0 # we want is generated after ,so maximum log(p()) is (0) scores_after_end = torch.from_numpy(scores_after_end) if self.use_gpu: