Browse Source

merge feat/nlp

master
ly119399 3 years ago
parent
commit
3bfb5948e2
64 changed files with 1602 additions and 152 deletions
  1. +4
    -4
      docs/source/api/modelscope.pydatasets.rst
  2. +1
    -1
      docs/source/api/modelscope.rst
  3. +5
    -5
      docs/source/quick_start.md
  4. +1
    -1
      modelscope/hub/file_download.py
  5. +5
    -0
      modelscope/metainfo.py
  6. +4
    -3
      modelscope/models/__init__.py
  7. +1
    -0
      modelscope/models/audio/kws/__init__.py
  8. +30
    -0
      modelscope/models/audio/kws/generic_key_word_spotting.py
  9. +1
    -0
      modelscope/models/multi_modal/__init__.py
  10. +0
    -0
      modelscope/models/multi_modal/clip/__init__.py
  11. +26
    -0
      modelscope/models/multi_modal/clip/clip_bert.py
  12. +158
    -0
      modelscope/models/multi_modal/clip/clip_model.py
  13. +121
    -0
      modelscope/models/multi_modal/clip/clip_vit.py
  14. +0
    -0
      modelscope/models/multi_modal/image_captioning_model.py
  15. +2
    -2
      modelscope/models/nlp/bert_for_sequence_classification.py
  16. +2
    -2
      modelscope/models/nlp/sbert_for_sentence_similarity.py
  17. +2
    -2
      modelscope/models/nlp/sbert_for_sentiment_classification.py
  18. +2
    -2
      modelscope/models/nlp/sbert_for_token_classification.py
  19. +7
    -7
      modelscope/models/nlp/space/dialog_intent_prediction_model.py
  20. +7
    -7
      modelscope/models/nlp/space/dialog_modeling_model.py
  21. +5
    -6
      modelscope/models/nlp/space/model/generator.py
  22. +9
    -6
      modelscope/models/nlp/space/model/model_base.py
  23. +4
    -10
      modelscope/models/nlp/space/model/unified_transformer.py
  24. +1
    -0
      modelscope/msdatasets/__init__.py
  25. +0
    -0
      modelscope/msdatasets/config.py
  26. +12
    -12
      modelscope/msdatasets/ms_dataset.py
  27. +0
    -0
      modelscope/msdatasets/utils/__init__.py
  28. +1
    -1
      modelscope/msdatasets/utils/ms_api.py
  29. +1
    -0
      modelscope/pipelines/audio/__init__.py
  30. +449
    -0
      modelscope/pipelines/audio/kws_kwsbp_pipeline.py
  31. +3
    -3
      modelscope/pipelines/base.py
  32. +3
    -1
      modelscope/pipelines/builder.py
  33. +1
    -0
      modelscope/pipelines/multi_modal/__init__.py
  34. +34
    -0
      modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py
  35. +2
    -2
      modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py
  36. +4
    -4
      modelscope/pipelines/nlp/dialog_modeling_pipeline.py
  37. +7
    -0
      modelscope/pipelines/outputs.py
  38. +3
    -0
      modelscope/preprocessors/__init__.py
  39. +253
    -0
      modelscope/preprocessors/kws.py
  40. +2
    -2
      modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py
  41. +2
    -2
      modelscope/preprocessors/space/dialog_modeling_preprocessor.py
  42. +0
    -1
      modelscope/preprocessors/space/dst_processors.py
  43. +3
    -13
      modelscope/preprocessors/space/fields/gen_field.py
  44. +0
    -10
      modelscope/preprocessors/space/fields/intent_field.py
  45. +0
    -4
      modelscope/preprocessors/space/tokenizer.py
  46. +0
    -1
      modelscope/pydatasets/__init__.py
  47. +0
    -0
      modelscope/trainers/nlp/space/__init__.py
  48. +0
    -0
      modelscope/trainers/nlp/space/metrics/__init__.py
  49. +2
    -2
      modelscope/trainers/nlp/space/metrics/metrics_tracker.py
  50. +0
    -0
      modelscope/trainers/nlp/space/trainer/__init__.py
  51. +1
    -1
      modelscope/trainers/nlp/space/trainer/gen_trainer.py
  52. +7
    -8
      modelscope/trainers/nlp/space/trainer/intent_trainer.py
  53. +2
    -0
      modelscope/utils/constant.py
  54. +2
    -2
      modelscope/utils/nlp/space/db_ops.py
  55. +0
    -0
      tests/msdatasets/__init__.py
  56. +9
    -10
      tests/msdatasets/test_ms_dataset.py
  57. +1
    -1
      tests/pipelines/test_action_recognition.py
  58. +2
    -2
      tests/pipelines/test_dialog_intent_prediction.py
  59. +2
    -2
      tests/pipelines/test_dialog_modeling.py
  60. +3
    -3
      tests/pipelines/test_image_matting.py
  61. +334
    -0
      tests/pipelines/test_key_word_spotting.py
  62. +52
    -0
      tests/pipelines/test_multi_modal_embedding.py
  63. +1
    -1
      tests/pipelines/test_speech_signal_process.py
  64. +6
    -6
      tests/pipelines/test_text_classification.py

+ 4
- 4
docs/source/api/modelscope.pydatasets.rst View File

@@ -1,7 +1,7 @@
modelscope.pydatasets package
modelscope.msdatasets package
=============================

.. automodule:: modelscope.pydatasets
.. automodule:: modelscope.msdatasets
:members:
:undoc-members:
:show-inheritance:
@@ -9,10 +9,10 @@ modelscope.pydatasets package
Submodules
----------

modelscope.pydatasets.py\_dataset module
modelscope.msdatasets.ms\_dataset module
----------------------------------------

.. automodule:: modelscope.pydatasets.py_dataset
.. automodule:: modelscope.msdatasets.ms_dataset
:members:
:undoc-members:
:show-inheritance:

+ 1
- 1
docs/source/api/modelscope.rst View File

@@ -16,7 +16,7 @@ Subpackages
modelscope.models
modelscope.pipelines
modelscope.preprocessors
modelscope.pydatasets
modelscope.msdatasets
modelscope.trainers
modelscope.utils



+ 5
- 5
docs/source/quick_start.md View File

@@ -3,7 +3,7 @@
## python环境配置
首先,参考[文档](https://docs.anaconda.com/anaconda/install/) 安装配置Anaconda环境

安装完成后,执行如下命令为maas library创建对应的python环境。
安装完成后,执行如下命令为modelscope library创建对应的python环境。
```shell
conda create -n modelscope python=3.6
conda activate modelscope
@@ -105,15 +105,15 @@ import cv2
import os.path as osp
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.pydatasets import PyDataset
from modelscope.msdatasets import MsDataset

# 使用图像url构建PyDataset,此处也可通过 input_location = '/dir/to/images' 来使用本地文件夹
# 使用图像url构建MsDataset,此处也可通过 input_location = '/dir/to/images' 来使用本地文件夹
input_location = [
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png'
]
dataset = PyDataset.load(input_location, target='image')
dataset = MsDataset.load(input_location, target='image')
img_matting = pipeline(Tasks.image_matting, model='damo/image-matting-person')
# 输入为PyDataset时,输出的结果为迭代器
# 输入为MsDataset时,输出的结果为迭代器
result = img_matting(dataset)
cv2.imwrite('result.png', next(result)['output_png'])
print(f'Output written to {osp.abspath("result.png")}')


+ 1
- 1
modelscope/hub/file_download.py View File

@@ -187,7 +187,7 @@ def get_file_download_url(model_id: str, file_path: str, revision: str):
"""
Format file download url according to `model_id`, `revision` and `file_path`.
e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`,
the resulted download url is: https://maas.co/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md
the resulted download url is: https://modelscope.co/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md
"""
download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}'
return download_url_template.format(


+ 5
- 0
modelscope/metainfo.py View File

@@ -21,9 +21,11 @@ class Models(object):
sambert_hifi_16k = 'sambert-hifi-16k'
generic_tts_frontend = 'generic-tts-frontend'
hifigan16k = 'hifigan16k'
kws_kwsbp = 'kws-kwsbp'

# multi-modal models
ofa = 'ofa'
clip = 'clip-multi-modal-embedding'


class Pipelines(object):
@@ -57,9 +59,11 @@ class Pipelines(object):
# audio tasks
sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts'
speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k'
kws_kwsbp = 'kws-kwsbp'

# multi-modal tasks
image_caption = 'image-caption'
multi_modal_embedding = 'multi-modal-embedding'


class Trainers(object):
@@ -99,6 +103,7 @@ class Preprocessors(object):
# audio preprocessor
linear_aec_fbank = 'linear-aec-fbank'
text_to_tacotron_symbols = 'text-to-tacotron-symbols'
wav_to_lists = 'wav-to-lists'

# multi-modal
ofa_image_caption = 'ofa-image-caption'

+ 4
- 3
modelscope/models/__init__.py View File

@@ -1,10 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

# from .audio.tts.am import SambertNetHifi16k
# from .audio.tts.vocoder import Hifigan16k
from .audio.kws import GenericKeyWordSpotting
from .audio.tts.am import SambertNetHifi16k
from .audio.tts.vocoder import Hifigan16k
from .base import Model
from .builder import MODELS, build_model
# from .multi_model import OfaForImageCaptioning
from .multi_modal import OfaForImageCaptioning
from .nlp import (BertForSequenceClassification, SbertForNLI,
SbertForSentenceSimilarity, SbertForSentimentClassification,
SbertForTokenClassification, StructBertForMaskedLM,


+ 1
- 0
modelscope/models/audio/kws/__init__.py View File

@@ -0,0 +1 @@
from .generic_key_word_spotting import * # noqa F403

+ 30
- 0
modelscope/models/audio/kws/generic_key_word_spotting.py View File

@@ -0,0 +1,30 @@
import os
from typing import Any, Dict

from modelscope.metainfo import Models
from modelscope.models.base import Model
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks

__all__ = ['GenericKeyWordSpotting']


@MODELS.register_module(Tasks.key_word_spotting, module_name=Models.kws_kwsbp)
class GenericKeyWordSpotting(Model):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the info of model.

Args:
model_dir (str): the model path.
"""

self.model_cfg = {
'model_workspace': model_dir,
'config_path': os.path.join(model_dir, 'config.yaml')
}

def forward(self) -> Dict[str, Any]:
"""return the info of the model
"""
return self.model_cfg

modelscope/models/multi_model/__init__.py → modelscope/models/multi_modal/__init__.py View File

@@ -1 +1,2 @@
from .clip.clip_model import CLIPForMultiModalEmbedding
from .image_captioning_model import OfaForImageCaptioning

modelscope/models/nlp/space/application/__init__.py → modelscope/models/multi_modal/clip/__init__.py View File


+ 26
- 0
modelscope/models/multi_modal/clip/clip_bert.py View File

@@ -0,0 +1,26 @@
import torch.nn as nn
from transformers import BertConfig, BertForMaskedLM


class TextTransformer(nn.Module):

def __init__(self, config_dict, feat_dim=768):
super(TextTransformer, self).__init__()
bert_config = BertConfig.from_dict(config_dict)
self.bert = BertForMaskedLM(bert_config).bert

self.projector = nn.Linear(
bert_config.hidden_size, feat_dim, bias=False)

def forward(self, input_ids, attention_mask):
trans_features = {
'input_ids': input_ids,
'attention_mask': attention_mask
}

output_states = self.bert(**trans_features, return_dict=False)
output_tokens = output_states[0]

cls_tokens = output_tokens[:, 0, :]

return self.projector(cls_tokens)

+ 158
- 0
modelscope/models/multi_modal/clip/clip_model.py View File

@@ -0,0 +1,158 @@
import os.path as osp
from typing import Any, Dict

import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from tokenizers import BertWordPieceTokenizer
from torchvision.transforms import Compose, Normalize, Resize, ToTensor

from modelscope.metainfo import Models
from modelscope.models.base import Model
from modelscope.models.builder import MODELS
from modelscope.models.multi_modal.clip.clip_bert import TextTransformer
from modelscope.models.multi_modal.clip.clip_vit import VisionTransformer
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()

__all__ = ['CLIPForMultiModalEmbedding']


class CLIPModel(nn.Module):

def __init__(self, model_dir):
super(CLIPModel, self).__init__()
# including vision config and text config
model_config = json.load(
open('{}/encoder_config.json'.format(model_dir)))

# vision encoder
vision_config = model_config['vision_config']
self.img_size = vision_config['input_resolution']
self.vision_encoder = VisionTransformer(
input_resolution=self.img_size,
patch_size=vision_config['patch_size'],
width=vision_config['width'],
layers=vision_config['layers'],
heads=vision_config['heads'],
output_dim=vision_config['feat_dim'])

# text encoder
text_config = model_config['text_config']
self.text_encoder = TextTransformer(
text_config['bert_config'], feat_dim=text_config['feat_dim'])

def forward(self, input_data, input_type):
if input_type == 'img':
img_embedding = self.vision_encoder(input_data)
img_embedding = F.normalize(img_embedding, p=2.0, dim=1)
return img_embedding
elif input_type == 'text':
text_ids_tensor, text_mask_tensor = input_data
text_embedding = self.text_encoder(text_ids_tensor,
text_mask_tensor)
text_embedding = F.normalize(text_embedding, p=2.0, dim=1)
return text_embedding
else:
raise ValueError('Unknown input type')


@MODELS.register_module(Tasks.multi_modal_embedding, module_name=Models.clip)
class CLIPForMultiModalEmbedding(Model):

def __init__(self, model_dir, device_id=-1):
super().__init__(model_dir=model_dir, device_id=device_id)
self.clip_model = CLIPModel(model_dir=model_dir)
pretrained_params = torch.load(
'{}/pytorch_model.bin'.format(model_dir), 'cpu')
self.clip_model.load_state_dict(pretrained_params)
self.clip_model.eval()

self.device_id = device_id
if self.device_id >= 0:
self.clip_model.to('cuda:{}'.format(self.device_id))
logger.info('Use GPU: {}'.format(self.device_id))
else:
logger.info('Use CPU for inference')

# image preprocessor
norm_op = Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711))
self.img_preprocessor = Compose([
Resize((self.clip_model.img_size, self.clip_model.img_size),
interpolation=Image.BICUBIC),
ToTensor(), norm_op
])

