Browse Source

add modelzoo network transformer testcase

pull/15829/head
anzhengqi 5 years ago
parent
commit
fa3354c19b
2 changed files with 27 additions and 1 deletions
  1. +26
    -0
      tests/st/model_zoo_tests/transformer/test_transformer.py
  2. +1
    -1
      tests/st/model_zoo_tests/utils.py

+ 26
- 0
tests/st/model_zoo_tests/transformer/test_transformer.py View File

@@ -33,6 +33,8 @@ from model_zoo.official.nlp.transformer.src.transformer_for_train import Transfo
from model_zoo.official.nlp.transformer.src.config import cfg, transformer_net_cfg
from model_zoo.official.nlp.transformer.src.lr_schedule import create_dynamic_lr

from tests.st.model_zoo_tests import utils

DATA_DIR = ["/home/workspace/mindspore_dataset/transformer/test-mindrecord"]


@@ -201,5 +203,29 @@ def test_transformer():
assert per_step_mseconds <= expect_per_step_mseconds + 2


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_transformer_export_mindir():
cur_path = os.path.dirname(os.path.abspath(__file__))
model_path = "{}/../../../../model_zoo/official/nlp".format(cur_path)
model_name = "transformer"
utils.copy_files(model_path, cur_path, model_name)
cur_model_path = os.path.join(cur_path, model_name)
export_file = "transformer80_bs_0"
ckpt_path = os.path.join(utils.ckpt_root, "transformer/transformer_trained.ckpt")
print("ckpt_path:", ckpt_path)
old_list = ["'model_file': '/your/path/checkpoint_file'"]
new_list = ["'model_file': '{}'".format(ckpt_path)]
utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "src/eval_config.py"))
old_list = ["context.set_context(device_id=args.device_id)"]
new_list = ["context.set_context()"]
utils.exec_sed_command(old_list, new_list, os.path.join(cur_model_path, "export.py"))
exec_export_shell = "cd transformer; python -u export.py --file_name={}" \
" --file_format=MINDIR".format(export_file)
os.system(exec_export_shell)
assert os.path.exists(os.path.join(cur_model_path, "{}.mindir".format(export_file)))

if __name__ == '__main__':
test_transformer()

+ 1
- 1
tests/st/model_zoo_tests/utils.py View File

@@ -14,7 +14,7 @@ from mindspore import log as logger

rank_table_path = "/home/workspace/mindspore_config/hccl/rank_table_8p.json"
data_root = "/home/workspace/mindspore_dataset/"
ckpt_root = "/home/workspace/mindspore_ckpt/"
ckpt_root = "/home/workspace/mindspore_dataset/checkpoint"
cur_path = os.path.split(os.path.realpath(__file__))[0]
geir_root = os.path.join(cur_path, "mindspore_geir")
arm_main_path = os.path.join(cur_path, "mindir_310infer_exe")


Loading…
Cancel
Save