diff --git a/modelscope/hub/constants.py b/modelscope/hub/constants.py index 08f7c31d..0ee451c2 100644 --- a/modelscope/hub/constants.py +++ b/modelscope/hub/constants.py @@ -1,5 +1,5 @@ MODELSCOPE_URL_SCHEME = 'http://' -DEFAULT_MODELSCOPE_DOMAIN = '101.201.119.157:32330' +DEFAULT_MODELSCOPE_DOMAIN = '47.94.223.21:31090' DEFAULT_MODELSCOPE_GITLAB_DOMAIN = '101.201.119.157:31102' DEFAULT_MODELSCOPE_GROUP = 'damo' diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index e1be09ed..eb556d86 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -16,6 +16,7 @@ class Models(object): palm = 'palm-v2' structbert = 'structbert' veco = 'veco' + space = 'space' # audio models sambert_hifi_16k = 'sambert-hifi-16k' @@ -67,7 +68,7 @@ class Pipelines(object): kws_kwsbp = 'kws-kwsbp' # multi-modal tasks - image_caption = 'image-caption' + image_caption = 'image-captioning' multi_modal_embedding = 'multi-modal-embedding' visual_question_answering = 'visual-question-answering' @@ -105,6 +106,9 @@ class Preprocessors(object): token_cls_tokenizer = 'token-cls-tokenizer' nli_tokenizer = 'nli-tokenizer' sen_cls_tokenizer = 'sen-cls-tokenizer' + dialog_intent_preprocessor = 'dialog-intent-preprocessor' + dialog_modeling_preprocessor = 'dialog-modeling-preprocessor' + dialog_state_tracking_preprocessor = 'dialog_state_tracking_preprocessor' sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' diff --git a/modelscope/models/__init__.py b/modelscope/models/__init__.py index 678d0b38..d7a4bfed 100644 --- a/modelscope/models/__init__.py +++ b/modelscope/models/__init__.py @@ -9,6 +9,7 @@ from .builder import MODELS, build_model from .multi_modal import OfaForImageCaptioning from .nlp import (BertForMaskedLM, BertForSequenceClassification, SbertForNLI, SbertForSentenceSimilarity, SbertForSentimentClassification, - SbertForTokenClassification, SpaceForDialogIntentModel, - SpaceForDialogModelingModel, SpaceForDialogStateTracking, - StructBertForMaskedLM, VecoForMaskedLM) + SbertForTokenClassification, SbertForZeroShotClassification, + SpaceForDialogIntent, SpaceForDialogModeling, + SpaceForDialogStateTracking, StructBertForMaskedLM, + VecoForMaskedLM) diff --git a/modelscope/models/nlp/space/dialog_intent_prediction_model.py b/modelscope/models/nlp/space/dialog_intent_prediction_model.py index 74e4e9e7..644af4c7 100644 --- a/modelscope/models/nlp/space/dialog_intent_prediction_model.py +++ b/modelscope/models/nlp/space/dialog_intent_prediction_model.py @@ -1,6 +1,9 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os from typing import Any, Dict +from ....metainfo import Models from ....preprocessors.space.fields.intent_field import IntentBPETextField from ....trainers.nlp.space.trainer.intent_trainer import IntentTrainer from ....utils.config import Config @@ -10,19 +13,18 @@ from ...builder import MODELS from .model.generator import Generator from .model.model_base import SpaceModelBase -__all__ = ['SpaceForDialogIntentModel'] +__all__ = ['SpaceForDialogIntent'] -@MODELS.register_module(Tasks.dialog_intent_prediction, module_name=r'space') -class SpaceForDialogIntentModel(Model): +@MODELS.register_module( + Tasks.dialog_intent_prediction, module_name=Models.space) +class SpaceForDialogIntent(Model): def __init__(self, model_dir: str, *args, **kwargs): """initialize the test generation model from the `model_dir` path. Args: model_dir (str): the model path. - model_cls (Optional[Any], optional): model loader, if None, use the - default loader to load model weights, by default None. """ super().__init__(model_dir, *args, **kwargs) diff --git a/modelscope/models/nlp/space/dialog_modeling_model.py b/modelscope/models/nlp/space/dialog_modeling_model.py index e11ef9fd..872155e2 100644 --- a/modelscope/models/nlp/space/dialog_modeling_model.py +++ b/modelscope/models/nlp/space/dialog_modeling_model.py @@ -1,6 +1,9 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import os from typing import Any, Dict, Optional +from ....metainfo import Models from ....preprocessors.space.fields.gen_field import MultiWOZBPETextField from ....trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer from ....utils.config import Config @@ -10,19 +13,17 @@ from ...builder import MODELS from .model.generator import Generator from .model.model_base import SpaceModelBase -__all__ = ['SpaceForDialogModelingModel'] +__all__ = ['SpaceForDialogModeling'] -@MODELS.register_module(Tasks.dialog_modeling, module_name=r'space') -class SpaceForDialogModelingModel(Model): +@MODELS.register_module(Tasks.dialog_modeling, module_name=Models.space) +class SpaceForDialogModeling(Model): def __init__(self, model_dir: str, *args, **kwargs): """initialize the test generation model from the `model_dir` path. Args: model_dir (str): the model path. - model_cls (Optional[Any], optional): model loader, if None, use the - default loader to load model weights, by default None. """ super().__init__(model_dir, *args, **kwargs) diff --git a/modelscope/models/nlp/space/model/gen_unified_transformer.py b/modelscope/models/nlp/space/model/gen_unified_transformer.py index 0f1b1a83..c5d50cd9 100644 --- a/modelscope/models/nlp/space/model/gen_unified_transformer.py +++ b/modelscope/models/nlp/space/model/gen_unified_transformer.py @@ -1,6 +1,5 @@ -""" -IntentUnifiedTransformer -""" +# Copyright (c) Alibaba, Inc. and its affiliates. + import torch from .unified_transformer import UnifiedTransformer diff --git a/modelscope/models/nlp/space/model/generator.py b/modelscope/models/nlp/space/model/generator.py index 08e1c765..c1521e3d 100644 --- a/modelscope/models/nlp/space/model/generator.py +++ b/modelscope/models/nlp/space/model/generator.py @@ -1,6 +1,4 @@ -""" -Generator class. -""" +# Copyright (c) Alibaba, Inc. and its affiliates. import math diff --git a/modelscope/models/nlp/space/model/intent_unified_transformer.py b/modelscope/models/nlp/space/model/intent_unified_transformer.py index b9c699d7..cae96479 100644 --- a/modelscope/models/nlp/space/model/intent_unified_transformer.py +++ b/modelscope/models/nlp/space/model/intent_unified_transformer.py @@ -1,6 +1,5 @@ -""" -IntentUnifiedTransformer -""" +# Copyright (c) Alibaba, Inc. and its affiliates. + import torch import torch.nn as nn import torch.nn.functional as F diff --git a/modelscope/models/nlp/space/model/model_base.py b/modelscope/models/nlp/space/model/model_base.py index 42496e76..7e0a6b0b 100644 --- a/modelscope/models/nlp/space/model/model_base.py +++ b/modelscope/models/nlp/space/model/model_base.py @@ -1,6 +1,5 @@ -""" -Model base -""" +# Copyright (c) Alibaba, Inc. and its affiliates. + import os import torch.nn as nn diff --git a/modelscope/models/nlp/space/model/unified_transformer.py b/modelscope/models/nlp/space/model/unified_transformer.py index 8060879d..17f9fde3 100644 --- a/modelscope/models/nlp/space/model/unified_transformer.py +++ b/modelscope/models/nlp/space/model/unified_transformer.py @@ -1,6 +1,4 @@ -""" -UnifiedTransformer -""" +# Copyright (c) Alibaba, Inc. and its affiliates. import numpy as np import torch diff --git a/modelscope/models/nlp/space/modules/embedder.py b/modelscope/models/nlp/space/modules/embedder.py index 4fb592ef..e68ac7d3 100644 --- a/modelscope/models/nlp/space/modules/embedder.py +++ b/modelscope/models/nlp/space/modules/embedder.py @@ -1,6 +1,4 @@ -""" -Embedder class. -""" +# Copyright (c) Alibaba, Inc. and its affiliates. import torch import torch.nn as nn diff --git a/modelscope/models/nlp/space/modules/feedforward.py b/modelscope/models/nlp/space/modules/feedforward.py index e9a5f4c7..43318eb6 100644 --- a/modelscope/models/nlp/space/modules/feedforward.py +++ b/modelscope/models/nlp/space/modules/feedforward.py @@ -1,6 +1,4 @@ -""" -FeedForward class. -""" +# Copyright (c) Alibaba, Inc. and its affiliates. import torch import torch.nn as nn diff --git a/modelscope/models/nlp/space/modules/functions.py b/modelscope/models/nlp/space/modules/functions.py index 45c02e21..daa62bb4 100644 --- a/modelscope/models/nlp/space/modules/functions.py +++ b/modelscope/models/nlp/space/modules/functions.py @@ -1,6 +1,4 @@ -""" -Helpful functions. -""" +# Copyright (c) Alibaba, Inc. and its affiliates. import numpy as np import torch diff --git a/modelscope/models/nlp/space/modules/multihead_attention.py b/modelscope/models/nlp/space/modules/multihead_attention.py index 209eab5e..d075e9c5 100644 --- a/modelscope/models/nlp/space/modules/multihead_attention.py +++ b/modelscope/models/nlp/space/modules/multihead_attention.py @@ -1,6 +1,4 @@ -""" -MultiheadAttention class. -""" +# Copyright (c) Alibaba, Inc. and its affiliates. import torch import torch.nn as nn @@ -53,8 +51,6 @@ class MultiheadAttention(nn.Module): if mask is not None: ''' mask: [batch size, num_heads, seq_len, seq_len] - mask后两维(seq_len, seq_len)矩阵来看,其中有的行可能都是true(1),对应句子中位看的行 - 导致softmax后该行的每个位置的attn prob都为1/n而非0,所以此处需重置为0 >>> F.softmax([-1e10, -100, -100]) >>> [0.00, 0.50, 0.50] diff --git a/modelscope/models/nlp/space/modules/transformer_block.py b/modelscope/models/nlp/space/modules/transformer_block.py index 5b6c79a5..37f968d9 100644 --- a/modelscope/models/nlp/space/modules/transformer_block.py +++ b/modelscope/models/nlp/space/modules/transformer_block.py @@ -1,6 +1,4 @@ -""" -TransformerBlock class. -""" +# Copyright (c) Alibaba, Inc. and its affiliates. import torch import torch.nn as nn diff --git a/modelscope/msdatasets/config.py b/modelscope/msdatasets/config.py index 00c24c3a..22390ed7 100644 --- a/modelscope/msdatasets/config.py +++ b/modelscope/msdatasets/config.py @@ -19,4 +19,4 @@ DOWNLOADED_DATASETS_PATH = Path( os.getenv('DOWNLOADED_DATASETS_PATH', DEFAULT_DOWNLOADED_DATASETS_PATH)) MS_HUB_ENDPOINT = os.environ.get('MS_HUB_ENDPOINT', - 'http://123.57.189.90:31752') + 'http://47.94.223.21:31752') diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 5f564d0b..36f87269 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -36,6 +36,11 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.zero_shot_classification: (Pipelines.zero_shot_classification, 'damo/nlp_structbert_zero-shot-classification_chinese-base'), + Tasks.dialog_intent_prediction: + (Pipelines.dialog_intent_prediction, + 'damo/nlp_space_dialog-intent-prediction'), + Tasks.dialog_modeling: (Pipelines.dialog_modeling, + 'damo/nlp_space_dialog-modeling'), Tasks.image_captioning: (Pipelines.image_caption, 'damo/ofa_image-caption_coco_large_en'), Tasks.image_generation: diff --git a/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py b/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py index 4677b62e..4b2e29dd 100644 --- a/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py +++ b/modelscope/pipelines/nlp/dialog_intent_prediction_pipeline.py @@ -1,7 +1,9 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + from typing import Any, Dict from ...metainfo import Pipelines -from ...models.nlp import SpaceForDialogIntentModel +from ...models.nlp import SpaceForDialogIntent from ...preprocessors import DialogIntentPredictionPreprocessor from ...utils.constant import Tasks from ..base import Pipeline @@ -15,7 +17,7 @@ __all__ = ['DialogIntentPredictionPipeline'] module_name=Pipelines.dialog_intent_prediction) class DialogIntentPredictionPipeline(Pipeline): - def __init__(self, model: SpaceForDialogIntentModel, + def __init__(self, model: SpaceForDialogIntent, preprocessor: DialogIntentPredictionPreprocessor, **kwargs): """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction @@ -26,7 +28,7 @@ class DialogIntentPredictionPipeline(Pipeline): super().__init__(model=model, preprocessor=preprocessor, **kwargs) self.model = model - # self.tokenizer = preprocessor.tokenizer + self.categories = preprocessor.categories def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: """process the prediction results @@ -41,6 +43,10 @@ class DialogIntentPredictionPipeline(Pipeline): pred = inputs['pred'] pos = np.where(pred == np.max(pred)) - result = {'pred': pred, 'label': pos[0]} + result = { + 'pred': pred, + 'label_pos': pos[0], + 'label': self.categories[pos[0][0]] + } return result diff --git a/modelscope/pipelines/nlp/dialog_modeling_pipeline.py b/modelscope/pipelines/nlp/dialog_modeling_pipeline.py index 29303d4b..76b00511 100644 --- a/modelscope/pipelines/nlp/dialog_modeling_pipeline.py +++ b/modelscope/pipelines/nlp/dialog_modeling_pipeline.py @@ -1,7 +1,9 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + from typing import Any, Dict, Optional from ...metainfo import Pipelines -from ...models.nlp import SpaceForDialogModelingModel +from ...models.nlp import SpaceForDialogModeling from ...preprocessors import DialogModelingPreprocessor from ...utils.constant import Tasks from ..base import Pipeline, Tensor @@ -14,7 +16,7 @@ __all__ = ['DialogModelingPipeline'] Tasks.dialog_modeling, module_name=Pipelines.dialog_modeling) class DialogModelingPipeline(Pipeline): - def __init__(self, model: SpaceForDialogModelingModel, + def __init__(self, model: SpaceForDialogModeling, preprocessor: DialogModelingPreprocessor, **kwargs): """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction @@ -40,7 +42,6 @@ class DialogModelingPipeline(Pipeline): inputs['resp']) assert len(sys_rsp) > 2 sys_rsp = sys_rsp[1:len(sys_rsp) - 1] - # sys_rsp = self.preprocessor.text_field.tokenizer. inputs['sys'] = sys_rsp diff --git a/modelscope/pipelines/outputs.py b/modelscope/pipelines/outputs.py index 5b9e36b7..f7592410 100644 --- a/modelscope/pipelines/outputs.py +++ b/modelscope/pipelines/outputs.py @@ -108,14 +108,19 @@ TASK_OUTPUTS = { # } Tasks.sentiment_classification: ['scores', 'labels'], + # zero-shot classification result for single sample + # { + # "labels": ["happy", "sad", "calm", "angry"], + # "scores": [0.9, 0.1, 0.05, 0.05] + # } + Tasks.zero_shot_classification: ['scores', 'labels'], + # nli result for single sample # { # "labels": ["happy", "sad", "calm", "angry"], # "scores": [0.9, 0.1, 0.05, 0.05] # } Tasks.nli: ['scores', 'labels'], - Tasks.dialog_modeling: [], - Tasks.dialog_intent_prediction: [], # { # "dialog_states": { @@ -153,6 +158,31 @@ TASK_OUTPUTS = { # } Tasks.dialog_state_tracking: ['dialog_states'], + # {'pred': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05, + # 1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04, + # 6.45841064e-05, 1.15997791e-04, 5.11605394e-05, 9.87020373e-01, + # 2.66957268e-05, 4.72324500e-05, 9.74208378e-05, 4.18022355e-05, + # 2.97343540e-05, 5.81317654e-05, 5.44203431e-05, 6.28319322e-05, + # 7.34537680e-05, 6.61411541e-05, 3.62534920e-05, 8.58885178e-05, + # 8.24327726e-05, 4.66077945e-05, 5.32869453e-05, 4.16190960e-05, + # 5.97518992e-05, 3.92273068e-05, 3.44069012e-05, 9.92335918e-05, + # 9.25978165e-05, 6.26462061e-05, 3.32317031e-05, 1.32061413e-03, + # 2.01607945e-05, 3.36636294e-05, 3.99156743e-05, 5.84108493e-05, + # 2.53432900e-05, 4.95731190e-04, 2.64443643e-05, 4.46992999e-05, + # 2.42672231e-05, 4.75615161e-05, 2.66230145e-05, 4.00083954e-05, + # 2.90536875e-04, 4.23891543e-05, 8.63691166e-05, 4.98188965e-05, + # 3.47019341e-05, 4.52718523e-05, 4.20905781e-05, 5.50173208e-05, + # 4.92360487e-05, 3.56021264e-05, 2.13957210e-05, 6.17428886e-05, + # 1.43893281e-04, 7.32152112e-05, 2.91354867e-04, 2.46623786e-05, + # 3.61441926e-05, 3.38475402e-05, 3.44323053e-05, 5.70138109e-05, + # 4.31488479e-05, 4.94503947e-05, 4.30105974e-05, 1.00963116e-04, + # 2.82062047e-05, 1.15582036e-04, 4.48261271e-05, 3.99339879e-05, + # 7.27692823e-05], dtype=float32), 'label_pos': array([11]), 'label': 'lost_or_stolen_card'} + Tasks.dialog_intent_prediction: ['pred', 'label_pos', 'label'], + + # sys : ['you', 'are', 'welcome', '.', 'have', 'a', 'great', 'day', '!'] + Tasks.dialog_modeling: ['sys'], + # ============ audio tasks =================== # audio processed for single file in PCM format diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index 3bd1f110..007a3ac1 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -15,7 +15,7 @@ __all__ = [ 'Tokenize', 'SequenceClassificationPreprocessor', 'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor', 'NLIPreprocessor', 'SentimentClassificationPreprocessor', - 'FillMaskPreprocessor' + 'FillMaskPreprocessor', 'ZeroShotClassificationPreprocessor' ] @@ -421,3 +421,47 @@ class TokenClassifcationPreprocessor(Preprocessor): 'attention_mask': attention_mask, 'token_type_ids': token_type_ids } + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.zero_shot_cls_tokenizer) +class ZeroShotClassificationPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + model_dir (str): model path + """ + + super().__init__(*args, **kwargs) + + from sofa import SbertTokenizer + self.model_dir: str = model_dir + self.sequence_length = kwargs.pop('sequence_length', 512) + self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) + + @type_assert(object, str) + def __call__(self, data: str, hypothesis_template: str, + candidate_labels: list) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + pairs = [[data, hypothesis_template.format(label)] + for label in candidate_labels] + + features = self.tokenizer( + pairs, + padding=True, + truncation=True, + max_length=self.sequence_length, + return_tensors='pt', + truncation_strategy='only_first') + return features diff --git a/modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py b/modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py index 4b46b044..2ceede02 100644 --- a/modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py +++ b/modelscope/preprocessors/space/dialog_intent_prediction_preprocessor.py @@ -3,6 +3,9 @@ import os from typing import Any, Dict +import json + +from ...metainfo import Preprocessors from ...utils.config import Config from ...utils.constant import Fields, ModelFile from ...utils.type_assert import type_assert @@ -13,7 +16,8 @@ from .fields.intent_field import IntentBPETextField __all__ = ['DialogIntentPredictionPreprocessor'] -@PREPROCESSORS.register_module(Fields.nlp, module_name=r'space-intent') +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.dialog_intent_preprocessor) class DialogIntentPredictionPreprocessor(Preprocessor): def __init__(self, model_dir: str, *args, **kwargs): @@ -30,6 +34,11 @@ class DialogIntentPredictionPreprocessor(Preprocessor): self.text_field = IntentBPETextField( self.model_dir, config=self.config) + self.categories = None + with open(os.path.join(self.model_dir, 'categories.json'), 'r') as f: + self.categories = json.load(f) + assert len(self.categories) == 77 + @type_assert(object, str) def __call__(self, data: str) -> Dict[str, Any]: """process the raw input data diff --git a/modelscope/preprocessors/space/dialog_modeling_preprocessor.py b/modelscope/preprocessors/space/dialog_modeling_preprocessor.py index d5e02c4a..db83d906 100644 --- a/modelscope/preprocessors/space/dialog_modeling_preprocessor.py +++ b/modelscope/preprocessors/space/dialog_modeling_preprocessor.py @@ -3,6 +3,7 @@ import os from typing import Any, Dict +from ...metainfo import Preprocessors from ...utils.config import Config from ...utils.constant import Fields, ModelFile from ...utils.type_assert import type_assert @@ -13,7 +14,8 @@ from .fields.gen_field import MultiWOZBPETextField __all__ = ['DialogModelingPreprocessor'] -@PREPROCESSORS.register_module(Fields.nlp, module_name=r'space-modeling') +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.dialog_modeling_preprocessor) class DialogModelingPreprocessor(Preprocessor): def __init__(self, model_dir: str, *args, **kwargs): diff --git a/modelscope/preprocessors/space/fields/gen_field.py b/modelscope/preprocessors/space/fields/gen_field.py index fa037145..28928029 100644 --- a/modelscope/preprocessors/space/fields/gen_field.py +++ b/modelscope/preprocessors/space/fields/gen_field.py @@ -1,6 +1,5 @@ -""" -Field class -""" +# Copyright (c) Alibaba, Inc. and its affiliates. + import os import random from collections import OrderedDict @@ -8,7 +7,6 @@ from itertools import chain import numpy as np -from ....utils.constant import ModelFile from ....utils.nlp.space import ontology, utils from ....utils.nlp.space.db_ops import MultiWozDB from ....utils.nlp.space.utils import list2np diff --git a/modelscope/preprocessors/space/fields/intent_field.py b/modelscope/preprocessors/space/fields/intent_field.py index 15bd20b6..d7f69eec 100644 --- a/modelscope/preprocessors/space/fields/intent_field.py +++ b/modelscope/preprocessors/space/fields/intent_field.py @@ -1,6 +1,5 @@ -""" -Intent Field class -""" +# Copyright (c) Alibaba, Inc. and its affiliates. + import glob import multiprocessing import os diff --git a/modelscope/utils/nlp/space/db_ops.py b/modelscope/utils/nlp/space/db_ops.py index 10c3aab7..880b018b 100644 --- a/modelscope/utils/nlp/space/db_ops.py +++ b/modelscope/utils/nlp/space/db_ops.py @@ -308,14 +308,6 @@ if __name__ == '__main__': 'attraction': 5, 'train': 1, } - # for ent in res: - # if reidx.get(domain): - # report.append(ent[reidx[domain]]) - # for ent in res: - # if 'name' in ent: - # report.append(ent['name']) - # if 'trainid' in ent: - # report.append(ent['trainid']) print(constraints) print(res) print('count:', len(res), '\nnames:', report) diff --git a/modelscope/utils/nlp/space/ontology.py b/modelscope/utils/nlp/space/ontology.py index 4f27168a..99b084bb 100644 --- a/modelscope/utils/nlp/space/ontology.py +++ b/modelscope/utils/nlp/space/ontology.py @@ -123,19 +123,6 @@ dialog_act_all_slots = all_slots + ['choice', 'open'] # no need of this, just covert slot to [slot] e.g. pricerange -> [pricerange] slot_name_to_slot_token = {} -# special slot tokens in responses -# not use at the momoent -slot_name_to_value_token = { - # 'entrance fee': '[value_price]', - # 'pricerange': '[value_price]', - # 'arriveby': '[value_time]', - # 'leaveat': '[value_time]', - # 'departure': '[value_place]', - # 'destination': '[value_place]', - # 'stay': 'count', - # 'people': 'count' -} - # eos tokens definition eos_tokens = { 'user': '', diff --git a/modelscope/utils/nlp/space/utils.py b/modelscope/utils/nlp/space/utils.py index ba956b7d..ef38684a 100644 --- a/modelscope/utils/nlp/space/utils.py +++ b/modelscope/utils/nlp/space/utils.py @@ -53,16 +53,9 @@ def clean_replace(s, r, t, forward=True, backward=False): return s, -1 return s[:idx] + t + s[idx_r:], idx_r - # source, replace, target = s, r, t - # count = 0 sidx = 0 while sidx != -1: s, sidx = clean_replace_single(s, r, t, forward, backward, sidx) - # count += 1 - # print(s, sidx) - # if count == 20: - # print(source, '\n', replace, '\n', target) - # quit() return s @@ -193,14 +186,3 @@ class MultiWOZVocab(object): return self._idx2word[idx] else: return self._idx2word[idx] + '(o)' - - # def sentence_decode(self, index_list, eos=None, indicate_oov=False): - # l = [self.decode(_, indicate_oov) for _ in index_list] - # if not eos or eos not in l: - # return ' '.join(l) - # else: - # idx = l.index(eos) - # return ' '.join(l[:idx]) - # - # def nl_decode(self, l, eos=None): - # return [self.sentence_decode(_, eos) + '\n' for _ in l] diff --git a/tests/pipelines/test_dialog_intent_prediction.py b/tests/pipelines/test_dialog_intent_prediction.py index ae3a9bf1..051f979b 100644 --- a/tests/pipelines/test_dialog_intent_prediction.py +++ b/tests/pipelines/test_dialog_intent_prediction.py @@ -3,10 +3,11 @@ import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model -from modelscope.models.nlp import SpaceForDialogIntentModel +from modelscope.models.nlp import SpaceForDialogIntent from modelscope.pipelines import DialogIntentPredictionPipeline, pipeline from modelscope.preprocessors import DialogIntentPredictionPreprocessor from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level class DialogIntentPredictionTest(unittest.TestCase): @@ -16,11 +17,11 @@ class DialogIntentPredictionTest(unittest.TestCase): 'I still have not received my new card, I ordered over a week ago.' ] - @unittest.skip('test with snapshot_download') + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run(self): cache_path = snapshot_download(self.model_id) preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path) - model = SpaceForDialogIntentModel( + model = SpaceForDialogIntent( model_dir=cache_path, text_field=preprocessor.text_field, config=preprocessor.config) @@ -37,6 +38,7 @@ class DialogIntentPredictionTest(unittest.TestCase): for my_pipeline, item in list(zip(pipelines, self.test_case)): print(my_pipeline(item)) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_from_modelhub(self): model = Model.from_pretrained(self.model_id) preprocessor = DialogIntentPredictionPreprocessor( diff --git a/tests/pipelines/test_dialog_modeling.py b/tests/pipelines/test_dialog_modeling.py index 79644bc5..cd17502e 100644 --- a/tests/pipelines/test_dialog_modeling.py +++ b/tests/pipelines/test_dialog_modeling.py @@ -1,15 +1,13 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import os -import os.path as osp -import tempfile import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model -from modelscope.models.nlp import SpaceForDialogModelingModel +from modelscope.models.nlp import SpaceForDialogModeling from modelscope.pipelines import DialogModelingPipeline, pipeline from modelscope.preprocessors import DialogModelingPreprocessor from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level class DialogModelingTest(unittest.TestCase): @@ -91,13 +89,13 @@ class DialogModelingTest(unittest.TestCase): } } - @unittest.skip('test with snapshot_download') + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run(self): cache_path = snapshot_download(self.model_id) preprocessor = DialogModelingPreprocessor(model_dir=cache_path) - model = SpaceForDialogModelingModel( + model = SpaceForDialogModeling( model_dir=cache_path, text_field=preprocessor.text_field, config=preprocessor.config) @@ -120,6 +118,7 @@ class DialogModelingTest(unittest.TestCase): }) print('sys : {}'.format(result['sys'])) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_from_modelhub(self): model = Model.from_pretrained(self.model_id) preprocessor = DialogModelingPreprocessor(model_dir=model.model_dir) diff --git a/tests/pipelines/test_nli.py b/tests/pipelines/test_nli.py index 0c8da8b4..ef824aa9 100644 --- a/tests/pipelines/test_nli.py +++ b/tests/pipelines/test_nli.py @@ -29,7 +29,7 @@ class NLITest(unittest.TestCase): f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' f'pipeline1: {pipeline2(input=(self.sentence1, self.sentence2))}') - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_model_from_modelhub(self): model = Model.from_pretrained(self.model_id) tokenizer = NLIPreprocessor(model.model_dir) diff --git a/tests/pipelines/test_sentiment_classification.py b/tests/pipelines/test_sentiment_classification.py index 0ba22d5c..829c0f7d 100644 --- a/tests/pipelines/test_sentiment_classification.py +++ b/tests/pipelines/test_sentiment_classification.py @@ -42,7 +42,7 @@ class SentimentClassificationTest(unittest.TestCase): preprocessor=tokenizer) print(pipeline_ins(input=self.sentence1)) - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_model_name(self): pipeline_ins = pipeline( task=Tasks.sentiment_classification, model=self.model_id)