From f7c9c1018ab8ced1a2f5559cc1b49866915ce6c4 Mon Sep 17 00:00:00 2001 From: gaojing Date: Sun, 27 Dec 2020 02:55:16 -0500 Subject: [PATCH] modified gnmt --- model_zoo/official/nlp/gnmt_v2/README.md | 32 ++--- .../official/nlp/gnmt_v2/config/config.json | 1 - .../official/nlp/gnmt_v2/config/config.py | 3 - .../nlp/gnmt_v2/config/config_test.json | 1 - model_zoo/official/nlp/gnmt_v2/eval.py | 3 - model_zoo/official/nlp/gnmt_v2/export.py | 97 ++++++++++++++ .../nlp/gnmt_v2/mindspore_hub_conf.py | 40 ++++++ .../official/nlp/gnmt_v2/requirements.txt | 6 +- .../scripts/run_distributed_train_ascend.sh | 13 +- .../scripts/run_standalone_eval_ascend.sh | 19 ++- .../scripts/run_standalone_train_ascend.sh | 15 +-- .../nlp/gnmt_v2/src/dataset/load_dataset.py | 15 +-- .../nlp/gnmt_v2/src/gnmt_model/attention.py | 119 +++++++++--------- .../nlp/gnmt_v2/src/gnmt_model/beam_search.py | 51 +++++--- .../gnmt_v2/src/gnmt_model/gnmt_for_infer.py | 63 ++++------ .../nlp/gnmt_v2/src/utils/load_weights.py | 46 +++---- model_zoo/official/nlp/gnmt_v2/train.py | 10 +- .../model_zoo_tests/gnmt_v2/test_gnmt_v2.py | 6 - .../model_zoo_tests/gnmt_v2/test_gnmt_v2.sh | 48 +++---- 19 files changed, 331 insertions(+), 257 deletions(-) create mode 100644 model_zoo/official/nlp/gnmt_v2/export.py create mode 100644 model_zoo/official/nlp/gnmt_v2/mindspore_hub_conf.py diff --git a/model_zoo/official/nlp/gnmt_v2/README.md b/model_zoo/official/nlp/gnmt_v2/README.md index 25e2a229b0..64d2745073 100644 --- a/model_zoo/official/nlp/gnmt_v2/README.md +++ b/model_zoo/official/nlp/gnmt_v2/README.md @@ -58,8 +58,8 @@ Note that you can run the scripts based on the dataset mentioned in original pap ```txt numpy -sacrebleu==1.2.10 -sacremoses==0.0.19 +sacrebleu==1.4.14 +sacremoses==0.0.35 subword_nmt==0.3.7 ``` @@ -77,15 +77,15 @@ After dataset preparation, you can start training and evaluation as follows: ```bash # run training example cd ./scripts -sh run_standalone_train_ascend.sh DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET +sh run_standalone_train_ascend.sh PRE_TRAIN_DATASET # run distributed training example cd ./scripts -sh run_distributed_train_ascend.sh RANK_TABLE_ADDR DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET +sh run_distributed_train_ascend.sh RANK_TABLE_ADDR PRE_TRAIN_DATASET # run evaluation example cd ./scripts -sh run_standalone_eval_ascend.sh DATASET_SCHEMA_TEST TEST_DATASET EXISTED_CKPT_PATH \ +sh run_standalone_eval_ascend.sh TEST_DATASET EXISTED_CKPT_PATH \ VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET ``` @@ -135,11 +135,13 @@ The GNMT network script and code result are as follows: │ ├──lr_scheduler.py // Learning rate scheduler. │ ├──optimizer.py // Optimizer. ├── scripts - │ ├──run_distributed_train_ascend.sh // shell script for distributed train on ascend. - │ ├──run_standalone_eval_ascend.sh // shell script for standalone eval on ascend. - │ ├──run_standalone_train_ascend.sh // shell script for standalone eval on ascend. - ├── create_dataset.py // dataset preparation. + │ ├──run_distributed_train_ascend.sh // Shell script for distributed train on ascend. + │ ├──run_standalone_eval_ascend.sh // Shell script for standalone eval on ascend. + │ ├──run_standalone_train_ascend.sh // Shell script for standalone eval on ascend. + ├── create_dataset.py // Dataset preparation. ├── eval.py // Infer API entry. + ├── export.py // Export checkpoint file into air models. + ├── mindspore_hub_conf.py // Hub config. ├── requirements.txt // Requirements of third party package. ├── train.py // Train API entry. ``` @@ -187,7 +189,7 @@ For more configuration details, please refer the script `config/config.py` file. ## Training Process -For a pre-trained model, configure the following options in the `scripts/run_standalone_train_ascend.json` file: +For a pre-trained model, configure the following options in the `config/config.json` file: - Select an optimizer ('momentum/adam/lamb' is available). - Specify `ckpt_prefix` and `ckpt_path` in `checkpoint_path` to save the model file. @@ -198,17 +200,17 @@ Start task training on a single device and run the shell script `scripts/run_sta ```bash cd ./scripts -sh run_standalone_train_ascend.sh DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET +sh run_standalone_train_ascend.sh PRE_TRAIN_DATASET ``` -In this script, the `DATASET_SCHEMA_TRAIN` and `PRE_TRAIN_DATASET` are the dataset schema and dataset address. +In this script, the `PRE_TRAIN_DATASET` is the dataset address. Run `scripts/run_distributed_train_ascend.sh` for distributed training of GNMTv2 model. Task training on multiple devices and run the following command in bash to be executed in `scripts/`.: ```bash cd ./scripts -sh run_distributed_train_ascend.sh RANK_TABLE_ADDR DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET +sh run_distributed_train_ascend.sh RANK_TABLE_ADDR PRE_TRAIN_DATASET ``` Note: the `RANK_TABLE_ADDR` is the hccl_json file assigned when distributed training is running. @@ -224,11 +226,11 @@ Run the shell script `scripts/run_standalone_eval_ascend.sh` to process the outp ```bash cd ./scripts sh run_standalone_eval_ascend.sh -sh run_standalone_eval_ascend.sh DATASET_SCHEMA_TEST TEST_DATASET EXISTED_CKPT_PATH \ +sh run_standalone_eval_ascend.sh TEST_DATASET EXISTED_CKPT_PATH \ VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET ``` -The `DATASET_SCHEMA_TEST` and the `TEST_DATASET` are the schema and address of inference dataset respectively, and `EXISTED_CKPT_PATH` is the path of the model file generated during training process. +The `TEST_DATASET` is the address of inference dataset, and `EXISTED_CKPT_PATH` is the path of the model file generated during training process. The `VOCAB_ADDR` is the vocabulary address, `BPE_CODE_ADDR` is the bpe code address and the `TEST_TARGET` are the path of answers. # [Model Description](#contents) diff --git a/model_zoo/official/nlp/gnmt_v2/config/config.json b/model_zoo/official/nlp/gnmt_v2/config/config.json index e1c6c3fab2..65911bf583 100644 --- a/model_zoo/official/nlp/gnmt_v2/config/config.json +++ b/model_zoo/official/nlp/gnmt_v2/config/config.json @@ -3,7 +3,6 @@ "random_seed": 50, "epochs": 6, "batch_size": 128, - "dataset_schema": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json", "pre_train_dataset": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord", "fine_tune_dataset": null, "valid_dataset": null, diff --git a/model_zoo/official/nlp/gnmt_v2/config/config.py b/model_zoo/official/nlp/gnmt_v2/config/config.py index 21f0b5f6fb..d87fad2ee4 100644 --- a/model_zoo/official/nlp/gnmt_v2/config/config.py +++ b/model_zoo/official/nlp/gnmt_v2/config/config.py @@ -67,7 +67,6 @@ class GNMTConfig: random_seed (int): Random seed, it can be changed. epochs (int): Epoch number. batch_size (int): Batch size of input dataset. - dataset_schema (str): Path of dataset schema file. pre_train_dataset (str): Path of pre-training dataset file or folder. fine_tune_dataset (str): Path of fine-tune dataset file or folder. test_dataset (str): Path of test dataset file or folder. @@ -126,7 +125,6 @@ class GNMTConfig: def __init__(self, random_seed=50, epochs=6, batch_size=128, - dataset_schema: str = None, pre_train_dataset: str = None, fine_tune_dataset: str = None, test_dataset: str = None, @@ -157,7 +155,6 @@ class GNMTConfig: self.save_graphs = save_graphs self.random_seed = random_seed - self.dataset_schema = dataset_schema self.pre_train_dataset = get_source_list(pre_train_dataset) # type: List[str] self.fine_tune_dataset = get_source_list(fine_tune_dataset) # type: List[str] self.valid_dataset = get_source_list(valid_dataset) # type: List[str] diff --git a/model_zoo/official/nlp/gnmt_v2/config/config_test.json b/model_zoo/official/nlp/gnmt_v2/config/config_test.json index bfc44438c0..0b039b3aec 100644 --- a/model_zoo/official/nlp/gnmt_v2/config/config_test.json +++ b/model_zoo/official/nlp/gnmt_v2/config/config_test.json @@ -3,7 +3,6 @@ "random_seed": 50, "epochs": 6, "batch_size": 128, - "dataset_schema": "/home/workspace/dataset_menu/newstest2014.en.json", "pre_train_dataset": null, "fine_tune_dataset": null, "test_dataset": "/home/workspace/dataset_menu/newstest2014.en.mindrecord", diff --git a/model_zoo/official/nlp/gnmt_v2/eval.py b/model_zoo/official/nlp/gnmt_v2/eval.py index 34814402cf..849116798c 100644 --- a/model_zoo/official/nlp/gnmt_v2/eval.py +++ b/model_zoo/official/nlp/gnmt_v2/eval.py @@ -27,8 +27,6 @@ from src.dataset.tokenizer import Tokenizer parser = argparse.ArgumentParser(description='gnmt') parser.add_argument("--config", type=str, required=True, help="model config json file path.") -parser.add_argument("--dataset_schema_test", type=str, required=True, - help="dataset schema for evaluation.") parser.add_argument("--test_dataset", type=str, required=True, help="test dataset address.") parser.add_argument("--existed_ckpt", type=str, required=True, @@ -63,7 +61,6 @@ if __name__ == '__main__': args, _ = parser.parse_known_args() _check_args(args.config) _config = get_config(args.config) - _config.dataset_schema = args.dataset_schema_test _config.test_dataset = args.test_dataset _config.existed_ckpt = args.existed_ckpt result = infer(_config) diff --git a/model_zoo/official/nlp/gnmt_v2/export.py b/model_zoo/official/nlp/gnmt_v2/export.py new file mode 100644 index 0000000000..822ebee1c5 --- /dev/null +++ b/model_zoo/official/nlp/gnmt_v2/export.py @@ -0,0 +1,97 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""export checkpoint file into air models""" + +import argparse +import numpy as np + +from mindspore import Tensor, context, Parameter +from mindspore.common import dtype as mstype +from mindspore.train.serialization import export + +from config import GNMTConfig +from src.gnmt_model.gnmt import GNMT +from src.gnmt_model.gnmt_for_infer import GNMTInferCell +from src.utils import zero_weight +from src.utils.load_weights import load_infer_weights + +parser = argparse.ArgumentParser(description="gnmt_v2 export") +parser.add_argument("--file_name", type=str, default="gnmt_v2", help="output file name.") +parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format") +parser.add_argument('--infer_config', type=str, required=True, help='gnmt_v2 config file') +parser.add_argument("--existed_ckpt", type=str, required=True, help="existed checkpoint address.") +parser.add_argument('--vocab_file', type=str, required=True, help='vocabulary file') +parser.add_argument("--bpe_codes", type=str, required=True, help="bpe codes to use.") +args = parser.parse_args() + +context.set_context( + mode=context.GRAPH_MODE, + save_graphs=False, + device_target="Ascend", + reserve_class_name_in_scope=False) + + +def get_config(config_file): + tfm_config = GNMTConfig.from_json_file(config_file) + tfm_config.compute_type = mstype.float16 + tfm_config.dtype = mstype.float32 + return tfm_config + + +if __name__ == '__main__': + config = get_config(args.infer_config) + config.existed_ckpt = args.existed_ckpt + vocab = args.vocab_file + bpe_codes = args.bpe_codes + + tfm_model = GNMT(config=config, + is_training=False, + use_one_hot_embeddings=False) + + params = tfm_model.trainable_params() + weights = load_infer_weights(config) + + for param in params: + value = param.data + weights_name = param.name + if weights_name not in weights: + raise ValueError(f"{weights_name} is not found in weights.") + if isinstance(value, Tensor): + if weights_name in weights: + assert weights_name in weights + if isinstance(weights[weights_name], Parameter): + if param.data.dtype == "Float32": + param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) + elif param.data.dtype == "Float16": + param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) + + elif isinstance(weights[weights_name], Tensor): + param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) + elif isinstance(weights[weights_name], np.ndarray): + param.set_data(Tensor(weights[weights_name], config.dtype)) + else: + param.set_data(weights[weights_name]) + else: + print("weight not found in checkpoint: " + weights_name) + param.set_data(zero_weight(value.asnumpy().shape)) + + print(" | Load weights successfully.") + tfm_infer = GNMTInferCell(tfm_model) + tfm_infer.set_train(False) + + source_ids = Tensor(np.ones((config.batch_size, config.seq_length)).astype(np.int32)) + source_mask = Tensor(np.ones((config.batch_size, config.seq_length)).astype(np.int32)) + + export(tfm_infer, source_ids, source_mask, file_name=args.file_name, file_format=args.file_format) diff --git a/model_zoo/official/nlp/gnmt_v2/mindspore_hub_conf.py b/model_zoo/official/nlp/gnmt_v2/mindspore_hub_conf.py new file mode 100644 index 0000000000..c85ba28580 --- /dev/null +++ b/model_zoo/official/nlp/gnmt_v2/mindspore_hub_conf.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""hub config.""" +import mindspore.common.dtype as mstype + +from config import GNMTConfig +from src.gnmt_model import GNMTNetworkWithLoss, GNMT + + +def get_config(config): + config = GNMTConfig.from_json_file(config) + config.compute_type = mstype.float16 + config.dtype = mstype.float32 + return config + + +def create_network(name, *args, **kwargs): + """create gnmt network.""" + if name == "gnmt": + if "config" in kwargs: + config = get_config(kwargs["config"]) + else: + raise NotImplementedError(f"Please make sure the configuration file path is correct") + is_training = kwargs.get("is_training", False) + if is_training: + return GNMTNetworkWithLoss(config, is_training=is_training, *args) + return GNMT(config, *args) + raise NotImplementedError(f"{name} is not implemented in the repo") diff --git a/model_zoo/official/nlp/gnmt_v2/requirements.txt b/model_zoo/official/nlp/gnmt_v2/requirements.txt index 8c457a4f72..ed71712362 100644 --- a/model_zoo/official/nlp/gnmt_v2/requirements.txt +++ b/model_zoo/official/nlp/gnmt_v2/requirements.txt @@ -1,6 +1,4 @@ -nltk -jieba numpy subword-nmt==0.3.7 -sacrebleu==1.2.10 -sacremoses==0.0.19 +sacrebleu==1.4.14 +sacremoses==0.0.35 diff --git a/model_zoo/official/nlp/gnmt_v2/scripts/run_distributed_train_ascend.sh b/model_zoo/official/nlp/gnmt_v2/scripts/run_distributed_train_ascend.sh index 34028a19bb..7af838480b 100644 --- a/model_zoo/official/nlp/gnmt_v2/scripts/run_distributed_train_ascend.sh +++ b/model_zoo/official/nlp/gnmt_v2/scripts/run_distributed_train_ascend.sh @@ -16,18 +16,16 @@ echo "==============================================================================================================" echo "Please run the script as: " -echo "sh run_distributed_train_ascend.sh RANK_TABLE_ADDR DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET" +echo "sh run_distributed_train_ascend.sh RANK_TABLE_ADDR PRE_TRAIN_DATASET" echo "for example:" echo "sh run_distributed_train_ascend.sh \ /home/workspace/rank_table_8p.json \ - /home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json \ /home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord" echo "It is better to use absolute path." echo "==============================================================================================================" RANK_TABLE_ADDR=$1 -DATASET_SCHEMA_TRAIN=$2 -PRE_TRAIN_DATASET=$3 +PRE_TRAIN_DATASET=$2 current_exec_path=$(pwd) echo ${current_exec_path} @@ -49,10 +47,9 @@ do cp -r ../../config . export RANK_ID=$i export DEVICE_ID=$i - python ../../train.py \ - --config=${current_exec_path}/device${i}/config/config.json \ - --dataset_schema_train=$DATASET_SCHEMA_TRAIN \ - --pre_train_dataset=$PRE_TRAIN_DATASET > log_gnmt_network${i}.log 2>&1 & + python ../../train.py \ + --config=${current_exec_path}/device${i}/config/config.json \ + --pre_train_dataset=$PRE_TRAIN_DATASET > log_gnmt_network${i}.log 2>&1 & cd ${current_exec_path} || exit done cd ${current_exec_path} || exit diff --git a/model_zoo/official/nlp/gnmt_v2/scripts/run_standalone_eval_ascend.sh b/model_zoo/official/nlp/gnmt_v2/scripts/run_standalone_eval_ascend.sh index 5b7f56e5bc..5929112808 100644 --- a/model_zoo/official/nlp/gnmt_v2/scripts/run_standalone_eval_ascend.sh +++ b/model_zoo/official/nlp/gnmt_v2/scripts/run_standalone_eval_ascend.sh @@ -16,11 +16,10 @@ echo "==============================================================================================================" echo "Please run the script as: " -echo "sh run_standalone_eval_ascend.sh DATASET_SCHEMA_TEST TEST_DATASET EXISTED_CKPT_PATH \ +echo "sh run_standalone_eval_ascend.sh TEST_DATASET EXISTED_CKPT_PATH \ VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET" echo "for example:" echo "sh run_standalone_eval_ascend.sh \ - /home/workspace/dataset_menu/newstest2014.en.json \ /home/workspace/dataset_menu/newstest2014.en.mindrecord \ /home/workspace/gnmt_v2/gnmt-6_3452.ckpt \ /home/workspace/wmt16_de_en/vocab.bpe.32000 \ @@ -29,19 +28,16 @@ echo "sh run_standalone_eval_ascend.sh \ echo "It is better to use absolute path." echo "==============================================================================================================" -DATASET_SCHEMA_TEST=$1 -TEST_DATASET=$2 -EXISTED_CKPT_PATH=$3 -VOCAB_ADDR=$4 -BPE_CODE_ADDR=$5 -TEST_TARGET=$6 +TEST_DATASET=$1 +EXISTED_CKPT_PATH=$2 +VOCAB_ADDR=$3 +BPE_CODE_ADDR=$4 +TEST_TARGET=$5 current_exec_path=$(pwd) echo ${current_exec_path} -export DEVICE_NUM=1 -export RANK_ID=0 -export RANK_SIZE=1 + export GLOG_v=2 if [ -d "eval" ]; @@ -57,7 +53,6 @@ echo "start for evaluation" env > env.log python eval.py \ --config=${current_exec_path}/eval/config/config_test.json \ - --dataset_schema_test=$DATASET_SCHEMA_TEST \ --test_dataset=$TEST_DATASET \ --existed_ckpt=$EXISTED_CKPT_PATH \ --vocab=$VOCAB_ADDR \ diff --git a/model_zoo/official/nlp/gnmt_v2/scripts/run_standalone_train_ascend.sh b/model_zoo/official/nlp/gnmt_v2/scripts/run_standalone_train_ascend.sh index 6ed8a6dac8..7aad9beea9 100644 --- a/model_zoo/official/nlp/gnmt_v2/scripts/run_standalone_train_ascend.sh +++ b/model_zoo/official/nlp/gnmt_v2/scripts/run_standalone_train_ascend.sh @@ -16,21 +16,17 @@ echo "==============================================================================================================" echo "Please run the script as: " -echo "sh run_standalone_train_ascend.sh DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET" +echo "sh run_standalone_train_ascend.sh PRE_TRAIN_DATASET" echo "for example:" echo "sh run_standalone_train_ascend.sh \ - /home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json \ /home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord" echo "It is better to use absolute path." echo "==============================================================================================================" -DATASET_SCHEMA_TRAIN=$1 -PRE_TRAIN_DATASET=$2 +PRE_TRAIN_DATASET=$1 -export DEVICE_NUM=1 -export RANK_ID=0 -export RANK_SIZE=1 export GLOG_v=2 + current_exec_path=$(pwd) echo ${current_exec_path} if [ -d "train" ]; @@ -45,7 +41,6 @@ cd ./train || exit echo "start for training" env > env.log python train.py \ - --config=${current_exec_path}/train/config/config.json \ - --dataset_schema_train=$DATASET_SCHEMA_TRAIN \ - --pre_train_dataset=$PRE_TRAIN_DATASET > log_gnmt_network.log 2>&1 & + --config=${current_exec_path}/train/config/config.json \ + --pre_train_dataset=$PRE_TRAIN_DATASET > log_gnmt_network.log 2>&1 & cd .. diff --git a/model_zoo/official/nlp/gnmt_v2/src/dataset/load_dataset.py b/model_zoo/official/nlp/gnmt_v2/src/dataset/load_dataset.py index 681f9040e5..8af9fe84cc 100644 --- a/model_zoo/official/nlp/gnmt_v2/src/dataset/load_dataset.py +++ b/model_zoo/official/nlp/gnmt_v2/src/dataset/load_dataset.py @@ -13,13 +13,12 @@ # limitations under the License. # ============================================================================ """Dataset loader to feed into model.""" -import os import mindspore.common.dtype as mstype import mindspore.dataset as ds import mindspore.dataset.transforms.c_transforms as deC -def _load_dataset(input_files, schema_file, batch_size, sink_mode=False, +def _load_dataset(input_files, batch_size, sink_mode=False, rank_size=1, rank_id=0, shuffle=True, drop_remainder=True, is_translate=False): """ @@ -27,7 +26,6 @@ def _load_dataset(input_files, schema_file, batch_size, sink_mode=False, Args: input_files (list): Data files. - schema_file (str): Schema file path. batch_size (int): Batch size. sink_mode (bool): Whether enable sink mode. rank_size (int): Rank size. @@ -42,12 +40,6 @@ def _load_dataset(input_files, schema_file, batch_size, sink_mode=False, if not input_files: raise FileNotFoundError("Require at least one dataset.") - if not (schema_file and - os.path.exists(schema_file) - and os.path.isfile(schema_file) - and os.path.basename(schema_file).endswith(".json")): - raise FileNotFoundError("`dataset_schema` must be a existed json file.") - if not isinstance(sink_mode, bool): raise ValueError("`sink` must be type of bool.") @@ -116,14 +108,13 @@ def _load_dataset(input_files, schema_file, batch_size, sink_mode=False, return data_set -def load_dataset(data_files: list, schema: str, batch_size: int, sink_mode: bool, +def load_dataset(data_files: list, batch_size: int, sink_mode: bool, rank_size: int = 1, rank_id: int = 0, shuffle=True, drop_remainder=True, is_translate=False): """ Load dataset. Args: data_files (list): Data files. - schema (str): Schema file path. batch_size (int): Batch size. sink_mode (bool): Whether enable sink mode. rank_size (int): Rank size. @@ -133,5 +124,5 @@ def load_dataset(data_files: list, schema: str, batch_size: int, sink_mode: bool Returns: Dataset, dataset instance. """ - return _load_dataset(data_files, schema, batch_size, sink_mode, rank_size, rank_id, shuffle=shuffle, + return _load_dataset(data_files, batch_size, sink_mode, rank_size, rank_id, shuffle=shuffle, drop_remainder=drop_remainder, is_translate=is_translate) diff --git a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/attention.py b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/attention.py index cdc32943ed..68fa456dab 100644 --- a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/attention.py +++ b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/attention.py @@ -38,7 +38,7 @@ class BahdanauAttention(nn.Cell): initializer_range: range for uniform initializer parameters. Returns: - Tensor, shape (N, T, D). + Tensor, shape (t_q_length, N, D). """ def __init__(self, @@ -93,108 +93,107 @@ class BahdanauAttention(nn.Cell): Construct attention block. Args: - query (Tensor): Shape (t_q, N, D). - keys (Tensor): Shape (t_k, N, D). - attention_mask: Shape(N, t_k). + query (Tensor): Shape (t_q_length, N, D). + keys (Tensor): Shape (t_k_length, N, D). + attention_mask: Shape(N, t_k_length). Returns: - Tensor, shape (N, t_q, D). + Tensor, shape (t_q_length, N, D). """ - # (t_k, N, D) -> (N, t_k, D). + # (t_k_length, N, D) -> (N, t_k_length, D). keys = self.transpose(keys, self.transpose_orders) - # (t_q, N, D) -> (N, t_q, D). - query = self.transpose(query, self.transpose_orders) + # (t_q_length, N, D) -> (N, t_q_length, D). + query_trans = self.transpose(query, self.transpose_orders) - query_shape = self.shape_op(query) - b = query_shape[0] - t_q = query_shape[1] - t_k = self.shape_op(keys)[1] + query_shape = self.shape_op(query_trans) + batch_size = query_shape[0] + t_q_length = query_shape[1] + t_k_length = self.shape_op(keys)[1] - # (N, t_q, D) - query = self.reshape(query, (b * t_q, self.query_size)) + # (N, t_q_length, D) + query_trans = self.reshape(query_trans, (batch_size * t_q_length, self.query_size)) if self.is_training: - query = self.cast(query, mstype.float16) - processed_query = self.linear_q(query) + query_trans = self.cast(query_trans, mstype.float16) + processed_query = self.linear_q(query_trans) if self.is_trining: processed_query = self.cast(processed_query, mstype.float32) - processed_query = self.reshape(processed_query, (b, t_q, self.num_units)) - # (N, t_k, D) - keys = self.reshape(keys, (b * t_k, self.key_size)) + processed_query = self.reshape(processed_query, (batch_size, t_q_length, self.num_units)) + # (N, t_k_length, D) + keys = self.reshape(keys, (batch_size * t_k_length, self.key_size)) if self.is_training: keys = self.cast(keys, mstype.float16) processed_key = self.linear_k(keys) if self.is_trining: processed_key = self.cast(processed_key, mstype.float32) - processed_key = self.reshape(processed_key, (b, t_k, self.num_units)) + processed_key = self.reshape(processed_key, (batch_size, t_k_length, self.num_units)) - # scores: (N , T_q, T_k) - scores = self.calc_score(processed_query, processed_key) - # attention_mask: (N, T_k) + # scores: (N, t_q_length, t_k_length) + scores = self.obtain_score(processed_query, processed_key) + # attention_mask: (N, t_k_length) mask = attention_mask - # [N, 1] if mask is not None: mask = 1.0 - mask - mask = self.tile(self.expand(mask, 1), (1, t_q, 1)) + mask = self.tile(self.expand(mask, 1), (1, t_q_length, 1)) scores += mask * (-INF) - # [b, t_q, t_k] - scores_normalized = self.softmax(scores) + # [batch_size, t_q_length, t_k_length] + scores_softmax = self.softmax(scores) - keys = self.reshape(keys, (b, t_k, self.key_size)) + keys = self.reshape(keys, (batch_size, t_k_length, self.key_size)) if self.is_training: keys = self.cast(keys, mstype.float16) - scores_normalized_fp16 = self.cast(scores_normalized, mstype.float16) + scores_softmax_fp16 = self.cast(scores_softmax, mstype.float16) else: - scores_normalized_fp16 = scores_normalized + scores_softmax_fp16 = scores_softmax - # (b, t_q, n) - context_attention = self.batchMatmul(scores_normalized_fp16, keys) - # [t_q,b,D] + # (b, t_q_length, D) + context_attention = self.batchMatmul(scores_softmax_fp16, keys) + # [t_q_length, b, D] context_attention = self.transpose(context_attention, self.transpose_orders) if self.is_training: context_attention = self.cast(context_attention, mstype.float32) - return context_attention, scores_normalized + return context_attention, scores_softmax - def calc_score(self, att_query, att_keys): + def obtain_score(self, attention_q, attention_k): """ Calculate Bahdanau score Args: - att_query: (N, T_q, D). - att_keys: (N, T_k, D). + attention_q: (batch_size, t_q_length, D). + attention_k: (batch_size, t_k_length, D). returns: - scores: (N, T_q, T_k). + scores: (batch_size, t_q_length, t_k_length). """ - b, t_k, n = self.shape_op(att_keys) - t_q = self.shape_op(att_query)[1] - # (b, t_q, t_k, n) - att_query = self.tile(self.expand(att_query, 2), (1, 1, t_k, 1)) - att_keys = self.tile(self.expand(att_keys, 1), (1, t_q, 1, 1)) - # (b, t_q, t_k, n) - sum_qk = att_query + att_keys + batch_size, t_k_length, D = self.shape_op(attention_k) + t_q_length = self.shape_op(attention_q)[1] + # (batch_size, t_q_length, t_k_length, n) + attention_q = self.tile(self.expand(attention_q, 2), (1, 1, t_k_length, 1)) + attention_k = self.tile(self.expand(attention_k, 1), (1, t_q_length, 1, 1)) + # (batch_size, t_q_length, t_k_length, n) + sum_qk_add = attention_q + attention_k if self.normalize: - # (b, t_q, t_k, n) - sum_qk = sum_qk + self.normalize_bias - linear_att = self.linear_att / self.norm(self.linear_att) - linear_att = self.cast(linear_att, mstype.float32) - linear_att = self.mul(linear_att, self.normalize_scalar) + # (batch_size, t_q_length, t_k_length, n) + sum_qk_add = sum_qk_add + self.normalize_bias + linear_att_norm = self.linear_att / self.norm(self.linear_att) + linear_att_norm = self.cast(linear_att_norm, mstype.float32) + linear_att_norm = self.mul(linear_att_norm, self.normalize_scalar) else: - linear_att = self.linear_att + linear_att_norm = self.linear_att - linear_att = self.expand(linear_att, -1) - sum_qk = self.reshape(sum_qk, (-1, n)) + linear_att_norm = self.expand(linear_att_norm, -1) + sum_qk_add = self.reshape(sum_qk_add, (-1, D)) - tanh_sum_qk = self.tanh(sum_qk) + tanh_sum_qk = self.tanh(sum_qk_add) if self.is_training: - linear_att = self.cast(linear_att, mstype.float16) + linear_att_norm = self.cast(linear_att_norm, mstype.float16) tanh_sum_qk = self.cast(tanh_sum_qk, mstype.float16) - out = self.matmul(tanh_sum_qk, linear_att) + scores_out = self.matmul(tanh_sum_qk, linear_att_norm) - # (b, t_q, t_k) - out = self.reshape(out, (b, t_q, t_k)) + # (N, t_q_length, t_k_length) + scores_out = self.reshape(scores_out, (batch_size, t_q_length, t_k_length)) if self.is_training: - out = self.cast(out, mstype.float32) - return out + scores_out = self.cast(scores_out, mstype.float32) + return scores_out diff --git a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/beam_search.py b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/beam_search.py index 8314538d9f..fceb987099 100644 --- a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/beam_search.py +++ b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/beam_search.py @@ -172,6 +172,7 @@ class BeamSearchDecoder(nn.Cell): max_decode_length=64, sos_id=2, eos_id=3, + is_using_while=False, compute_type=mstype.float32): super(BeamSearchDecoder, self).__init__() @@ -185,6 +186,7 @@ class BeamSearchDecoder(nn.Cell): self.cov_penalty_factor = cov_penalty_factor self.max_decode_length = max_decode_length self.decoder = decoder + self.is_using_while = is_using_while self.add = P.TensorAdd() self.expand = P.ExpandDims() @@ -214,9 +216,13 @@ class BeamSearchDecoder(nn.Cell): self.concat = P.Concat(axis=-1) self.gather_nd = P.GatherNd() - self.start = Tensor(0, dtype=mstype.int32) self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32) - self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length], sos_id), mstype.int32) + if self.is_using_while: + self.start = Tensor(0, dtype=mstype.int32) + self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length], sos_id), + mstype.int32) + else: + self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32) init_scores = np.tile(np.array([[0.] + [-INF] * (beam_width - 1)]), [batch_size, 1]) self.init_scores = Tensor(init_scores, mstype.float32) @@ -270,7 +276,7 @@ class BeamSearchDecoder(nn.Cell): enc_states (Tensor): with shape (batch_size * beam_width, T, D). enc_attention_mask (Tensor): with shape (batch_size * beam_width, T). state_log_probs (Tensor): with shape (batch_size, beam_width). - state_seq (Tensor): with shape (batch_size, beam_width, max_decoder_length). + state_seq (Tensor): with shape (batch_size, beam_width, m). state_length (Tensor): with shape (batch_size, beam_width). idx (Tensor): with shape (). decoder_hidden_state (Tensor): with shape (decoder_layer_num, 2, batch_size * beam_width, D). @@ -360,10 +366,13 @@ class BeamSearchDecoder(nn.Cell): self.hidden_size)) # update state_seq - state_seq_new = self.cast(seq, mstype.float32) - word_indices_fp32 = self.cast(word_indices, mstype.float32) - state_seq_new[:, :, idx] = word_indices_fp32 - state_seq = self.cast(state_seq_new, mstype.int32) + if self.is_using_while: + state_seq_new = self.cast(seq, mstype.float32) + word_indices_fp32 = self.cast(word_indices, mstype.float32) + state_seq_new[:, :, idx] = word_indices_fp32 + state_seq = self.cast(state_seq_new, mstype.int32) + else: + state_seq = self.concat((seq, self.expand(word_indices, -1))) cur_input_ids = self.reshape(word_indices, (-1, 1)) state_log_probs = topk_scores @@ -392,14 +401,21 @@ class BeamSearchDecoder(nn.Cell): decoder_hidden_state = self.decoder_hidden_state accu_attn_scores = self.accu_attn_scores - idx = self.start + 1 - ends = self.start + self.max_decode_length + 1 - while idx < ends: - cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \ - state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs, - state_seq, state_length, idx, decoder_hidden_state, accu_attn_scores, - state_finished) - idx = idx + 1 + if not self.is_using_while: + for _ in range(self.max_decode_length + 1): + cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \ + state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs, + state_seq, state_length, None, decoder_hidden_state, accu_attn_scores, + state_finished) + else: + idx = self.start + 1 + ends = self.start + self.max_decode_length + 1 + while idx < ends: + cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \ + state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs, + state_seq, state_length, idx, decoder_hidden_state, accu_attn_scores, + state_finished) + idx = idx + 1 # add length penalty scores penalty_len = self.length_penalty(state_length) @@ -416,6 +432,9 @@ class BeamSearchDecoder(nn.Cell): gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1))) # sort sequence and attention scores predicted_ids = self.gather_nd(state_seq, gather_indices) - predicted_ids = predicted_ids[:, 0:1, :self.max_decode_length] + if not self.is_using_while: + predicted_ids = predicted_ids[:, 0:1, 1:(self.max_decode_length + 1)] + else: + predicted_ids = predicted_ids[:, 0:1, :self.max_decode_length] return predicted_ids diff --git a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/gnmt_for_infer.py b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/gnmt_for_infer.py index 0c157045d3..2c62caf517 100644 --- a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/gnmt_for_infer.py +++ b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/gnmt_for_infer.py @@ -23,7 +23,6 @@ from mindspore.common.tensor import Tensor from mindspore.ops import operations as P from mindspore import context, Parameter from mindspore.train.model import Model -from mindspore.train.serialization import load_checkpoint from src.dataset import load_dataset from .gnmt import GNMT @@ -37,18 +36,6 @@ context.set_context( reserve_class_name_in_scope=False) -def get_weight_and_variable(model_path, params): - print("model path is {}".format(model_path)) - ms_ckpt = load_checkpoint(model_path) - with open("variable.txt", "w") as f: - for msname in ms_ckpt: - f.write(msname + "\n") - with open("weights.txt", "w") as f: - for param in params: - name = param.name - f.write(name + "\n") - - class GNMTInferCell(nn.Cell): """ Encapsulation class of GNMT network infer. @@ -92,38 +79,31 @@ def gnmt_infer(config, dataset): use_one_hot_embeddings=False) params = tfm_model.trainable_params() - get_weight_and_variable(config.existed_ckpt, params) weights = load_infer_weights(config) - for param in params: value = param.data - name = param.name - if name not in weights: - raise ValueError(f"{name} is not found in weights.") - with open("weight_after_deal.txt", "a+") as f: - weights_name = name - f.write(weights_name) - f.write("\n") - if isinstance(value, Tensor): - print(name, value.asnumpy().shape) - if weights_name in weights: - assert weights_name in weights - if isinstance(weights[weights_name], Parameter): - if param.data.dtype == "Float32": - param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) - elif param.data.dtype == "Float16": - param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) - - elif isinstance(weights[weights_name], Tensor): - param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) - elif isinstance(weights[weights_name], np.ndarray): - param.set_data(Tensor(weights[weights_name], config.dtype)) - else: - param.set_data(weights[weights_name]) + weights_name = param.name + if weights_name not in weights: + raise ValueError(f"{weights_name} is not found in weights.") + if isinstance(value, Tensor): + if weights_name in weights: + assert weights_name in weights + if isinstance(weights[weights_name], Parameter): + if param.data.dtype == "Float32": + param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) + elif param.data.dtype == "Float16": + param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) + + elif isinstance(weights[weights_name], Tensor): + param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) + elif isinstance(weights[weights_name], np.ndarray): + param.set_data(Tensor(weights[weights_name], config.dtype)) else: - print("weight not found in checkpoint: " + weights_name) - param.set_data(zero_weight(value.asnumpy().shape)) - f.close() + param.set_data(weights[weights_name]) + else: + print("weight not found in checkpoint: " + weights_name) + param.set_data(zero_weight(value.asnumpy().shape)) + print(" | Load weights successfully.") tfm_infer = GNMTInferCell(tfm_model) model = Model(tfm_infer) @@ -187,7 +167,6 @@ def infer(config): list, result with """ eval_dataset = load_dataset(data_files=config.test_dataset, - schema=config.dataset_schema, batch_size=config.batch_size, sink_mode=config.dataset_sink_mode, drop_remainder=False, diff --git a/model_zoo/official/nlp/gnmt_v2/src/utils/load_weights.py b/model_zoo/official/nlp/gnmt_v2/src/utils/load_weights.py index a29c9b4d12..9add9b5def 100644 --- a/model_zoo/official/nlp/gnmt_v2/src/utils/load_weights.py +++ b/model_zoo/official/nlp/gnmt_v2/src/utils/load_weights.py @@ -37,36 +37,26 @@ def load_infer_weights(config): ms_ckpt = load_checkpoint(model_path) is_npz = False weights = {} - with open("variable_after_deal.txt", "w") as f: - for param_name in ms_ckpt: - infer_name = param_name.replace("gnmt.gnmt.", "") - if infer_name.startswith("embedding_lookup."): - if is_npz: - weights[infer_name] = ms_ckpt[param_name] - else: - weights[infer_name] = ms_ckpt[param_name].data.asnumpy() - f.write(infer_name) - f.write("\n") - infer_name = "beam_decoder.decoder." + infer_name - if is_npz: - weights[infer_name] = ms_ckpt[param_name] - else: - weights[infer_name] = ms_ckpt[param_name].data.asnumpy() - f.write(infer_name) - f.write("\n") - continue - - elif not infer_name.startswith("gnmt_encoder"): - if infer_name.startswith("gnmt_decoder."): - infer_name = infer_name.replace("gnmt_decoder.", "decoder.") - infer_name = "beam_decoder.decoder." + infer_name - + for param_name in ms_ckpt: + infer_name = param_name.replace("gnmt.gnmt.", "") + if infer_name.startswith("embedding_lookup."): if is_npz: weights[infer_name] = ms_ckpt[param_name] else: weights[infer_name] = ms_ckpt[param_name].data.asnumpy() - f.write(infer_name) - f.write("\n") - - f.close() + infer_name = "beam_decoder.decoder." + infer_name + if is_npz: + weights[infer_name] = ms_ckpt[param_name] + else: + weights[infer_name] = ms_ckpt[param_name].data.asnumpy() + continue + elif not infer_name.startswith("gnmt_encoder"): + if infer_name.startswith("gnmt_decoder."): + infer_name = infer_name.replace("gnmt_decoder.", "decoder.") + infer_name = "beam_decoder.decoder." + infer_name + + if is_npz: + weights[infer_name] = ms_ckpt[param_name] + else: + weights[infer_name] = ms_ckpt[param_name].data.asnumpy() return weights diff --git a/model_zoo/official/nlp/gnmt_v2/train.py b/model_zoo/official/nlp/gnmt_v2/train.py index 2af1f64ba9..17fde1d3b7 100644 --- a/model_zoo/official/nlp/gnmt_v2/train.py +++ b/model_zoo/official/nlp/gnmt_v2/train.py @@ -40,7 +40,6 @@ from src.utils.optimizer import Adam parser = argparse.ArgumentParser(description='GNMT train entry point.') parser.add_argument("--config", type=str, required=True, help="model config json file path.") -parser.add_argument("--dataset_schema_train", type=str, required=True, help="dataset schema for train.") parser.add_argument("--pre_train_dataset", type=str, required=True, help="pre-train dataset address.") device_id = os.getenv('DEVICE_ID', None) @@ -273,21 +272,20 @@ def train_parallel(config: GNMTConfig): pre_train_dataset = load_dataset( data_files=config.pre_train_dataset, - schema=config.dataset_schema, batch_size=config.batch_size, sink_mode=config.dataset_sink_mode, rank_size=MultiAscend.get_group_size(), rank_id=MultiAscend.get_rank() ) if config.pre_train_dataset else None fine_tune_dataset = load_dataset( - data_files=config.fine_tune_dataset, schema=config.dataset_schema, + data_files=config.fine_tune_dataset, batch_size=config.batch_size, sink_mode=config.dataset_sink_mode, rank_size=MultiAscend.get_group_size(), rank_id=MultiAscend.get_rank() ) if config.fine_tune_dataset else None test_dataset = load_dataset( - data_files=config.test_dataset, schema=config.dataset_schema, + data_files=config.test_dataset, batch_size=config.batch_size, sink_mode=config.dataset_sink_mode, rank_size=MultiAscend.get_group_size(), @@ -310,15 +308,12 @@ def train_single(config: GNMTConfig): print(" | Starting training on single device.") pre_train_dataset = load_dataset(data_files=config.pre_train_dataset, - schema=config.dataset_schema, batch_size=config.batch_size, sink_mode=config.dataset_sink_mode) if config.pre_train_dataset else None fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset, - schema=config.dataset_schema, batch_size=config.batch_size, sink_mode=config.dataset_sink_mode) if config.fine_tune_dataset else None test_dataset = load_dataset(data_files=config.test_dataset, - schema=config.dataset_schema, batch_size=config.batch_size, sink_mode=config.dataset_sink_mode) if config.test_dataset else None @@ -341,7 +336,6 @@ if __name__ == '__main__': args, _ = parser.parse_known_args() _check_args(args.config) _config = get_config(args.config) - _config.dataset_schema = args.dataset_schema_train _config.pre_train_dataset = args.pre_train_dataset set_seed(_config.random_seed) if _rank_size is not None and int(_rank_size) > 1: diff --git a/tests/st/model_zoo_tests/gnmt_v2/test_gnmt_v2.py b/tests/st/model_zoo_tests/gnmt_v2/test_gnmt_v2.py index 742c5b90f3..86e47f9a47 100644 --- a/tests/st/model_zoo_tests/gnmt_v2/test_gnmt_v2.py +++ b/tests/st/model_zoo_tests/gnmt_v2/test_gnmt_v2.py @@ -31,15 +31,11 @@ parser = argparse.ArgumentParser(description='GNMT train and eval.') # train parser.add_argument("--config_train", type=str, required=True, help="model config json file path.") -parser.add_argument("--dataset_schema_train", type=str, required=True, - help="dataset schema for train.") parser.add_argument("--pre_train_dataset", type=str, required=True, help="pre-train dataset address.") # eval parser.add_argument("--config_test", type=str, required=True, help="model config json file path.") -parser.add_argument("--dataset_schema_test", type=str, required=True, - help="dataset schema for evaluation.") parser.add_argument("--test_dataset", type=str, required=True, help="test dataset address.") parser.add_argument("--existed_ckpt", type=str, required=True, @@ -77,7 +73,6 @@ if __name__ == '__main__': # train _check_args(args.config_train) _config_train = get_config(args.config_train) - _config_train.dataset_schema = args.dataset_schema_train _config_train.pre_train_dataset = args.pre_train_dataset set_seed(_config_train.random_seed) assert _rank_size is not None and int(_rank_size) > 1 @@ -86,7 +81,6 @@ if __name__ == '__main__': # eval _check_args(args.config_test) _config_test = get_config(args.config_test) - _config_test.dataset_schema = args.dataset_schema_test _config_test.test_dataset = args.test_dataset _config_test.existed_ckpt = args.existed_ckpt result = infer(_config_test) diff --git a/tests/st/model_zoo_tests/gnmt_v2/test_gnmt_v2.sh b/tests/st/model_zoo_tests/gnmt_v2/test_gnmt_v2.sh index 8ca8685f0c..a793be8985 100644 --- a/tests/st/model_zoo_tests/gnmt_v2/test_gnmt_v2.sh +++ b/tests/st/model_zoo_tests/gnmt_v2/test_gnmt_v2.sh @@ -16,19 +16,15 @@ echo "==============================================================================================================" echo "Please run the scipt as: " -echo "sh run_distributed_train_ascend.sh \ - GNMT_ADDR RANK_TABLE_ADDR \ - DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET \ - DATASET_SCHEMA_TEST TEST_DATASET EXISTED_CKPT_PATH \ +echo "sh test_gnmt_v2.sh \ + GNMT_ADDR RANK_TABLE_ADDR PRE_TRAIN_DATASET TEST_DATASET EXISTED_CKPT_PATH \ VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET" echo "for example:" -echo "sh run_distributed_train_ascend.sh \ +echo "sh test_gnmt_v2.sh \ /home/workspace/gnmt_v2 \ /home/workspace/rank_table_8p.json \ - /home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json \ - /home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.tfrecord-001-of-001 \ - /home/workspace/dataset_menu/newstest2014.en.json \ - /home/workspace/dataset_menu/newstest2014.en.tfrecord-001-of-001 \ + /home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord \ + /home/workspace/dataset_menu/newstest2014.en.mindrecord \ /home/workspace/gnmt_v2/gnmt-6_3452.ckpt \ /home/workspace/wmt16_de_en/vocab.bpe.32000 \ /home/workspace/wmt16_de_en/bpe.32000 \ @@ -39,15 +35,13 @@ echo "========================================================================== GNMT_ADDR=$1 RANK_TABLE_ADDR=$2 # train dataset addr -DATASET_SCHEMA_TRAIN=$3 -PRE_TRAIN_DATASET=$4 +PRE_TRAIN_DATASET=$3 # eval dataset addr -DATASET_SCHEMA_TEST=$5 -TEST_DATASET=$6 -EXISTED_CKPT_PATH=$7 -VOCAB_ADDR=$8 -BPE_CODE_ADDR=$9 -TEST_TARGET=${10} +TEST_DATASET=$4 +EXISTED_CKPT_PATH=$5 +VOCAB_ADDR=$6 +BPE_CODE_ADDR=$7 +TEST_TARGET=$8 current_exec_path=$(pwd) echo ${current_exec_path} @@ -70,17 +64,15 @@ do cp -r ${GNMT_ADDR}/config . export RANK_ID=$i export DEVICE_ID=$i - python test_gnmt_v2.py \ - --config_train=${GNMT_ADDR}/config/config.json \ - --dataset_schema_train=$DATASET_SCHEMA_TRAIN \ - --pre_train_dataset=$PRE_TRAIN_DATASET \ - --config_test=${GNMT_ADDR}/config/config_test.json \ - --dataset_schema_test=$DATASET_SCHEMA_TEST \ - --test_dataset=$TEST_DATASET \ - --existed_ckpt=$EXISTED_CKPT_PATH \ - --vocab=$VOCAB_ADDR \ - --bpe_codes=$BPE_CODE_ADDR \ - --test_tgt=$TEST_TARGET > log_gnmt_network${i}.log 2>&1 & + python test_gnmt_v2.py \ + --config_train=${GNMT_ADDR}/config/config.json \ + --pre_train_dataset=$PRE_TRAIN_DATASET \ + --config_test=${GNMT_ADDR}/config/config_test.json \ + --test_dataset=$TEST_DATASET \ + --existed_ckpt=$EXISTED_CKPT_PATH \ + --vocab=$VOCAB_ADDR \ + --bpe_codes=$BPE_CODE_ADDR \ + --test_tgt=$TEST_TARGET > log_gnmt_network${i}.log 2>&1 & cd ${current_exec_path} || exit done cd ${current_exec_path} || exit