| @@ -58,8 +58,8 @@ Note that you can run the scripts based on the dataset mentioned in original pap | |||||
| ```txt | ```txt | ||||
| numpy | numpy | ||||
| sacrebleu==1.2.10 | |||||
| sacremoses==0.0.19 | |||||
| sacrebleu==1.4.14 | |||||
| sacremoses==0.0.35 | |||||
| subword_nmt==0.3.7 | subword_nmt==0.3.7 | ||||
| ``` | ``` | ||||
| @@ -77,15 +77,15 @@ After dataset preparation, you can start training and evaluation as follows: | |||||
| ```bash | ```bash | ||||
| # run training example | # run training example | ||||
| cd ./scripts | cd ./scripts | ||||
| sh run_standalone_train_ascend.sh DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET | |||||
| sh run_standalone_train_ascend.sh PRE_TRAIN_DATASET | |||||
| # run distributed training example | # run distributed training example | ||||
| cd ./scripts | cd ./scripts | ||||
| sh run_distributed_train_ascend.sh RANK_TABLE_ADDR DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET | |||||
| sh run_distributed_train_ascend.sh RANK_TABLE_ADDR PRE_TRAIN_DATASET | |||||
| # run evaluation example | # run evaluation example | ||||
| cd ./scripts | cd ./scripts | ||||
| sh run_standalone_eval_ascend.sh DATASET_SCHEMA_TEST TEST_DATASET EXISTED_CKPT_PATH \ | |||||
| sh run_standalone_eval_ascend.sh TEST_DATASET EXISTED_CKPT_PATH \ | |||||
| VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET | VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET | ||||
| ``` | ``` | ||||
| @@ -187,7 +187,7 @@ For more configuration details, please refer the script `config/config.py` file. | |||||
| ## Training Process | ## Training Process | ||||
| For a pre-trained model, configure the following options in the `scripts/run_standalone_train_ascend.json` file: | |||||
| For a pre-trained model, configure the following options in the `config/config.json` file: | |||||
| - Select an optimizer ('momentum/adam/lamb' is available). | - Select an optimizer ('momentum/adam/lamb' is available). | ||||
| - Specify `ckpt_prefix` and `ckpt_path` in `checkpoint_path` to save the model file. | - Specify `ckpt_prefix` and `ckpt_path` in `checkpoint_path` to save the model file. | ||||
| @@ -198,17 +198,17 @@ Start task training on a single device and run the shell script `scripts/run_sta | |||||
| ```bash | ```bash | ||||
| cd ./scripts | cd ./scripts | ||||
| sh run_standalone_train_ascend.sh DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET | |||||
| sh run_standalone_train_ascend.sh PRE_TRAIN_DATASET | |||||
| ``` | ``` | ||||
| In this script, the `DATASET_SCHEMA_TRAIN` and `PRE_TRAIN_DATASET` are the dataset schema and dataset address. | |||||
| In this script, the `PRE_TRAIN_DATASET` is the dataset address. | |||||
| Run `scripts/run_distributed_train_ascend.sh` for distributed training of GNMTv2 model. | Run `scripts/run_distributed_train_ascend.sh` for distributed training of GNMTv2 model. | ||||
| Task training on multiple devices and run the following command in bash to be executed in `scripts/`.: | Task training on multiple devices and run the following command in bash to be executed in `scripts/`.: | ||||
| ```bash | ```bash | ||||
| cd ./scripts | cd ./scripts | ||||
| sh run_distributed_train_ascend.sh RANK_TABLE_ADDR DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET | |||||
| sh run_distributed_train_ascend.sh RANK_TABLE_ADDR PRE_TRAIN_DATASET | |||||
| ``` | ``` | ||||
| Note: the `RANK_TABLE_ADDR` is the hccl_json file assigned when distributed training is running. | Note: the `RANK_TABLE_ADDR` is the hccl_json file assigned when distributed training is running. | ||||
| @@ -224,11 +224,11 @@ Run the shell script `scripts/run_standalone_eval_ascend.sh` to process the outp | |||||
| ```bash | ```bash | ||||
| cd ./scripts | cd ./scripts | ||||
| sh run_standalone_eval_ascend.sh | sh run_standalone_eval_ascend.sh | ||||
| sh run_standalone_eval_ascend.sh DATASET_SCHEMA_TEST TEST_DATASET EXISTED_CKPT_PATH \ | |||||
| sh run_standalone_eval_ascend.sh TEST_DATASET EXISTED_CKPT_PATH \ | |||||
| VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET | VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET | ||||
| ``` | ``` | ||||
| The `DATASET_SCHEMA_TEST` and the `TEST_DATASET` are the schema and address of inference dataset respectively, and `EXISTED_CKPT_PATH` is the path of the model file generated during training process. | |||||
| The `TEST_DATASET` is the address of inference dataset, and `EXISTED_CKPT_PATH` is the path of the model file generated during training process. | |||||
| The `VOCAB_ADDR` is the vocabulary address, `BPE_CODE_ADDR` is the bpe code address and the `TEST_TARGET` are the path of answers. | The `VOCAB_ADDR` is the vocabulary address, `BPE_CODE_ADDR` is the bpe code address and the `TEST_TARGET` are the path of answers. | ||||
| # [Model Description](#contents) | # [Model Description](#contents) | ||||
| @@ -3,7 +3,6 @@ | |||||
| "random_seed": 50, | "random_seed": 50, | ||||
| "epochs": 6, | "epochs": 6, | ||||
| "batch_size": 128, | "batch_size": 128, | ||||
| "dataset_schema": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json", | |||||
| "pre_train_dataset": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord", | "pre_train_dataset": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord", | ||||
| "fine_tune_dataset": null, | "fine_tune_dataset": null, | ||||
| "valid_dataset": null, | "valid_dataset": null, | ||||
| @@ -67,7 +67,6 @@ class GNMTConfig: | |||||
| random_seed (int): Random seed, it can be changed. | random_seed (int): Random seed, it can be changed. | ||||
| epochs (int): Epoch number. | epochs (int): Epoch number. | ||||
| batch_size (int): Batch size of input dataset. | batch_size (int): Batch size of input dataset. | ||||
| dataset_schema (str): Path of dataset schema file. | |||||
| pre_train_dataset (str): Path of pre-training dataset file or folder. | pre_train_dataset (str): Path of pre-training dataset file or folder. | ||||
| fine_tune_dataset (str): Path of fine-tune dataset file or folder. | fine_tune_dataset (str): Path of fine-tune dataset file or folder. | ||||
| test_dataset (str): Path of test dataset file or folder. | test_dataset (str): Path of test dataset file or folder. | ||||
| @@ -126,7 +125,6 @@ class GNMTConfig: | |||||
| def __init__(self, | def __init__(self, | ||||
| random_seed=50, | random_seed=50, | ||||
| epochs=6, batch_size=128, | epochs=6, batch_size=128, | ||||
| dataset_schema: str = None, | |||||
| pre_train_dataset: str = None, | pre_train_dataset: str = None, | ||||
| fine_tune_dataset: str = None, | fine_tune_dataset: str = None, | ||||
| test_dataset: str = None, | test_dataset: str = None, | ||||
| @@ -157,7 +155,6 @@ class GNMTConfig: | |||||
| self.save_graphs = save_graphs | self.save_graphs = save_graphs | ||||
| self.random_seed = random_seed | self.random_seed = random_seed | ||||
| self.dataset_schema = dataset_schema | |||||
| self.pre_train_dataset = get_source_list(pre_train_dataset) # type: List[str] | self.pre_train_dataset = get_source_list(pre_train_dataset) # type: List[str] | ||||
| self.fine_tune_dataset = get_source_list(fine_tune_dataset) # type: List[str] | self.fine_tune_dataset = get_source_list(fine_tune_dataset) # type: List[str] | ||||
| self.valid_dataset = get_source_list(valid_dataset) # type: List[str] | self.valid_dataset = get_source_list(valid_dataset) # type: List[str] | ||||
| @@ -3,7 +3,6 @@ | |||||
| "random_seed": 50, | "random_seed": 50, | ||||
| "epochs": 6, | "epochs": 6, | ||||
| "batch_size": 128, | "batch_size": 128, | ||||
| "dataset_schema": "/home/workspace/dataset_menu/newstest2014.en.json", | |||||
| "pre_train_dataset": null, | "pre_train_dataset": null, | ||||
| "fine_tune_dataset": null, | "fine_tune_dataset": null, | ||||
| "test_dataset": "/home/workspace/dataset_menu/newstest2014.en.mindrecord", | "test_dataset": "/home/workspace/dataset_menu/newstest2014.en.mindrecord", | ||||
| @@ -27,8 +27,6 @@ from src.dataset.tokenizer import Tokenizer | |||||
| parser = argparse.ArgumentParser(description='gnmt') | parser = argparse.ArgumentParser(description='gnmt') | ||||
| parser.add_argument("--config", type=str, required=True, | parser.add_argument("--config", type=str, required=True, | ||||
| help="model config json file path.") | help="model config json file path.") | ||||
| parser.add_argument("--dataset_schema_test", type=str, required=True, | |||||
| help="dataset schema for evaluation.") | |||||
| parser.add_argument("--test_dataset", type=str, required=True, | parser.add_argument("--test_dataset", type=str, required=True, | ||||
| help="test dataset address.") | help="test dataset address.") | ||||
| parser.add_argument("--existed_ckpt", type=str, required=True, | parser.add_argument("--existed_ckpt", type=str, required=True, | ||||
| @@ -63,7 +61,6 @@ if __name__ == '__main__': | |||||
| args, _ = parser.parse_known_args() | args, _ = parser.parse_known_args() | ||||
| _check_args(args.config) | _check_args(args.config) | ||||
| _config = get_config(args.config) | _config = get_config(args.config) | ||||
| _config.dataset_schema = args.dataset_schema_test | |||||
| _config.test_dataset = args.test_dataset | _config.test_dataset = args.test_dataset | ||||
| _config.existed_ckpt = args.existed_ckpt | _config.existed_ckpt = args.existed_ckpt | ||||
| result = infer(_config) | result = infer(_config) | ||||
| @@ -0,0 +1,102 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """export checkpoint file into air models""" | |||||
| import argparse | |||||
| import numpy as np | |||||
| from mindspore import Tensor, context, Parameter | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.train.serialization import export | |||||
| from config import GNMTConfig | |||||
| from src.gnmt_model.gnmt import GNMT | |||||
| from src.gnmt_model.gnmt_for_infer import GNMTInferCell | |||||
| from src.utils import zero_weight | |||||
| from src.utils.load_weights import load_infer_weights | |||||
| parser = argparse.ArgumentParser(description="gnmt_v2 export") | |||||
| parser.add_argument("--file_name", type=str, default="gnmt_v2", help="output file name.") | |||||
| parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format") | |||||
| parser.add_argument('--infer_config', type=str, required=True, help='gnmt_v2 config file') | |||||
| parser.add_argument("--existed_ckpt", type=str, required=True, help="existed checkpoint address.") | |||||
| parser.add_argument('--vocab_file', type=str, required=True, help='vocabulary file') | |||||
| parser.add_argument("--bpe_codes", type=str, required=True, help="bpe codes to use.") | |||||
| args = parser.parse_args() | |||||
| context.set_context( | |||||
| mode=context.GRAPH_MODE, | |||||
| save_graphs=False, | |||||
| device_target="Ascend", | |||||
| reserve_class_name_in_scope=False) | |||||
| def get_config(config_file): | |||||
| tfm_config = GNMTConfig.from_json_file(config_file) | |||||
| tfm_config.compute_type = mstype.float16 | |||||
| tfm_config.dtype = mstype.float32 | |||||
| return tfm_config | |||||
| if __name__ == '__main__': | |||||
| config = get_config(args.infer_config) | |||||
| config.existed_ckpt = args.existed_ckpt | |||||
| vocab = args.vocab_file | |||||
| bpe_codes = args.bpe_codes | |||||
| tfm_model = GNMT(config=config, | |||||
| is_training=False, | |||||
| use_one_hot_embeddings=False) | |||||
| params = tfm_model.trainable_params() | |||||
| weights = load_infer_weights(config) | |||||
| for param in params: | |||||
| value = param.data | |||||
| name = param.name | |||||
| if name not in weights: | |||||
| raise ValueError(f"{name} is not found in weights.") | |||||
| with open("weight_after_deal.txt", "a+") as f: | |||||
| weights_name = name | |||||
| f.write(weights_name) | |||||
| f.write("\n") | |||||
| if isinstance(value, Tensor): | |||||
| print(name, value.asnumpy().shape) | |||||
| if weights_name in weights: | |||||
| assert weights_name in weights | |||||
| if isinstance(weights[weights_name], Parameter): | |||||
| if param.data.dtype == "Float32": | |||||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) | |||||
| elif param.data.dtype == "Float16": | |||||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) | |||||
| elif isinstance(weights[weights_name], Tensor): | |||||
| param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) | |||||
| elif isinstance(weights[weights_name], np.ndarray): | |||||
| param.set_data(Tensor(weights[weights_name], config.dtype)) | |||||
| else: | |||||
| param.set_data(weights[weights_name]) | |||||
| else: | |||||
| print("weight not found in checkpoint: " + weights_name) | |||||
| param.set_data(zero_weight(value.asnumpy().shape)) | |||||
| f.close() | |||||
| print(" | Load weights successfully.") | |||||
| tfm_infer = GNMTInferCell(tfm_model) | |||||
| tfm_infer.set_train(False) | |||||
| source_ids = Tensor(np.ones((config.batch_size, config.seq_length)).astype(np.int32)) | |||||
| source_mask = Tensor(np.ones((config.batch_size, config.seq_length)).astype(np.int32)) | |||||
| export(tfm_infer, source_ids, source_mask, file_name=args.file_name, file_format=args.file_format) | |||||
| @@ -1,6 +1,4 @@ | |||||
| nltk | |||||
| jieba | |||||
| numpy | numpy | ||||
| subword-nmt==0.3.7 | subword-nmt==0.3.7 | ||||
| sacrebleu==1.2.10 | |||||
| sacremoses==0.0.19 | |||||
| sacrebleu==1.4.14 | |||||
| sacremoses==0.0.35 | |||||
| @@ -16,18 +16,16 @@ | |||||
| echo "==============================================================================================================" | echo "==============================================================================================================" | ||||
| echo "Please run the script as: " | echo "Please run the script as: " | ||||
| echo "sh run_distributed_train_ascend.sh RANK_TABLE_ADDR DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET" | |||||
| echo "sh run_distributed_train_ascend.sh RANK_TABLE_ADDR PRE_TRAIN_DATASET" | |||||
| echo "for example:" | echo "for example:" | ||||
| echo "sh run_distributed_train_ascend.sh \ | echo "sh run_distributed_train_ascend.sh \ | ||||
| /home/workspace/rank_table_8p.json \ | /home/workspace/rank_table_8p.json \ | ||||
| /home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json \ | |||||
| /home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord" | /home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord" | ||||
| echo "It is better to use absolute path." | echo "It is better to use absolute path." | ||||
| echo "==============================================================================================================" | echo "==============================================================================================================" | ||||
| RANK_TABLE_ADDR=$1 | RANK_TABLE_ADDR=$1 | ||||
| DATASET_SCHEMA_TRAIN=$2 | |||||
| PRE_TRAIN_DATASET=$3 | |||||
| PRE_TRAIN_DATASET=$2 | |||||
| current_exec_path=$(pwd) | current_exec_path=$(pwd) | ||||
| echo ${current_exec_path} | echo ${current_exec_path} | ||||
| @@ -49,10 +47,9 @@ do | |||||
| cp -r ../../config . | cp -r ../../config . | ||||
| export RANK_ID=$i | export RANK_ID=$i | ||||
| export DEVICE_ID=$i | export DEVICE_ID=$i | ||||
| python ../../train.py \ | |||||
| --config=${current_exec_path}/device${i}/config/config.json \ | |||||
| --dataset_schema_train=$DATASET_SCHEMA_TRAIN \ | |||||
| --pre_train_dataset=$PRE_TRAIN_DATASET > log_gnmt_network${i}.log 2>&1 & | |||||
| python ../../train.py \ | |||||
| --config=${current_exec_path}/device${i}/config/config.json \ | |||||
| --pre_train_dataset=$PRE_TRAIN_DATASET > log_gnmt_network${i}.log 2>&1 & | |||||
| cd ${current_exec_path} || exit | cd ${current_exec_path} || exit | ||||
| done | done | ||||
| cd ${current_exec_path} || exit | cd ${current_exec_path} || exit | ||||
| @@ -16,11 +16,10 @@ | |||||
| echo "==============================================================================================================" | echo "==============================================================================================================" | ||||
| echo "Please run the script as: " | echo "Please run the script as: " | ||||
| echo "sh run_standalone_eval_ascend.sh DATASET_SCHEMA_TEST TEST_DATASET EXISTED_CKPT_PATH \ | |||||
| echo "sh run_standalone_eval_ascend.sh TEST_DATASET EXISTED_CKPT_PATH \ | |||||
| VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET" | VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET" | ||||
| echo "for example:" | echo "for example:" | ||||
| echo "sh run_standalone_eval_ascend.sh \ | echo "sh run_standalone_eval_ascend.sh \ | ||||
| /home/workspace/dataset_menu/newstest2014.en.json \ | |||||
| /home/workspace/dataset_menu/newstest2014.en.mindrecord \ | /home/workspace/dataset_menu/newstest2014.en.mindrecord \ | ||||
| /home/workspace/gnmt_v2/gnmt-6_3452.ckpt \ | /home/workspace/gnmt_v2/gnmt-6_3452.ckpt \ | ||||
| /home/workspace/wmt16_de_en/vocab.bpe.32000 \ | /home/workspace/wmt16_de_en/vocab.bpe.32000 \ | ||||
| @@ -29,19 +28,16 @@ echo "sh run_standalone_eval_ascend.sh \ | |||||
| echo "It is better to use absolute path." | echo "It is better to use absolute path." | ||||
| echo "==============================================================================================================" | echo "==============================================================================================================" | ||||
| DATASET_SCHEMA_TEST=$1 | |||||
| TEST_DATASET=$2 | |||||
| EXISTED_CKPT_PATH=$3 | |||||
| VOCAB_ADDR=$4 | |||||
| BPE_CODE_ADDR=$5 | |||||
| TEST_TARGET=$6 | |||||
| TEST_DATASET=$1 | |||||
| EXISTED_CKPT_PATH=$2 | |||||
| VOCAB_ADDR=$3 | |||||
| BPE_CODE_ADDR=$4 | |||||
| TEST_TARGET=$5 | |||||
| current_exec_path=$(pwd) | current_exec_path=$(pwd) | ||||
| echo ${current_exec_path} | echo ${current_exec_path} | ||||
| export DEVICE_NUM=1 | |||||
| export RANK_ID=0 | |||||
| export RANK_SIZE=1 | |||||
| export GLOG_v=2 | export GLOG_v=2 | ||||
| if [ -d "eval" ]; | if [ -d "eval" ]; | ||||
| @@ -57,7 +53,6 @@ echo "start for evaluation" | |||||
| env > env.log | env > env.log | ||||
| python eval.py \ | python eval.py \ | ||||
| --config=${current_exec_path}/eval/config/config_test.json \ | --config=${current_exec_path}/eval/config/config_test.json \ | ||||
| --dataset_schema_test=$DATASET_SCHEMA_TEST \ | |||||
| --test_dataset=$TEST_DATASET \ | --test_dataset=$TEST_DATASET \ | ||||
| --existed_ckpt=$EXISTED_CKPT_PATH \ | --existed_ckpt=$EXISTED_CKPT_PATH \ | ||||
| --vocab=$VOCAB_ADDR \ | --vocab=$VOCAB_ADDR \ | ||||
| @@ -16,21 +16,17 @@ | |||||
| echo "==============================================================================================================" | echo "==============================================================================================================" | ||||
| echo "Please run the script as: " | echo "Please run the script as: " | ||||
| echo "sh run_standalone_train_ascend.sh DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET" | |||||
| echo "sh run_standalone_train_ascend.sh PRE_TRAIN_DATASET" | |||||
| echo "for example:" | echo "for example:" | ||||
| echo "sh run_standalone_train_ascend.sh \ | echo "sh run_standalone_train_ascend.sh \ | ||||
| /home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json \ | |||||
| /home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord" | /home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord" | ||||
| echo "It is better to use absolute path." | echo "It is better to use absolute path." | ||||
| echo "==============================================================================================================" | echo "==============================================================================================================" | ||||
| DATASET_SCHEMA_TRAIN=$1 | |||||
| PRE_TRAIN_DATASET=$2 | |||||
| PRE_TRAIN_DATASET=$1 | |||||
| export DEVICE_NUM=1 | |||||
| export RANK_ID=0 | |||||
| export RANK_SIZE=1 | |||||
| export GLOG_v=2 | export GLOG_v=2 | ||||
| current_exec_path=$(pwd) | current_exec_path=$(pwd) | ||||
| echo ${current_exec_path} | echo ${current_exec_path} | ||||
| if [ -d "train" ]; | if [ -d "train" ]; | ||||
| @@ -45,7 +41,6 @@ cd ./train || exit | |||||
| echo "start for training" | echo "start for training" | ||||
| env > env.log | env > env.log | ||||
| python train.py \ | python train.py \ | ||||
| --config=${current_exec_path}/train/config/config.json \ | |||||
| --dataset_schema_train=$DATASET_SCHEMA_TRAIN \ | |||||
| --pre_train_dataset=$PRE_TRAIN_DATASET > log_gnmt_network.log 2>&1 & | |||||
| --config=${current_exec_path}/train/config/config.json \ | |||||
| --pre_train_dataset=$PRE_TRAIN_DATASET > log_gnmt_network.log 2>&1 & | |||||
| cd .. | cd .. | ||||
| @@ -13,13 +13,12 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Dataset loader to feed into model.""" | """Dataset loader to feed into model.""" | ||||
| import os | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| import mindspore.dataset.engine as de | import mindspore.dataset.engine as de | ||||
| import mindspore.dataset.transforms.c_transforms as deC | import mindspore.dataset.transforms.c_transforms as deC | ||||
| def _load_dataset(input_files, schema_file, batch_size, sink_mode=False, | |||||
| def _load_dataset(input_files, batch_size, sink_mode=False, | |||||
| rank_size=1, rank_id=0, shuffle=True, drop_remainder=True, | rank_size=1, rank_id=0, shuffle=True, drop_remainder=True, | ||||
| is_translate=False): | is_translate=False): | ||||
| """ | """ | ||||
| @@ -27,7 +26,6 @@ def _load_dataset(input_files, schema_file, batch_size, sink_mode=False, | |||||
| Args: | Args: | ||||
| input_files (list): Data files. | input_files (list): Data files. | ||||
| schema_file (str): Schema file path. | |||||
| batch_size (int): Batch size. | batch_size (int): Batch size. | ||||
| sink_mode (bool): Whether enable sink mode. | sink_mode (bool): Whether enable sink mode. | ||||
| rank_size (int): Rank size. | rank_size (int): Rank size. | ||||
| @@ -42,12 +40,6 @@ def _load_dataset(input_files, schema_file, batch_size, sink_mode=False, | |||||
| if not input_files: | if not input_files: | ||||
| raise FileNotFoundError("Require at least one dataset.") | raise FileNotFoundError("Require at least one dataset.") | ||||
| if not (schema_file and | |||||
| os.path.exists(schema_file) | |||||
| and os.path.isfile(schema_file) | |||||
| and os.path.basename(schema_file).endswith(".json")): | |||||
| raise FileNotFoundError("`dataset_schema` must be a existed json file.") | |||||
| if not isinstance(sink_mode, bool): | if not isinstance(sink_mode, bool): | ||||
| raise ValueError("`sink` must be type of bool.") | raise ValueError("`sink` must be type of bool.") | ||||
| @@ -116,14 +108,13 @@ def _load_dataset(input_files, schema_file, batch_size, sink_mode=False, | |||||
| return ds | return ds | ||||
| def load_dataset(data_files: list, schema: str, batch_size: int, sink_mode: bool, | |||||
| def load_dataset(data_files: list, batch_size: int, sink_mode: bool, | |||||
| rank_size: int = 1, rank_id: int = 0, shuffle=True, drop_remainder=True, is_translate=False): | rank_size: int = 1, rank_id: int = 0, shuffle=True, drop_remainder=True, is_translate=False): | ||||
| """ | """ | ||||
| Load dataset. | Load dataset. | ||||
| Args: | Args: | ||||
| data_files (list): Data files. | data_files (list): Data files. | ||||
| schema (str): Schema file path. | |||||
| batch_size (int): Batch size. | batch_size (int): Batch size. | ||||
| sink_mode (bool): Whether enable sink mode. | sink_mode (bool): Whether enable sink mode. | ||||
| rank_size (int): Rank size. | rank_size (int): Rank size. | ||||
| @@ -133,5 +124,5 @@ def load_dataset(data_files: list, schema: str, batch_size: int, sink_mode: bool | |||||
| Returns: | Returns: | ||||
| Dataset, dataset instance. | Dataset, dataset instance. | ||||
| """ | """ | ||||
| return _load_dataset(data_files, schema, batch_size, sink_mode, rank_size, rank_id, shuffle=shuffle, | |||||
| return _load_dataset(data_files, batch_size, sink_mode, rank_size, rank_id, shuffle=shuffle, | |||||
| drop_remainder=drop_remainder, is_translate=is_translate) | drop_remainder=drop_remainder, is_translate=is_translate) | ||||
| @@ -38,7 +38,7 @@ class BahdanauAttention(nn.Cell): | |||||
| initializer_range: range for uniform initializer parameters. | initializer_range: range for uniform initializer parameters. | ||||
| Returns: | Returns: | ||||
| Tensor, shape (N, T, D). | |||||
| Tensor, shape (t_q_length, N, D). | |||||
| """ | """ | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -93,108 +93,107 @@ class BahdanauAttention(nn.Cell): | |||||
| Construct attention block. | Construct attention block. | ||||
| Args: | Args: | ||||
| query (Tensor): Shape (t_q, N, D). | |||||
| keys (Tensor): Shape (t_k, N, D). | |||||
| attention_mask: Shape(N, t_k). | |||||
| query (Tensor): Shape (t_q_length, N, D). | |||||
| keys (Tensor): Shape (t_k_length, N, D). | |||||
| attention_mask: Shape(N, t_k_length). | |||||
| Returns: | Returns: | ||||
| Tensor, shape (N, t_q, D). | |||||
| Tensor, shape (t_q_length, N, D). | |||||
| """ | """ | ||||
| # (t_k, N, D) -> (N, t_k, D). | |||||
| # (t_k_length, N, D) -> (N, t_k_length, D). | |||||
| keys = self.transpose(keys, self.transpose_orders) | keys = self.transpose(keys, self.transpose_orders) | ||||
| # (t_q, N, D) -> (N, t_q, D). | |||||
| query = self.transpose(query, self.transpose_orders) | |||||
| # (t_q_length, N, D) -> (N, t_q_length, D). | |||||
| query_trans = self.transpose(query, self.transpose_orders) | |||||
| query_shape = self.shape_op(query) | |||||
| b = query_shape[0] | |||||
| t_q = query_shape[1] | |||||
| t_k = self.shape_op(keys)[1] | |||||
| query_shape = self.shape_op(query_trans) | |||||
| batch_size = query_shape[0] | |||||
| t_q_length = query_shape[1] | |||||
| t_k_length = self.shape_op(keys)[1] | |||||
| # (N, t_q, D) | |||||
| query = self.reshape(query, (b * t_q, self.query_size)) | |||||
| # (N, t_q_length, D) | |||||
| query_trans = self.reshape(query_trans, (batch_size * t_q_length, self.query_size)) | |||||
| if self.is_training: | if self.is_training: | ||||
| query = self.cast(query, mstype.float16) | |||||
| processed_query = self.linear_q(query) | |||||
| query_trans = self.cast(query_trans, mstype.float16) | |||||
| processed_query = self.linear_q(query_trans) | |||||
| if self.is_trining: | if self.is_trining: | ||||
| processed_query = self.cast(processed_query, mstype.float32) | processed_query = self.cast(processed_query, mstype.float32) | ||||
| processed_query = self.reshape(processed_query, (b, t_q, self.num_units)) | |||||
| # (N, t_k, D) | |||||
| keys = self.reshape(keys, (b * t_k, self.key_size)) | |||||
| processed_query = self.reshape(processed_query, (batch_size, t_q_length, self.num_units)) | |||||
| # (N, t_k_length, D) | |||||
| keys = self.reshape(keys, (batch_size * t_k_length, self.key_size)) | |||||
| if self.is_training: | if self.is_training: | ||||
| keys = self.cast(keys, mstype.float16) | keys = self.cast(keys, mstype.float16) | ||||
| processed_key = self.linear_k(keys) | processed_key = self.linear_k(keys) | ||||
| if self.is_trining: | if self.is_trining: | ||||
| processed_key = self.cast(processed_key, mstype.float32) | processed_key = self.cast(processed_key, mstype.float32) | ||||
| processed_key = self.reshape(processed_key, (b, t_k, self.num_units)) | |||||
| processed_key = self.reshape(processed_key, (batch_size, t_k_length, self.num_units)) | |||||
| # scores: (N , T_q, T_k) | |||||
| scores = self.calc_score(processed_query, processed_key) | |||||
| # attention_mask: (N, T_k) | |||||
| # scores: (N, t_q_length, t_k_length) | |||||
| scores = self.obtain_score(processed_query, processed_key) | |||||
| # attention_mask: (N, t_k_length) | |||||
| mask = attention_mask | mask = attention_mask | ||||
| # [N, 1] | |||||
| if mask is not None: | if mask is not None: | ||||
| mask = 1.0 - mask | mask = 1.0 - mask | ||||
| mask = self.tile(self.expand(mask, 1), (1, t_q, 1)) | |||||
| mask = self.tile(self.expand(mask, 1), (1, t_q_length, 1)) | |||||
| scores += mask * (-INF) | scores += mask * (-INF) | ||||
| # [b, t_q, t_k] | |||||
| scores_normalized = self.softmax(scores) | |||||
| # [batch_size, t_q_length, t_k_length] | |||||
| scores_softmax = self.softmax(scores) | |||||
| keys = self.reshape(keys, (b, t_k, self.key_size)) | |||||
| keys = self.reshape(keys, (batch_size, t_k_length, self.key_size)) | |||||
| if self.is_training: | if self.is_training: | ||||
| keys = self.cast(keys, mstype.float16) | keys = self.cast(keys, mstype.float16) | ||||
| scores_normalized_fp16 = self.cast(scores_normalized, mstype.float16) | |||||
| scores_softmax_fp16 = self.cast(scores_softmax, mstype.float16) | |||||
| else: | else: | ||||
| scores_normalized_fp16 = scores_normalized | |||||
| scores_softmax_fp16 = scores_softmax | |||||
| # (b, t_q, n) | |||||
| context_attention = self.batchMatmul(scores_normalized_fp16, keys) | |||||
| # [t_q,b,D] | |||||
| # (b, t_q_length, D) | |||||
| context_attention = self.batchMatmul(scores_softmax_fp16, keys) | |||||
| # [t_q_length, b, D] | |||||
| context_attention = self.transpose(context_attention, self.transpose_orders) | context_attention = self.transpose(context_attention, self.transpose_orders) | ||||
| if self.is_training: | if self.is_training: | ||||
| context_attention = self.cast(context_attention, mstype.float32) | context_attention = self.cast(context_attention, mstype.float32) | ||||
| return context_attention, scores_normalized | |||||
| return context_attention, scores_softmax | |||||
| def calc_score(self, att_query, att_keys): | |||||
| def obtain_score(self, attention_q, attention_k): | |||||
| """ | """ | ||||
| Calculate Bahdanau score | Calculate Bahdanau score | ||||
| Args: | Args: | ||||
| att_query: (N, T_q, D). | |||||
| att_keys: (N, T_k, D). | |||||
| attention_q: (batch_size, t_q_length, D). | |||||
| attention_k: (batch_size, t_k_length, D). | |||||
| returns: | returns: | ||||
| scores: (N, T_q, T_k). | |||||
| scores: (batch_size, t_q_length, t_k_length). | |||||
| """ | """ | ||||
| b, t_k, n = self.shape_op(att_keys) | |||||
| t_q = self.shape_op(att_query)[1] | |||||
| # (b, t_q, t_k, n) | |||||
| att_query = self.tile(self.expand(att_query, 2), (1, 1, t_k, 1)) | |||||
| att_keys = self.tile(self.expand(att_keys, 1), (1, t_q, 1, 1)) | |||||
| # (b, t_q, t_k, n) | |||||
| sum_qk = att_query + att_keys | |||||
| batch_size, t_k_length, D = self.shape_op(attention_k) | |||||
| t_q_length = self.shape_op(attention_q)[1] | |||||
| # (batch_size, t_q_length, t_k_length, n) | |||||
| attention_q = self.tile(self.expand(attention_q, 2), (1, 1, t_k_length, 1)) | |||||
| attention_k = self.tile(self.expand(attention_k, 1), (1, t_q_length, 1, 1)) | |||||
| # (batch_size, t_q_length, t_k_length, n) | |||||
| sum_qk_add = attention_q + attention_k | |||||
| if self.normalize: | if self.normalize: | ||||
| # (b, t_q, t_k, n) | |||||
| sum_qk = sum_qk + self.normalize_bias | |||||
| linear_att = self.linear_att / self.norm(self.linear_att) | |||||
| linear_att = self.cast(linear_att, mstype.float32) | |||||
| linear_att = self.mul(linear_att, self.normalize_scalar) | |||||
| # (batch_size, t_q_length, t_k_length, n) | |||||
| sum_qk_add = sum_qk_add + self.normalize_bias | |||||
| linear_att_norm = self.linear_att / self.norm(self.linear_att) | |||||
| linear_att_norm = self.cast(linear_att_norm, mstype.float32) | |||||
| linear_att_norm = self.mul(linear_att_norm, self.normalize_scalar) | |||||
| else: | else: | ||||
| linear_att = self.linear_att | |||||
| linear_att_norm = self.linear_att | |||||
| linear_att = self.expand(linear_att, -1) | |||||
| sum_qk = self.reshape(sum_qk, (-1, n)) | |||||
| linear_att_norm = self.expand(linear_att_norm, -1) | |||||
| sum_qk_add = self.reshape(sum_qk_add, (-1, D)) | |||||
| tanh_sum_qk = self.tanh(sum_qk) | |||||
| tanh_sum_qk = self.tanh(sum_qk_add) | |||||
| if self.is_training: | if self.is_training: | ||||
| linear_att = self.cast(linear_att, mstype.float16) | |||||
| linear_att_norm = self.cast(linear_att_norm, mstype.float16) | |||||
| tanh_sum_qk = self.cast(tanh_sum_qk, mstype.float16) | tanh_sum_qk = self.cast(tanh_sum_qk, mstype.float16) | ||||
| out = self.matmul(tanh_sum_qk, linear_att) | |||||
| scores_out = self.matmul(tanh_sum_qk, linear_att_norm) | |||||
| # (b, t_q, t_k) | |||||
| out = self.reshape(out, (b, t_q, t_k)) | |||||
| # (N, t_q_length, t_k_length) | |||||
| scores_out = self.reshape(scores_out, (batch_size, t_q_length, t_k_length)) | |||||
| if self.is_training: | if self.is_training: | ||||
| out = self.cast(out, mstype.float32) | |||||
| return out | |||||
| scores_out = self.cast(scores_out, mstype.float32) | |||||
| return scores_out | |||||
| @@ -214,9 +214,8 @@ class BeamSearchDecoder(nn.Cell): | |||||
| self.concat = P.Concat(axis=-1) | self.concat = P.Concat(axis=-1) | ||||
| self.gather_nd = P.GatherNd() | self.gather_nd = P.GatherNd() | ||||
| self.start = Tensor(0, dtype=mstype.int32) | |||||
| self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32) | self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32) | ||||
| self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length], sos_id), mstype.int32) | |||||
| self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32) | |||||
| init_scores = np.tile(np.array([[0.] + [-INF] * (beam_width - 1)]), [batch_size, 1]) | init_scores = np.tile(np.array([[0.] + [-INF] * (beam_width - 1)]), [batch_size, 1]) | ||||
| self.init_scores = Tensor(init_scores, mstype.float32) | self.init_scores = Tensor(init_scores, mstype.float32) | ||||
| @@ -260,7 +259,7 @@ class BeamSearchDecoder(nn.Cell): | |||||
| self.sub = P.Sub() | self.sub = P.Sub() | ||||
| def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs, | def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs, | ||||
| state_seq, state_length, idx=None, decoder_hidden_state=None, accu_attn_scores=None, | |||||
| state_seq, state_length, decoder_hidden_state=None, accu_attn_scores=None, | |||||
| state_finished=None): | state_finished=None): | ||||
| """ | """ | ||||
| Beam search one_step output. | Beam search one_step output. | ||||
| @@ -270,7 +269,7 @@ class BeamSearchDecoder(nn.Cell): | |||||
| enc_states (Tensor): with shape (batch_size * beam_width, T, D). | enc_states (Tensor): with shape (batch_size * beam_width, T, D). | ||||
| enc_attention_mask (Tensor): with shape (batch_size * beam_width, T). | enc_attention_mask (Tensor): with shape (batch_size * beam_width, T). | ||||
| state_log_probs (Tensor): with shape (batch_size, beam_width). | state_log_probs (Tensor): with shape (batch_size, beam_width). | ||||
| state_seq (Tensor): with shape (batch_size, beam_width, max_decoder_length). | |||||
| state_seq (Tensor): with shape (batch_size, beam_width, m). | |||||
| state_length (Tensor): with shape (batch_size, beam_width). | state_length (Tensor): with shape (batch_size, beam_width). | ||||
| idx (Tensor): with shape (). | idx (Tensor): with shape (). | ||||
| decoder_hidden_state (Tensor): with shape (decoder_layer_num, 2, batch_size * beam_width, D). | decoder_hidden_state (Tensor): with shape (decoder_layer_num, 2, batch_size * beam_width, D). | ||||
| @@ -360,10 +359,7 @@ class BeamSearchDecoder(nn.Cell): | |||||
| self.hidden_size)) | self.hidden_size)) | ||||
| # update state_seq | # update state_seq | ||||
| state_seq_new = self.cast(seq, mstype.float32) | |||||
| word_indices_fp32 = self.cast(word_indices, mstype.float32) | |||||
| state_seq_new[:, :, idx] = word_indices_fp32 | |||||
| state_seq = self.cast(state_seq_new, mstype.int32) | |||||
| state_seq = self.concat((seq, self.expand(word_indices, -1))) | |||||
| cur_input_ids = self.reshape(word_indices, (-1, 1)) | cur_input_ids = self.reshape(word_indices, (-1, 1)) | ||||
| state_log_probs = topk_scores | state_log_probs = topk_scores | ||||
| @@ -392,15 +388,11 @@ class BeamSearchDecoder(nn.Cell): | |||||
| decoder_hidden_state = self.decoder_hidden_state | decoder_hidden_state = self.decoder_hidden_state | ||||
| accu_attn_scores = self.accu_attn_scores | accu_attn_scores = self.accu_attn_scores | ||||
| idx = self.start + 1 | |||||
| ends = self.start + self.max_decode_length + 1 | |||||
| while idx < ends: | |||||
| for _ in range(self.max_decode_length + 1): | |||||
| cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \ | cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \ | ||||
| state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs, | state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs, | ||||
| state_seq, state_length, idx, decoder_hidden_state, accu_attn_scores, | |||||
| state_seq, state_length, decoder_hidden_state, accu_attn_scores, | |||||
| state_finished) | state_finished) | ||||
| idx = idx + 1 | |||||
| # add length penalty scores | # add length penalty scores | ||||
| penalty_len = self.length_penalty(state_length) | penalty_len = self.length_penalty(state_length) | ||||
| # return penalty_len | # return penalty_len | ||||
| @@ -416,6 +408,6 @@ class BeamSearchDecoder(nn.Cell): | |||||
| gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1))) | gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1))) | ||||
| # sort sequence and attention scores | # sort sequence and attention scores | ||||
| predicted_ids = self.gather_nd(state_seq, gather_indices) | predicted_ids = self.gather_nd(state_seq, gather_indices) | ||||
| predicted_ids = predicted_ids[:, 0:1, :self.max_decode_length] | |||||
| predicted_ids = predicted_ids[:, 0:1, 1:(self.max_decode_length + 1)] | |||||
| return predicted_ids | return predicted_ids | ||||
| @@ -187,7 +187,6 @@ def infer(config): | |||||
| list, result with | list, result with | ||||
| """ | """ | ||||
| eval_dataset = load_dataset(data_files=config.test_dataset, | eval_dataset = load_dataset(data_files=config.test_dataset, | ||||
| schema=config.dataset_schema, | |||||
| batch_size=config.batch_size, | batch_size=config.batch_size, | ||||
| sink_mode=config.dataset_sink_mode, | sink_mode=config.dataset_sink_mode, | ||||
| drop_remainder=False, | drop_remainder=False, | ||||
| @@ -40,7 +40,6 @@ from src.utils.optimizer import Adam | |||||
| parser = argparse.ArgumentParser(description='GNMT train entry point.') | parser = argparse.ArgumentParser(description='GNMT train entry point.') | ||||
| parser.add_argument("--config", type=str, required=True, help="model config json file path.") | parser.add_argument("--config", type=str, required=True, help="model config json file path.") | ||||
| parser.add_argument("--dataset_schema_train", type=str, required=True, help="dataset schema for train.") | |||||
| parser.add_argument("--pre_train_dataset", type=str, required=True, help="pre-train dataset address.") | parser.add_argument("--pre_train_dataset", type=str, required=True, help="pre-train dataset address.") | ||||
| device_id = os.getenv('DEVICE_ID', None) | device_id = os.getenv('DEVICE_ID', None) | ||||
| @@ -273,21 +272,20 @@ def train_parallel(config: GNMTConfig): | |||||
| pre_train_dataset = load_dataset( | pre_train_dataset = load_dataset( | ||||
| data_files=config.pre_train_dataset, | data_files=config.pre_train_dataset, | ||||
| schema=config.dataset_schema, | |||||
| batch_size=config.batch_size, | batch_size=config.batch_size, | ||||
| sink_mode=config.dataset_sink_mode, | sink_mode=config.dataset_sink_mode, | ||||
| rank_size=MultiAscend.get_group_size(), | rank_size=MultiAscend.get_group_size(), | ||||
| rank_id=MultiAscend.get_rank() | rank_id=MultiAscend.get_rank() | ||||
| ) if config.pre_train_dataset else None | ) if config.pre_train_dataset else None | ||||
| fine_tune_dataset = load_dataset( | fine_tune_dataset = load_dataset( | ||||
| data_files=config.fine_tune_dataset, schema=config.dataset_schema, | |||||
| data_files=config.fine_tune_dataset, | |||||
| batch_size=config.batch_size, | batch_size=config.batch_size, | ||||
| sink_mode=config.dataset_sink_mode, | sink_mode=config.dataset_sink_mode, | ||||
| rank_size=MultiAscend.get_group_size(), | rank_size=MultiAscend.get_group_size(), | ||||
| rank_id=MultiAscend.get_rank() | rank_id=MultiAscend.get_rank() | ||||
| ) if config.fine_tune_dataset else None | ) if config.fine_tune_dataset else None | ||||
| test_dataset = load_dataset( | test_dataset = load_dataset( | ||||
| data_files=config.test_dataset, schema=config.dataset_schema, | |||||
| data_files=config.test_dataset, | |||||
| batch_size=config.batch_size, | batch_size=config.batch_size, | ||||
| sink_mode=config.dataset_sink_mode, | sink_mode=config.dataset_sink_mode, | ||||
| rank_size=MultiAscend.get_group_size(), | rank_size=MultiAscend.get_group_size(), | ||||
| @@ -310,15 +308,12 @@ def train_single(config: GNMTConfig): | |||||
| print(" | Starting training on single device.") | print(" | Starting training on single device.") | ||||
| pre_train_dataset = load_dataset(data_files=config.pre_train_dataset, | pre_train_dataset = load_dataset(data_files=config.pre_train_dataset, | ||||
| schema=config.dataset_schema, | |||||
| batch_size=config.batch_size, | batch_size=config.batch_size, | ||||
| sink_mode=config.dataset_sink_mode) if config.pre_train_dataset else None | sink_mode=config.dataset_sink_mode) if config.pre_train_dataset else None | ||||
| fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset, | fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset, | ||||
| schema=config.dataset_schema, | |||||
| batch_size=config.batch_size, | batch_size=config.batch_size, | ||||
| sink_mode=config.dataset_sink_mode) if config.fine_tune_dataset else None | sink_mode=config.dataset_sink_mode) if config.fine_tune_dataset else None | ||||
| test_dataset = load_dataset(data_files=config.test_dataset, | test_dataset = load_dataset(data_files=config.test_dataset, | ||||
| schema=config.dataset_schema, | |||||
| batch_size=config.batch_size, | batch_size=config.batch_size, | ||||
| sink_mode=config.dataset_sink_mode) if config.test_dataset else None | sink_mode=config.dataset_sink_mode) if config.test_dataset else None | ||||
| @@ -341,7 +336,6 @@ if __name__ == '__main__': | |||||
| args, _ = parser.parse_known_args() | args, _ = parser.parse_known_args() | ||||
| _check_args(args.config) | _check_args(args.config) | ||||
| _config = get_config(args.config) | _config = get_config(args.config) | ||||
| _config.dataset_schema = args.dataset_schema_train | |||||
| _config.pre_train_dataset = args.pre_train_dataset | _config.pre_train_dataset = args.pre_train_dataset | ||||
| set_seed(_config.random_seed) | set_seed(_config.random_seed) | ||||
| if _rank_size is not None and int(_rank_size) > 1: | if _rank_size is not None and int(_rank_size) > 1: | ||||
| @@ -31,15 +31,11 @@ parser = argparse.ArgumentParser(description='GNMT train and eval.') | |||||
| # train | # train | ||||
| parser.add_argument("--config_train", type=str, required=True, | parser.add_argument("--config_train", type=str, required=True, | ||||
| help="model config json file path.") | help="model config json file path.") | ||||
| parser.add_argument("--dataset_schema_train", type=str, required=True, | |||||
| help="dataset schema for train.") | |||||
| parser.add_argument("--pre_train_dataset", type=str, required=True, | parser.add_argument("--pre_train_dataset", type=str, required=True, | ||||
| help="pre-train dataset address.") | help="pre-train dataset address.") | ||||
| # eval | # eval | ||||
| parser.add_argument("--config_test", type=str, required=True, | parser.add_argument("--config_test", type=str, required=True, | ||||
| help="model config json file path.") | help="model config json file path.") | ||||
| parser.add_argument("--dataset_schema_test", type=str, required=True, | |||||
| help="dataset schema for evaluation.") | |||||
| parser.add_argument("--test_dataset", type=str, required=True, | parser.add_argument("--test_dataset", type=str, required=True, | ||||
| help="test dataset address.") | help="test dataset address.") | ||||
| parser.add_argument("--existed_ckpt", type=str, required=True, | parser.add_argument("--existed_ckpt", type=str, required=True, | ||||
| @@ -77,7 +73,6 @@ if __name__ == '__main__': | |||||
| # train | # train | ||||
| _check_args(args.config_train) | _check_args(args.config_train) | ||||
| _config_train = get_config(args.config_train) | _config_train = get_config(args.config_train) | ||||
| _config_train.dataset_schema = args.dataset_schema_train | |||||
| _config_train.pre_train_dataset = args.pre_train_dataset | _config_train.pre_train_dataset = args.pre_train_dataset | ||||
| set_seed(_config_train.random_seed) | set_seed(_config_train.random_seed) | ||||
| assert _rank_size is not None and int(_rank_size) > 1 | assert _rank_size is not None and int(_rank_size) > 1 | ||||
| @@ -86,7 +81,6 @@ if __name__ == '__main__': | |||||
| # eval | # eval | ||||
| _check_args(args.config_test) | _check_args(args.config_test) | ||||
| _config_test = get_config(args.config_test) | _config_test = get_config(args.config_test) | ||||
| _config_test.dataset_schema = args.dataset_schema_test | |||||
| _config_test.test_dataset = args.test_dataset | _config_test.test_dataset = args.test_dataset | ||||
| _config_test.existed_ckpt = args.existed_ckpt | _config_test.existed_ckpt = args.existed_ckpt | ||||
| result = infer(_config_test) | result = infer(_config_test) | ||||
| @@ -16,19 +16,15 @@ | |||||
| echo "==============================================================================================================" | echo "==============================================================================================================" | ||||
| echo "Please run the scipt as: " | echo "Please run the scipt as: " | ||||
| echo "sh run_distributed_train_ascend.sh \ | |||||
| GNMT_ADDR RANK_TABLE_ADDR \ | |||||
| DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET \ | |||||
| DATASET_SCHEMA_TEST TEST_DATASET EXISTED_CKPT_PATH \ | |||||
| echo "sh test_gnmt_v2.sh \ | |||||
| GNMT_ADDR RANK_TABLE_ADDR PRE_TRAIN_DATASET TEST_DATASET EXISTED_CKPT_PATH \ | |||||
| VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET" | VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET" | ||||
| echo "for example:" | echo "for example:" | ||||
| echo "sh run_distributed_train_ascend.sh \ | |||||
| echo "sh test_gnmt_v2.sh \ | |||||
| /home/workspace/gnmt_v2 \ | /home/workspace/gnmt_v2 \ | ||||
| /home/workspace/rank_table_8p.json \ | /home/workspace/rank_table_8p.json \ | ||||
| /home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json \ | |||||
| /home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.tfrecord-001-of-001 \ | |||||
| /home/workspace/dataset_menu/newstest2014.en.json \ | |||||
| /home/workspace/dataset_menu/newstest2014.en.tfrecord-001-of-001 \ | |||||
| /home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord \ | |||||
| /home/workspace/dataset_menu/newstest2014.en.mindrecord \ | |||||
| /home/workspace/gnmt_v2/gnmt-6_3452.ckpt \ | /home/workspace/gnmt_v2/gnmt-6_3452.ckpt \ | ||||
| /home/workspace/wmt16_de_en/vocab.bpe.32000 \ | /home/workspace/wmt16_de_en/vocab.bpe.32000 \ | ||||
| /home/workspace/wmt16_de_en/bpe.32000 \ | /home/workspace/wmt16_de_en/bpe.32000 \ | ||||
| @@ -39,15 +35,13 @@ echo "========================================================================== | |||||
| GNMT_ADDR=$1 | GNMT_ADDR=$1 | ||||
| RANK_TABLE_ADDR=$2 | RANK_TABLE_ADDR=$2 | ||||
| # train dataset addr | # train dataset addr | ||||
| DATASET_SCHEMA_TRAIN=$3 | |||||
| PRE_TRAIN_DATASET=$4 | |||||
| PRE_TRAIN_DATASET=$3 | |||||
| # eval dataset addr | # eval dataset addr | ||||
| DATASET_SCHEMA_TEST=$5 | |||||
| TEST_DATASET=$6 | |||||
| EXISTED_CKPT_PATH=$7 | |||||
| VOCAB_ADDR=$8 | |||||
| BPE_CODE_ADDR=$9 | |||||
| TEST_TARGET=${10} | |||||
| TEST_DATASET=$4 | |||||
| EXISTED_CKPT_PATH=$5 | |||||
| VOCAB_ADDR=$6 | |||||
| BPE_CODE_ADDR=$7 | |||||
| TEST_TARGET=$8 | |||||
| current_exec_path=$(pwd) | current_exec_path=$(pwd) | ||||
| echo ${current_exec_path} | echo ${current_exec_path} | ||||
| @@ -70,17 +64,15 @@ do | |||||
| cp -r ${GNMT_ADDR}/config . | cp -r ${GNMT_ADDR}/config . | ||||
| export RANK_ID=$i | export RANK_ID=$i | ||||
| export DEVICE_ID=$i | export DEVICE_ID=$i | ||||
| python test_gnmt_v2.py \ | |||||
| --config_train=${GNMT_ADDR}/config/config.json \ | |||||
| --dataset_schema_train=$DATASET_SCHEMA_TRAIN \ | |||||
| --pre_train_dataset=$PRE_TRAIN_DATASET \ | |||||
| --config_test=${GNMT_ADDR}/config/config_test.json \ | |||||
| --dataset_schema_test=$DATASET_SCHEMA_TEST \ | |||||
| --test_dataset=$TEST_DATASET \ | |||||
| --existed_ckpt=$EXISTED_CKPT_PATH \ | |||||
| --vocab=$VOCAB_ADDR \ | |||||
| --bpe_codes=$BPE_CODE_ADDR \ | |||||
| --test_tgt=$TEST_TARGET > log_gnmt_network${i}.log 2>&1 & | |||||
| python test_gnmt_v2.py \ | |||||
| --config_train=${GNMT_ADDR}/config/config.json \ | |||||
| --pre_train_dataset=$PRE_TRAIN_DATASET \ | |||||
| --config_test=${GNMT_ADDR}/config/config_test.json \ | |||||
| --test_dataset=$TEST_DATASET \ | |||||
| --existed_ckpt=$EXISTED_CKPT_PATH \ | |||||
| --vocab=$VOCAB_ADDR \ | |||||
| --bpe_codes=$BPE_CODE_ADDR \ | |||||
| --test_tgt=$TEST_TARGET > log_gnmt_network${i}.log 2>&1 & | |||||
| cd ${current_exec_path} || exit | cd ${current_exec_path} || exit | ||||
| done | done | ||||
| cd ${current_exec_path} || exit | cd ${current_exec_path} || exit | ||||