# text tokenizer
vocab_path = '{}/vocab.txt'.format(model_dir)
self.text_tokenizer = BertWordPieceTokenizer(
vocab_path, lowercase=False)
self.text_tokenizer.enable_truncation(max_length=30)

def tokenize_text(self, text_str):
tokens = self.text_tokenizer.encode(text_str)
max_tokens = 30
text_ids_tensor = torch.zeros((1, max_tokens)).long()
text_mask_tensor = torch.zeros((1, max_tokens))

text_ids, text_mask = tokens.ids, tokens.attention_mask
text_ids_tensor[0, 0:len(text_ids)] = torch.tensor(text_ids)
text_mask_tensor[0, 0:len(text_mask)] = torch.tensor(text_mask)

return text_ids_tensor, text_mask_tensor

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
output = {'img_embedding': None, 'text_embedding': None}
if 'img' in input and input['img'] is not None:
input_img = input['img']
if isinstance(input_img, Image.Image):
img_tensor = self.img_preprocessor(input_img)[None, ...]
elif isinstance(input_img, np.ndarray):
if len(input_img.shape) == 2:
input_img = cv2.cvtColor(input_img, cv2.COLOR_GRAY2BGR)
input_img = input_img[:, :, ::-1] # in rgb order
input_img = Image.fromarray(
input_img.astype('uint8')).convert('RGB')
img_tensor = self.img_preprocessor(input_img)[None, ...]
else:
raise TypeError(
f'img should be either PIL.Image or np.array, but got {type(input_img)}'
)

if self.device_id >= 0:
img_tensor = img_tensor.to('cuda:{}'.format(self.device_id))

img_embedding = self.clip_model(
input_data=img_tensor, input_type='img')
output['img_embedding'] = img_embedding.data.cpu().numpy()

if 'text' in input and input['text'] is not None:
text_str = input['text']
if isinstance(text_str, str):
text_ids_tensor, text_mask_tensor = self.tokenize_text(
text_str)
else:
raise TypeError(
f'text should be str, but got {type(text_str)}')

if self.device_id >= 0:
text_ids_tensor = text_ids_tensor.to('cuda:{}'.format(
self.device_id))
text_mask_tensor = text_mask_tensor.to('cuda:{}'.format(
self.device_id))

text_embedding = self.clip_model(
input_data=(text_ids_tensor, text_mask_tensor),
input_type='text')
output['text_embedding'] = text_embedding.data.cpu().numpy()

return output

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 121
- 0
modelscope/models/multi_modal/clip/clip_vit.py View File

@@ -0,0 +1,121 @@
# Copyright 2021 The OpenAI CLIP Authors. All rights reserved.

from collections import OrderedDict
from typing import Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn


class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""

def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)


class QuickGELU(nn.Module):

def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)


class ResidualAttentionBlock(nn.Module):

def __init__(self,
d_model: int,
n_head: int,
attn_mask: torch.Tensor = None):
super().__init__()

self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
('gelu', QuickGELU()),
('c_proj', nn.Linear(d_model * 4, d_model))]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask

def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(
dtype=x.dtype,
device=x.device) if self.attn_mask is not None else None
return self.attn(
x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x


class Transformer(nn.Module):

def __init__(self,
width: int,
layers: int,
heads: int,
attn_mask: torch.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[
ResidualAttentionBlock(width, heads, attn_mask)
for _ in range(layers)
])

def forward(self, x: torch.Tensor):
return self.resblocks(x)


class VisionTransformer(nn.Module):

def __init__(self, input_resolution: int, patch_size: int, width: int,
layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False)

scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn(
(input_resolution // patch_size)**2 + 1, width))
self.ln_pre = LayerNorm(width)

self.transformer = Transformer(width, layers, heads)

self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1],
-1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
class_embeddings = self.class_embedding.to(x.dtype) + \
torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
x = torch.cat([class_embeddings, x], dim=1)
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)

x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD

x = self.ln_post(x[:, 0, :])

if self.proj is not None:
x = x @ self.proj

return x

modelscope/models/multi_model/image_captioning_model.py → modelscope/models/multi_modal/image_captioning_model.py View File


+ 2
- 2
modelscope/models/nlp/bert_for_sequence_classification.py View File

@@ -4,8 +4,8 @@ from typing import Any, Dict
import json
import numpy as np

from modelscope.metainfo import Models
from modelscope.utils.constant import Tasks
from ...metainfo import Models
from ...utils.constant import Tasks
from ..base import Model
from ..builder import MODELS



+ 2
- 2
modelscope/models/nlp/sbert_for_sentence_similarity.py View File

@@ -1,5 +1,5 @@
from modelscope.metainfo import Models
from modelscope.utils.constant import Tasks
from ...metainfo import Models
from ...utils.constant import Tasks
from ..builder import MODELS
from .sbert_for_sequence_classification import \
SbertForSequenceClassificationBase


+ 2
- 2
modelscope/models/nlp/sbert_for_sentiment_classification.py View File

@@ -1,5 +1,5 @@
from modelscope.metainfo import Models
from modelscope.utils.constant import Tasks
from ...metainfo import Models
from ...utils.constant import Tasks
from ..builder import MODELS
from .sbert_for_sequence_classification import \
SbertForSequenceClassificationBase


+ 2
- 2
modelscope/models/nlp/sbert_for_token_classification.py View File

@@ -3,8 +3,8 @@ from typing import Any, Dict, Union
import numpy as np
import torch

from modelscope.metainfo import Models
from modelscope.utils.constant import Tasks
from ...metainfo import Models
from ...utils.constant import Tasks
from ..base import Model, Tensor
from ..builder import MODELS



+ 7
- 7
modelscope/models/nlp/space/dialog_intent_prediction_model.py View File

@@ -2,19 +2,19 @@ import os
from typing import Any, Dict

from ....preprocessors.space.fields.intent_field import IntentBPETextField
from ....trainers.nlp.space.trainer.intent_trainer import IntentTrainer
from ....utils.config import Config
from ....utils.constant import Tasks
from ....utils.constant import ModelFile, Tasks
from ...base import Model, Tensor
from ...builder import MODELS
from .application.intent_app import IntentTrainer
from .model.generator import Generator
from .model.model_base import ModelBase
from .model.model_base import SpaceModelBase

__all__ = ['DialogIntentModel']
__all__ = ['SpaceForDialogIntentModel']


@MODELS.register_module(Tasks.dialog_intent_prediction, module_name=r'space')
class DialogIntentModel(Model):
class SpaceForDialogIntentModel(Model):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the test generation model from the `model_dir` path.
@@ -30,13 +30,13 @@ class DialogIntentModel(Model):
self.config = kwargs.pop(
'config',
Config.from_file(
os.path.join(self.model_dir, 'configuration.json')))
os.path.join(self.model_dir, ModelFile.CONFIGURATION)))
self.text_field = kwargs.pop(
'text_field',
IntentBPETextField(self.model_dir, config=self.config))

self.generator = Generator.create(self.config, reader=self.text_field)
self.model = ModelBase.create(
self.model = SpaceModelBase.create(
model_dir=model_dir,
config=self.config,
reader=self.text_field,


+ 7
- 7
modelscope/models/nlp/space/dialog_modeling_model.py View File

@@ -2,19 +2,19 @@ import os
from typing import Any, Dict, Optional

from ....preprocessors.space.fields.gen_field import MultiWOZBPETextField
from ....trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer
from ....utils.config import Config
from ....utils.constant import Tasks
from ....utils.constant import ModelFile, Tasks
from ...base import Model, Tensor
from ...builder import MODELS
from .application.gen_app import MultiWOZTrainer
from .model.generator import Generator
from .model.model_base import ModelBase
from .model.model_base import SpaceModelBase

__all__ = ['DialogModelingModel']
__all__ = ['SpaceForDialogModelingModel']


@MODELS.register_module(Tasks.dialog_modeling, module_name=r'space')
class DialogModelingModel(Model):
class SpaceForDialogModelingModel(Model):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the test generation model from the `model_dir` path.
@@ -30,12 +30,12 @@ class DialogModelingModel(Model):
self.config = kwargs.pop(
'config',
Config.from_file(
os.path.join(self.model_dir, 'configuration.json')))
os.path.join(self.model_dir, ModelFile.CONFIGURATION)))
self.text_field = kwargs.pop(
'text_field',
MultiWOZBPETextField(self.model_dir, config=self.config))
self.generator = Generator.create(self.config, reader=self.text_field)
self.model = ModelBase.create(
self.model = SpaceModelBase.create(
model_dir=model_dir,
config=self.config,
reader=self.text_field,


+ 5
- 6
modelscope/models/nlp/space/model/generator.py View File

@@ -183,7 +183,8 @@ class BeamSearch(Generator):

scores_after_end = np.full(self.vocab_size, -1e10, dtype='float32')
scores_after_end[
self.pad_id] = 0 # 希望<eos>之后只生成<pad>,故使词表中log(p(<pad>))最高(0)
self.
pad_id] = 0 # we want <pad> is generated after <eos>,so maximum log(p(<pad>)) is (0)
scores_after_end = torch.from_numpy(scores_after_end)

if self.use_gpu:
@@ -245,10 +246,8 @@ class BeamSearch(Generator):
scores = scores.reshape(batch_size, beam_size * self.vocab_size)

topk_scores, topk_indices = torch.topk(scores, beam_size)
# topk_indices: [batch_size, beam_size * self.vocab_size] (已reshape)
# 判断当前时间步产生词的前一个词在哪个beam中,对vocab_size取商
# topk_indices: [batch_size, beam_size * self.vocab_size] (already reshaped)
parent_idx = topk_indices.floor_divide(self.vocab_size)
# 对vocab_size取余
preds = topk_indices % self.vocab_size

# Gather state / sequence_scores
@@ -262,14 +261,14 @@ class BeamSearch(Generator):
predictions = predictions.reshape(batch_size, beam_size, step)
predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2)

