| @@ -10,6 +10,5 @@ from .multi_modal import OfaForImageCaptioning | |||||
| from .nlp import (BertForMaskedLM, BertForSequenceClassification, SbertForNLI, | from .nlp import (BertForMaskedLM, BertForSequenceClassification, SbertForNLI, | ||||
| SbertForSentenceSimilarity, SbertForSentimentClassification, | SbertForSentenceSimilarity, SbertForSentimentClassification, | ||||
| SbertForTokenClassification, SpaceForDialogIntentModel, | SbertForTokenClassification, SpaceForDialogIntentModel, | ||||
| SpaceForDialogModelingModel, | |||||
| SpaceForDialogStateTrackingModel, StructBertForMaskedLM, | |||||
| VecoForMaskedLM) | |||||
| SpaceForDialogModelingModel, SpaceForDialogStateTracking, | |||||
| StructBertForMaskedLM, VecoForMaskedLM) | |||||
| @@ -6,11 +6,11 @@ from ....utils.nlp.space.utils_dst import batch_to_device | |||||
| from ...base import Model, Tensor | from ...base import Model, Tensor | ||||
| from ...builder import MODELS | from ...builder import MODELS | ||||
| __all__ = ['SpaceForDialogStateTrackingModel'] | |||||
| __all__ = ['SpaceForDialogStateTracking'] | |||||
| @MODELS.register_module(Tasks.dialog_state_tracking, module_name=r'space') | @MODELS.register_module(Tasks.dialog_state_tracking, module_name=r'space') | ||||
| class SpaceForDialogStateTrackingModel(Model): | |||||
| class SpaceForDialogStateTracking(Model): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| """initialize the test generation model from the `model_dir` path. | """initialize the test generation model from the `model_dir` path. | ||||
| @@ -1,7 +1,7 @@ | |||||
| from typing import Any, Dict | from typing import Any, Dict | ||||
| from ...metainfo import Pipelines | from ...metainfo import Pipelines | ||||
| from ...models import SpaceForDialogStateTrackingModel | |||||
| from ...models import SpaceForDialogStateTracking | |||||
| from ...preprocessors import DialogStateTrackingPreprocessor | from ...preprocessors import DialogStateTrackingPreprocessor | ||||
| from ...utils.constant import Tasks | from ...utils.constant import Tasks | ||||
| from ..base import Pipeline | from ..base import Pipeline | ||||
| @@ -14,7 +14,7 @@ __all__ = ['DialogStateTrackingPipeline'] | |||||
| Tasks.dialog_state_tracking, module_name=Pipelines.dialog_state_tracking) | Tasks.dialog_state_tracking, module_name=Pipelines.dialog_state_tracking) | ||||
| class DialogStateTrackingPipeline(Pipeline): | class DialogStateTrackingPipeline(Pipeline): | ||||
| def __init__(self, model: SpaceForDialogStateTrackingModel, | |||||
| def __init__(self, model: SpaceForDialogStateTracking, | |||||
| preprocessor: DialogStateTrackingPreprocessor, **kwargs): | preprocessor: DialogStateTrackingPreprocessor, **kwargs): | ||||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | ||||
| @@ -5,7 +5,7 @@ import tempfile | |||||
| import unittest | import unittest | ||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.models import Model, SpaceForDialogStateTrackingModel | |||||
| from modelscope.models import Model, SpaceForDialogStateTracking | |||||
| from modelscope.pipelines import DialogStateTrackingPipeline, pipeline | from modelscope.pipelines import DialogStateTrackingPipeline, pipeline | ||||
| from modelscope.preprocessors import DialogStateTrackingPreprocessor | from modelscope.preprocessors import DialogStateTrackingPreprocessor | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| @@ -81,7 +81,7 @@ class DialogStateTrackingTest(unittest.TestCase): | |||||
| cache_path = '/Users/yangliu/Space/maas_model/nlp_space_dialog-state-tracking' | cache_path = '/Users/yangliu/Space/maas_model/nlp_space_dialog-state-tracking' | ||||
| # cache_path = snapshot_download(self.model_id) | # cache_path = snapshot_download(self.model_id) | ||||
| model = SpaceForDialogStateTrackingModel(cache_path) | |||||
| model = SpaceForDialogStateTracking(cache_path) | |||||
| preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) | preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) | ||||
| pipelines = [ | pipelines = [ | ||||
| DialogStateTrackingPipeline( | DialogStateTrackingPipeline( | ||||
| @@ -94,20 +94,19 @@ class DialogStateTrackingTest(unittest.TestCase): | |||||
| pipelines_len = len(pipelines) | pipelines_len = len(pipelines) | ||||
| import json | import json | ||||
| for _test_case in self.test_case: | |||||
| history_states = [{}] | |||||
| utter = {} | |||||
| for step, item in enumerate(_test_case): | |||||
| utter.update(item) | |||||
| result = pipelines[step % pipelines_len]({ | |||||
| 'utter': | |||||
| utter, | |||||
| 'history_states': | |||||
| history_states | |||||
| }) | |||||
| print(json.dumps(result)) | |||||
| history_states = [{}] | |||||
| utter = {} | |||||
| for step, item in enumerate(self.test_case): | |||||
| utter.update(item) | |||||
| result = pipelines[step % pipelines_len]({ | |||||
| 'utter': | |||||
| utter, | |||||
| 'history_states': | |||||
| history_states | |||||
| }) | |||||
| print(json.dumps(result)) | |||||
| history_states.extend([result['dialog_states'], {}]) | |||||
| history_states.extend([result['dialog_states'], {}]) | |||||
| @unittest.skip('test with snapshot_download') | @unittest.skip('test with snapshot_download') | ||||
| def test_run_with_model_from_modelhub(self): | def test_run_with_model_from_modelhub(self): | ||||