Browse Source

dst test ready

master
ly119399 3 years ago
parent
commit
82f8d2aefd
4 changed files with 37 additions and 12 deletions
  1. +1
    -1
      modelscope/metainfo.py
  2. +2
    -3
      modelscope/pipelines/__init__.py
  3. +3
    -1
      modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py
  4. +31
    -7
      tests/pipelines/test_dialog_state_tracking.py

+ 1
- 1
modelscope/metainfo.py View File

@@ -108,7 +108,7 @@ class Preprocessors(object):
sen_cls_tokenizer = 'sen-cls-tokenizer' sen_cls_tokenizer = 'sen-cls-tokenizer'
dialog_intent_preprocessor = 'dialog-intent-preprocessor' dialog_intent_preprocessor = 'dialog-intent-preprocessor'
dialog_modeling_preprocessor = 'dialog-modeling-preprocessor' dialog_modeling_preprocessor = 'dialog-modeling-preprocessor'
dialog_state_tracking_preprocessor = 'dialog_state_tracking_preprocessor'
dialog_state_tracking_preprocessor = 'dialog-state-tracking-preprocessor'
sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer'
zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer'




+ 2
- 3
modelscope/pipelines/__init__.py View File

@@ -1,7 +1,6 @@
from .audio import LinearAECPipeline
from .audio.ans_pipeline import ANSPipeline
# from .audio import LinearAECPipeline
# from .audio.ans_pipeline import ANSPipeline
from .base import Pipeline from .base import Pipeline
from .builder import pipeline from .builder import pipeline
from .cv import * # noqa F403
from .multi_modal import * # noqa F403 from .multi_modal import * # noqa F403
from .nlp import * # noqa F403 from .nlp import * # noqa F403

+ 3
- 1
modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py View File

@@ -5,6 +5,7 @@ from typing import Any, Dict


from modelscope.utils.constant import Fields from modelscope.utils.constant import Fields
from modelscope.utils.type_assert import type_assert from modelscope.utils.type_assert import type_assert
from ...metainfo import Preprocessors
from ..base import Preprocessor from ..base import Preprocessor
from ..builder import PREPROCESSORS from ..builder import PREPROCESSORS
from .dst_processors import convert_examples_to_features, multiwoz22Processor from .dst_processors import convert_examples_to_features, multiwoz22Processor
@@ -12,7 +13,8 @@ from .dst_processors import convert_examples_to_features, multiwoz22Processor
__all__ = ['DialogStateTrackingPreprocessor'] __all__ = ['DialogStateTrackingPreprocessor']




@PREPROCESSORS.register_module(Fields.nlp, module_name=r'space-dst')
@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.dialog_state_tracking_preprocessor)
class DialogStateTrackingPreprocessor(Preprocessor): class DialogStateTrackingPreprocessor(Preprocessor):


def __init__(self, model_dir: str, *args, **kwargs): def __init__(self, model_dir: str, *args, **kwargs):


tests/pipelines/nlp/test_dialog_state_tracking.py → tests/pipelines/test_dialog_state_tracking.py View File

@@ -1,7 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
import tempfile
import unittest import unittest


from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.snapshot_download import snapshot_download
@@ -9,6 +6,7 @@ from modelscope.models import Model, SpaceForDialogStateTracking
from modelscope.pipelines import DialogStateTrackingPipeline, pipeline from modelscope.pipelines import DialogStateTrackingPipeline, pipeline
from modelscope.preprocessors import DialogStateTrackingPreprocessor from modelscope.preprocessors import DialogStateTrackingPreprocessor
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level




class DialogStateTrackingTest(unittest.TestCase): class DialogStateTrackingTest(unittest.TestCase):
@@ -77,9 +75,9 @@ class DialogStateTrackingTest(unittest.TestCase):
'User-8': 'Thank you, goodbye', 'User-8': 'Thank you, goodbye',
}] }]


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run(self): def test_run(self):
cache_path = '/Users/yangliu/Space/maas_model/nlp_space_dialog-state-tracking'
# cache_path = snapshot_download(self.model_id)
cache_path = snapshot_download(self.model_id)


model = SpaceForDialogStateTracking(cache_path) model = SpaceForDialogStateTracking(cache_path)
preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path)
@@ -108,9 +106,35 @@ class DialogStateTrackingTest(unittest.TestCase):


history_states.extend([result['dialog_states'], {}]) history_states.extend([result['dialog_states'], {}])


@unittest.skip('test with snapshot_download')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self): def test_run_with_model_from_modelhub(self):
pass
model = Model.from_pretrained(self.model_id)
preprocessor = DialogStateTrackingPreprocessor(
model_dir=model.model_dir)
pipelines = [
DialogStateTrackingPipeline(
model=model, preprocessor=preprocessor),
pipeline(
task=Tasks.dialog_state_tracking,
model=model,
preprocessor=preprocessor)
]

pipelines_len = len(pipelines)
import json
history_states = [{}]
utter = {}
for step, item in enumerate(self.test_case):
utter.update(item)
result = pipelines[step % pipelines_len]({
'utter':
utter,
'history_states':
history_states
})
print(json.dumps(result))

history_states.extend([result['dialog_states'], {}])




if __name__ == '__main__': if __name__ == '__main__':

Loading…
Cancel
Save