Browse Source

!15829 add modelzoo network transformer testcase

From: @anzhengqi
Reviewed-by: @jonyguo,@liucunwei
Signed-off-by: @jonyguo,@liucunwei
pull/15829/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
7dd9819e9b
1 changed files with 26 additions and 0 deletions
  1. +26
    -0
      tests/st/model_zoo_tests/transformer/test_transformer.py

+ 26
- 0
tests/st/model_zoo_tests/transformer/test_transformer.py View File

@@ -33,6 +33,8 @@ from model_zoo.official.nlp.transformer.src.transformer_for_train import Transfo
from model_zoo.official.nlp.transformer.src.config import cfg, transformer_net_cfg
from model_zoo.official.nlp.transformer.src.lr_schedule import create_dynamic_lr

from tests.st.model_zoo_tests import utils

DATA_DIR = ["/home/workspace/mindspore_dataset/transformer/test-mindrecord"]


@@ -201,5 +203,29 @@ def test_transformer():
assert per_step_mseconds <= expect_per_step_mseconds + 2


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_transformer_export_mindir():
cur_path = os.path.dirname(os.path.abspath(__file__))
model_path = "{}/../../../../model_zoo/official/nlp".format(cur_path)
model_name = "transformer"
utils.copy_files(model_path, cur_path, model_name)
cur_model_path = os.path.join(cur_path, model_name)
export_file = "transformer80_bs_0"
ckpt_path = os.path.join(utils.ckpt_root, "transformer/transformer_trained.ckpt")
print("ckpt_path:", ckpt_path)
old_list = ["'model_file': '/your/path/checkpoint_file'"]
new_list = ["'model_file': '{}'".format(ckpt_path)]
utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "src/eval_config.py"))
old_list = ["context.set_context(device_id=args.device_id)"]
new_list = ["context.set_context()"]
utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "export.py"))
exec_export_shell = "cd transformer; python -u export.py --file_name={}" \
" --file_format=MINDIR".format(export_file)
os.system(exec_export_shell)
assert os.path.exists(os.path.join(cur_model_path, "{}.mindir".format(export_file)))

if __name__ == '__main__':
test_transformer()

Loading…
Cancel
Save