# 希望生成的整个句子已完结,所以要求最后一个token为<eos>或者<pad>(跟在<eos>之后),否则惩罚
# The last token should be <eos> or <pad>
pre_ids = predictions[:, :, -1]
pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \
(1 - torch.not_equal(pre_ids, self.pad_id).float())
sequence_scores = sequence_scores * pre_eos_mask + (
1 - pre_eos_mask) * (-1e10)

# 先获得ascending排序的index,便于之后对predictions和sequence_scores排序(针对beam size轴)
# first get ascending ordered index,then sort "predictions" and "sequence_scores"
indices = torch.argsort(sequence_scores, dim=1)
indices = indices + pos_index
indices = indices.reshape(-1)


+ 9
- 6
modelscope/models/nlp/space/model/model_base.py View File

@@ -5,8 +5,10 @@ import os

import torch.nn as nn

from .....utils.constant import ModelFile

class ModelBase(nn.Module):

class SpaceModelBase(nn.Module):
"""
Basic model wrapper for static graph and dygrpah.
"""
@@ -14,21 +16,22 @@ class ModelBase(nn.Module):

@classmethod
def register(cls, name):
ModelBase._registry[name] = cls
SpaceModelBase._registry[name] = cls
return

@staticmethod
def by_name(name):
return ModelBase._registry[name]
return SpaceModelBase._registry[name]

@staticmethod
def create(model_dir, config, *args, **kwargs):
model_cls = ModelBase.by_name(config.Model.model)
model_cls = SpaceModelBase.by_name(config.Model.model)
return model_cls(model_dir, config, *args, **kwargs)

def __init__(self, model_dir, config):
super(ModelBase, self).__init__()
self.init_checkpoint = os.path.join(model_dir, 'pytorch_model.bin')
super(SpaceModelBase, self).__init__()
self.init_checkpoint = os.path.join(model_dir,
ModelFile.TORCH_MODEL_BIN_FILE)
self.abandon_label = config.Dataset.abandon_label
self.use_gpu = config.use_gpu
self.gpu = config.Trainer.gpu


+ 4
- 10
modelscope/models/nlp/space/model/unified_transformer.py View File

@@ -9,10 +9,10 @@ import torch.nn.functional as F

from ..modules.embedder import Embedder
from ..modules.transformer_block import TransformerBlock
from .model_base import ModelBase
from .model_base import SpaceModelBase


class UnifiedTransformer(ModelBase):
class UnifiedTransformer(SpaceModelBase):
"""
Implement unified transformer.
"""
@@ -122,11 +122,7 @@ class UnifiedTransformer(ModelBase):
auto_regressive=False):
"""
Create attention mask.
创建从序列形式到矩阵形式的mask:[batch_size, max_seq_len, 1] -> [batch_size, max_seq_len, max_seq_len]
mask除了要考虑attention mask(自回归),还需要考虑pad的mask(自回归和双向)
注:
1. 一个句子中的非<pad>词看整个句子,该句中只有<pad>词才被mask
2. 一个句子中的<pad>词看整个句子,该句的所有词都应该被mask
from sequence to matrix:[batch_size, max_seq_len, 1] -> [batch_size, max_seq_len, max_seq_len]

@param : input_mask
@type : Variable(shape: [batch_size, max_seq_len])
@@ -142,13 +138,11 @@ class UnifiedTransformer(ModelBase):
mask = mask1 * mask2

if append_head:
# 拼接上句首位置([M]/z)的mask
mask = torch.cat([mask[:, :1, :], mask], dim=1)
mask = torch.cat([mask[:, :, :1], mask], dim=2)
seq_len += 1

if auto_regressive:
# 将tgt端的<pad> mask和自回归attention mask融合
seq_mask = self.sequence_mask[:seq_len, :seq_len]
seq_mask = seq_mask.to(mask.device)
mask = mask * seq_mask
@@ -159,7 +153,7 @@ class UnifiedTransformer(ModelBase):
def _join_mask(self, mask1, mask2):
"""
Merge source attention mask and target attention mask.
合并后的整个mask矩阵可以分为四个部分:左上lu/右上ru/左下lb/右下rb
There are four parts:left upper (lu) / right upper (ru) / left below (lb) / right below (rb)

@param : mask1 : source attention mask
@type : Variable(shape: [batch_size, max_src_len, max_src_len])


+ 1
- 0
modelscope/msdatasets/__init__.py View File

@@ -0,0 +1 @@
from .ms_dataset import MsDataset

modelscope/pydatasets/config.py → modelscope/msdatasets/config.py View File


modelscope/pydatasets/py_dataset.py → modelscope/msdatasets/ms_dataset.py View File

@@ -10,8 +10,8 @@ from datasets.packaged_modules import _PACKAGED_DATASETS_MODULES
from datasets.utils.file_utils import (is_relative_path,
relative_to_absolute_path)

from modelscope.pydatasets.config import MS_DATASETS_CACHE
from modelscope.pydatasets.utils.ms_api import MsApi
from modelscope.msdatasets.config import MS_DATASETS_CACHE
from modelscope.msdatasets.utils.ms_api import MsApi
from modelscope.utils.constant import Hubs
from modelscope.utils.logger import get_logger

@@ -28,9 +28,9 @@ def format_list(para) -> List:
return para


class PyDataset:
class MsDataset:
_hf_ds = None # holds the underlying HuggingFace Dataset
"""A PyDataset backed by hugging face Dataset."""
"""A MsDataset backed by hugging face Dataset."""

def __init__(self, hf_ds: Dataset, target: Optional[str] = None):
self._hf_ds = hf_ds
@@ -49,7 +49,7 @@ class PyDataset:
@classmethod
def from_hf_dataset(cls,
hf_ds: Dataset,
target: str = None) -> Union[dict, 'PyDataset']:
target: str = None) -> Union[dict, 'MsDataset']:
if isinstance(hf_ds, Dataset):
return cls(hf_ds, target)
if len(hf_ds.keys()) == 1:
@@ -68,8 +68,8 @@ class PyDataset:
data_files: Optional[Union[str, Sequence[str],
Mapping[str, Union[str,
Sequence[str]]]]] = None
) -> Union[dict, 'PyDataset']:
"""Load a PyDataset from the ModelScope Hub, Hugging Face Hub, urls, or a local dataset.
) -> Union[dict, 'MsDataset']:
"""Load a MsDataset from the ModelScope Hub, Hugging Face Hub, urls, or a local dataset.
Args:

dataset_name (str): Path or name of the dataset.
@@ -82,7 +82,7 @@ class PyDataset:
hub (Hubs, optional): When loading from a remote hub, where it is from

Returns:
PyDataset (obj:`PyDataset`): PyDataset object for a certain dataset.
MsDataset (obj:`MsDataset`): MsDataset object for a certain dataset.
"""
if hub == Hubs.huggingface:
dataset = hf_load_dataset(
@@ -92,9 +92,9 @@ class PyDataset:
split=split,
data_dir=data_dir,
data_files=data_files)
return PyDataset.from_hf_dataset(dataset, target=target)
return MsDataset.from_hf_dataset(dataset, target=target)
else:
return PyDataset._load_ms_dataset(
return MsDataset._load_ms_dataset(
dataset_name,
target=target,
subset_name=subset_name,
@@ -114,7 +114,7 @@ class PyDataset:
data_files: Optional[Union[str, Sequence[str],
Mapping[str, Union[str,
Sequence[str]]]]] = None
) -> Union[dict, 'PyDataset']:
) -> Union[dict, 'MsDataset']:
if isinstance(dataset_name, str):
use_hf = False
if dataset_name in _PACKAGED_DATASETS_MODULES or os.path.isdir(dataset_name) or \
@@ -153,7 +153,7 @@ class PyDataset:
else:
raise TypeError('path must be a str or a list, but got'
f' {type(dataset_name)}')
return PyDataset.from_hf_dataset(dataset, target=target)
return MsDataset.from_hf_dataset(dataset, target=target)

def to_torch_dataset_with_processors(
self,

modelscope/models/nlp/space/metrics/__init__.py → modelscope/msdatasets/utils/__init__.py View File


modelscope/pydatasets/utils/ms_api.py → modelscope/msdatasets/utils/ms_api.py View File

@@ -4,7 +4,7 @@ from typing import Optional

import requests

from modelscope.pydatasets.config import (DOWNLOADED_DATASETS_PATH,
from modelscope.msdatasets.config import (DOWNLOADED_DATASETS_PATH,
MS_HUB_ENDPOINT)
from modelscope.utils.logger import get_logger


+ 1
- 0
modelscope/pipelines/audio/__init__.py View File

@@ -1,2 +1,3 @@
from .kws_kwsbp_pipeline import * # noqa F403
from .linear_aec_pipeline import LinearAECPipeline
from .text_to_speech_pipeline import * # noqa F403

+ 449
- 0
modelscope/pipelines/audio/kws_kwsbp_pipeline.py View File

@@ -0,0 +1,449 @@
import io
import os
import shutil
import stat
import subprocess
from typing import Any, Dict, List

from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import WavToLists
from modelscope.utils.constant import Tasks

__all__ = ['KeyWordSpottingKwsbpPipeline']


@PIPELINES.register_module(
Tasks.key_word_spotting, module_name=Pipelines.kws_kwsbp)
class KeyWordSpottingKwsbpPipeline(Pipeline):
"""KWS Pipeline - key word spotting decoding
"""

def __init__(self,
config_file: str = None,
model: Model = None,
preprocessor: WavToLists = None,
**kwargs):
"""use `model` and `preprocessor` to create a kws pipeline for prediction
"""

super().__init__(
config_file=config_file,
model=model,
preprocessor=preprocessor,
**kwargs)
assert model is not None, 'kws model should be provided'
assert preprocessor is not None, 'preprocessor is none'

self._preprocessor = preprocessor
self._model = model

def __call__(self, kws_type: str, wav_path: List[str]) -> Dict[str, Any]:
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets',
'roc'], f'kws_type {kws_type} is invalid'
output = self._preprocessor.forward(self._model.forward(), kws_type,
wav_path)
output = self.forward(output)
rst = self.postprocess(output)
return rst

def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Decoding
"""

# will generate kws result into dump/dump.JOB.log
out = self._run_with_kwsbp(inputs)

