|
|
@@ -1,7 +1,4 @@ |
|
|
# Copyright (c) Alibaba, Inc. and its affiliates. |
|
|
# Copyright (c) Alibaba, Inc. and its affiliates. |
|
|
import os |
|
|
|
|
|
import os.path as osp |
|
|
|
|
|
import tempfile |
|
|
|
|
|
import unittest |
|
|
import unittest |
|
|
|
|
|
|
|
|
from modelscope.hub.snapshot_download import snapshot_download |
|
|
from modelscope.hub.snapshot_download import snapshot_download |
|
|
@@ -9,6 +6,7 @@ from modelscope.models import Model, SpaceForDialogStateTracking |
|
|
from modelscope.pipelines import DialogStateTrackingPipeline, pipeline |
|
|
from modelscope.pipelines import DialogStateTrackingPipeline, pipeline |
|
|
from modelscope.preprocessors import DialogStateTrackingPreprocessor |
|
|
from modelscope.preprocessors import DialogStateTrackingPreprocessor |
|
|
from modelscope.utils.constant import Tasks |
|
|
from modelscope.utils.constant import Tasks |
|
|
|
|
|
from modelscope.utils.test_utils import test_level |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DialogStateTrackingTest(unittest.TestCase): |
|
|
class DialogStateTrackingTest(unittest.TestCase): |
|
|
@@ -77,9 +75,9 @@ class DialogStateTrackingTest(unittest.TestCase): |
|
|
'User-8': 'Thank you, goodbye', |
|
|
'User-8': 'Thank you, goodbye', |
|
|
}] |
|
|
}] |
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
def test_run(self): |
|
|
def test_run(self): |
|
|
cache_path = '/Users/yangliu/Space/maas_model/nlp_space_dialog-state-tracking' |
|
|
|
|
|
# cache_path = snapshot_download(self.model_id) |
|
|
|
|
|
|
|
|
cache_path = snapshot_download(self.model_id) |
|
|
|
|
|
|
|
|
model = SpaceForDialogStateTracking(cache_path) |
|
|
model = SpaceForDialogStateTracking(cache_path) |
|
|
preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) |
|
|
preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) |
|
|
@@ -108,9 +106,35 @@ class DialogStateTrackingTest(unittest.TestCase): |
|
|
|
|
|
|
|
|
history_states.extend([result['dialog_states'], {}]) |
|
|
history_states.extend([result['dialog_states'], {}]) |
|
|
|
|
|
|
|
|
@unittest.skip('test with snapshot_download') |
|
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
def test_run_with_model_from_modelhub(self): |
|
|
def test_run_with_model_from_modelhub(self): |
|
|
pass |
|
|
|
|
|
|
|
|
model = Model.from_pretrained(self.model_id) |
|
|
|
|
|
preprocessor = DialogStateTrackingPreprocessor( |
|
|
|
|
|
model_dir=model.model_dir) |
|
|
|
|
|
pipelines = [ |
|
|
|
|
|
DialogStateTrackingPipeline( |
|
|
|
|
|
model=model, preprocessor=preprocessor), |
|
|
|
|
|
pipeline( |
|
|
|
|
|
task=Tasks.dialog_state_tracking, |
|
|
|
|
|
model=model, |
|
|
|
|
|
preprocessor=preprocessor) |
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
pipelines_len = len(pipelines) |
|
|
|
|
|
import json |
|
|
|
|
|
history_states = [{}] |
|
|
|
|
|
utter = {} |
|
|
|
|
|
for step, item in enumerate(self.test_case): |
|
|
|
|
|
utter.update(item) |
|
|
|
|
|
result = pipelines[step % pipelines_len]({ |
|
|
|
|
|
'utter': |
|
|
|
|
|
utter, |
|
|
|
|
|
'history_states': |
|
|
|
|
|
history_states |
|
|
|
|
|
}) |
|
|
|
|
|
print(json.dumps(result)) |
|
|
|
|
|
|
|
|
|
|
|
history_states.extend([result['dialog_states'], {}]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
if __name__ == '__main__': |