Browse Source

[to #43259593] task_oriented_conversation model hub use master branch

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9652513
master
ly119399 yingda.chen 3 years ago
parent
commit
4af5ae52ce
1 changed files with 6 additions and 16 deletions
  1. +6
    -16
      tests/pipelines/test_task_oriented_conversation.py

+ 6
- 16
tests/pipelines/test_task_oriented_conversation.py View File

@@ -108,8 +108,7 @@ class TaskOrientedConversationTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self): def test_run_by_direct_model_download(self):


cache_path = snapshot_download(
self.model_id, revision='task_oriented_conversation')
cache_path = snapshot_download(self.model_id)


preprocessor = DialogModelingPreprocessor(model_dir=cache_path) preprocessor = DialogModelingPreprocessor(model_dir=cache_path)
model = SpaceForDialogModeling( model = SpaceForDialogModeling(
@@ -128,8 +127,7 @@ class TaskOrientedConversationTest(unittest.TestCase):


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self): def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(
self.model_id, revision='task_oriented_conversation')
model = Model.from_pretrained(self.model_id)
preprocessor = DialogModelingPreprocessor(model_dir=model.model_dir) preprocessor = DialogModelingPreprocessor(model_dir=model.model_dir)


pipelines = [ pipelines = [
@@ -147,25 +145,17 @@ class TaskOrientedConversationTest(unittest.TestCase):
def test_run_with_model_name(self): def test_run_with_model_name(self):
pipelines = [ pipelines = [
pipeline( pipeline(
task=Tasks.task_oriented_conversation,
model=self.model_id,
model_revision='task_oriented_conversation'),
task=Tasks.task_oriented_conversation, model=self.model_id),
pipeline( pipeline(
task=Tasks.task_oriented_conversation,
model=self.model_id,
model_revision='task_oriented_conversation')
task=Tasks.task_oriented_conversation, model=self.model_id)
] ]
self.generate_and_print_dialog_response(pipelines) self.generate_and_print_dialog_response(pipelines)


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self): def test_run_with_default_model(self):
pipelines = [ pipelines = [
pipeline(
task=Tasks.task_oriented_conversation,
model_revision='task_oriented_conversation'),
pipeline(
task=Tasks.task_oriented_conversation,
model_revision='task_oriented_conversation')
pipeline(task=Tasks.task_oriented_conversation),
pipeline(task=Tasks.task_oriented_conversation)
] ]
self.generate_and_print_dialog_response(pipelines) self.generate_and_print_dialog_response(pipelines)




Loading…
Cancel
Save