return out

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""process the kws results
"""

pos_result_json = {}
neg_result_json = {}

if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']:
self._parse_dump_log(pos_result_json, inputs['pos_dump_path'])
if inputs['kws_set'] in ['neg_testsets', 'roc']:
self._parse_dump_log(neg_result_json, inputs['neg_dump_path'])
"""
result_json format example:
{
"wav_count": 450,
"keywords": ["小云小云"],
"wav_time": 3560.999999,
"detected": [
{
"xxx.wav": {
"confidence": "0.990368",
"keyword": "小云小云"
}
},
{
"yyy.wav": {
"confidence": "0.990368",
"keyword": "小云小云"
}
},
......
],
"detected_count": 429,
"rejected_count": 21,
"rejected": [
"yyy.wav",
"zzz.wav",
......
]
}
"""

rst_dict = {'kws_set': inputs['kws_set']}

# parsing the result of wav
if inputs['kws_set'] == 'wav':
rst_dict['wav_count'] = pos_result_json['wav_count'] = inputs[
'pos_wav_count']
rst_dict['wav_time'] = round(pos_result_json['wav_time'], 6)
if pos_result_json['detected_count'] == 1:
rst_dict['keywords'] = pos_result_json['keywords']
rst_dict['detected'] = True
wav_file_name = os.path.basename(inputs['pos_wav_path'])
rst_dict['confidence'] = float(pos_result_json['detected'][0]
[wav_file_name]['confidence'])
else:
rst_dict['detected'] = False

# parsing the result of pos_tests
elif inputs['kws_set'] == 'pos_testsets':
rst_dict['wav_count'] = pos_result_json['wav_count'] = inputs[
'pos_wav_count']
rst_dict['wav_time'] = round(pos_result_json['wav_time'], 6)
if pos_result_json.__contains__('keywords'):
rst_dict['keywords'] = pos_result_json['keywords']

rst_dict['recall'] = round(
pos_result_json['detected_count'] / rst_dict['wav_count'], 6)

if pos_result_json.__contains__('detected_count'):
rst_dict['detected_count'] = pos_result_json['detected_count']
if pos_result_json.__contains__('rejected_count'):
rst_dict['rejected_count'] = pos_result_json['rejected_count']
if pos_result_json.__contains__('rejected'):
rst_dict['rejected'] = pos_result_json['rejected']

# parsing the result of neg_tests
elif inputs['kws_set'] == 'neg_testsets':
rst_dict['wav_count'] = neg_result_json['wav_count'] = inputs[
'neg_wav_count']
rst_dict['wav_time'] = round(neg_result_json['wav_time'], 6)
if neg_result_json.__contains__('keywords'):
rst_dict['keywords'] = neg_result_json['keywords']

rst_dict['fa_rate'] = 0.0
rst_dict['fa_per_hour'] = 0.0

if neg_result_json.__contains__('detected_count'):
rst_dict['detected_count'] = neg_result_json['detected_count']
rst_dict['fa_rate'] = round(
neg_result_json['detected_count'] / rst_dict['wav_count'],
6)
if neg_result_json.__contains__('wav_time'):
rst_dict['fa_per_hour'] = round(
neg_result_json['detected_count']
/ float(neg_result_json['wav_time'] / 3600), 6)

if neg_result_json.__contains__('rejected_count'):
rst_dict['rejected_count'] = neg_result_json['rejected_count']

if neg_result_json.__contains__('detected'):
rst_dict['detected'] = neg_result_json['detected']

# parsing the result of roc
elif inputs['kws_set'] == 'roc':
threshold_start = 0.000
threshold_step = 0.001
threshold_end = 1.000

pos_keywords_list = []
neg_keywords_list = []
if pos_result_json.__contains__('keywords'):
pos_keywords_list = pos_result_json['keywords']
if neg_result_json.__contains__('keywords'):
neg_keywords_list = neg_result_json['keywords']

keywords_list = list(set(pos_keywords_list + neg_keywords_list))

pos_result_json['wav_count'] = inputs['pos_wav_count']
neg_result_json['wav_count'] = inputs['neg_wav_count']

if len(keywords_list) > 0:
rst_dict['keywords'] = keywords_list

for index in range(len(rst_dict['keywords'])):
cur_keyword = rst_dict['keywords'][index]
output_list = self._generate_roc_list(
start=threshold_start,
step=threshold_step,
end=threshold_end,
keyword=cur_keyword,
pos_inputs=pos_result_json,
neg_inputs=neg_result_json)

rst_dict[cur_keyword] = output_list

return rst_dict

def _run_with_kwsbp(self, inputs: Dict[str, Any]) -> Dict[str, Any]:

if inputs['kws_set'] == 'roc':
inputs['keyword_grammar_path'] = os.path.join(
inputs['model_workspace'], 'keywords_roc.json')

if inputs['kws_set'] == 'wav':
dump_log_path: str = os.path.join(inputs['pos_dump_path'],
'dump.log')
kws_cmd: str = inputs['kws_tool_path'] + \
' --sys-dir=' + inputs['model_workspace'] + \
' --cfg-file=' + inputs['cfg_file_path'] + \
' --sample-rate=' + inputs['sample_rate'] + \
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
' --wave-scp=' + os.path.join(inputs['pos_data_path'], 'wave.list') + \
' --num-thread=1 > ' + dump_log_path + ' 2>&1'
os.system(kws_cmd)

if inputs['kws_set'] in ['pos_testsets', 'roc']:
data_dir: str = os.listdir(inputs['pos_data_path'])
wav_list = []
for i in data_dir:
suffix = os.path.splitext(os.path.basename(i))[1]
if suffix == '.list':
wav_list.append(os.path.join(inputs['pos_data_path'], i))

j: int = 0
process = []
while j < inputs['pos_num_thread']:
wav_list_path: str = inputs['pos_data_path'] + '/wave.' + str(
j) + '.list'
dump_log_path: str = inputs['pos_dump_path'] + '/dump.' + str(
j) + '.log'

kws_cmd: str = inputs['kws_tool_path'] + \
' --sys-dir=' + inputs['model_workspace'] + \
' --cfg-file=' + inputs['cfg_file_path'] + \
' --sample-rate=' + inputs['sample_rate'] + \
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
' --wave-scp=' + wav_list_path + \
' --num-thread=1 > ' + dump_log_path + ' 2>&1'
p = subprocess.Popen(kws_cmd, shell=True)
process.append(p)
j += 1

k: int = 0
while k < len(process):
process[k].wait()
k += 1

if inputs['kws_set'] in ['neg_testsets', 'roc']:
data_dir: str = os.listdir(inputs['neg_data_path'])
wav_list = []
for i in data_dir:
suffix = os.path.splitext(os.path.basename(i))[1]
if suffix == '.list':
wav_list.append(os.path.join(inputs['neg_data_path'], i))

j: int = 0
process = []
while j < inputs['neg_num_thread']:
wav_list_path: str = inputs['neg_data_path'] + '/wave.' + str(
j) + '.list'
dump_log_path: str = inputs['neg_dump_path'] + '/dump.' + str(
j) + '.log'

kws_cmd: str = inputs['kws_tool_path'] + \
' --sys-dir=' + inputs['model_workspace'] + \
' --cfg-file=' + inputs['cfg_file_path'] + \
' --sample-rate=' + inputs['sample_rate'] + \
' --keyword-grammar=' + inputs['keyword_grammar_path'] + \
' --wave-scp=' + wav_list_path + \
' --num-thread=1 > ' + dump_log_path + ' 2>&1'
p = subprocess.Popen(kws_cmd, shell=True)
process.append(p)
j += 1

k: int = 0
while k < len(process):
process[k].wait()
k += 1

return inputs

def _parse_dump_log(self, result_json: Dict[str, Any],
dump_path: str) -> Dict[str, Any]:
dump_dir = os.listdir(dump_path)
for i in dump_dir:
basename = os.path.splitext(os.path.basename(i))[0]
# find dump.JOB.log
if 'dump' in basename:
with open(
os.path.join(dump_path, i), mode='r',
encoding='utf-8') as file:
while 1:
line = file.readline()
if not line:
break
else:
result_json = self._parse_result_log(
line, result_json)

def _parse_result_log(self, line: str,
result_json: Dict[str, Any]) -> Dict[str, Any]:
# valid info
if '[rejected]' in line or '[detected]' in line:
detected_count = 0
rejected_count = 0

if result_json.__contains__('detected_count'):
detected_count = result_json['detected_count']
if result_json.__contains__('rejected_count'):
rejected_count = result_json['rejected_count']

if '[detected]' in line:
# [detected], fname:/xxx/.tmp_pos_testsets/pos_testsets/33.wav,
# kw:小云小云, confidence:0.965155, time:[4.62-5.10], threshold:0.00,
detected_count += 1
content_list = line.split(', ')
file_name = os.path.basename(content_list[1].split(':')[1])
keyword = content_list[2].split(':')[1]
confidence = content_list[3].split(':')[1]

keywords_list = []
if result_json.__contains__('keywords'):
keywords_list = result_json['keywords']

if keyword not in keywords_list:
keywords_list.append(keyword)
result_json['keywords'] = keywords_list

keyword_item = {}
keyword_item['confidence'] = confidence
keyword_item['keyword'] = keyword
item = {}
item[file_name] = keyword_item

detected_list = []
if result_json.__contains__('detected'):
detected_list = result_json['detected']

detected_list.append(item)
result_json['detected'] = detected_list

elif '[rejected]' in line:
# [rejected], fname:/xxx/.tmp_pos_testsets/pos_testsets/28.wav
rejected_count += 1
content_list = line.split(', ')
file_name = os.path.basename(content_list[1].split(':')[1])
file_name = file_name.strip().replace('\n',
'').replace('\r', '')

rejected_list = []
if result_json.__contains__('rejected'):
rejected_list = result_json['rejected']

rejected_list.append(file_name)
result_json['rejected'] = rejected_list

result_json['detected_count'] = detected_count
result_json['rejected_count'] = rejected_count

elif 'total_proc_time=' in line and 'wav_time=' in line:
# eg: total_proc_time=0.289000(s), wav_time=20.944125(s), kwsbp_rtf=0.013799
wav_total_time = 0
content_list = line.split('), ')
if result_json.__contains__('wav_time'):
wav_total_time = result_json['wav_time']

wav_time_str = content_list[1].split('=')[1]
wav_time_str = wav_time_str.split('(')[0]
wav_time = float(wav_time_str)
wav_time = round(wav_time, 6)

if isinstance(wav_time, float):
wav_total_time += wav_time

result_json['wav_time'] = wav_total_time

return result_json

def _generate_roc_list(self, start: float, step: float, end: float,
keyword: str, pos_inputs: Dict[str, Any],
neg_inputs: Dict[str, Any]) -> Dict[str, Any]:
pos_wav_count = pos_inputs['wav_count']
neg_wav_time = neg_inputs['wav_time']
det_lists = pos_inputs['detected']
fa_lists = neg_inputs['detected']
threshold_cur = start
"""
input det_lists dict
[
{
"xxx.wav": {
"confidence": "0.990368",
"keyword": "小云小云"
}
},
{
"yyy.wav": {
"confidence": "0.990368",
"keyword": "小云小云"
}
},
]

