Merge pull request !18159 from zhanghuiyao/gru_clouldtags/v1.3.0
| @@ -1,53 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| network config setting | |||
| """ | |||
| import argparse | |||
| def parser_args(): | |||
| """Config for BGCF""" | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument("-d", "--dataset", type=str, default="Beauty", help="choose which dataset") | |||
| parser.add_argument("-dpath", "--datapath", type=str, default="./scripts/data_mr", help="minddata path") | |||
| parser.add_argument("-de", "--device", type=str, default='0', help="device id") | |||
| parser.add_argument('--Ks', type=list, default=[5, 10, 20, 100], help="top K") | |||
| parser.add_argument('-w', '--workers', type=int, default=8, help="number of process to generate data") | |||
| parser.add_argument("-ckpt", "--ckptpath", type=str, default="./ckpts", help="checkpoint path") | |||
| parser.add_argument("-eps", "--epsilon", type=float, default=1e-8, help="optimizer parameter") | |||
| parser.add_argument("-lr", "--learning_rate", type=float, default=1e-3, help="learning rate") | |||
| parser.add_argument("-l2", "--l2", type=float, default=0.03, help="l2 coefficient") | |||
| parser.add_argument("-act", "--activation", type=str, default='tanh', choices=['relu', 'tanh'], | |||
| help="activation function") | |||
| parser.add_argument("-ndrop", "--neighbor_dropout", type=list, default=[0.0, 0.2, 0.3], | |||
| help="dropout ratio for different aggregation layer") | |||
| parser.add_argument("-log", "--log_name", type=str, default='test', help="log name") | |||
| parser.add_argument("-e", "--num_epoch", type=int, default=600, help="epoch sizes for training") | |||
| parser.add_argument('-input', '--input_dim', type=int, default=64, choices=[64, 128], | |||
| help="user and item embedding dimension") | |||
| parser.add_argument("-b", "--batch_pairs", type=int, default=5000, help="batch size") | |||
| parser.add_argument('--eval_interval', type=int, default=20, help="evaluation interval") | |||
| parser.add_argument("-neg", "--num_neg", type=int, default=10, help="negative sampling rate ") | |||
| parser.add_argument("-g1", "--raw_neighs", type=int, default=40, help="num of sampling neighbors in raw graph") | |||
| parser.add_argument("-g2", "--gnew_neighs", type=int, default=20, help="num of sampling neighbors in sample graph") | |||
| parser.add_argument("-emb", "--embedded_dimension", type=int, default=64, help="output embedding dim") | |||
| parser.add_argument('--dist_reg', type=float, default=0.003, help="distance loss coefficient") | |||
| parser.add_argument('--device_target', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='device target') | |||
| return parser.parse_args() | |||
| @@ -76,19 +76,73 @@ nltk.download() | |||
| # [Quick Start](#content) | |||
| After dataset preparation, you can start training and evaluation as follows: | |||
| - Running on local with Ascend | |||
| ```bash | |||
| # run training example | |||
| cd ./scripts | |||
| sh run_standalone_train.sh [TRAIN_DATASET_PATH] | |||
| After dataset preparation, you can start training and evaluation as follows: | |||
| # run distributed training example | |||
| sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TRAIN_DATASET_PATH] | |||
| ```bash | |||
| # run training example | |||
| cd ./scripts | |||
| sh run_standalone_train.sh [TRAIN_DATASET_PATH] | |||
| # run evaluation example | |||
| sh run_eval.sh [CKPT_FILE] [DATASET_PATH] | |||
| ``` | |||
| # run distributed training example | |||
| sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TRAIN_DATASET_PATH] | |||
| # run evaluation example | |||
| sh run_eval.sh [CKPT_FILE] [DATASET_PATH] | |||
| ``` | |||
| - Running on ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows) | |||
| ```python | |||
| # Train 8p on ModelArts | |||
| # (1) Perform a or b. | |||
| # a. Set "enable_modelarts=True" on default_config.yaml file. | |||
| # Set "run_distribute=True" on default_config.yaml file. | |||
| # Set "dataset_path='/cache/data/mindrecord/multi30k_train_mindrecord_32_0'" on default_config.yaml file. | |||
| # Set other parameters on default_config.yaml file you need. | |||
| # b. Add "enable_modelarts=True" on the website UI interface. | |||
| # Add "run_distribute=True" on the website UI interface. | |||
| # Add "dataset_path=/cache/data/mindrecord/multi30k_train_mindrecord_32_0" on the website UI interface. | |||
| # Add other parameters on the website UI interface. | |||
| # (2) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.) | |||
| # (3) Set the code directory to "/path/gru" on the website UI interface. | |||
| # (4) Set the startup file to "train.py" on the website UI interface. | |||
| # (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface. | |||
| # (6) Create your job. | |||
| # | |||
| # Train 1p on ModelArts | |||
| # (1) Perform a or b. | |||
| # a. Set "enable_modelarts=True" on default_config.yaml file. | |||
| # Set "dataset_path='/cache/data/mindrecord/multi30k_train_mindrecord_32_0'" on default_config.yaml file. | |||
| # Set other parameters on default_config.yaml file you need. | |||
| # b. Add "enable_modelarts=True" on the website UI interface. | |||
| # Add "dataset_path=/cache/data/mindrecord/multi30k_train_mindrecord_32_0" on the website UI interface. | |||
| # Add other parameters on the website UI interface. | |||
| # (2) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.) | |||
| # (3) Set the code directory to "/path/gru" on the website UI interface. | |||
| # (4) Set the startup file to "train.py" on the website UI interface. | |||
| # (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface. | |||
| # (6) Create your job. | |||
| # | |||
| # Eval 1p on ModelArts | |||
| # (1) Perform a or b. | |||
| # a. Set "enable_modelarts=True" on default_config.yaml file. | |||
| # Set "ckpt_file='/cache/checkpoint_path/model.ckpt'" on default_config.yaml file. | |||
| # Set "checkpoint_url='s3://dir_to_trained_ckpt/'" on default_config.yaml file. | |||
| # Set "dataset_path='/cache/data/mindrecord/multi30k_train_mindrecord_32_0'" on default_config.yaml file. | |||
| # Set other parameters on default_config.yaml file you need. | |||
| # b. Add "enable_modelarts=True" on the website UI interface. | |||
| # Add "ckpt_file=/cache/checkpoint_path/model.ckpt" on the website UI interface. | |||
| # Add "checkpoint_url=s3://dir_to_trained_ckpt/" on the website UI interface. | |||
| # Add "dataset_path=/cache/data/mindrecord/multi30k_train_mindrecord_32" on the website UI interface. | |||
| # Add other parameters on the website UI interface. | |||
| # (2) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.) | |||
| # (3) Set the code directory to "/path/gru" on the website UI interface. | |||
| # (4) Set the startup file to "eval.py" on the website UI interface. | |||
| # (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface. | |||
| # (6) Create your job. | |||
| ``` | |||
| # [Script Description](#content) | |||
| @@ -97,9 +151,14 @@ The GRU network script and code result are as follows: | |||
| ```text | |||
| ├── gru | |||
| ├── README.md // Introduction of GRU model. | |||
| ├── model_utils | |||
| │ ├──__init__.py // module init file | |||
| │ ├──config.py // Parse arguments | |||
| │ ├──device_adapter.py // Device adapter for ModelArts | |||
| │ ├──local_adapter.py // Local adapter | |||
| │ ├──moxing_adapter.py // Moxing adapter for ModelArts | |||
| ├── src | |||
| | ├──gru.py // gru cell architecture. | |||
| │ ├──config.py // Configuration instance definition. | |||
| │ ├──create_data.py // Dataset preparation. | |||
| │ ├──dataset.py // Dataset loader to feed into model. | |||
| │ ├──gru_for_infer.py // GRU eval model architecture. | |||
| @@ -118,6 +177,10 @@ The GRU network script and code result are as follows: | |||
| │ ├──run_distributed_train.sh // shell script for distributed train on ascend. | |||
| │ ├──run_eval.sh // shell script for standalone eval on ascend. | |||
| │ ├──run_standalone_train.sh // shell script for standalone eval on ascend. | |||
| ├── default_config.yaml // Configurations | |||
| ├── postprocess.py // GRU postprocess script. | |||
| ├── preprocess.py // GRU preprocess script. | |||
| ├── export.py // Export API entry. | |||
| ├── eval.py // Infer API entry. | |||
| ├── requirements.txt // Requirements of third party package. | |||
| ├── train.py // Train API entry. | |||
| @@ -213,22 +276,49 @@ Parameters for both training and evaluation can be set in config.py. All the dat | |||
| sh parse_output.sh target.txt output.txt /path/vocab.en | |||
| ``` | |||
| Extra: We recommend doing this locally, but you can also do it on modelarts by running a python script with the following command "os.system("sh parse_output.sh target.txt output.txt /path/vocab.en")". | |||
| - After parse output, we will get target.txt.forbleu and output.txt.forbleu.To calculate BLEU score, you may use this [perl script](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/multi-bleu.perl) and run following command to get the BLEU score. | |||
| ```bash | |||
| perl multi-bleu.perl target.txt.forbleu < output.txt.forbleu | |||
| ``` | |||
| Extra: We recommend doing this locally, but you can also do it on modelarts by running a python script with the following command "os.system("perl multi-bleu.perl target.txt.forbleu < output.txt.forbleu")". | |||
| Note: The `DATASET_PATH` is path to mindrecord. eg. train: /dataset_path/multi30k_train_mindrecord_0 eval: /dataset_path/multi30k_test_mindrecord | |||
| ## [Export MindIR](#contents) | |||
| ```shell | |||
| python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT] | |||
| ``` | |||
| - Export on local | |||
| ```python | |||
| # The ckpt_file parameter is required, `EXPORT_FORMAT` should be in ["AIR", "MINDIR"] | |||
| python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT] | |||
| ``` | |||
| The ckpt_file parameter is required, | |||
| `EXPORT_FORMAT` should be in ["AIR", "MINDIR"] | |||
| - Export on ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start as follows) | |||
| ```python | |||
| # Eval 1p on ModelArts | |||
| # (1) Perform a or b. | |||
| # a. Set "enable_modelarts=True" on default_config.yaml file. | |||
| # Set "ckpt_file='/cache/checkpoint_path/model.ckpt'" on default_config.yaml file. | |||
| # Set "checkpoint_url='s3://dir_to_trained_ckpt/'" on default_config.yaml file. | |||
| # Set "file_name='./gru'" on default_config.yaml file. | |||
| # Set "file_format='MINDIR'" on default_config.yaml file. | |||
| # Set other parameters on default_config.yaml file you need. | |||
| # b. Add "enable_modelarts=True" on the website UI interface. | |||
| # Add "ckpt_file='/cache/checkpoint_path/model.ckpt'" on the website UI interface. | |||
| # Add "checkpoint_url='s3://dir_to_trained_ckpt/'" on the website UI interface. | |||
| # Add "file_name='./gru'" on the website UI interface. | |||
| # Add "file_format='MINDIR'" on the website UI interface. | |||
| # Add other parameters on the website UI interface. | |||
| # (2) Set the code directory to "/path/gru" on the website UI interface. | |||
| # (3) Set the startup file to "export.py" on the website UI interface. | |||
| # (4) Set the "Output file path" and "Job log path" to your path on the website UI interface. | |||
| # (5) Create your job. | |||
| ``` | |||
| ## [Inference Process](#contents) | |||
| @@ -0,0 +1,83 @@ | |||
| # Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) | |||
| enable_modelarts: False | |||
| # Url for modelarts | |||
| data_url: "" | |||
| train_url: "" | |||
| checkpoint_url: "" | |||
| # Path for local | |||
| data_path: "/cache/data" | |||
| output_path: "/cache/train" | |||
| load_path: "/cache/checkpoint_path" | |||
| device_target: "Ascend" | |||
| need_modelarts_dataset_unzip: False | |||
| modelarts_dataset_unzip_name: "" | |||
| # ============================================================================== | |||
| # options | |||
| batch_size: 16 | |||
| eval_batch_size: 1 | |||
| src_vocab_size: 8154 | |||
| trg_vocab_size: 6113 | |||
| encoder_embedding_size: 256 | |||
| decoder_embedding_size: 256 | |||
| hidden_size: 512 | |||
| max_length: 32 | |||
| num_epochs: 30 | |||
| save_checkpoint: True | |||
| ckpt_epoch: 10 | |||
| target_file: "target.txt" | |||
| output_file: "output.txt" | |||
| keep_checkpoint_max: 30 | |||
| base_lr: 0.001 | |||
| warmup_step: 300 | |||
| momentum: 0.9 | |||
| init_loss_scale_value: 1024 | |||
| scale_factor: 2 | |||
| scale_window: 2000 | |||
| warmup_ratio: 0.333333 | |||
| teacher_force_ratio: 0.5 | |||
| run_distribute: False | |||
| dataset_path: "" | |||
| pre_trained: "" | |||
| ckpt_path: "outputs/" | |||
| outputs_dir: "./" | |||
| ckpt_file: "" | |||
| # export option | |||
| file_name: "gru" | |||
| file_format: "MINDIR" | |||
| # postprocess option | |||
| label_dir: "" | |||
| result_dir: "./result_Files" | |||
| # preprocess option | |||
| device_num: 1 | |||
| result_path: "./preprocess_Result/" | |||
| --- | |||
| # Help description for each configuration | |||
| enable_modelarts: "Whether training on modelarts, default: False" | |||
| data_url: "Url for modelarts" | |||
| train_url: "Url for modelarts" | |||
| data_path: "The location of the input data." | |||
| output_path: "The location of the output file." | |||
| device_target: 'Target device type' | |||
| run_distribute: "Run distribute, default: false." | |||
| dataset_path: "Dataset path" | |||
| pre_trained: "Pretrained file path." | |||
| ckpt_path: "Checkpoint save location. Default: outputs/" | |||
| outputs_dir: "Checkpoint save location. Default: outputs/" | |||
| ckpt_file: "ckpt file path" | |||
| # export option | |||
| file_name: "output file name." | |||
| file_format: "file format. choices in ['AIR', 'MINDIR']" | |||
| # postprocess option | |||
| label_dir: "label data dir" | |||
| result_dir: "infer result Files" | |||
| # preprocess option | |||
| device_num: "Use device nums, default is 1" | |||
| result_path: "result path" | |||
| @@ -14,46 +14,97 @@ | |||
| # ============================================================================ | |||
| """Transformer evaluation script.""" | |||
| import os | |||
| import argparse | |||
| import time | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore import context | |||
| from src.dataset import create_gru_dataset | |||
| from src.seq2seq import Seq2Seq | |||
| from src.gru_for_infer import GRUInferCell | |||
| from src.config import config | |||
| from model_utils.config import config | |||
| from model_utils.moxing_adapter import moxing_wrapper | |||
| from model_utils.device_adapter import get_device_id, get_device_num | |||
| def modelarts_pre_process(): | |||
| '''modelarts pre process function.''' | |||
| def unzip(zip_file, save_dir): | |||
| import zipfile | |||
| s_time = time.time() | |||
| if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)): | |||
| zip_isexist = zipfile.is_zipfile(zip_file) | |||
| if zip_isexist: | |||
| fz = zipfile.ZipFile(zip_file, 'r') | |||
| data_num = len(fz.namelist()) | |||
| print("Extract Start...") | |||
| print("unzip file num: {}".format(data_num)) | |||
| data_print = int(data_num / 100) if data_num > 100 else 1 | |||
| i = 0 | |||
| for file in fz.namelist(): | |||
| if i % data_print == 0: | |||
| print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True) | |||
| i += 1 | |||
| fz.extract(file, save_dir) | |||
| print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60), | |||
| int(int(time.time() - s_time) % 60))) | |||
| print("Extract Done.") | |||
| else: | |||
| print("This is not zip.") | |||
| else: | |||
| print("Zip has been extracted.") | |||
| if config.need_modelarts_dataset_unzip: | |||
| zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip") | |||
| save_dir_1 = os.path.join(config.data_path) | |||
| sync_lock = "/tmp/unzip_sync.lock" | |||
| # Each server contains 8 devices as most. | |||
| if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): | |||
| print("Zip file path: ", zip_file_1) | |||
| print("Unzip file save dir: ", save_dir_1) | |||
| unzip(zip_file_1, save_dir_1) | |||
| print("===Finish extract data synchronization===") | |||
| try: | |||
| os.mknod(sync_lock) | |||
| except IOError: | |||
| pass | |||
| while True: | |||
| if os.path.exists(sync_lock): | |||
| break | |||
| time.sleep(1) | |||
| print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1)) | |||
| config.output_file = os.path.join(config.output_path, config.output_file) | |||
| config.target_file = os.path.join(config.output_path, config.target_file) | |||
| @moxing_wrapper(pre_process=modelarts_pre_process) | |||
| def run_gru_eval(): | |||
| """ | |||
| Transformer evaluation. | |||
| """ | |||
| parser = argparse.ArgumentParser(description='GRU eval') | |||
| parser.add_argument("--device_target", type=str, default="Ascend", | |||
| help="device where the code will be implemented, default is Ascend") | |||
| parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend, default is 0') | |||
| parser.add_argument('--device_num', type=int, default=1, help='Use device nums, default is 1') | |||
| parser.add_argument('--ckpt_file', type=str, default="", help='ckpt file path') | |||
| parser.add_argument("--dataset_path", type=str, default="", | |||
| help="Dataset path, default: f`sns.") | |||
| args = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, \ | |||
| device_id=args.device_id, save_graphs=False) | |||
| mindrecord_file = args.dataset_path | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, reserve_class_name_in_scope=False, | |||
| device_id=get_device_id(), save_graphs=False) | |||
| mindrecord_file = config.dataset_path | |||
| if not os.path.exists(mindrecord_file): | |||
| print("dataset file {} not exists, please check!".format(mindrecord_file)) | |||
| raise ValueError(mindrecord_file) | |||
| dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size, \ | |||
| dataset_path=mindrecord_file, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False) | |||
| dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size, | |||
| dataset_path=mindrecord_file, rank_size=get_device_num(), rank_id=0, | |||
| do_shuffle=False, is_training=False) | |||
| dataset_size = dataset.get_dataset_size() | |||
| print("dataset size is {}".format(dataset_size)) | |||
| network = Seq2Seq(config, is_training=False) | |||
| network = GRUInferCell(network) | |||
| network.set_train(False) | |||
| if args.ckpt_file != "": | |||
| parameter_dict = load_checkpoint(args.ckpt_file) | |||
| if config.ckpt_file != "": | |||
| parameter_dict = load_checkpoint(config.ckpt_file) | |||
| load_param_into_net(network, parameter_dict) | |||
| model = Model(network) | |||
| @@ -13,35 +13,39 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """export script.""" | |||
| import argparse | |||
| import os | |||
| import numpy as np | |||
| from mindspore import context | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||
| from src.seq2seq import Seq2Seq | |||
| from src.gru_for_infer import GRUInferCell | |||
| from src.config import config | |||
| parser = argparse.ArgumentParser(description='export') | |||
| parser.add_argument("--device_target", type=str, default="Ascend", | |||
| help="device where the code will be implemented, default is Ascend") | |||
| parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend, default is 0') | |||
| parser.add_argument('--file_name', type=str, default="gru", help='output file name.') | |||
| parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format.") | |||
| parser.add_argument('--ckpt_file', type=str, required=True, help='ckpt file path') | |||
| args = parser.parse_args() | |||
| from model_utils.config import config | |||
| from model_utils.moxing_adapter import moxing_wrapper | |||
| from model_utils.device_adapter import get_device_id | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, \ | |||
| device_id=args.device_id, save_graphs=False) | |||
| if __name__ == "__main__": | |||
| def modelarts_pre_process(): | |||
| '''modelarts pre process function.''' | |||
| config.file_name = os.path.join(config.output_path, config.file_name) | |||
| @moxing_wrapper(pre_process=modelarts_pre_process) | |||
| def run_export(): | |||
| """run export.""" | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, reserve_class_name_in_scope=False, | |||
| device_id=get_device_id(), save_graphs=False) | |||
| network = Seq2Seq(config, is_training=False) | |||
| network = GRUInferCell(network) | |||
| network.set_train(False) | |||
| if args.ckpt_file != "": | |||
| parameter_dict = load_checkpoint(args.ckpt_file) | |||
| if config.ckpt_file != "": | |||
| parameter_dict = load_checkpoint(config.ckpt_file) | |||
| load_param_into_net(network, parameter_dict) | |||
| source_ids = Tensor(np.random.uniform(0.0, 1e5, size=[config.eval_batch_size, config.max_length]).astype(np.int32)) | |||
| target_ids = Tensor(np.random.uniform(0.0, 1e5, size=[config.eval_batch_size, config.max_length]).astype(np.int32)) | |||
| export(network, source_ids, target_ids, file_name=args.file_name, file_format=args.file_format) | |||
| export(network, source_ids, target_ids, file_name=config.file_name, file_format=config.file_format) | |||
| if __name__ == "__main__": | |||
| run_export() | |||
| @@ -0,0 +1,126 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Parse arguments""" | |||
| import os | |||
| import ast | |||
| import argparse | |||
| from pprint import pformat | |||
| import yaml | |||
| class Config: | |||
| """ | |||
| Configuration namespace. Convert dictionary to members. | |||
| """ | |||
| def __init__(self, cfg_dict): | |||
| for k, v in cfg_dict.items(): | |||
| if isinstance(v, (list, tuple)): | |||
| setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v]) | |||
| else: | |||
| setattr(self, k, Config(v) if isinstance(v, dict) else v) | |||
| def __str__(self): | |||
| return pformat(self.__dict__) | |||
| def __repr__(self): | |||
| return self.__str__() | |||
| def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"): | |||
| """ | |||
| Parse command line arguments to the configuration according to the default yaml. | |||
| Args: | |||
| parser: Parent parser. | |||
| cfg: Base configuration. | |||
| helper: Helper description. | |||
| cfg_path: Path to the default yaml config. | |||
| """ | |||
| parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]", | |||
| parents=[parser]) | |||
| helper = {} if helper is None else helper | |||
| choices = {} if choices is None else choices | |||
| for item in cfg: | |||
| if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict): | |||
| help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path) | |||
| choice = choices[item] if item in choices else None | |||
| if isinstance(cfg[item], bool): | |||
| parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice, | |||
| help=help_description) | |||
| else: | |||
| parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice, | |||
| help=help_description) | |||
| args = parser.parse_args() | |||
| return args | |||
| def parse_yaml(yaml_path): | |||
| """ | |||
| Parse the yaml config file. | |||
| Args: | |||
| yaml_path: Path to the yaml config. | |||
| """ | |||
| with open(yaml_path, 'r') as fin: | |||
| try: | |||
| cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader) | |||
| cfgs = [x for x in cfgs] | |||
| if len(cfgs) == 1: | |||
| cfg_helper = {} | |||
| cfg = cfgs[0] | |||
| cfg_choices = {} | |||
| elif len(cfgs) == 2: | |||
| cfg, cfg_helper = cfgs | |||
| cfg_choices = {} | |||
| elif len(cfgs) == 3: | |||
| cfg, cfg_helper, cfg_choices = cfgs | |||
| else: | |||
| raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml") | |||
| print(cfg_helper) | |||
| except: | |||
| raise ValueError("Failed to parse yaml") | |||
| return cfg, cfg_helper, cfg_choices | |||
| def merge(args, cfg): | |||
| """ | |||
| Merge the base config from yaml file and command line arguments. | |||
| Args: | |||
| args: Command line arguments. | |||
| cfg: Base configuration. | |||
| """ | |||
| args_var = vars(args) | |||
| for item in args_var: | |||
| cfg[item] = args_var[item] | |||
| return cfg | |||
| def get_config(): | |||
| """ | |||
| Get Config according to the yaml file and cli arguments. | |||
| """ | |||
| parser = argparse.ArgumentParser(description="default name", add_help=False) | |||
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |||
| parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"), | |||
| help="Config file path") | |||
| path_args, _ = parser.parse_known_args() | |||
| default, helper, choices = parse_yaml(path_args.config_path) | |||
| args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path) | |||
| final_config = merge(args, default) | |||
| return Config(final_config) | |||
| config = get_config() | |||
| @@ -0,0 +1,27 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Device adapter for ModelArts""" | |||
| from .config import config | |||
| if config.enable_modelarts: | |||
| from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id | |||
| else: | |||
| from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id | |||
| __all__ = [ | |||
| "get_device_id", "get_device_num", "get_rank_id", "get_job_id" | |||
| ] | |||
| @@ -0,0 +1,36 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Local adapter""" | |||
| import os | |||
| def get_device_id(): | |||
| device_id = os.getenv('DEVICE_ID', '0') | |||
| return int(device_id) | |||
| def get_device_num(): | |||
| device_num = os.getenv('RANK_SIZE', '1') | |||
| return int(device_num) | |||
| def get_rank_id(): | |||
| global_rank_id = os.getenv('RANK_ID', '0') | |||
| return int(global_rank_id) | |||
| def get_job_id(): | |||
| return "Local Job" | |||
| @@ -0,0 +1,116 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Moxing adapter for ModelArts""" | |||
| import os | |||
| import functools | |||
| from mindspore import context | |||
| from .config import config | |||
| _global_sync_count = 0 | |||
| def get_device_id(): | |||
| device_id = os.getenv('DEVICE_ID', '0') | |||
| return int(device_id) | |||
| def get_device_num(): | |||
| device_num = os.getenv('RANK_SIZE', '1') | |||
| return int(device_num) | |||
| def get_rank_id(): | |||
| global_rank_id = os.getenv('RANK_ID', '0') | |||
| return int(global_rank_id) | |||
| def get_job_id(): | |||
| job_id = os.getenv('JOB_ID') | |||
| job_id = job_id if job_id != "" else "default" | |||
| return job_id | |||
| def sync_data(from_path, to_path): | |||
| """ | |||
| Download data from remote obs to local directory if the first url is remote url and the second one is local path | |||
| Upload data from local directory to remote obs in contrast. | |||
| """ | |||
| import moxing as mox | |||
| import time | |||
| global _global_sync_count | |||
| sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count) | |||
| _global_sync_count += 1 | |||
| # Each server contains 8 devices as most. | |||
| if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): | |||
| print("from path: ", from_path) | |||
| print("to path: ", to_path) | |||
| mox.file.copy_parallel(from_path, to_path) | |||
| print("===finish data synchronization===") | |||
| try: | |||
| os.mknod(sync_lock) | |||
| except IOError: | |||
| pass | |||
| print("===save flag===") | |||
| while True: | |||
| if os.path.exists(sync_lock): | |||
| break | |||
| time.sleep(1) | |||
| print("Finish sync data from {} to {}.".format(from_path, to_path)) | |||
| def moxing_wrapper(pre_process=None, post_process=None): | |||
| """ | |||
| Moxing wrapper to download dataset and upload outputs. | |||
| """ | |||
| def wrapper(run_func): | |||
| @functools.wraps(run_func) | |||
| def wrapped_func(*args, **kwargs): | |||
| # Download data from data_url | |||
| if config.enable_modelarts: | |||
| if config.data_url: | |||
| sync_data(config.data_url, config.data_path) | |||
| print("Dataset downloaded: ", os.listdir(config.data_path)) | |||
| if config.checkpoint_url: | |||
| sync_data(config.checkpoint_url, config.load_path) | |||
| print("Preload downloaded: ", os.listdir(config.load_path)) | |||
| if config.train_url: | |||
| sync_data(config.train_url, config.output_path) | |||
| print("Workspace downloaded: ", os.listdir(config.output_path)) | |||
| context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id()))) | |||
| config.device_num = get_device_num() | |||
| config.device_id = get_device_id() | |||
| if not os.path.exists(config.output_path): | |||
| os.makedirs(config.output_path) | |||
| if pre_process: | |||
| pre_process() | |||
| # Run the main function | |||
| run_func(*args, **kwargs) | |||
| # Upload data to train_url | |||
| if config.enable_modelarts: | |||
| if post_process: | |||
| post_process() | |||
| if config.train_url: | |||
| print("Start to copy output directory") | |||
| sync_data(config.output_path, config.train_url) | |||
| return wrapped_func | |||
| return wrapper | |||
| @@ -18,24 +18,17 @@ postprocess script. | |||
| ''' | |||
| import os | |||
| import argparse | |||
| import numpy as np | |||
| from src.config import config | |||
| parser = argparse.ArgumentParser(description="postprocess") | |||
| parser.add_argument("--label_dir", type=str, default="", help="label data dir") | |||
| parser.add_argument("--result_dir", type=str, default="./result_Files", help="infer result Files") | |||
| args, _ = parser.parse_known_args() | |||
| from model_utils.config import config | |||
| if __name__ == "__main__": | |||
| file_name = os.listdir(args.label_dir) | |||
| file_name = os.listdir(config.label_dir) | |||
| predictions = [] | |||
| target_sents = [] | |||
| for f in file_name: | |||
| target_ids = np.fromfile(os.path.join(args.label_dir, f), np.int32) | |||
| target_ids = np.fromfile(os.path.join(config.label_dir, f), np.int32) | |||
| target_sents.append(target_ids.reshape(config.eval_batch_size, config.max_length)) | |||
| predicted_ids = np.fromfile(os.path.join(args.result_dir, f.split('.')[0] + '_0.bin'), np.int32) | |||
| predicted_ids = np.fromfile(os.path.join(config.result_dir, f.split('.')[0] + '_0.bin'), np.int32) | |||
| predictions.append(predicted_ids.reshape(config.eval_batch_size, config.max_length - 1)) | |||
| f_output = open(config.output_file, 'w') | |||
| @@ -14,27 +14,19 @@ | |||
| # ============================================================================ | |||
| """GRU preprocess script.""" | |||
| import os | |||
| import argparse | |||
| from src.dataset import create_gru_dataset | |||
| from src.config import config | |||
| parser = argparse.ArgumentParser(description='GRU preprocess') | |||
| parser.add_argument("--dataset_path", type=str, default="", | |||
| help="Dataset path, default: f`sns.") | |||
| parser.add_argument('--device_num', type=int, default=1, help='Use device nums, default is 1') | |||
| parser.add_argument('--result_path', type=str, default='./preprocess_Result/', help='result path') | |||
| args = parser.parse_args() | |||
| from model_utils.config import config | |||
| if __name__ == "__main__": | |||
| mindrecord_file = args.dataset_path | |||
| mindrecord_file = config.dataset_path | |||
| if not os.path.exists(mindrecord_file): | |||
| print("dataset file {} not exists, please check!".format(mindrecord_file)) | |||
| raise ValueError(mindrecord_file) | |||
| dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size, \ | |||
| dataset_path=mindrecord_file, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False) | |||
| dataset_path=mindrecord_file, rank_size=config.device_num, rank_id=0, do_shuffle=False, is_training=False) | |||
| source_ids_path = os.path.join(args.result_path, "00_data") | |||
| target_ids_path = os.path.join(args.result_path, "01_data") | |||
| source_ids_path = os.path.join(config.result_path, "00_data") | |||
| target_ids_path = os.path.join(config.result_path, "01_data") | |||
| os.makedirs(source_ids_path) | |||
| os.makedirs(target_ids_path) | |||
| @@ -58,11 +58,13 @@ do | |||
| rm -rf ./train_parallel$i | |||
| mkdir ./train_parallel$i | |||
| cp ../*.py ./train_parallel$i | |||
| cp ../*.yaml ./train_parallel$i | |||
| cp *.sh ./train_parallel$i | |||
| cp -r ../src ./train_parallel$i | |||
| cp -r ../model_utils ./train_parallel$i | |||
| cd ./train_parallel$i || exit | |||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | |||
| env > env.log | |||
| python train.py --device_id=$i --rank_id=$i --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$DATASET_PATH &> log & | |||
| python train.py --run_distribute=True --dataset_path=$DATASET_PATH &> log & | |||
| cd .. | |||
| done | |||
| @@ -49,8 +49,10 @@ fi | |||
| rm -rf ./eval | |||
| mkdir ./eval | |||
| cp ../*.py ./eval | |||
| cp ../*.yaml ./eval | |||
| cp *.sh ./eval | |||
| cp -r ../src ./eval | |||
| cp -r ../model_utils ./eval | |||
| cd ./eval || exit | |||
| echo "start eval for device $DEVICE_ID" | |||
| env > env.log | |||
| @@ -42,10 +42,12 @@ fi | |||
| rm -rf ./train | |||
| mkdir ./train | |||
| cp ../*.py ./train | |||
| cp ../*.yaml ./train | |||
| cp *.sh ./train | |||
| cp -r ../src ./train | |||
| cp -r ../model_utils ./train | |||
| cd ./train || exit | |||
| echo "start training for device $DEVICE_ID" | |||
| env > env.log | |||
| python train.py --device_id=$DEVICE_ID --dataset_path=$DATASET_PATH &> log & | |||
| python train.py --dataset_path=$DATASET_PATH &> log & | |||
| cd .. | |||
| @@ -1,41 +0,0 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """GRU config""" | |||
| from easydict import EasyDict | |||
| config = EasyDict({ | |||
| "batch_size": 16, | |||
| "eval_batch_size": 1, | |||
| "src_vocab_size": 8154, | |||
| "trg_vocab_size": 6113, | |||
| "encoder_embedding_size": 256, | |||
| "decoder_embedding_size": 256, | |||
| "hidden_size": 512, | |||
| "max_length": 32, | |||
| "num_epochs": 30, | |||
| "save_checkpoint": True, | |||
| "ckpt_epoch": 10, | |||
| "target_file": "target.txt", | |||
| "output_file": "output.txt", | |||
| "keep_checkpoint_max": 30, | |||
| "base_lr": 0.001, | |||
| "warmup_step": 300, | |||
| "momentum": 0.9, | |||
| "init_loss_scale_value": 1024, | |||
| 'scale_factor': 2, | |||
| 'scale_window': 2000, | |||
| "warmup_ratio": 1/3.0, | |||
| "teacher_force_ratio": 0.5 | |||
| }) | |||
| @@ -18,7 +18,7 @@ import numpy as np | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset as de | |||
| import mindspore.dataset.transforms.c_transforms as deC | |||
| from src.config import config | |||
| from model_utils.config import config | |||
| de.config.set_seed(1) | |||
| @@ -18,7 +18,7 @@ from mindspore import Tensor | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.operations as P | |||
| import mindspore.common.dtype as mstype | |||
| from src.config import config | |||
| from model_utils.config import config | |||
| class GRUInferCell(nn.Cell): | |||
| ''' | |||
| @@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||
| from mindspore.communication.management import get_group_size | |||
| from src.config import config | |||
| from model_utils.config import config | |||
| from src.loss import NLLLoss | |||
| class GRUWithLossCell(nn.Cell): | |||
| @@ -15,8 +15,6 @@ | |||
| """train script""" | |||
| import os | |||
| import time | |||
| import argparse | |||
| import ast | |||
| from mindspore.context import ParallelMode | |||
| from mindspore import context | |||
| from mindspore.communication.management import init | |||
| @@ -25,31 +23,25 @@ from mindspore.train import Model | |||
| from mindspore.common import set_seed | |||
| from mindspore.train.loss_scale_manager import DynamicLossScaleManager | |||
| from mindspore.nn.optim import Adam | |||
| from src.config import config | |||
| from src.seq2seq import Seq2Seq | |||
| from src.gru_for_train import GRUWithLossCell, GRUTrainOneStepWithLossScaleCell | |||
| from src.dataset import create_gru_dataset | |||
| from src.lr_schedule import dynamic_lr | |||
| set_seed(1) | |||
| parser = argparse.ArgumentParser(description="GRU training") | |||
| parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.") | |||
| parser.add_argument("--dataset_path", type=str, default=None, help="Dataset path") | |||
| parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained file path.") | |||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.") | |||
| parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default: 1.") | |||
| parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.") | |||
| parser.add_argument('--ckpt_path', type=str, default='outputs/', help='Checkpoint save location. Default: outputs/') | |||
| parser.add_argument('--outputs_dir', type=str, default='./', help='Checkpoint save location. Default: outputs/') | |||
| args = parser.parse_args() | |||
| from model_utils.config import config | |||
| from model_utils.moxing_adapter import moxing_wrapper | |||
| from model_utils.device_adapter import get_rank_id, get_device_id, get_device_num | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id, save_graphs=False) | |||
| set_seed(1) | |||
| def get_ms_timestamp(): | |||
| t = time.time() | |||
| return int(round(t * 1000)) | |||
| time_stamp_init = False | |||
| time_stamp_first = 0 | |||
| class LossCallBack(Callback): | |||
| """ | |||
| Monitor the loss in training. | |||
| @@ -89,17 +81,72 @@ class LossCallBack(Callback): | |||
| str(cb_params.net_outputs[2].asnumpy()))) | |||
| f.write('\n') | |||
| if __name__ == '__main__': | |||
| if args.run_distribute: | |||
| rank = args.rank_id | |||
| device_num = args.device_num | |||
| def modelarts_pre_process(): | |||
| '''modelarts pre process function.''' | |||
| def unzip(zip_file, save_dir): | |||
| import zipfile | |||
| s_time = time.time() | |||
| if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)): | |||
| zip_isexist = zipfile.is_zipfile(zip_file) | |||
| if zip_isexist: | |||
| fz = zipfile.ZipFile(zip_file, 'r') | |||
| data_num = len(fz.namelist()) | |||
| print("Extract Start...") | |||
| print("unzip file num: {}".format(data_num)) | |||
| data_print = int(data_num / 100) if data_num > 100 else 1 | |||
| i = 0 | |||
| for file in fz.namelist(): | |||
| if i % data_print == 0: | |||
| print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True) | |||
| i += 1 | |||
| fz.extract(file, save_dir) | |||
| print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60), | |||
| int(int(time.time() - s_time) % 60))) | |||
| print("Extract Done.") | |||
| else: | |||
| print("This is not zip.") | |||
| else: | |||
| print("Zip has been extracted.") | |||
| if config.need_modelarts_dataset_unzip: | |||
| zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip") | |||
| save_dir_1 = os.path.join(config.data_path) | |||
| sync_lock = "/tmp/unzip_sync.lock" | |||
| # Each server contains 8 devices as most. | |||
| if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): | |||
| print("Zip file path: ", zip_file_1) | |||
| print("Unzip file save dir: ", save_dir_1) | |||
| unzip(zip_file_1, save_dir_1) | |||
| print("===Finish extract data synchronization===") | |||
| try: | |||
| os.mknod(sync_lock) | |||
| except IOError: | |||
| pass | |||
| while True: | |||
| if os.path.exists(sync_lock): | |||
| break | |||
| time.sleep(1) | |||
| print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1)) | |||
| config.outputs_dir = os.path.join(config.output_path, config.outputs_dir) | |||
| @moxing_wrapper(pre_process=modelarts_pre_process) | |||
| def run_train(): | |||
| """run train.""" | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id(), save_graphs=False) | |||
| rank = get_rank_id() | |||
| device_num = get_device_num() | |||
| if config.run_distribute: | |||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | |||
| gradients_mean=True) | |||
| init() | |||
| else: | |||
| rank = 0 | |||
| device_num = 1 | |||
| mindrecord_file = args.dataset_path | |||
| mindrecord_file = config.dataset_path | |||
| if not os.path.exists(mindrecord_file): | |||
| print("dataset file {} not exists, please check!".format(mindrecord_file)) | |||
| raise ValueError(mindrecord_file) | |||
| @@ -120,15 +167,18 @@ if __name__ == '__main__': | |||
| time_cb = TimeMonitor(data_size=dataset_size) | |||
| loss_cb = LossCallBack(rank_id=rank) | |||
| cb = [time_cb, loss_cb] | |||
| #Save Checkpoint | |||
| # Save Checkpoint | |||
| if config.save_checkpoint: | |||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=config.ckpt_epoch*dataset_size, | |||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=config.ckpt_epoch * dataset_size, | |||
| keep_checkpoint_max=config.keep_checkpoint_max) | |||
| save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_'+str(args.rank_id)+'/') | |||
| save_ckpt_path = os.path.join(config.outputs_dir, 'ckpt_' + str(get_rank_id()) + '/') | |||
| ckpt_cb = ModelCheckpoint(config=ckpt_config, | |||
| directory=save_ckpt_path, | |||
| prefix='{}'.format(args.rank_id)) | |||
| prefix='{}'.format(get_rank_id())) | |||
| cb += [ckpt_cb] | |||
| netwithgrads.set_train(True) | |||
| model = Model(netwithgrads) | |||
| model.train(config.num_epochs, dataset, callbacks=cb, dataset_sink_mode=True) | |||
| if __name__ == '__main__': | |||
| run_train() | |||