Browse Source

dst test ready

master
ly119399 3 years ago
parent
commit
82f8d2aefd
4 changed files with 37 additions and 12 deletions
  1. +1
    -1
      modelscope/metainfo.py
  2. +2
    -3
      modelscope/pipelines/__init__.py
  3. +3
    -1
      modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py
  4. +31
    -7
      tests/pipelines/test_dialog_state_tracking.py

+ 1
- 1
modelscope/metainfo.py View File

@@ -108,7 +108,7 @@ class Preprocessors(object):
sen_cls_tokenizer = 'sen-cls-tokenizer'
dialog_intent_preprocessor = 'dialog-intent-preprocessor'
dialog_modeling_preprocessor = 'dialog-modeling-preprocessor'
dialog_state_tracking_preprocessor = 'dialog_state_tracking_preprocessor'
dialog_state_tracking_preprocessor = 'dialog-state-tracking-preprocessor'
sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer'
zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer'



+ 2
- 3
modelscope/pipelines/__init__.py View File

@@ -1,7 +1,6 @@
from .audio import LinearAECPipeline
from .audio.ans_pipeline import ANSPipeline
# from .audio import LinearAECPipeline
# from .audio.ans_pipeline import ANSPipeline
from .base import Pipeline
from .builder import pipeline
from .cv import * # noqa F403
from .multi_modal import * # noqa F403
from .nlp import * # noqa F403

+ 3
- 1
modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py View File

@@ -5,6 +5,7 @@ from typing import Any, Dict

from modelscope.utils.constant import Fields
from modelscope.utils.type_assert import type_assert
from ...metainfo import Preprocessors
from ..base import Preprocessor
from ..builder import PREPROCESSORS
from .dst_processors import convert_examples_to_features, multiwoz22Processor
@@ -12,7 +13,8 @@ from .dst_processors import convert_examples_to_features, multiwoz22Processor
__all__ = ['DialogStateTrackingPreprocessor']


@PREPROCESSORS.register_module(Fields.nlp, module_name=r'space-dst')
@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.dialog_state_tracking_preprocessor)
class DialogStateTrackingPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):


tests/pipelines/nlp/test_dialog_state_tracking.py → tests/pipelines/test_dialog_state_tracking.py View File

@@ -1,7 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
import tempfile
import unittest

from modelscope.hub.snapshot_download import snapshot_download
@@ -9,6 +6,7 @@ from modelscope.models import Model, SpaceForDialogStateTracking
from modelscope.pipelines import DialogStateTrackingPipeline, pipeline
from modelscope.preprocessors import DialogStateTrackingPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class DialogStateTrackingTest(unittest.TestCase):
@@ -77,9 +75,9 @@ class DialogStateTrackingTest(unittest.TestCase):
'User-8': 'Thank you, goodbye',
}]

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run(self):
cache_path = '/Users/yangliu/Space/maas_model/nlp_space_dialog-state-tracking'
# cache_path = snapshot_download(self.model_id)
cache_path = snapshot_download(self.model_id)

model = SpaceForDialogStateTracking(cache_path)
preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path)
@@ -108,9 +106,35 @@ class DialogStateTrackingTest(unittest.TestCase):

history_states.extend([result['dialog_states'], {}])

@unittest.skip('test with snapshot_download')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
pass
model = Model.from_pretrained(self.model_id)
preprocessor = DialogStateTrackingPreprocessor(
model_dir=model.model_dir)
pipelines = [
DialogStateTrackingPipeline(
model=model, preprocessor=preprocessor),
pipeline(
task=Tasks.dialog_state_tracking,
model=model,
preprocessor=preprocessor)
]

pipelines_len = len(pipelines)
import json
history_states = [{}]
utter = {}
for step, item in enumerate(self.test_case):
utter.update(item)
result = pipelines[step % pipelines_len]({
'utter':
utter,
'history_states':
history_states
})
print(json.dumps(result))

history_states.extend([result['dialog_states'], {}])


if __name__ == '__main__':

Loading…
Cancel
Save