output dict
[
{
"threshold": 0.000,
"recall": 0.999888,
"fa_per_hour": 1.999999
},
{
"threshold": 0.001,
"recall": 0.999888,
"fa_per_hour": 1.999999
},
]
"""

output = []
while threshold_cur <= end:
det_count = 0
fa_count = 0
for index in range(len(det_lists)):
det_item = det_lists[index]
det_wav_item = det_item.get(next(iter(det_item)))
if det_wav_item['keyword'] == keyword:
confidence = float(det_wav_item['confidence'])
if confidence >= threshold_cur:
det_count += 1

for index in range(len(fa_lists)):
fa_item = fa_lists[index]
fa_wav_item = fa_item.get(next(iter(fa_item)))
if fa_wav_item['keyword'] == keyword:
confidence = float(fa_wav_item['confidence'])
if confidence >= threshold_cur:
fa_count += 1

output_item = {
'threshold': round(threshold_cur, 3),
'recall': round(float(det_count / pos_wav_count), 6),
'fa_per_hour': round(fa_count / float(neg_wav_time / 3600), 6)
}
output.append(output_item)

threshold_cur += step

return output

+ 3
- 3
modelscope/pipelines/base.py View File

@@ -6,15 +6,15 @@ from typing import Any, Dict, Generator, List, Union

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.base import Model
from modelscope.msdatasets import MsDataset
from modelscope.preprocessors import Preprocessor
from modelscope.pydatasets import PyDataset
from modelscope.utils.config import Config
from modelscope.utils.logger import get_logger
from .outputs import TASK_OUTPUTS
from .util import is_model, is_official_hub_path

Tensor = Union['torch.Tensor', 'tf.Tensor']
Input = Union[str, tuple, PyDataset, 'PIL.Image.Image', 'numpy.ndarray']
Input = Union[str, tuple, MsDataset, 'PIL.Image.Image', 'numpy.ndarray']
InputModel = Union[str, Model]

output_keys = [
@@ -85,7 +85,7 @@ class Pipeline(ABC):
for ele in input:
output.append(self._process_single(ele, *args, **post_kwargs))

elif isinstance(input, PyDataset):
elif isinstance(input, MsDataset):
return self._process_iterator(input, *args, **post_kwargs)

else:


+ 3
- 1
modelscope/pipelines/builder.py View File

@@ -21,7 +21,6 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.sentence_similarity:
(Pipelines.sentence_similarity,
'damo/nlp_structbert_sentence-similarity_chinese-base'),
Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting'),
Tasks.nli: (Pipelines.nli, 'damo/nlp_structbert_nli_chinese-base'),
Tasks.sentiment_classification:
(Pipelines.sentiment_classification,
@@ -44,6 +43,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'),
Tasks.action_recognition: (Pipelines.action_recognition,
'damo/cv_TAdaConv_action-recognition'),
Tasks.multi_modal_embedding:
(Pipelines.multi_modal_embedding,
'damo/multi-modal_clip-vit-large-patch14-chinese_multi-modal-embedding')
}




+ 1
- 0
modelscope/pipelines/multi_modal/__init__.py View File

@@ -1 +1,2 @@
from .image_captioning_pipeline import ImageCaptionPipeline
from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline

+ 34
- 0
modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py View File

@@ -0,0 +1,34 @@
from typing import Any, Dict, Union

from modelscope.metainfo import Pipelines
from modelscope.pipelines.base import Input
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from ..base import Model, Pipeline
from ..builder import PIPELINES

logger = get_logger()


@PIPELINES.register_module(
Tasks.multi_modal_embedding, module_name=Pipelines.multi_modal_embedding)
class MultiModalEmbeddingPipeline(Pipeline):

def __init__(self, model: str, device_id: int = -1):
if isinstance(model, str):
pipe_model = Model.from_pretrained(model)
elif isinstance(model, Model):
pipe_model = model
else:
raise NotImplementedError('model must be a single str')

super().__init__(model=pipe_model)

def preprocess(self, input: Input) -> Dict[str, Any]:
return input

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
return self.model(input)

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 2
- 2
modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py View File

@@ -1,7 +1,7 @@
from typing import Any, Dict

from ...metainfo import Pipelines
from ...models.nlp import DialogIntentModel
from ...models.nlp import SpaceForDialogIntentModel
from ...preprocessors import DialogIntentPredictionPreprocessor
from ...utils.constant import Tasks
from ..base import Pipeline
@@ -15,7 +15,7 @@ __all__ = ['DialogIntentPredictionPipeline']
module_name=Pipelines.dialog_intent_prediction)
class DialogIntentPredictionPipeline(Pipeline):

def __init__(self, model: DialogIntentModel,
def __init__(self, model: SpaceForDialogIntentModel,
preprocessor: DialogIntentPredictionPreprocessor, **kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction



+ 4
- 4
modelscope/pipelines/nlp/dialog_modeling_pipeline.py View File

@@ -1,9 +1,9 @@
from typing import Any, Dict, Optional

from modelscope.models.nlp import DialogModelingModel
from modelscope.preprocessors import DialogModelingPreprocessor
from modelscope.utils.constant import Tasks
from ...metainfo import Pipelines
from ...models.nlp import SpaceForDialogModelingModel
from ...preprocessors import DialogModelingPreprocessor
from ...utils.constant import Tasks
from ..base import Pipeline, Tensor
from ..builder import PIPELINES

@@ -14,7 +14,7 @@ __all__ = ['DialogModelingPipeline']
Tasks.dialog_modeling, module_name=Pipelines.dialog_modeling)
class DialogModelingPipeline(Pipeline):

def __init__(self, model: DialogModelingModel,
def __init__(self, model: SpaceForDialogModelingModel,
preprocessor: DialogModelingPreprocessor, **kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction



+ 7
- 0
modelscope/pipelines/outputs.py View File

@@ -131,6 +131,13 @@ TASK_OUTPUTS = {
# }
Tasks.image_captioning: ['caption'],

# multi-modal embedding result for single sample
# {
# "img_embedding": np.array with shape [1, D],
# "text_embedding": np.array with shape [1, D]
# }
Tasks.multi_modal_embedding: ['img_embedding', 'text_embedding'],

# visual grounding result for single sample
# {
# "boxes": [


+ 3
- 0
modelscope/preprocessors/__init__.py View File

@@ -4,6 +4,9 @@
from .base import Preprocessor
# from .builder import PREPROCESSORS, build_preprocessor
from .common import Compose
from .image import LoadImage, load_image
from .kws import WavToLists
from .multi_modal import OfaImageCaptionPreprocessor
from .nlp import * # noqa F403
from .space.dialog_intent_prediction_preprocessor import * # noqa F403
from .space.dialog_modeling_preprocessor import * # noqa F403


+ 253
- 0
modelscope/preprocessors/kws.py View File

@@ -0,0 +1,253 @@
import os
import shutil
import stat
from pathlib import Path
from typing import Any, Dict, List

import yaml

from modelscope.metainfo import Preprocessors
from modelscope.models.base import Model
from modelscope.utils.constant import Fields
from .base import Preprocessor
from .builder import PREPROCESSORS

__all__ = ['WavToLists']


@PREPROCESSORS.register_module(
Fields.audio, module_name=Preprocessors.wav_to_lists)
class WavToLists(Preprocessor):
"""generate audio lists file from wav

Args:
workspace (str): store temporarily kws intermedium and result
"""

def __init__(self, workspace: str = None):
# the workspace path
if len(workspace) == 0:
self._workspace = os.path.join(os.getcwd(), '.tmp')
else:
self._workspace = workspace

if not os.path.exists(self._workspace):
os.mkdir(self._workspace)

def __call__(self,
model: Model = None,
kws_type: str = None,
wav_path: List[str] = None) -> Dict[str, Any]:
"""Call functions to load model and wav.

Args:
model (Model): model should be provided
kws_type (str): kws work type: wav, neg_testsets, pos_testsets, roc
wav_path (List[str]): wav_path[0] is positive wav path, wav_path[1] is negative wav path
Returns:
Dict[str, Any]: the kws result
"""

assert model is not None, 'preprocess kws model should be provided'
assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', 'roc'
], f'preprocess kws_type {kws_type} is invalid'
assert wav_path[0] is not None or wav_path[
1] is not None, 'preprocess wav_path is invalid'

self._model = model
out = self.forward(self._model.forward(), kws_type, wav_path)
return out

def forward(self, model: Dict[str, Any], kws_type: str,
wav_path: List[str]) -> Dict[str, Any]:
assert len(kws_type) > 0, 'preprocess kws_type is empty'
assert len(
model['config_path']) > 0, 'preprocess model[config_path] is empty'
assert os.path.exists(
model['config_path']), 'model config.yaml is absent'

inputs = model.copy()

inputs['kws_set'] = kws_type
inputs['workspace'] = self._workspace
if wav_path[0] is not None:
inputs['pos_wav_path'] = wav_path[0]
if wav_path[1] is not None:
inputs['neg_wav_path'] = wav_path[1]

out = self._read_config(inputs)
out = self._generate_wav_lists(out)

return out

def _read_config(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""read and parse config.yaml to get all model files
"""

assert os.path.exists(
inputs['config_path']), 'model config yaml file does not exist'

config_file = open(inputs['config_path'])
root = yaml.full_load(config_file)
config_file.close()

inputs['cfg_file'] = root['cfg_file']
inputs['cfg_file_path'] = os.path.join(inputs['model_workspace'],
root['cfg_file'])
inputs['keyword_grammar'] = root['keyword_grammar']
inputs['keyword_grammar_path'] = os.path.join(
inputs['model_workspace'], root['keyword_grammar'])
inputs['sample_rate'] = str(root['sample_rate'])
inputs['kws_tool'] = root['kws_tool']

if os.path.exists(
os.path.join(inputs['workspace'], inputs['kws_tool'])):
inputs['kws_tool_path'] = os.path.join(inputs['workspace'],
inputs['kws_tool'])
elif os.path.exists(os.path.join('/usr/bin', inputs['kws_tool'])):
inputs['kws_tool_path'] = os.path.join('/usr/bin',
inputs['kws_tool'])
elif os.path.exists(os.path.join('/bin', inputs['kws_tool'])):
inputs['kws_tool_path'] = os.path.join('/bin', inputs['kws_tool'])

assert os.path.exists(inputs['kws_tool_path']), 'cannot find kwsbp'
os.chmod(inputs['kws_tool_path'],
stat.S_IXUSR + stat.S_IXGRP + stat.S_IXOTH)

self._config_checking(inputs)
return inputs

