From: @qujianwei Reviewed-by: @linqingke,@liangchenghui Signed-off-by: @liangchenghuitags/v1.2.0-rc1
| @@ -1,4 +1,4 @@ | |||||
|  | |||||
|  | |||||
| <!-- TOC --> | <!-- TOC --> | ||||
| @@ -52,6 +52,26 @@ In this model, we use the Multi30K dataset as our train and test dataset.As trai | |||||
| - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | ||||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | ||||
| ## Requirements | |||||
| ```txt | |||||
| nltk | |||||
| numpy | |||||
| ``` | |||||
| To install nltk, you should install nltk as follow: | |||||
| ```bash | |||||
| pip install nltk | |||||
| ``` | |||||
| Then you should download extra packages as follow: | |||||
| ```python | |||||
| import nltk | |||||
| nltk.download() | |||||
| ``` | |||||
| # [Quick Start](#content) | # [Quick Start](#content) | ||||
| After dataset preparation, you can start training and evaluation as follows: | After dataset preparation, you can start training and evaluation as follows: | ||||
| @@ -13,7 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Transformer evaluation script.""" | """Transformer evaluation script.""" | ||||
| import os | |||||
| import argparse | import argparse | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| @@ -41,8 +41,13 @@ def run_gru_eval(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, \ | context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, \ | ||||
| device_id=args.device_id, save_graphs=False) | device_id=args.device_id, save_graphs=False) | ||||
| prefix = "multi30k_test_mindrecord_32" | |||||
| mindrecord_file = os.path.join(args.dataset_path, prefix) | |||||
| if not os.path.exists(mindrecord_file): | |||||
| print("dataset file {} not exists, please check!".format(mindrecord_file)) | |||||
| raise ValueError(mindrecord_file) | |||||
| dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size, \ | dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size, \ | ||||
| dataset_path=args.dataset_path, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False) | |||||
| dataset_path=mindrecord_file, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False) | |||||
| dataset_size = dataset.get_dataset_size() | dataset_size = dataset.get_dataset_size() | ||||
| print("dataset size is {}".format(dataset_size)) | print("dataset size is {}".format(dataset_size)) | ||||
| network = Seq2Seq(config, is_training=False) | network = Seq2Seq(config, is_training=False) | ||||
| @@ -40,9 +40,9 @@ fi | |||||
| DATASET_PATH=$(get_real_path $2) | DATASET_PATH=$(get_real_path $2) | ||||
| echo $DATASET_PATH | echo $DATASET_PATH | ||||
| if [ ! -f $DATASET_PATH ] | |||||
| if [ ! -d $DATASET_PATH ] | |||||
| then | then | ||||
| echo "error: DATASET_PATH=$DATASET_PATH is not a file" | |||||
| echo "error: DATASET_PATH=$DATASET_PATH is not a directory" | |||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| @@ -41,9 +41,9 @@ fi | |||||
| DATASET_PATH=$(get_real_path $2) | DATASET_PATH=$(get_real_path $2) | ||||
| echo $DATASET_PATH | echo $DATASET_PATH | ||||
| if [ ! -f $DATASET_PATH ] | |||||
| if [ ! -d $DATASET_PATH ] | |||||
| then | then | ||||
| echo "error: DATASET_PATH=$DATASET_PATH is not a file" | |||||
| echo "error: DATASET_PATH=$DATASET_PATH is not a directory" | |||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| rm -rf ./eval | rm -rf ./eval | ||||
| @@ -33,9 +33,9 @@ get_real_path(){ | |||||
| DATASET_PATH=$(get_real_path $1) | DATASET_PATH=$(get_real_path $1) | ||||
| echo $DATASET_PATH | echo $DATASET_PATH | ||||
| if [ ! -f $DATASET_PATH ] | |||||
| if [ ! -d $DATASET_PATH ] | |||||
| then | then | ||||
| echo "error: DATASET_PATH=$DATASET_PATH is not a file" | |||||
| echo "error: DATASET_PATH=$DATASET_PATH is not a directory" | |||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| @@ -99,8 +99,13 @@ if __name__ == '__main__': | |||||
| else: | else: | ||||
| rank = 0 | rank = 0 | ||||
| device_num = 1 | device_num = 1 | ||||
| prefix = "multi30k_train_mindrecord_32_" | |||||
| mindrecord_file = os.path.join(args.dataset_path, prefix+"0") | |||||
| if not os.path.exists(mindrecord_file): | |||||
| print("dataset file {} not exists, please check!".format(mindrecord_file)) | |||||
| raise ValueError(mindrecord_file) | |||||
| dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.batch_size, | dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.batch_size, | ||||
| dataset_path=args.dataset_path, rank_size=device_num, rank_id=rank) | |||||
| dataset_path=mindrecord_file, rank_size=device_num, rank_id=rank) | |||||
| dataset_size = dataset.get_dataset_size() | dataset_size = dataset.get_dataset_size() | ||||
| print("dataset size is {}".format(dataset_size)) | print("dataset size is {}".format(dataset_size)) | ||||
| network = Seq2Seq(config) | network = Seq2Seq(config) | ||||