From: @Somnus2020 Reviewed-by: @c_34,@oacjiewen Signed-off-by: @c_34pull/15515/MERGE
| @@ -65,25 +65,59 @@ After installing MindSpore via the official website, you can start training and | |||
| sh scripts/run_eval.sh ckpt_path | |||
| ``` | |||
| If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training and evaluation as follows: | |||
| ```python | |||
| # run distributed training on modelarts example | |||
| # (1) First, Perform a or b. | |||
| # a. Set "enable_modelarts=True" on yaml file. | |||
| # Set other parameters on yaml file you need. | |||
| # b. Add "enable_modelarts=True" on the website UI interface. | |||
| # Add other parameters on the website UI interface. | |||
| # (2) Set the code directory to "/path/textcnn" on the website UI interface. | |||
| # (3) Set the startup file to "train.py" on the website UI interface. | |||
| # (4) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface. | |||
| # (5) Create your job. | |||
| # run evaluation on modelarts example | |||
| # (1) Copy or upload your trained model to S3 bucket. | |||
| # (2) Perform a or b. | |||
| # a. Set "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on yaml file. | |||
| # Set "checkpoint_url=/The path of checkpoint in S3/" on yaml file. | |||
| # b. Add "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on the website UI interface. | |||
| # Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface. | |||
| # (3) Set the code directory to "/path/textcnn" on the website UI interface. | |||
| # (4) Set the startup file to "eval.py" on the website UI interface. | |||
| # (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface. | |||
| # (6) Create your job. | |||
| ``` | |||
| # [Script Description](#contents) | |||
| ## [Script and Sample Code](#contents) | |||
| ```bash | |||
| ├── model_zoo | |||
| ├── README.md // descriptions about all the models | |||
| ├── README.md // descriptions about all the models | |||
| ├── textcnn | |||
| ├── README.md // descriptions about textcnn | |||
| ├── README.md // descriptions about textcnn | |||
| ├──scripts | |||
| │ ├── run_train.sh // shell script for distributed on Ascend | |||
| │ ├── run_eval.sh // shell script for evaluation on Ascend | |||
| │ ├── run_train.sh // shell script for distributed on Ascend | |||
| │ ├── run_eval.sh // shell script for evaluation on Ascend | |||
| ├── src | |||
| │ ├── dataset.py // Processing dataset | |||
| │ ├── textcnn.py // textcnn architecture | |||
| │ ├── config.py // parameter configuration | |||
| ├── train.py // training script | |||
| ├── eval.py // evaluation script | |||
| ├── export.py // export checkpoint to other format file | |||
| │ ├── dataset.py // Processing dataset | |||
| │ ├── textcnn.py // textcnn architecture | |||
| ├── utils | |||
| │ ├──device_adapter.py // device adapter | |||
| │ ├──local_adapter.py // local adapter | |||
| │ ├──moxing_adapter.py // moxing adapter | |||
| │ ├── config.py // parameter analysis | |||
| ├── mr_config.yaml // parameter configuration | |||
| ├── sst2_config.yaml // parameter configuration | |||
| ├── subj_config.yaml // parameter configuration | |||
| ├── train.py // training script | |||
| ├── eval.py // evaluation script | |||
| ├── export.py // export checkpoint to other format file | |||
| ``` | |||
| ## [Script Parameters](#contents) | |||
| @@ -16,50 +16,40 @@ | |||
| ##############test textcnn example on movie review################# | |||
| python eval.py | |||
| """ | |||
| import argparse | |||
| import mindspore.nn as nn | |||
| from mindspore.nn.metrics import Accuracy | |||
| from mindspore import context | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.config import cfg_mr, cfg_subj, cfg_sst2 | |||
| from utils.moxing_adapter import moxing_wrapper | |||
| from utils.device_adapter import get_device_id | |||
| from utils.config import config | |||
| from src.textcnn import TextCNN | |||
| from src.dataset import MovieReview, SST2, Subjectivity | |||
| parser = argparse.ArgumentParser(description='TextCNN') | |||
| parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') | |||
| parser.add_argument('--dataset', type=str, default="MR", choices=['MR', 'SUBJ', 'SST2']) | |||
| args_opt = parser.parse_args() | |||
| if __name__ == '__main__': | |||
| if args_opt.dataset == 'MR': | |||
| cfg = cfg_mr | |||
| instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) | |||
| elif args_opt.dataset == 'SUBJ': | |||
| cfg = cfg_subj | |||
| @moxing_wrapper() | |||
| def eval_net(): | |||
| '''eval net''' | |||
| if config.dataset == 'MR': | |||
| instance = MovieReview(root_dir=config.data_path, maxlen=config.word_len, split=0.9) | |||
| elif config.dataset == 'SUBJ': | |||
| instance = Subjectivity(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) | |||
| elif args_opt.dataset == 'SST2': | |||
| cfg = cfg_sst2 | |||
| instance = SST2(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) | |||
| device_target = cfg.device_target | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) | |||
| elif config.dataset == 'SST2': | |||
| instance = SST2(root_dir=config.data_path, maxlen=config.word_len, split=0.9) | |||
| device_target = config.device_target | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) | |||
| if device_target == "Ascend": | |||
| context.set_context(device_id=cfg.device_id) | |||
| dataset = instance.create_test_dataset(batch_size=cfg.batch_size) | |||
| context.set_context(device_id=get_device_id()) | |||
| dataset = instance.create_test_dataset(batch_size=config.batch_size) | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | |||
| net = TextCNN(vocab_len=instance.get_dict_len(), word_len=cfg.word_len, | |||
| num_classes=cfg.num_classes, vec_length=cfg.vec_length) | |||
| net = TextCNN(vocab_len=instance.get_dict_len(), word_len=config.word_len, | |||
| num_classes=config.num_classes, vec_length=config.vec_length) | |||
| opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate=0.001, | |||
| weight_decay=cfg.weight_decay) | |||
| weight_decay=float(config.weight_decay)) | |||
| if args_opt.checkpoint_path is not None: | |||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | |||
| print("load checkpoint from [{}].".format(args_opt.checkpoint_path)) | |||
| else: | |||
| param_dict = load_checkpoint(cfg.checkpoint_path) | |||
| print("load checkpoint from [{}].".format(cfg.checkpoint_path)) | |||
| param_dict = load_checkpoint(config.checkpoint_file_path) | |||
| print("load checkpoint from [{}].".format(config.checkpoint_file_path)) | |||
| load_param_into_net(net, param_dict) | |||
| net.set_train(False) | |||
| @@ -67,3 +57,6 @@ if __name__ == '__main__': | |||
| acc = model.eval(dataset) | |||
| print("accuracy: ", acc) | |||
| if __name__ == '__main__': | |||
| eval_net() | |||
| @@ -16,50 +16,34 @@ | |||
| ##############export checkpoint file into air, onnx, mindir models################# | |||
| python export.py | |||
| """ | |||
| import argparse | |||
| import numpy as np | |||
| from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context | |||
| from src.config import cfg_mr, cfg_subj, cfg_sst2 | |||
| from utils.config import config | |||
| from src.textcnn import TextCNN | |||
| from src.dataset import MovieReview, SST2, Subjectivity | |||
| parser = argparse.ArgumentParser(description='TextCNN export') | |||
| parser.add_argument("--device_id", type=int, default=0, help="device id") | |||
| parser.add_argument("--ckpt_file", type=str, required=True, help="checkpoint file path.") | |||
| parser.add_argument("--file_name", type=str, default="textcnn", help="output file name.") | |||
| parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') | |||
| parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend", | |||
| help="device target") | |||
| parser.add_argument('--dataset', type=str, default='MR', choices=['MR', 'SUBJ', 'SST2'], | |||
| help='dataset name.') | |||
| args = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| if args.device_target == "Ascend": | |||
| context.set_context(device_id=args.device_id) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) | |||
| if config.device_target == "Ascend": | |||
| context.set_context(device_id=config.device_id) | |||
| if __name__ == '__main__': | |||
| if args.dataset == 'MR': | |||
| cfg = cfg_mr | |||
| instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) | |||
| elif args.dataset == 'SUBJ': | |||
| cfg = cfg_subj | |||
| instance = Subjectivity(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) | |||
| elif args.dataset == 'SST2': | |||
| cfg = cfg_sst2 | |||
| instance = SST2(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) | |||
| if config.dataset == 'MR': | |||
| instance = MovieReview(root_dir=config.data_path, maxlen=config.word_len, split=0.9) | |||
| elif config.dataset == 'SUBJ': | |||
| instance = Subjectivity(root_dir=config.data_path, maxlen=config.word_len, split=0.9) | |||
| elif config.dataset == 'SST2': | |||
| instance = SST2(root_dir=config.data_path, maxlen=config.word_len, split=0.9) | |||
| else: | |||
| raise ValueError("dataset is not support.") | |||
| net = TextCNN(vocab_len=instance.get_dict_len(), word_len=cfg.word_len, | |||
| num_classes=cfg.num_classes, vec_length=cfg.vec_length) | |||
| net = TextCNN(vocab_len=instance.get_dict_len(), word_len=config.word_len, | |||
| num_classes=config.num_classes, vec_length=config.vec_length) | |||
| param_dict = load_checkpoint(args.ckpt_file) | |||
| param_dict = load_checkpoint(config.ckpt_file) | |||
| load_param_into_net(net, param_dict) | |||
| input_arr = Tensor(np.ones([cfg.batch_size, cfg.word_len], np.int32)) | |||
| export(net, input_arr, file_name=args.file_name, file_format=args.file_format) | |||
| input_arr = Tensor(np.ones([config.batch_size, config.word_len], np.int32)) | |||
| export(net, input_arr, file_name=config.file_name, file_format=config.file_format) | |||
| @@ -0,0 +1,57 @@ | |||
| # Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) | |||
| enable_modelarts: False | |||
| # Url for modelarts | |||
| data_url: "" | |||
| train_url: "" | |||
| checkpoint_url: "" | |||
| # Path for local | |||
| data_path: "/cache/data" | |||
| output_path: "/cache/train" | |||
| load_path: "/cache/checkpoint_path/" | |||
| device_target: 'Ascend' | |||
| enable_profiling: False | |||
| # ============================================================================== | |||
| # Training options | |||
| dataset: 'MR' | |||
| pre_trained: False | |||
| num_classes: 2 | |||
| batch_size: 64 | |||
| epoch_size: 4 | |||
| weight_decay: 3e-5 | |||
| keep_checkpoint_max: 1 | |||
| checkpoint_path: './checkpoint/' | |||
| checkpoint_file_path: 'train_textcnn-4_149.ckpt' | |||
| word_len: 51 | |||
| vec_length: 40 | |||
| base_lr: 1e-3 | |||
| # Export options | |||
| device_id: 0 | |||
| ckpt_file: "" | |||
| file_name: "" | |||
| file_format: "" | |||
| --- | |||
| # Help description for each configuration | |||
| enable_modelarts: 'Whether training on modelarts, default: False' | |||
| data_url: 'Dataset url for obs' | |||
| train_url: 'Training output url for obs' | |||
| checkpoint_url: 'The location of checkpoint for obs' | |||
| data_path: 'Dataset path for local' | |||
| output_path: 'Training output path for local' | |||
| load_path: 'The location of checkpoint for obs' | |||
| device_target: 'Target device type, available: [Ascend, GPU, CPU]' | |||
| enable_profiling: 'Whether enable profiling while training, default: False' | |||
| dataset: "Dataset to be trained and evaluated, choice: ['MR, SUBJ, SST2']" | |||
| train_epochs: "The number of epochs used to train." | |||
| pre_trained: 'If need load pre_trained checkpoint, default: False' | |||
| num_classes: 'Class for dataset' | |||
| batch_size: "Batch size for training and evaluation" | |||
| epoch_size: "Total training epochs." | |||
| weight_decay: "Weight decay." | |||
| keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint" | |||
| num_factors: "The Embedding size of MF model." | |||
| checkpoint_path: "The location of the checkpoint file." | |||
| eval_file_name: "Eval output file." | |||
| checkpoint_file_path: "The location of the checkpoint file." | |||
| @@ -13,14 +13,21 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| BASE_PATH=$(cd ./ "`dirname $0`" || exit; pwd) | |||
| dataset_type='MR' | |||
| CONFIG_FILE="${BASE_PATH}/mr_config.yaml" | |||
| if [ $# == 2 ] | |||
| then | |||
| if [ $2 != "MR" ] && [ $2 != "SUBJ" ] && [ $2 != "SST2" ] | |||
| then | |||
| if [ $2 == "MR" ]; then | |||
| CONFIG_FILE="${BASE_PATH}/mr_config.yaml" | |||
| elif [ $2 == "SUBJ" ]; then | |||
| CONFIG_FILE="${BASE_PATH}/subj_config.yaml" | |||
| elif [ $2 == "SST2" ]; then | |||
| CONFIG_FILE="${BASE_PATH}/sst2_config.yaml" | |||
| else | |||
| echo "error: the selected dataset is not in supported set{MR, SUBJ, SST2}" | |||
| exit 1 | |||
| fi | |||
| dataset_type=$2 | |||
| fi | |||
| python eval.py --checkpoint_path="$1" --dataset=$dataset_type > eval.log 2>&1 & | |||
| python eval.py --checkpoint_file_path="$1" --dataset=$dataset_type --config_path=$CONFIG_FILE > eval.log 2>&1 & | |||
| @@ -13,15 +13,22 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| BASE_PATH=$(cd ./ "`dirname $0`" || exit; pwd) | |||
| dataset_type='MR' | |||
| CONFIG_FILE="${BASE_PATH}/mr_config.yaml" | |||
| if [ $# == 1 ] | |||
| then | |||
| if [ $1 != "MR" ] && [ $1 != "SUBJ" ] && [ $1 != "SST2" ] | |||
| then | |||
| if [ $1 == "MR" ]; then | |||
| CONFIG_FILE="${BASE_PATH}/mr_config.yaml" | |||
| elif [ $1 == "SUBJ" ]; then | |||
| CONFIG_FILE="${BASE_PATH}/subj_config.yaml" | |||
| elif [ $1 == "SST2" ]; then | |||
| CONFIG_FILE="${BASE_PATH}/sst2_config.yaml" | |||
| else | |||
| echo "error: the selected dataset is not in supported set{MR, SUBJ, SST2}" | |||
| exit 1 | |||
| fi | |||
| dataset_type=$1 | |||
| fi | |||
| rm ./ckpt_0 -rf | |||
| python train.py --dataset=$dataset_type > train.log 2>&1 & | |||
| python train.py --dataset=$dataset_type --config_path=$CONFIG_FILE --output_path './output' > train.log 2>&1 & | |||
| @@ -1,69 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| network config setting, will be used in main.py | |||
| """ | |||
| from easydict import EasyDict as edict | |||
| cfg_mr = edict({ | |||
| 'name': 'movie review', | |||
| 'pre_trained': False, | |||
| 'num_classes': 2, | |||
| 'batch_size': 64, | |||
| 'epoch_size': 4, | |||
| 'weight_decay': 3e-5, | |||
| 'data_path': './data/', | |||
| 'device_target': 'Ascend', | |||
| 'device_id': 7, | |||
| 'keep_checkpoint_max': 1, | |||
| 'checkpoint_path': './ckpt/train_textcnn-4_149.ckpt', | |||
| 'word_len': 51, | |||
| 'vec_length': 40, | |||
| 'base_lr': 1e-3 | |||
| }) | |||
| cfg_subj = edict({ | |||
| 'name': 'subjectivity', | |||
| 'pre_trained': False, | |||
| 'num_classes': 2, | |||
| 'batch_size': 64, | |||
| 'epoch_size': 5, | |||
| 'weight_decay': 3e-5, | |||
| 'data_path': './Subj/', | |||
| 'device_target': 'Ascend', | |||
| 'device_id': 7, | |||
| 'keep_checkpoint_max': 1, | |||
| 'checkpoint_path': './ckpt/train_textcnn-4_149.ckpt', | |||
| 'word_len': 51, | |||
| 'vec_length': 40, | |||
| 'base_lr': 8e-4 | |||
| }) | |||
| cfg_sst2 = edict({ | |||
| 'name': 'SST2', | |||
| 'pre_trained': False, | |||
| 'num_classes': 2, | |||
| 'batch_size': 64, | |||
| 'epoch_size': 4, | |||
| 'weight_decay': 3e-5, | |||
| 'data_path': './SST-2/', | |||
| 'device_target': 'Ascend', | |||
| 'device_id': 7, | |||
| 'keep_checkpoint_max': 1, | |||
| 'checkpoint_path': './ckpt/train_textcnn-4_149.ckpt', | |||
| 'word_len': 51, | |||
| 'vec_length': 40, | |||
| 'base_lr': 5e-3 | |||
| }) | |||
| @@ -0,0 +1,57 @@ | |||
| # Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) | |||
| enable_modelarts: False | |||
| # Url for modelarts | |||
| data_url: "" | |||
| train_url: "" | |||
| checkpoint_url: "" | |||
| # Path for local | |||
| data_path: "/cache/data" | |||
| output_path: "/cache/train" | |||
| load_path: "/cache/checkpoint_path/" | |||
| device_target: 'Ascend' | |||
| enable_profiling: False | |||
| # ============================================================================== | |||
| # Training options | |||
| dataset: 'SST2' | |||
| pre_trained: False | |||
| num_classes: 2 | |||
| batch_size: 64 | |||
| epoch_size: 4 | |||
| weight_decay: 3e-5 | |||
| keep_checkpoint_max: 1 | |||
| checkpoint_path: './checkpoint/' | |||
| checkpoint_file_path: 'train_textcnn-4_149.ckpt' | |||
| word_len: 51 | |||
| vec_length: 40 | |||
| base_lr: 1e-3 | |||
| # Export options | |||
| device_id: 0 | |||
| ckpt_file: "" | |||
| file_name: "" | |||
| file_format: "" | |||
| --- | |||
| # Help description for each configuration | |||
| enable_modelarts: 'Whether training on modelarts, default: False' | |||
| data_url: 'Dataset url for obs' | |||
| train_url: 'Training output url for obs' | |||
| checkpoint_url: 'The location of checkpoint for obs' | |||
| data_path: 'Dataset path for local' | |||
| output_path: 'Training output path for local' | |||
| load_path: 'The location of checkpoint for obs' | |||
| device_target: 'Target device type, available: [Ascend, GPU, CPU]' | |||
| enable_profiling: 'Whether enable profiling while training, default: False' | |||
| dataset: "Dataset to be trained and evaluated, choice: ['MR, SUBJ, SST2']" | |||
| train_epochs: "The number of epochs used to train." | |||
| pre_trained: 'If need load pre_trained checkpoint, default: False' | |||
| num_classes: 'Class for dataset' | |||
| batch_size: "Batch size for training and evaluation" | |||
| epoch_size: "Total training epochs." | |||
| weight_decay: "Weight decay." | |||
| keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint" | |||
| num_factors: "The Embedding size of MF model." | |||
| checkpoint_path: "The location of the checkpoint file." | |||
| eval_file_name: "Eval output file." | |||
| checkpoint_file_path: "The location of the checkpoint file." | |||
| @@ -0,0 +1,58 @@ | |||
| # Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) | |||
| enable_modelarts: False | |||
| # Url for modelarts | |||
| data_url: "" | |||
| train_url: "" | |||
| checkpoint_url: "" | |||
| # Path for local | |||
| data_path: "/cache/data" | |||
| output_path: "/cache/train" | |||
| load_path: "/cache/checkpoint_path/" | |||
| device_target: 'Ascend' | |||
| enable_profiling: False | |||
| # ============================================================================== | |||
| # Training options | |||
| dataset: 'SUBJ' | |||
| pre_trained: False | |||
| num_classes: 2 | |||
| batch_size: 64 | |||
| epoch_size: 4 | |||
| weight_decay: 3e-5 | |||
| keep_checkpoint_max: 1 | |||
| checkpoint_path: './checkpoint/' | |||
| checkpoint_file_path: 'train_textcnn-4_149.ckpt' | |||
| word_len: 51 | |||
| vec_length: 40 | |||
| base_lr: 1e-3 | |||
| # Export options | |||
| device_id: 0 | |||
| ckpt_file: "" | |||
| file_name: "" | |||
| file_format: "" | |||
| --- | |||
| # Help description for each configuration | |||
| enable_modelarts: 'Whether training on modelarts, default: False' | |||
| data_url: 'Dataset url for obs' | |||
| train_url: 'Training output url for obs' | |||
| checkpoint_url: 'The location of checkpoint for obs' | |||
| data_path: 'Dataset path for local' | |||
| output_path: 'Training output path for local' | |||
| load_path: 'The location of checkpoint for obs' | |||
| device_target: 'Target device type, available: [Ascend, GPU, CPU]' | |||
| enable_profiling: 'Whether enable profiling while training, default: False' | |||
| dataset: "Dataset to be trained and evaluated, choice: ['MR, SUBJ, SST2']" | |||
| train_epochs: "The number of epochs used to train." | |||
| pre_trained: 'If need load pre_trained checkpoint, default: False' | |||
| num_classes: 'Class for dataset' | |||
| batch_size: "Batch size for training and evaluation" | |||
| epoch_size: "Total training epochs." | |||
| weight_decay: "Weight decay." | |||
| keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint" | |||
| num_factors: "The Embedding size of MF model." | |||
| checkpoint_path: "The location of the checkpoint file." | |||
| eval_file_name: "Eval output file." | |||
| checkpoint_file_path: "The location of the checkpoint file." | |||
| @@ -16,7 +16,7 @@ | |||
| #################train textcnn example on movie review######################## | |||
| python train.py | |||
| """ | |||
| import argparse | |||
| import os | |||
| import math | |||
| import mindspore.nn as nn | |||
| @@ -26,62 +26,62 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMoni | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.config import cfg_mr, cfg_subj, cfg_sst2 | |||
| from utils.moxing_adapter import moxing_wrapper | |||
| from utils.device_adapter import get_device_id, get_rank_id | |||
| from utils.config import config | |||
| from src.textcnn import TextCNN | |||
| from src.textcnn import SoftmaxCrossEntropyExpand | |||
| from src.dataset import MovieReview, SST2, Subjectivity | |||
| parser = argparse.ArgumentParser(description='TextCNN') | |||
| parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], | |||
| help='device where the code will be implemented (default: Ascend)') | |||
| parser.add_argument('--device_id', type=int, default=5, help='device id of GPU or Ascend.') | |||
| parser.add_argument('--dataset', type=str, default="MR", choices=['MR', 'SUBJ', 'SST2']) | |||
| args_opt = parser.parse_args() | |||
| def modelarts_pre_process(): | |||
| config.checkpoint_path = os.path.join(config.output_path, str(get_rank_id()), config.checkpoint_path) | |||
| if __name__ == '__main__': | |||
| rank = 0 | |||
| @moxing_wrapper(pre_process=modelarts_pre_process) | |||
| def train_net(): | |||
| '''train net''' | |||
| # set context | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | |||
| context.set_context(device_id=args_opt.device_id) | |||
| if args_opt.dataset == 'MR': | |||
| cfg = cfg_mr | |||
| instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) | |||
| elif args_opt.dataset == 'SUBJ': | |||
| cfg = cfg_subj | |||
| instance = Subjectivity(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) | |||
| elif args_opt.dataset == 'SST2': | |||
| cfg = cfg_sst2 | |||
| instance = SST2(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) | |||
| context.set_context(device_id=get_device_id()) | |||
| if config.dataset == 'MR': | |||
| instance = MovieReview(root_dir=config.data_path, maxlen=config.word_len, split=0.9) | |||
| elif config.dataset == 'SUBJ': | |||
| instance = Subjectivity(root_dir=config.data_path, maxlen=config.word_len, split=0.9) | |||
| elif config.dataset == 'SST2': | |||
| instance = SST2(root_dir=config.data_path, maxlen=config.word_len, split=0.9) | |||
| dataset = instance.create_train_dataset(batch_size=cfg.batch_size, epoch_size=cfg.epoch_size) | |||
| dataset = instance.create_train_dataset(batch_size=config.batch_size, epoch_size=config.epoch_size) | |||
| batch_num = dataset.get_dataset_size() | |||
| base_lr = cfg.base_lr | |||
| base_lr = float(config.base_lr) | |||
| learning_rate = [] | |||
| warm_up = [base_lr / math.floor(cfg.epoch_size / 5) * (i + 1) for _ in range(batch_num) for i in | |||
| range(math.floor(cfg.epoch_size / 5))] | |||
| shrink = [base_lr / (16 * (i + 1)) for _ in range(batch_num) for i in range(math.floor(cfg.epoch_size * 3 / 5))] | |||
| warm_up = [base_lr / math.floor(config.epoch_size / 5) * (i + 1) for _ in range(batch_num) for i in | |||
| range(math.floor(config.epoch_size / 5))] | |||
| shrink = [base_lr / (16 * (i + 1)) for _ in range(batch_num) for i in range(math.floor(config.epoch_size * 3 / 5))] | |||
| normal_run = [base_lr for _ in range(batch_num) for i in | |||
| range(cfg.epoch_size - math.floor(cfg.epoch_size / 5) - math.floor(cfg.epoch_size * 2 / 5))] | |||
| range(config.epoch_size - math.floor(config.epoch_size / 5) - math.floor(config.epoch_size * 2 / 5))] | |||
| learning_rate = learning_rate + warm_up + normal_run + shrink | |||
| net = TextCNN(vocab_len=instance.get_dict_len(), word_len=cfg.word_len, | |||
| num_classes=cfg.num_classes, vec_length=cfg.vec_length) | |||
| net = TextCNN(vocab_len=instance.get_dict_len(), word_len=config.word_len, | |||
| num_classes=config.num_classes, vec_length=config.vec_length) | |||
| # Continue training if set pre_trained to be True | |||
| if cfg.pre_trained: | |||
| param_dict = load_checkpoint(cfg.checkpoint_path) | |||
| if config.pre_trained: | |||
| param_dict = load_checkpoint(config.checkpoint_path) | |||
| load_param_into_net(net, param_dict) | |||
| opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate=learning_rate, weight_decay=cfg.weight_decay) | |||
| opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), \ | |||
| learning_rate=learning_rate, weight_decay=float(config.weight_decay)) | |||
| loss = SoftmaxCrossEntropyExpand(sparse=True) | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc': Accuracy()}) | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=int(cfg.epoch_size*batch_num/2), | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=int(config.epoch_size*batch_num/2), | |||
| keep_checkpoint_max=config.keep_checkpoint_max) | |||
| time_cb = TimeMonitor(data_size=batch_num) | |||
| ckpt_save_dir = "./ckpt_" + str(rank) + "/" | |||
| ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path) | |||
| ckpoint_cb = ModelCheckpoint(prefix="train_textcnn", directory=ckpt_save_dir, config=config_ck) | |||
| loss_cb = LossMonitor() | |||
| model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) | |||
| model.train(config.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) | |||
| print("train success") | |||
| if __name__ == '__main__': | |||
| train_net() | |||
| @@ -0,0 +1,125 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Parse arguments""" | |||
| import os | |||
| import ast | |||
| import argparse | |||
| from pprint import pprint, pformat | |||
| import yaml | |||
| _config_path = "./mr_config.yaml" | |||
| class Config: | |||
| """ | |||
| Configuration namespace. Convert dictionary to members. | |||
| """ | |||
| def __init__(self, cfg_dict): | |||
| for k, v in cfg_dict.items(): | |||
| if isinstance(v, (list, tuple)): | |||
| setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v]) | |||
| else: | |||
| setattr(self, k, Config(v) if isinstance(v, dict) else v) | |||
| def __str__(self): | |||
| return pformat(self.__dict__) | |||
| def __repr__(self): | |||
| return self.__str__() | |||
| def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"): | |||
| """ | |||
| Parse command line arguments to the configuration according to the default yaml. | |||
| Args: | |||
| parser: Parent parser. | |||
| cfg: Base configuration. | |||
| helper: Helper description. | |||
| cfg_path: Path to the default yaml config. | |||
| """ | |||
| parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]", | |||
| parents=[parser]) | |||
| helper = {} if helper is None else helper | |||
| choices = {} if choices is None else choices | |||
| for item in cfg: | |||
| if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict): | |||
| help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path) | |||
| choice = choices[item] if item in choices else None | |||
| if isinstance(cfg[item], bool): | |||
| parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice, | |||
| help=help_description) | |||
| else: | |||
| parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice, | |||
| help=help_description) | |||
| args = parser.parse_args() | |||
| return args | |||
| def parse_yaml(yaml_path): | |||
| """ | |||
| Parse the yaml config file. | |||
| Args: | |||
| yaml_path: Path to the yaml config. | |||
| """ | |||
| with open(yaml_path, 'r') as fin: | |||
| try: | |||
| cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader) | |||
| cfgs = [x for x in cfgs] | |||
| if len(cfgs) == 1: | |||
| cfg_helper = {} | |||
| cfg = cfgs[0] | |||
| elif len(cfgs) == 2: | |||
| cfg, cfg_helper = cfgs | |||
| else: | |||
| raise ValueError("At most 2 docs (config and help description for help) are supported in config yaml") | |||
| print(cfg_helper) | |||
| except: | |||
| raise ValueError("Failed to parse yaml") | |||
| return cfg, cfg_helper | |||
| def merge(args, cfg): | |||
| """ | |||
| Merge the base config from yaml file and command line arguments. | |||
| Args: | |||
| args: Command line arguments. | |||
| cfg: Base configuration. | |||
| """ | |||
| args_var = vars(args) | |||
| for item in args_var: | |||
| cfg[item] = args_var[item] | |||
| return cfg | |||
| def get_config(): | |||
| """ | |||
| Get Config according to the yaml file and cli arguments. | |||
| """ | |||
| parser = argparse.ArgumentParser(description="default name", add_help=False) | |||
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |||
| parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../mr_config.yaml"), | |||
| help="Config file path") | |||
| path_args, _ = parser.parse_known_args() | |||
| default, helper = parse_yaml(path_args.config_path) | |||
| pprint(default) | |||
| args = parse_cli_to_yaml(parser, default, helper, path_args.config_path) | |||
| final_config = merge(args, default) | |||
| return Config(final_config) | |||
| config = get_config() | |||
| @@ -0,0 +1,27 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Device adapter for ModelArts""" | |||
| from utils.config import config | |||
| if config.enable_modelarts: | |||
| from utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id | |||
| else: | |||
| from utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id | |||
| __all__ = [ | |||
| "get_device_id", "get_device_num", "get_rank_id", "get_job_id" | |||
| ] | |||
| @@ -0,0 +1,36 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Local adapter""" | |||
| import os | |||
| def get_device_id(): | |||
| device_id = os.getenv('DEVICE_ID', '0') | |||
| return int(device_id) | |||
| def get_device_num(): | |||
| device_num = os.getenv('RANK_SIZE', '1') | |||
| return int(device_num) | |||
| def get_rank_id(): | |||
| global_rank_id = os.getenv('RANK_ID', '0') | |||
| return int(global_rank_id) | |||
| def get_job_id(): | |||
| return "Local Job" | |||
| @@ -0,0 +1,115 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Moxing adapter for ModelArts""" | |||
| import os | |||
| import functools | |||
| from mindspore import context | |||
| from utils.config import config | |||
| _global_sync_count = 0 | |||
| def get_device_id(): | |||
| device_id = os.getenv('DEVICE_ID', '0') | |||
| return int(device_id) | |||
| def get_device_num(): | |||
| device_num = os.getenv('RANK_SIZE', '1') | |||
| return int(device_num) | |||
| def get_rank_id(): | |||
| global_rank_id = os.getenv('RANK_ID', '0') | |||
| return int(global_rank_id) | |||
| def get_job_id(): | |||
| job_id = os.getenv('JOB_ID') | |||
| job_id = job_id if job_id != "" else "default" | |||
| return job_id | |||
| def sync_data(from_path, to_path): | |||
| """ | |||
| Download data from remote obs to local directory if the first url is remote url and the second one is local path | |||
| Upload data from local directory to remote obs in contrast. | |||
| """ | |||
| import moxing as mox | |||
| import time | |||
| global _global_sync_count | |||
| sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count) | |||
| _global_sync_count += 1 | |||
| # Each server contains 8 devices as most. | |||
| if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): | |||
| print("from path: ", from_path) | |||
| print("to path: ", to_path) | |||
| mox.file.copy_parallel(from_path, to_path) | |||
| print("===finish data synchronization===") | |||
| try: | |||
| os.mknod(sync_lock) | |||
| except IOError: | |||
| pass | |||
| print("===save flag===") | |||
| while True: | |||
| if os.path.exists(sync_lock): | |||
| break | |||
| time.sleep(1) | |||
| print("Finish sync data from {} to {}.".format(from_path, to_path)) | |||
| def moxing_wrapper(pre_process=None, post_process=None): | |||
| """ | |||
| Moxing wrapper to download dataset and upload outputs. | |||
| """ | |||
| def wrapper(run_func): | |||
| @functools.wraps(run_func) | |||
| def wrapped_func(*args, **kwargs): | |||
| # Download data from data_url | |||
| if config.enable_modelarts: | |||
| if config.data_url: | |||
| sync_data(config.data_url, config.data_path) | |||
| print("Dataset downloaded: ", os.listdir(config.data_path)) | |||
| if config.checkpoint_url: | |||
| sync_data(config.checkpoint_url, config.load_path) | |||
| print("Preload downloaded: ", os.listdir(config.load_path)) | |||
| if config.train_url: | |||
| sync_data(config.train_url, config.output_path) | |||
| print("Workspace downloaded: ", os.listdir(config.output_path)) | |||
| context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id()))) | |||
| config.device_num = get_device_num() | |||
| config.device_id = get_device_id() | |||
| if not os.path.exists(config.output_path): | |||
| os.makedirs(config.output_path) | |||
| if pre_process: | |||
| pre_process() | |||
| run_func(*args, **kwargs) | |||
| # Upload data to train_url | |||
| if config.enable_modelarts: | |||
| if post_process: | |||
| post_process() | |||
| if config.train_url: | |||
| print("Start to copy output directory") | |||
| sync_data(config.output_path, config.train_url) | |||
| return wrapped_func | |||
| return wrapper | |||