def _generate_wav_lists(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""assemble wav lists
"""

if inputs['kws_set'] == 'wav':
inputs['pos_num_thread'] = 1
wave_scp_content: str = inputs['pos_wav_path'] + '\n'

with open(os.path.join(inputs['pos_data_path'], 'wave.list'),
'a') as f:
f.write(wave_scp_content)

inputs['pos_wav_count'] = 1

if inputs['kws_set'] in ['pos_testsets', 'roc']:
# find all positive wave
wav_list = []
wav_dir = inputs['pos_wav_path']
wav_list = self._recursion_dir_all_wave(wav_list, wav_dir)

list_count: int = len(wav_list)
inputs['pos_wav_count'] = list_count

if list_count <= 128:
inputs['pos_num_thread'] = list_count
j: int = 0
while j < list_count:
wave_scp_content: str = wav_list[j] + '\n'
wav_list_path = inputs['pos_data_path'] + '/wave.' + str(
j) + '.list'
with open(wav_list_path, 'a') as f:
f.write(wave_scp_content)
j += 1

else:
inputs['pos_num_thread'] = 128
j: int = 0
k: int = 0
while j < list_count:
wave_scp_content: str = wav_list[j] + '\n'
wav_list_path = inputs['pos_data_path'] + '/wave.' + str(
k) + '.list'
with open(wav_list_path, 'a') as f:
f.write(wave_scp_content)
j += 1
k += 1
if k >= 128:
k = 0

if inputs['kws_set'] in ['neg_testsets', 'roc']:
# find all negative wave
wav_list = []
wav_dir = inputs['neg_wav_path']
wav_list = self._recursion_dir_all_wave(wav_list, wav_dir)

list_count: int = len(wav_list)
inputs['neg_wav_count'] = list_count

if list_count <= 128:
inputs['neg_num_thread'] = list_count
j: int = 0
while j < list_count:
wave_scp_content: str = wav_list[j] + '\n'
wav_list_path = inputs['neg_data_path'] + '/wave.' + str(
j) + '.list'
with open(wav_list_path, 'a') as f:
f.write(wave_scp_content)
j += 1

else:
inputs['neg_num_thread'] = 128
j: int = 0
k: int = 0
while j < list_count:
wave_scp_content: str = wav_list[j] + '\n'
wav_list_path = inputs['neg_data_path'] + '/wave.' + str(
k) + '.list'
with open(wav_list_path, 'a') as f:
f.write(wave_scp_content)
j += 1
k += 1
if k >= 128:
k = 0

return inputs

def _recursion_dir_all_wave(self, wav_list,
dir_path: str) -> Dict[str, Any]:
dir_files = os.listdir(dir_path)
for file in dir_files:
file_path = os.path.join(dir_path, file)
if os.path.isfile(file_path):
if file_path.endswith('.wav') or file_path.endswith('.WAV'):
wav_list.append(file_path)
elif os.path.isdir(file_path):
self._recursion_dir_all_wave(wav_list, file_path)

return wav_list

def _config_checking(self, inputs: Dict[str, Any]):

if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']:
inputs['pos_data_path'] = os.path.join(inputs['workspace'],
'pos_data')
if not os.path.exists(inputs['pos_data_path']):
os.mkdir(inputs['pos_data_path'])
else:
shutil.rmtree(inputs['pos_data_path'])
os.mkdir(inputs['pos_data_path'])

inputs['pos_dump_path'] = os.path.join(inputs['workspace'],
'pos_dump')
if not os.path.exists(inputs['pos_dump_path']):
os.mkdir(inputs['pos_dump_path'])
else:
shutil.rmtree(inputs['pos_dump_path'])
os.mkdir(inputs['pos_dump_path'])

if inputs['kws_set'] in ['neg_testsets', 'roc']:
inputs['neg_data_path'] = os.path.join(inputs['workspace'],
'neg_data')
if not os.path.exists(inputs['neg_data_path']):
os.mkdir(inputs['neg_data_path'])
else:
shutil.rmtree(inputs['neg_data_path'])
os.mkdir(inputs['neg_data_path'])

inputs['neg_dump_path'] = os.path.join(inputs['workspace'],
'neg_dump')
if not os.path.exists(inputs['neg_dump_path']):
os.mkdir(inputs['neg_dump_path'])
else:
shutil.rmtree(inputs['neg_dump_path'])
os.mkdir(inputs['neg_dump_path'])

+ 2
- 2
modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py View File

@@ -4,7 +4,7 @@ import os
from typing import Any, Dict

from ...utils.config import Config
from ...utils.constant import Fields
from ...utils.constant import Fields, ModelFile
from ...utils.type_assert import type_assert
from ..base import Preprocessor
from ..builder import PREPROCESSORS
@@ -26,7 +26,7 @@ class DialogIntentPredictionPreprocessor(Preprocessor):

self.model_dir: str = model_dir
self.config = Config.from_file(
os.path.join(self.model_dir, 'configuration.json'))
os.path.join(self.model_dir, ModelFile.CONFIGURATION))
self.text_field = IntentBPETextField(
self.model_dir, config=self.config)



+ 2
- 2
modelscope/preprocessors/space/dialog_modeling_preprocessor.py View File

@@ -4,7 +4,7 @@ import os
from typing import Any, Dict

from ...utils.config import Config
from ...utils.constant import Fields
from ...utils.constant import Fields, ModelFile
from ...utils.type_assert import type_assert
from ..base import Preprocessor
from ..builder import PREPROCESSORS
@@ -26,7 +26,7 @@ class DialogModelingPreprocessor(Preprocessor):

self.model_dir: str = model_dir
self.config = Config.from_file(
os.path.join(self.model_dir, 'configuration.json'))
os.path.join(self.model_dir, ModelFile.CONFIGURATION))
self.text_field = MultiWOZBPETextField(
self.model_dir, config=self.config)



+ 0
- 1
modelscope/preprocessors/space/dst_processors.py View File

@@ -570,7 +570,6 @@ class multiwoz22Processor(DSTProcessor):
def delex_utt(self, utt, values, unk_token='[UNK]'):
utt_norm = self.tokenize(utt)
for s, vals in values.items():
# TODO vals可能不是数组形式,而是初始化的字符串"none"
for v in vals:
if v != 'none':
v_norm = self.tokenize(v)


+ 3
- 13
modelscope/preprocessors/space/fields/gen_field.py View File

@@ -8,6 +8,7 @@ from itertools import chain

import numpy as np

from ....utils.constant import ModelFile
from ....utils.nlp.space import ontology, utils
from ....utils.nlp.space.db_ops import MultiWozDB
from ....utils.nlp.space.utils import list2np
@@ -35,18 +36,10 @@ class BPETextField(object):

@property
def bot_id(self):
"""
用于区分user和bot两个角色
1和0不是词表中的index,而是专门针对role的index,大小就为2,对应超参数'num_type_embeddings'
"""
return 0

@property
def user_id(self):
"""
用于区分user和bot两个角色
1和0不是词表中的index,而是专门针对role的index,大小就为2,对应超参数'num_type_embeddings'
"""
return 1

@property
@@ -185,7 +178,7 @@ class BPETextField(object):
]
src_role.append(list(chain(*role))[-self.max_len:])

# src端序列和tgt端序列需要分开pad,以保证解码时第一个词对齐
# src sequence and tgt sequence should be padded separately,to make sure the first word is aligned
src_token = list2np(src_token, padding=self.pad_id)
src_pos = list2np(src_pos, padding=self.pad_id)
src_turn = list2np(src_turn, padding=self.pad_id)
@@ -438,7 +431,7 @@ class MultiWOZBPETextField(BPETextField):
# logging.info(log_str)
# cfg.num_training_steps = num_training_steps * cfg.epoch_num
self.set_stats[set_name][
'num_training_steps_per_epoch'] = num_training_steps # turn-levelsteps
'num_training_steps_per_epoch'] = num_training_steps # turn-level steps
self.set_stats[set_name]['num_turns'] = num_turns
self.set_stats[set_name]['num_dials'] = num_dials

@@ -547,9 +540,6 @@ class MultiWOZBPETextField(BPETextField):

def convert_batch_turn(self, turn_batch, pv_batch, first_turn=False):
"""
URURU:这里的含义是指轮级别的训练(数据整理),区别于session级别的训练方式(convert_batch_session);
但不同于eval时的含义,eval时二者都是逐轮依次生成的,那时URURU的含义请见相关的函数注释;

convert the current and the last turn
concat [U_0,R_0,...,U_{t-1}, R_{t-1}, U_t, B_t, A_t, R_t]
firts turn: [U_t, B_t, A_t, R_t]


+ 0
- 10
modelscope/preprocessors/space/fields/intent_field.py View File

@@ -154,18 +154,10 @@ class BPETextField(object):

@property
def bot_id(self):
"""
用于区分user和bot两个角色
1和0不是词表中的index,而是专门针对role的index,大小就为2,对应超参数'num_type_embeddings'
"""
return 0

@property
def user_id(self):
"""
用于区分user和bot两个角色
1和0不是词表中的index,而是专门针对role的index,大小就为2,对应超参数'num_type_embeddings'
"""
return 1

def add_sepcial_tokens(self):
@@ -862,7 +854,6 @@ class BPETextField(object):
]
src_role.append(list(chain(*role))[-self.max_len:])

# src端序列和tgt端序列需要分开pad,以保证解码时第一个词对齐
src_token = list2np(src_token, padding=self.pad_id)
src_pos = list2np(src_pos, padding=self.pad_id)
src_turn = list2np(src_turn, padding=self.pad_id)
@@ -1038,7 +1029,6 @@ class IntentBPETextField(BPETextField):
] * l for i, l in enumerate(utt_lens)]
src_role.append(list(chain(*role))[-self.max_len:])

# src端序列和tgt端序列需要分开pad,以保证解码时第一个词对齐
src_token = list2np(src_token, padding=self.pad_id)
src_pos = list2np(src_pos, padding=self.pad_id)
src_turn = list2np(src_turn, padding=self.pad_id)


+ 0
- 4
modelscope/preprocessors/space/tokenizer.py View File

@@ -56,10 +56,6 @@ class Tokenizer(object):
self._tokenizer = BertTokenizer(
vocab_path, never_split=self.special_tokens)
for tok in self.special_tokens:
'''
需要先保证special_tokens在词表中,这里设置special_tokens的目的是为了这些词能够完整占位,不再切分为子词;
若不在词表中,可以使用词表中的[unused]符号进行转换:spec_convert_dict;
'''
assert tok in self._tokenizer.vocab, f"special token '{tok}' is not in the vocabulary"
self.vocab_size = len(self._tokenizer.vocab)
elif tokenizer_type == 'GPT2':


+ 0
- 1
modelscope/pydatasets/__init__.py View File

@@ -1 +0,0 @@
from .py_dataset import PyDataset

modelscope/pydatasets/utils/__init__.py → modelscope/trainers/nlp/space/__init__.py View File


tests/pydatasets/__init__.py → modelscope/trainers/nlp/space/metrics/__init__.py View File


modelscope/models/nlp/space/metrics/metrics_tracker.py → modelscope/trainers/nlp/space/metrics/metrics_tracker.py View File

@@ -10,8 +10,8 @@ class MetricsTracker(object):
""" Tracking metrics. """

def __init__(self):
self.metrics_val = defaultdict(float) # 记录最新一个batch返回的指标
self.metrics_avg = defaultdict(float) # 维护一个epoch内已训练batches的平均指标
self.metrics_val = defaultdict(float) # for one batch
self.metrics_avg = defaultdict(float) # avg batches
self.num_samples = 0

def update(self, metrics, num_samples):

+ 0
- 0
modelscope/trainers/nlp/space/trainer/__init__.py View File


modelscope/models/nlp/space/application/gen_app.py → modelscope/trainers/nlp/space/trainer/gen_trainer.py View File

