From fa3354c19b269330c35e924f5552c5bab601ee68 Mon Sep 17 00:00:00 2001 From: anzhengqi Date: Tue, 27 Apr 2021 16:07:06 +0800 Subject: [PATCH] add modelzoo network transformer testcase --- .../transformer/test_transformer.py | 26 +++++++++++++++++++ tests/st/model_zoo_tests/utils.py | 2 +- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/tests/st/model_zoo_tests/transformer/test_transformer.py b/tests/st/model_zoo_tests/transformer/test_transformer.py index fe7766aad7..39732e057a 100644 --- a/tests/st/model_zoo_tests/transformer/test_transformer.py +++ b/tests/st/model_zoo_tests/transformer/test_transformer.py @@ -33,6 +33,8 @@ from model_zoo.official.nlp.transformer.src.transformer_for_train import Transfo from model_zoo.official.nlp.transformer.src.config import cfg, transformer_net_cfg from model_zoo.official.nlp.transformer.src.lr_schedule import create_dynamic_lr +from tests.st.model_zoo_tests import utils + DATA_DIR = ["/home/workspace/mindspore_dataset/transformer/test-mindrecord"] @@ -201,5 +203,29 @@ def test_transformer(): assert per_step_mseconds <= expect_per_step_mseconds + 2 +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_transformer_export_mindir(): + cur_path = os.path.dirname(os.path.abspath(__file__)) + model_path = "{}/../../../../model_zoo/official/nlp".format(cur_path) + model_name = "transformer" + utils.copy_files(model_path, cur_path, model_name) + cur_model_path = os.path.join(cur_path, model_name) + export_file = "transformer80_bs_0" + ckpt_path = os.path.join(utils.ckpt_root, "transformer/transformer_trained.ckpt") + print("ckpt_path:", ckpt_path) + old_list = ["'model_file': '/your/path/checkpoint_file'"] + new_list = ["'model_file': '{}'".format(ckpt_path)] + utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "src/eval_config.py")) + old_list = ["context.set_context(device_id=args.device_id)"] + new_list = ["context.set_context()"] + utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "export.py")) + exec_export_shell = "cd transformer; python -u export.py --file_name={}" \ + " --file_format=MINDIR".format(export_file) + os.system(exec_export_shell) + assert os.path.exists(os.path.join(cur_model_path, "{}.mindir".format(export_file))) + if __name__ == '__main__': test_transformer() diff --git a/tests/st/model_zoo_tests/utils.py b/tests/st/model_zoo_tests/utils.py index a5b3152031..fe773449d9 100644 --- a/tests/st/model_zoo_tests/utils.py +++ b/tests/st/model_zoo_tests/utils.py @@ -14,7 +14,7 @@ from mindspore import log as logger rank_table_path = "/home/workspace/mindspore_config/hccl/rank_table_8p.json" data_root = "/home/workspace/mindspore_dataset/" -ckpt_root = "/home/workspace/mindspore_ckpt/" +ckpt_root = "/home/workspace/mindspore_dataset/checkpoint" cur_path = os.path.split(os.path.realpath(__file__))[0] geir_root = os.path.join(cur_path, "mindspore_geir") arm_main_path = os.path.join(cur_path, "mindir_310infer_exe")