@@ -563,7 +563,7 @@ class MultiWOZTrainer(Trainer):
generated_bs = outputs[0].cpu().numpy().tolist()
bspn_gen = self.decode_generated_bspn(generated_bs)
# check DB result
if self.reader.use_true_db_pointer: # 控制当前轮的db是否为ground truth
if self.reader.use_true_db_pointer: # To control whether current db is ground truth
db = turn['db']
else:
db_result = self.reader.bspan_to_DBpointer(

modelscope/models/nlp/space/application/intent_app.py → modelscope/trainers/nlp/space/trainer/intent_trainer.py View File

@@ -314,18 +314,18 @@ class IntentTrainer(Trainer):
self.can_norm = config.Trainer.can_norm

def can_normalization(self, y_pred, y_true, ex_data_iter):
# 预测结果,计算修正前准确率
# compute ACC
acc_original = np.mean([y_pred.argmax(1) == y_true])
message = 'original acc: %s' % acc_original

# 评价每个预测结果的不确定性
# compute uncertainty
k = 3
y_pred_topk = np.sort(y_pred, axis=1)[:, -k:]
y_pred_topk /= y_pred_topk.sum(axis=1, keepdims=True)
y_pred_uncertainty =\
-(y_pred_topk * np.log(y_pred_topk)).sum(1) / np.log(k)

# 选择阈值,划分高、低置信度两部分
# choose threshold
# print(np.sort(y_pred_uncertainty)[-100:].tolist())
threshold = 0.7
y_pred_confident = y_pred[y_pred_uncertainty < threshold]
@@ -333,8 +333,7 @@ class IntentTrainer(Trainer):
y_true_confident = y_true[y_pred_uncertainty < threshold]
y_true_unconfident = y_true[y_pred_uncertainty >= threshold]

# 显示两部分各自的准确率
# 一般而言,高置信度集准确率会远高于低置信度的
# compute ACC again for high and low confidence sets
acc_confident = (y_pred_confident.argmax(1) == y_true_confident).mean() \
if len(y_true_confident) else 0.
acc_unconfident = (y_pred_unconfident.argmax(1) == y_true_unconfident).mean() \
@@ -344,7 +343,7 @@ class IntentTrainer(Trainer):
message += ' (%s) unconfident acc: %s' % (len(y_true_unconfident),
acc_unconfident)

# 从训练集统计先验分布
# get prior distribution from training set
prior = np.zeros(self.func_model.num_intent)
for _, (batch, batch_size) in ex_data_iter:
for intent_label in batch['intent_label']:
@@ -352,7 +351,7 @@ class IntentTrainer(Trainer):

prior /= prior.sum()

# 逐个修改低置信度样本,并重新评价准确率
# revise each sample from the low confidence set, and compute new ACC
right, alpha, iters = 0, 1, 1
for i, y in enumerate(y_pred_unconfident):
Y = np.concatenate([y_pred_confident, y[None]], axis=0)
@@ -365,7 +364,7 @@ class IntentTrainer(Trainer):
if y.argmax() == y_true_unconfident[i]:
right += 1

# 输出修正后的准确率
# get final ACC
acc_final = \
(acc_confident * len(y_pred_confident) + right) / \
len(y_pred)

+ 2
- 0
modelscope/utils/constant.py View File

@@ -57,11 +57,13 @@ class Tasks(object):
auto_speech_recognition = 'auto-speech-recognition'
text_to_speech = 'text-to-speech'
speech_signal_process = 'speech-signal-process'
key_word_spotting = 'key-word-spotting'

# multi-modal tasks
image_captioning = 'image-captioning'
visual_grounding = 'visual-grounding'
text_to_image_synthesis = 'text-to-image-synthesis'
multi_modal_embedding = 'multi-modal-embedding'


class InputFields(object):


+ 2
- 2
modelscope/utils/nlp/space/db_ops.py View File

@@ -172,8 +172,8 @@ class MultiWozDB(object):
continue
if s in ['people', 'stay'] or (domain == 'hotel' and s == 'day') or \
(domain == 'restaurant' and s in ['day', 'time']):
# 因为这些inform slot属于book info,而数据库中没有这些slot;
# 能否book是根据user goal中的信息判断,而非通过数据库查询;
# These inform slots belong to "book info",which do not exist in DB
# "book" is according to the user goal,not DB
continue

skip_case = {


+ 0
- 0
tests/msdatasets/__init__.py View File


tests/pydatasets/test_py_dataset.py → tests/msdatasets/test_ms_dataset.py View File

@@ -3,10 +3,9 @@ import unittest
import datasets as hfdata

from modelscope.models import Model
from modelscope.msdatasets import MsDataset
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.preprocessors.base import Preprocessor
from modelscope.pydatasets import PyDataset
from modelscope.utils.constant import Hubs
from modelscope.utils.test_utils import require_tf, require_torch, test_level


@@ -31,15 +30,15 @@ class ImgPreprocessor(Preprocessor):
}


class PyDatasetTest(unittest.TestCase):
class MsDatasetTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_ds_basic(self):
ms_ds_full = PyDataset.load('squad')
ms_ds_full = MsDataset.load('squad')
ms_ds_full_hf = hfdata.load_dataset('squad')
ms_ds_train = PyDataset.load('squad', split='train')
ms_ds_train = MsDataset.load('squad', split='train')
ms_ds_train_hf = hfdata.load_dataset('squad', split='train')
ms_image_train = PyDataset.from_hf_dataset(
ms_image_train = MsDataset.from_hf_dataset(
hfdata.load_dataset('beans', split='train'))
self.assertEqual(ms_ds_full['train'][0], ms_ds_full_hf['train'][0])
self.assertEqual(ms_ds_full['validation'][0],
@@ -58,7 +57,7 @@ class PyDatasetTest(unittest.TestCase):
nlp_model.model_dir,
first_sequence='context',
second_sequence=None)
ms_ds_train = PyDataset.load('squad', split='train')
ms_ds_train = MsDataset.load('squad', split='train')
pt_dataset = ms_ds_train.to_torch_dataset(preprocessors=preprocessor)
import torch
dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5)
@@ -75,7 +74,7 @@ class PyDatasetTest(unittest.TestCase):
nlp_model.model_dir,
first_sequence='context',
second_sequence=None)
ms_ds_train = PyDataset.load('squad', split='train')
ms_ds_train = MsDataset.load('squad', split='train')
tf_dataset = ms_ds_train.to_tf_dataset(
batch_size=5,
shuffle=True,
@@ -86,7 +85,7 @@ class PyDatasetTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@require_torch
def test_to_torch_dataset_img(self):
ms_image_train = PyDataset.from_hf_dataset(
ms_image_train = MsDataset.from_hf_dataset(
hfdata.load_dataset('beans', split='train'))
pt_dataset = ms_image_train.to_torch_dataset(
preprocessors=ImgPreprocessor(
@@ -100,7 +99,7 @@ class PyDatasetTest(unittest.TestCase):
def test_to_tf_dataset_img(self):
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
ms_image_train = PyDataset.load('beans', split='train')
ms_image_train = MsDataset.load('beans', split='train')
tf_dataset = ms_image_train.to_tf_dataset(
batch_size=5,
shuffle=True,

+ 1
- 1
tests/pipelines/test_action_recognition.py View File

@@ -8,8 +8,8 @@ import unittest
import cv2

from modelscope.fileio import File
from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
from modelscope.pydatasets import PyDataset
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.test_utils import test_level



+ 2
- 2
tests/pipelines/test_dialog_intent_prediction.py View File

@@ -3,7 +3,7 @@ import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import DialogIntentModel
from modelscope.models.nlp import SpaceForDialogIntentModel
from modelscope.pipelines import DialogIntentPredictionPipeline, pipeline
from modelscope.preprocessors import DialogIntentPredictionPreprocessor
from modelscope.utils.constant import Tasks
@@ -20,7 +20,7 @@ class DialogIntentPredictionTest(unittest.TestCase):
def test_run(self):
cache_path = snapshot_download(self.model_id)
preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path)
model = DialogIntentModel(
model = SpaceForDialogIntentModel(
model_dir=cache_path,
text_field=preprocessor.text_field,
config=preprocessor.config)


+ 2
- 2
tests/pipelines/test_dialog_modeling.py View File

@@ -6,7 +6,7 @@ import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import DialogModelingModel
from modelscope.models.nlp import SpaceForDialogModelingModel
from modelscope.pipelines import DialogModelingPipeline, pipeline
from modelscope.preprocessors import DialogModelingPreprocessor
from modelscope.utils.constant import Tasks
@@ -97,7 +97,7 @@ class DialogModelingTest(unittest.TestCase):
cache_path = snapshot_download(self.model_id)

preprocessor = DialogModelingPreprocessor(model_dir=cache_path)
model = DialogModelingModel(
model = SpaceForDialogModelingModel(
model_dir=cache_path,
text_field=preprocessor.text_field,
config=preprocessor.config)


+ 3
- 3
tests/pipelines/test_image_matting.py View File

@@ -7,8 +7,8 @@ import unittest
import cv2

from modelscope.fileio import File
from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
from modelscope.pydatasets import PyDataset
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.test_utils import test_level

@@ -37,7 +37,7 @@ class ImageMattingTest(unittest.TestCase):
# alternatively:
# input_location = '/dir/to/images'

dataset = PyDataset.load(input_location, target='image')
dataset = MsDataset.load(input_location, target='image')
img_matting = pipeline(Tasks.image_matting, model=self.model_id)
# note that for dataset output, the inference-output is a Generator that can be iterated.
result = img_matting(dataset)
@@ -62,7 +62,7 @@ class ImageMattingTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_modelscope_dataset(self):
dataset = PyDataset.load('beans', split='train', target='image')
dataset = MsDataset.load('beans', split='train', target='image')
img_matting = pipeline(Tasks.image_matting, model=self.model_id)
result = img_matting(dataset)
for i in range(10):


+ 334
- 0
tests/pipelines/test_key_word_spotting.py View File

@@ -0,0 +1,334 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tarfile
import unittest

import requests

from modelscope.metainfo import Pipelines, Preprocessors
from modelscope.models import Model
from modelscope.pipelines import pipeline
from modelscope.preprocessors import build_preprocessor
from modelscope.utils.constant import Fields, InputFields, Tasks
from modelscope.utils.test_utils import test_level

KWSBP_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/tools/kwsbp'

POS_WAV_FILE = '20200707_spk57db_storenoise52db_40cm_xiaoyun_sox_6.wav'
POS_WAV_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/' + POS_WAV_FILE

POS_TESTSETS_FILE = 'pos_testsets.tar.gz'
POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz'

NEG_TESTSETS_FILE = 'neg_testsets.tar.gz'
NEG_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/neg_testsets.tar.gz'


def un_tar_gz(fname, dirs):
t = tarfile.open(fname)
t.extractall(path=dirs)


class KeyWordSpottingTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/speech_charctc_kws_phone-xiaoyunxiaoyun'
self.workspace = os.path.join(os.getcwd(), '.tmp')
if not os.path.exists(self.workspace):
os.mkdir(self.workspace)

def tearDown(self) -> None:
if os.path.exists(self.workspace):
shutil.rmtree(self.workspace)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav(self):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'wav'

# downloading wav file
wav_file_path = os.path.join(self.workspace, POS_WAV_FILE)
if not os.path.exists(wav_file_path):
r = requests.get(POS_WAV_URL)
with open(wav_file_path, 'wb') as f:
f.write(r.content)

# downloading kwsbp
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
if not os.path.exists(kwsbp_file_path):
r = requests.get(KWSBP_URL)
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

model = Model.from_pretrained(self.model_id)
self.assertTrue(model is not None)

cfg_preprocessor = dict(
type=Preprocessors.wav_to_lists, workspace=self.workspace)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
self.assertTrue(preprocessor is not None)

kwsbp_16k_pipline = pipeline(
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor)
self.assertTrue(kwsbp_16k_pipline is not None)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set, wav_path=[wav_file_path, None])
self.assertTrue(kws_result.__contains__('detected'))
"""
kws result json format example:
{
'wav_count': 1,
'kws_set': 'wav',
'wav_time': 9.132938,
'keywords': ['小云小云'],
'detected': True,
'confidence': 0.990368
}
"""
if kws_result.__contains__('keywords'):
print('test_run_with_wav keywords: ', kws_result['keywords'])
print('test_run_with_wav detected result: ', kws_result['detected'])
print('test_run_with_wav wave time(seconds): ', kws_result['wav_time'])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_pos_testsets(self):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'pos_testsets'

# downloading pos_testsets file
testsets_file_path = os.path.join(self.workspace, POS_TESTSETS_FILE)
if not os.path.exists(testsets_file_path):
r = requests.get(POS_TESTSETS_URL)
with open(testsets_file_path, 'wb') as f:
f.write(r.content)

testsets_dir_name = os.path.splitext(
os.path.basename(POS_TESTSETS_FILE))[0]
testsets_dir_name = os.path.splitext(
os.path.basename(testsets_dir_name))[0]
# wav_file_path = <cwd>/.tmp_pos_testsets/pos_testsets/
wav_file_path = os.path.join(self.workspace, testsets_dir_name)

# untar the pos_testsets file
if not os.path.exists(wav_file_path):
un_tar_gz(testsets_file_path, self.workspace)

# downloading kwsbp -- a kws batch processing tool
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
if not os.path.exists(kwsbp_file_path):
r = requests.get(KWSBP_URL)
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

model = Model.from_pretrained(self.model_id)
self.assertTrue(model is not None)

cfg_preprocessor = dict(
type=Preprocessors.wav_to_lists, workspace=self.workspace)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
self.assertTrue(preprocessor is not None)

kwsbp_16k_pipline = pipeline(
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor)
self.assertTrue(kwsbp_16k_pipline is not None)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set, wav_path=[wav_file_path, None])
self.assertTrue(kws_result.__contains__('recall'))
"""
kws result json format example:
{
'wav_count': 450,
'kws_set': 'pos_testsets',
'wav_time': 3013.759254,
'keywords': ["小云小云"],
'recall': 0.953333,
'detected_count': 429,
'rejected_count': 21,
'rejected': [
'yyy.wav',
'zzz.wav',
......
]
}
"""
if kws_result.__contains__('keywords'):
print('test_run_with_pos_testsets keywords: ',
kws_result['keywords'])
print('test_run_with_pos_testsets recall: ', kws_result['recall'])
print('test_run_with_pos_testsets wave time(seconds): ',
kws_result['wav_time'])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_neg_testsets(self):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'neg_testsets'

# downloading neg_testsets file
testsets_file_path = os.path.join(self.workspace, NEG_TESTSETS_FILE)
if not os.path.exists(testsets_file_path):
r = requests.get(NEG_TESTSETS_URL)
with open(testsets_file_path, 'wb') as f:
f.write(r.content)

testsets_dir_name = os.path.splitext(
os.path.basename(NEG_TESTSETS_FILE))[0]
testsets_dir_name = os.path.splitext(
os.path.basename(testsets_dir_name))[0]
# wav_file_path = <cwd>/.tmp_neg_testsets/neg_testsets/
wav_file_path = os.path.join(self.workspace, testsets_dir_name)

# untar the neg_testsets file
if not os.path.exists(wav_file_path):
un_tar_gz(testsets_file_path, self.workspace)

# downloading kwsbp -- a kws batch processing tool
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
if not os.path.exists(kwsbp_file_path):
r = requests.get(KWSBP_URL)
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

model = Model.from_pretrained(self.model_id)
self.assertTrue(model is not None)

cfg_preprocessor = dict(
type=Preprocessors.wav_to_lists, workspace=self.workspace)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
self.assertTrue(preprocessor is not None)

kwsbp_16k_pipline = pipeline(
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor)
self.assertTrue(kwsbp_16k_pipline is not None)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set, wav_path=[None, wav_file_path])
self.assertTrue(kws_result.__contains__('fa_rate'))
"""
kws result json format example:
{
'wav_count': 751,
'kws_set': 'neg_testsets',
'wav_time': 3572.180812,
'keywords': ['小云小云'],
'fa_rate': 0.001332,
'fa_per_hour': 1.007788,
'detected_count': 1,
'rejected_count': 750,
'detected': [
{
'6.wav': {
'confidence': '0.321170'
}
}
]
}
"""
if kws_result.__contains__('keywords'):
print('test_run_with_neg_testsets keywords: ',
kws_result['keywords'])
print('test_run_with_neg_testsets fa rate: ', kws_result['fa_rate'])
print('test_run_with_neg_testsets fa per hour: ',
kws_result['fa_per_hour'])
print('test_run_with_neg_testsets wave time(seconds): ',
kws_result['wav_time'])

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_roc(self):
# wav, neg_testsets, pos_testsets, roc
kws_set = 'roc'

# downloading neg_testsets file
testsets_file_path = os.path.join(self.workspace, NEG_TESTSETS_FILE)
if not os.path.exists(testsets_file_path):
r = requests.get(NEG_TESTSETS_URL)
with open(testsets_file_path, 'wb') as f:
f.write(r.content)

testsets_dir_name = os.path.splitext(
os.path.basename(NEG_TESTSETS_FILE))[0]
testsets_dir_name = os.path.splitext(
os.path.basename(testsets_dir_name))[0]
# neg_file_path = <workspace>/.tmp_roc/neg_testsets/
neg_file_path = os.path.join(self.workspace, testsets_dir_name)

# untar the neg_testsets file
if not os.path.exists(neg_file_path):
un_tar_gz(testsets_file_path, self.workspace)

# downloading pos_testsets file
testsets_file_path = os.path.join(self.workspace, POS_TESTSETS_FILE)
if not os.path.exists(testsets_file_path):
r = requests.get(POS_TESTSETS_URL)
with open(testsets_file_path, 'wb') as f:
f.write(r.content)

testsets_dir_name = os.path.splitext(
os.path.basename(POS_TESTSETS_FILE))[0]
testsets_dir_name = os.path.splitext(
os.path.basename(testsets_dir_name))[0]
# pos_file_path = <workspace>/.tmp_roc/pos_testsets/
pos_file_path = os.path.join(self.workspace, testsets_dir_name)

# untar the pos_testsets file
if not os.path.exists(pos_file_path):
un_tar_gz(testsets_file_path, self.workspace)

# downloading kwsbp -- a kws batch processing tool
kwsbp_file_path = os.path.join(self.workspace, 'kwsbp')
if not os.path.exists(kwsbp_file_path):
r = requests.get(KWSBP_URL)
with open(kwsbp_file_path, 'wb') as f:
f.write(r.content)

model = Model.from_pretrained(self.model_id)
self.assertTrue(model is not None)

cfg_preprocessor = dict(
type=Preprocessors.wav_to_lists, workspace=self.workspace)
preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio)
self.assertTrue(preprocessor is not None)

kwsbp_16k_pipline = pipeline(
pipeline_name=Pipelines.kws_kwsbp,
model=model,
preprocessor=preprocessor)
self.assertTrue(kwsbp_16k_pipline is not None)

kws_result = kwsbp_16k_pipline(
kws_type=kws_set, wav_path=[pos_file_path, neg_file_path])
"""
kws result json format example:
{
'kws_set': 'roc',
'keywords': ['小云小云'],
'小云小云': [
{'threshold': 0.0, 'recall': 0.953333, 'fa_per_hour': 1.007788},
{'threshold': 0.001, 'recall': 0.953333, 'fa_per_hour': 1.007788},
......
{'threshold': 0.999, 'recall': 0.004444, 'fa_per_hour': 0.0}
]
}
"""
if kws_result.__contains__('keywords'):
find_keyword = kws_result['keywords'][0]
print('test_run_with_roc keywords: ', find_keyword)
keyword_list = kws_result[find_keyword]
for item in iter(keyword_list):
threshold: float = item['threshold']
recall: float = item['recall']
fa_per_hour: float = item['fa_per_hour']
print(' threshold:', threshold, ' recall:', recall,
' fa_per_hour:', fa_per_hour)


if __name__ == '__main__':
unittest.main()

+ 52
- 0
tests/pipelines/test_multi_modal_embedding.py View File

@@ -0,0 +1,52 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import unittest

import numpy as np

from modelscope.models import Model
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class MultiModalEmbeddingTest(unittest.TestCase):
model_id = 'damo/multi-modal_clip-vit-large-patch14-chinese_multi-modal-embedding'
test_text = {'text': '一张风景图'}

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run(self):
pipe_line_multi_modal_embedding = pipeline(
Tasks.multi_modal_embedding, model=self.model_id)
test_str_embedding = pipe_line_multi_modal_embedding(
self.test_text)['text_embedding']
print(np.sum(np.abs(test_str_embedding)))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
pipe_line_multi_modal_embedding = pipeline(
task=Tasks.multi_modal_embedding, model=model)
test_str_embedding = pipe_line_multi_modal_embedding(
self.test_text)['text_embedding']
print(np.sum(np.abs(test_str_embedding)))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_name(self):
pipe_line_multi_modal_embedding = pipeline(
task=Tasks.multi_modal_embedding, model=self.model_id)
test_str_embedding = pipe_line_multi_modal_embedding(
self.test_text)['text_embedding']
print(np.sum(np.abs(test_str_embedding)))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipe_line_multi_modal_embedding = pipeline(
task=Tasks.multi_modal_embedding)
test_str_embedding = pipe_line_multi_modal_embedding(
self.test_text)['text_embedding']
print(np.sum(np.abs(test_str_embedding)))


if __name__ == '__main__':
unittest.main()

+ 1
- 1
tests/pipelines/test_speech_signal_process.py View File

@@ -34,7 +34,7 @@ class SpeechSignalProcessTest(unittest.TestCase):
# A temporary hack to provide c++ lib. Download it first.
download(AEC_LIB_URL, AEC_LIB_FILE)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run(self):
download(NEAREND_MIC_URL, NEAREND_MIC_FILE)
download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE)


+ 6
- 6
tests/pipelines/test_text_classification.py View File

@@ -3,9 +3,9 @@ import shutil
import unittest

from modelscope.models import Model
from modelscope.msdatasets import MsDataset
from modelscope.pipelines import SequenceClassificationPipeline, pipeline
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.pydatasets import PyDataset
from modelscope.utils.constant import Hubs, Tasks
from modelscope.utils.test_utils import test_level

@@ -28,7 +28,7 @@ class SequenceClassificationTest(unittest.TestCase):

print(data)

def printDataset(self, dataset: PyDataset):
def printDataset(self, dataset: MsDataset):
for i, r in enumerate(dataset):
if i > 10:
break
@@ -50,7 +50,7 @@ class SequenceClassificationTest(unittest.TestCase):
text_classification = pipeline(
task=Tasks.text_classification, model=self.model_id)
result = text_classification(
PyDataset.load(
MsDataset.load(
'glue',
subset_name='sst2',
split='train',
@@ -62,7 +62,7 @@ class SequenceClassificationTest(unittest.TestCase):
def test_run_with_default_model(self):
text_classification = pipeline(task=Tasks.text_classification)
result = text_classification(
PyDataset.load(
MsDataset.load(
'glue',
subset_name='sst2',
split='train',
@@ -78,7 +78,7 @@ class SequenceClassificationTest(unittest.TestCase):
text_classification = pipeline(
Tasks.text_classification, model=model, preprocessor=preprocessor)
# loaded from huggingface dataset
dataset = PyDataset.load(
dataset = MsDataset.load(
'glue',
subset_name='sst2',
split='train',
@@ -91,7 +91,7 @@ class SequenceClassificationTest(unittest.TestCase):
def test_run_with_modelscope_dataset(self):
text_classification = pipeline(task=Tasks.text_classification)
# loaded from modelscope dataset
dataset = PyDataset.load(
dataset = MsDataset.load(
'squad', split='train', target='context', hub=Hubs.modelscope)
result = text_classification(dataset)
self.printDataset(result)


Loading…
Cancel
Save