From: @Somnus2020 Reviewed-by: @c_34,@oacjiewen Signed-off-by: @c_34pull/15515/MERGE
| @@ -65,25 +65,59 @@ After installing MindSpore via the official website, you can start training and | |||||
| sh scripts/run_eval.sh ckpt_path | sh scripts/run_eval.sh ckpt_path | ||||
| ``` | ``` | ||||
| If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training and evaluation as follows: | |||||
| ```python | |||||
| # run distributed training on modelarts example | |||||
| # (1) First, Perform a or b. | |||||
| # a. Set "enable_modelarts=True" on yaml file. | |||||
| # Set other parameters on yaml file you need. | |||||
| # b. Add "enable_modelarts=True" on the website UI interface. | |||||
| # Add other parameters on the website UI interface. | |||||
| # (2) Set the code directory to "/path/textcnn" on the website UI interface. | |||||
| # (3) Set the startup file to "train.py" on the website UI interface. | |||||
| # (4) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface. | |||||
| # (5) Create your job. | |||||
| # run evaluation on modelarts example | |||||
| # (1) Copy or upload your trained model to S3 bucket. | |||||
| # (2) Perform a or b. | |||||
| # a. Set "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on yaml file. | |||||
| # Set "checkpoint_url=/The path of checkpoint in S3/" on yaml file. | |||||
| # b. Add "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on the website UI interface. | |||||
| # Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface. | |||||
| # (3) Set the code directory to "/path/textcnn" on the website UI interface. | |||||
| # (4) Set the startup file to "eval.py" on the website UI interface. | |||||
| # (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface. | |||||
| # (6) Create your job. | |||||
| ``` | |||||
| # [Script Description](#contents) | # [Script Description](#contents) | ||||
| ## [Script and Sample Code](#contents) | ## [Script and Sample Code](#contents) | ||||
| ```bash | ```bash | ||||
| ├── model_zoo | ├── model_zoo | ||||
| ├── README.md // descriptions about all the models | |||||
| ├── README.md // descriptions about all the models | |||||
| ├── textcnn | ├── textcnn | ||||
| ├── README.md // descriptions about textcnn | |||||
| ├── README.md // descriptions about textcnn | |||||
| ├──scripts | ├──scripts | ||||
| │ ├── run_train.sh // shell script for distributed on Ascend | |||||
| │ ├── run_eval.sh // shell script for evaluation on Ascend | |||||
| │ ├── run_train.sh // shell script for distributed on Ascend | |||||
| │ ├── run_eval.sh // shell script for evaluation on Ascend | |||||
| ├── src | ├── src | ||||
| │ ├── dataset.py // Processing dataset | |||||
| │ ├── textcnn.py // textcnn architecture | |||||
| │ ├── config.py // parameter configuration | |||||
| ├── train.py // training script | |||||
| ├── eval.py // evaluation script | |||||
| ├── export.py // export checkpoint to other format file | |||||
| │ ├── dataset.py // Processing dataset | |||||
| │ ├── textcnn.py // textcnn architecture | |||||
| ├── utils | |||||
| │ ├──device_adapter.py // device adapter | |||||
| │ ├──local_adapter.py // local adapter | |||||
| │ ├──moxing_adapter.py // moxing adapter | |||||
| │ ├── config.py // parameter analysis | |||||
| ├── mr_config.yaml // parameter configuration | |||||
| ├── sst2_config.yaml // parameter configuration | |||||
| ├── subj_config.yaml // parameter configuration | |||||
| ├── train.py // training script | |||||
| ├── eval.py // evaluation script | |||||
| ├── export.py // export checkpoint to other format file | |||||
| ``` | ``` | ||||
| ## [Script Parameters](#contents) | ## [Script Parameters](#contents) | ||||
| @@ -16,50 +16,40 @@ | |||||
| ##############test textcnn example on movie review################# | ##############test textcnn example on movie review################# | ||||
| python eval.py | python eval.py | ||||
| """ | """ | ||||
| import argparse | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore.nn.metrics import Accuracy | from mindspore.nn.metrics import Accuracy | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from src.config import cfg_mr, cfg_subj, cfg_sst2 | |||||
| from utils.moxing_adapter import moxing_wrapper | |||||
| from utils.device_adapter import get_device_id | |||||
| from utils.config import config | |||||
| from src.textcnn import TextCNN | from src.textcnn import TextCNN | ||||
| from src.dataset import MovieReview, SST2, Subjectivity | from src.dataset import MovieReview, SST2, Subjectivity | ||||
| parser = argparse.ArgumentParser(description='TextCNN') | |||||
| parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') | |||||
| parser.add_argument('--dataset', type=str, default="MR", choices=['MR', 'SUBJ', 'SST2']) | |||||
| args_opt = parser.parse_args() | |||||
| if __name__ == '__main__': | |||||
| if args_opt.dataset == 'MR': | |||||
| cfg = cfg_mr | |||||
| instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) | |||||
| elif args_opt.dataset == 'SUBJ': | |||||
| cfg = cfg_subj | |||||
| @moxing_wrapper() | |||||
| def eval_net(): | |||||
| '''eval net''' | |||||
| if config.dataset == 'MR': | |||||
| instance = MovieReview(root_dir=config.data_path, maxlen=config.word_len, split=0.9) | |||||
| elif config.dataset == 'SUBJ': | |||||
| instance = Subjectivity(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) | instance = Subjectivity(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) | ||||
| elif args_opt.dataset == 'SST2': | |||||
| cfg = cfg_sst2 | |||||
| instance = SST2(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) | |||||
| device_target = cfg.device_target | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) | |||||
| elif config.dataset == 'SST2': | |||||
| instance = SST2(root_dir=config.data_path, maxlen=config.word_len, split=0.9) | |||||
| device_target = config.device_target | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) | |||||
| if device_target == "Ascend": | if device_target == "Ascend": | ||||
| context.set_context(device_id=cfg.device_id) | |||||
| dataset = instance.create_test_dataset(batch_size=cfg.batch_size) | |||||
| context.set_context(device_id=get_device_id()) | |||||
| dataset = instance.create_test_dataset(batch_size=config.batch_size) | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | ||||
| net = TextCNN(vocab_len=instance.get_dict_len(), word_len=cfg.word_len, | |||||
| num_classes=cfg.num_classes, vec_length=cfg.vec_length) | |||||
| net = TextCNN(vocab_len=instance.get_dict_len(), word_len=config.word_len, | |||||
| num_classes=config.num_classes, vec_length=config.vec_length) | |||||
| opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate=0.001, | opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate=0.001, | ||||
| weight_decay=cfg.weight_decay) | |||||
| weight_decay=float(config.weight_decay)) | |||||
| if args_opt.checkpoint_path is not None: | |||||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | |||||
| print("load checkpoint from [{}].".format(args_opt.checkpoint_path)) | |||||
| else: | |||||
| param_dict = load_checkpoint(cfg.checkpoint_path) | |||||
| print("load checkpoint from [{}].".format(cfg.checkpoint_path)) | |||||
| param_dict = load_checkpoint(config.checkpoint_file_path) | |||||
| print("load checkpoint from [{}].".format(config.checkpoint_file_path)) | |||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| net.set_train(False) | net.set_train(False) | ||||
| @@ -67,3 +57,6 @@ if __name__ == '__main__': | |||||
| acc = model.eval(dataset) | acc = model.eval(dataset) | ||||
| print("accuracy: ", acc) | print("accuracy: ", acc) | ||||
| if __name__ == '__main__': | |||||
| eval_net() | |||||
| @@ -16,50 +16,34 @@ | |||||
| ##############export checkpoint file into air, onnx, mindir models################# | ##############export checkpoint file into air, onnx, mindir models################# | ||||
| python export.py | python export.py | ||||
| """ | """ | ||||
| import argparse | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context | from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context | ||||
| from src.config import cfg_mr, cfg_subj, cfg_sst2 | |||||
| from utils.config import config | |||||
| from src.textcnn import TextCNN | from src.textcnn import TextCNN | ||||
| from src.dataset import MovieReview, SST2, Subjectivity | from src.dataset import MovieReview, SST2, Subjectivity | ||||
| parser = argparse.ArgumentParser(description='TextCNN export') | |||||
| parser.add_argument("--device_id", type=int, default=0, help="device id") | |||||
| parser.add_argument("--ckpt_file", type=str, required=True, help="checkpoint file path.") | |||||
| parser.add_argument("--file_name", type=str, default="textcnn", help="output file name.") | |||||
| parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') | |||||
| parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend", | |||||
| help="device target") | |||||
| parser.add_argument('--dataset', type=str, default='MR', choices=['MR', 'SUBJ', 'SST2'], | |||||
| help='dataset name.') | |||||
| args = parser.parse_args() | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||||
| if args.device_target == "Ascend": | |||||
| context.set_context(device_id=args.device_id) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) | |||||
| if config.device_target == "Ascend": | |||||
| context.set_context(device_id=config.device_id) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| if args.dataset == 'MR': | |||||
| cfg = cfg_mr | |||||
| instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) | |||||
| elif args.dataset == 'SUBJ': | |||||
| cfg = cfg_subj | |||||
| instance = Subjectivity(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) | |||||
| elif args.dataset == 'SST2': | |||||
| cfg = cfg_sst2 | |||||
| instance = SST2(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) | |||||
| if config.dataset == 'MR': | |||||
| instance = MovieReview(root_dir=config.data_path, maxlen=config.word_len, split=0.9) | |||||
| elif config.dataset == 'SUBJ': | |||||
| instance = Subjectivity(root_dir=config.data_path, maxlen=config.word_len, split=0.9) | |||||
| elif config.dataset == 'SST2': | |||||
| instance = SST2(root_dir=config.data_path, maxlen=config.word_len, split=0.9) | |||||
| else: | else: | ||||
| raise ValueError("dataset is not support.") | raise ValueError("dataset is not support.") | ||||
| net = TextCNN(vocab_len=instance.get_dict_len(), word_len=cfg.word_len, | |||||
| num_classes=cfg.num_classes, vec_length=cfg.vec_length) | |||||
| net = TextCNN(vocab_len=instance.get_dict_len(), word_len=config.word_len, | |||||
| num_classes=config.num_classes, vec_length=config.vec_length) | |||||
| param_dict = load_checkpoint(args.ckpt_file) | |||||
| param_dict = load_checkpoint(config.ckpt_file) | |||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| input_arr = Tensor(np.ones([cfg.batch_size, cfg.word_len], np.int32)) | |||||
| export(net, input_arr, file_name=args.file_name, file_format=args.file_format) | |||||
| input_arr = Tensor(np.ones([config.batch_size, config.word_len], np.int32)) | |||||
| export(net, input_arr, file_name=config.file_name, file_format=config.file_format) | |||||
| @@ -0,0 +1,57 @@ | |||||
| # Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) | |||||
| enable_modelarts: False | |||||
| # Url for modelarts | |||||
| data_url: "" | |||||
| train_url: "" | |||||
| checkpoint_url: "" | |||||
| # Path for local | |||||
| data_path: "/cache/data" | |||||
| output_path: "/cache/train" | |||||
| load_path: "/cache/checkpoint_path/" | |||||
| device_target: 'Ascend' | |||||
| enable_profiling: False | |||||
| # ============================================================================== | |||||
| # Training options | |||||
| dataset: 'MR' | |||||
| pre_trained: False | |||||
| num_classes: 2 | |||||
| batch_size: 64 | |||||
| epoch_size: 4 | |||||
| weight_decay: 3e-5 | |||||
| keep_checkpoint_max: 1 | |||||
| checkpoint_path: './checkpoint/' | |||||
| checkpoint_file_path: 'train_textcnn-4_149.ckpt' | |||||
| word_len: 51 | |||||
| vec_length: 40 | |||||
| base_lr: 1e-3 | |||||
| # Export options | |||||
| device_id: 0 | |||||
| ckpt_file: "" | |||||
| file_name: "" | |||||
| file_format: "" | |||||
| --- | |||||
| # Help description for each configuration | |||||
| enable_modelarts: 'Whether training on modelarts, default: False' | |||||
| data_url: 'Dataset url for obs' | |||||
| train_url: 'Training output url for obs' | |||||
| checkpoint_url: 'The location of checkpoint for obs' | |||||
| data_path: 'Dataset path for local' | |||||
| output_path: 'Training output path for local' | |||||
| load_path: 'The location of checkpoint for obs' | |||||
| device_target: 'Target device type, available: [Ascend, GPU, CPU]' | |||||
| enable_profiling: 'Whether enable profiling while training, default: False' | |||||
| dataset: "Dataset to be trained and evaluated, choice: ['MR, SUBJ, SST2']" | |||||
| train_epochs: "The number of epochs used to train." | |||||
| pre_trained: 'If need load pre_trained checkpoint, default: False' | |||||
| num_classes: 'Class for dataset' | |||||
| batch_size: "Batch size for training and evaluation" | |||||
| epoch_size: "Total training epochs." | |||||
| weight_decay: "Weight decay." | |||||
| keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint" | |||||
| num_factors: "The Embedding size of MF model." | |||||
| checkpoint_path: "The location of the checkpoint file." | |||||
| eval_file_name: "Eval output file." | |||||
| checkpoint_file_path: "The location of the checkpoint file." | |||||
| @@ -13,14 +13,21 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| BASE_PATH=$(cd ./ "`dirname $0`" || exit; pwd) | |||||
| dataset_type='MR' | dataset_type='MR' | ||||
| CONFIG_FILE="${BASE_PATH}/mr_config.yaml" | |||||
| if [ $# == 2 ] | if [ $# == 2 ] | ||||
| then | then | ||||
| if [ $2 != "MR" ] && [ $2 != "SUBJ" ] && [ $2 != "SST2" ] | |||||
| then | |||||
| if [ $2 == "MR" ]; then | |||||
| CONFIG_FILE="${BASE_PATH}/mr_config.yaml" | |||||
| elif [ $2 == "SUBJ" ]; then | |||||
| CONFIG_FILE="${BASE_PATH}/subj_config.yaml" | |||||
| elif [ $2 == "SST2" ]; then | |||||
| CONFIG_FILE="${BASE_PATH}/sst2_config.yaml" | |||||
| else | |||||
| echo "error: the selected dataset is not in supported set{MR, SUBJ, SST2}" | echo "error: the selected dataset is not in supported set{MR, SUBJ, SST2}" | ||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| dataset_type=$2 | dataset_type=$2 | ||||
| fi | fi | ||||
| python eval.py --checkpoint_path="$1" --dataset=$dataset_type > eval.log 2>&1 & | |||||
| python eval.py --checkpoint_file_path="$1" --dataset=$dataset_type --config_path=$CONFIG_FILE > eval.log 2>&1 & | |||||
| @@ -13,15 +13,22 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| BASE_PATH=$(cd ./ "`dirname $0`" || exit; pwd) | |||||
| dataset_type='MR' | dataset_type='MR' | ||||
| CONFIG_FILE="${BASE_PATH}/mr_config.yaml" | |||||
| if [ $# == 1 ] | if [ $# == 1 ] | ||||
| then | then | ||||
| if [ $1 != "MR" ] && [ $1 != "SUBJ" ] && [ $1 != "SST2" ] | |||||
| then | |||||
| if [ $1 == "MR" ]; then | |||||
| CONFIG_FILE="${BASE_PATH}/mr_config.yaml" | |||||
| elif [ $1 == "SUBJ" ]; then | |||||
| CONFIG_FILE="${BASE_PATH}/subj_config.yaml" | |||||
| elif [ $1 == "SST2" ]; then | |||||
| CONFIG_FILE="${BASE_PATH}/sst2_config.yaml" | |||||
| else | |||||
| echo "error: the selected dataset is not in supported set{MR, SUBJ, SST2}" | echo "error: the selected dataset is not in supported set{MR, SUBJ, SST2}" | ||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| dataset_type=$1 | dataset_type=$1 | ||||
| fi | fi | ||||
| rm ./ckpt_0 -rf | rm ./ckpt_0 -rf | ||||
| python train.py --dataset=$dataset_type > train.log 2>&1 & | |||||
| python train.py --dataset=$dataset_type --config_path=$CONFIG_FILE --output_path './output' > train.log 2>&1 & | |||||
| @@ -1,69 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| network config setting, will be used in main.py | |||||
| """ | |||||
| from easydict import EasyDict as edict | |||||
| cfg_mr = edict({ | |||||
| 'name': 'movie review', | |||||
| 'pre_trained': False, | |||||
| 'num_classes': 2, | |||||
| 'batch_size': 64, | |||||
| 'epoch_size': 4, | |||||
| 'weight_decay': 3e-5, | |||||
| 'data_path': './data/', | |||||
| 'device_target': 'Ascend', | |||||
| 'device_id': 7, | |||||
| 'keep_checkpoint_max': 1, | |||||
| 'checkpoint_path': './ckpt/train_textcnn-4_149.ckpt', | |||||
| 'word_len': 51, | |||||
| 'vec_length': 40, | |||||
| 'base_lr': 1e-3 | |||||
| }) | |||||
| cfg_subj = edict({ | |||||
| 'name': 'subjectivity', | |||||
| 'pre_trained': False, | |||||
| 'num_classes': 2, | |||||
| 'batch_size': 64, | |||||
| 'epoch_size': 5, | |||||
| 'weight_decay': 3e-5, | |||||
| 'data_path': './Subj/', | |||||
| 'device_target': 'Ascend', | |||||
| 'device_id': 7, | |||||
| 'keep_checkpoint_max': 1, | |||||
| 'checkpoint_path': './ckpt/train_textcnn-4_149.ckpt', | |||||
| 'word_len': 51, | |||||
| 'vec_length': 40, | |||||
| 'base_lr': 8e-4 | |||||
| }) | |||||
| cfg_sst2 = edict({ | |||||
| 'name': 'SST2', | |||||
| 'pre_trained': False, | |||||
| 'num_classes': 2, | |||||
| 'batch_size': 64, | |||||
| 'epoch_size': 4, | |||||
| 'weight_decay': 3e-5, | |||||
| 'data_path': './SST-2/', | |||||
| 'device_target': 'Ascend', | |||||
| 'device_id': 7, | |||||
| 'keep_checkpoint_max': 1, | |||||
| 'checkpoint_path': './ckpt/train_textcnn-4_149.ckpt', | |||||
| 'word_len': 51, | |||||
| 'vec_length': 40, | |||||
| 'base_lr': 5e-3 | |||||
| }) | |||||
| @@ -0,0 +1,57 @@ | |||||
| # Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) | |||||
| enable_modelarts: False | |||||
| # Url for modelarts | |||||
| data_url: "" | |||||
| train_url: "" | |||||
| checkpoint_url: "" | |||||
| # Path for local | |||||
| data_path: "/cache/data" | |||||
| output_path: "/cache/train" | |||||
| load_path: "/cache/checkpoint_path/" | |||||
| device_target: 'Ascend' | |||||
| enable_profiling: False | |||||
| # ============================================================================== | |||||
| # Training options | |||||
| dataset: 'SST2' | |||||
| pre_trained: False | |||||
| num_classes: 2 | |||||
| batch_size: 64 | |||||
| epoch_size: 4 | |||||
| weight_decay: 3e-5 | |||||
| keep_checkpoint_max: 1 | |||||
| checkpoint_path: './checkpoint/' | |||||
| checkpoint_file_path: 'train_textcnn-4_149.ckpt' | |||||
| word_len: 51 | |||||
| vec_length: 40 | |||||
| base_lr: 1e-3 | |||||
| # Export options | |||||
| device_id: 0 | |||||
| ckpt_file: "" | |||||
| file_name: "" | |||||
| file_format: "" | |||||
| --- | |||||
| # Help description for each configuration | |||||
| enable_modelarts: 'Whether training on modelarts, default: False' | |||||
| data_url: 'Dataset url for obs' | |||||
| train_url: 'Training output url for obs' | |||||
| checkpoint_url: 'The location of checkpoint for obs' | |||||
| data_path: 'Dataset path for local' | |||||
| output_path: 'Training output path for local' | |||||
| load_path: 'The location of checkpoint for obs' | |||||
| device_target: 'Target device type, available: [Ascend, GPU, CPU]' | |||||
| enable_profiling: 'Whether enable profiling while training, default: False' | |||||
| dataset: "Dataset to be trained and evaluated, choice: ['MR, SUBJ, SST2']" | |||||
| train_epochs: "The number of epochs used to train." | |||||
| pre_trained: 'If need load pre_trained checkpoint, default: False' | |||||
| num_classes: 'Class for dataset' | |||||
| batch_size: "Batch size for training and evaluation" | |||||
| epoch_size: "Total training epochs." | |||||
| weight_decay: "Weight decay." | |||||
| keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint" | |||||
| num_factors: "The Embedding size of MF model." | |||||
| checkpoint_path: "The location of the checkpoint file." | |||||
| eval_file_name: "Eval output file." | |||||
| checkpoint_file_path: "The location of the checkpoint file." | |||||
| @@ -0,0 +1,58 @@ | |||||
| # Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) | |||||
| enable_modelarts: False | |||||
| # Url for modelarts | |||||
| data_url: "" | |||||
| train_url: "" | |||||
| checkpoint_url: "" | |||||
| # Path for local | |||||
| data_path: "/cache/data" | |||||
| output_path: "/cache/train" | |||||
| load_path: "/cache/checkpoint_path/" | |||||
| device_target: 'Ascend' | |||||
| enable_profiling: False | |||||
| # ============================================================================== | |||||
| # Training options | |||||
| dataset: 'SUBJ' | |||||
| pre_trained: False | |||||
| num_classes: 2 | |||||
| batch_size: 64 | |||||
| epoch_size: 4 | |||||
| weight_decay: 3e-5 | |||||
| keep_checkpoint_max: 1 | |||||
| checkpoint_path: './checkpoint/' | |||||
| checkpoint_file_path: 'train_textcnn-4_149.ckpt' | |||||
| word_len: 51 | |||||
| vec_length: 40 | |||||
| base_lr: 1e-3 | |||||
| # Export options | |||||
| device_id: 0 | |||||
| ckpt_file: "" | |||||
| file_name: "" | |||||
| file_format: "" | |||||
| --- | |||||
| # Help description for each configuration | |||||
| enable_modelarts: 'Whether training on modelarts, default: False' | |||||
| data_url: 'Dataset url for obs' | |||||
| train_url: 'Training output url for obs' | |||||
| checkpoint_url: 'The location of checkpoint for obs' | |||||
| data_path: 'Dataset path for local' | |||||
| output_path: 'Training output path for local' | |||||
| load_path: 'The location of checkpoint for obs' | |||||
| device_target: 'Target device type, available: [Ascend, GPU, CPU]' | |||||
| enable_profiling: 'Whether enable profiling while training, default: False' | |||||
| dataset: "Dataset to be trained and evaluated, choice: ['MR, SUBJ, SST2']" | |||||
| train_epochs: "The number of epochs used to train." | |||||
| pre_trained: 'If need load pre_trained checkpoint, default: False' | |||||
| num_classes: 'Class for dataset' | |||||
| batch_size: "Batch size for training and evaluation" | |||||
| epoch_size: "Total training epochs." | |||||
| weight_decay: "Weight decay." | |||||
| keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint" | |||||
| num_factors: "The Embedding size of MF model." | |||||
| checkpoint_path: "The location of the checkpoint file." | |||||
| eval_file_name: "Eval output file." | |||||
| checkpoint_file_path: "The location of the checkpoint file." | |||||
| @@ -16,7 +16,7 @@ | |||||
| #################train textcnn example on movie review######################## | #################train textcnn example on movie review######################## | ||||
| python train.py | python train.py | ||||
| """ | """ | ||||
| import argparse | |||||
| import os | |||||
| import math | import math | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| @@ -26,62 +26,62 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMoni | |||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from src.config import cfg_mr, cfg_subj, cfg_sst2 | |||||
| from utils.moxing_adapter import moxing_wrapper | |||||
| from utils.device_adapter import get_device_id, get_rank_id | |||||
| from utils.config import config | |||||
| from src.textcnn import TextCNN | from src.textcnn import TextCNN | ||||
| from src.textcnn import SoftmaxCrossEntropyExpand | from src.textcnn import SoftmaxCrossEntropyExpand | ||||
| from src.dataset import MovieReview, SST2, Subjectivity | from src.dataset import MovieReview, SST2, Subjectivity | ||||
| parser = argparse.ArgumentParser(description='TextCNN') | |||||
| parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], | |||||
| help='device where the code will be implemented (default: Ascend)') | |||||
| parser.add_argument('--device_id', type=int, default=5, help='device id of GPU or Ascend.') | |||||
| parser.add_argument('--dataset', type=str, default="MR", choices=['MR', 'SUBJ', 'SST2']) | |||||
| args_opt = parser.parse_args() | |||||
| def modelarts_pre_process(): | |||||
| config.checkpoint_path = os.path.join(config.output_path, str(get_rank_id()), config.checkpoint_path) | |||||
| if __name__ == '__main__': | |||||
| rank = 0 | |||||
| @moxing_wrapper(pre_process=modelarts_pre_process) | |||||
| def train_net(): | |||||
| '''train net''' | |||||
| # set context | # set context | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | |||||
| context.set_context(device_id=args_opt.device_id) | |||||
| if args_opt.dataset == 'MR': | |||||
| cfg = cfg_mr | |||||
| instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) | |||||
| elif args_opt.dataset == 'SUBJ': | |||||
| cfg = cfg_subj | |||||
| instance = Subjectivity(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) | |||||
| elif args_opt.dataset == 'SST2': | |||||
| cfg = cfg_sst2 | |||||
| instance = SST2(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) | |||||
| context.set_context(device_id=get_device_id()) | |||||
| if config.dataset == 'MR': | |||||
| instance = MovieReview(root_dir=config.data_path, maxlen=config.word_len, split=0.9) | |||||
| elif config.dataset == 'SUBJ': | |||||
| instance = Subjectivity(root_dir=config.data_path, maxlen=config.word_len, split=0.9) | |||||
| elif config.dataset == 'SST2': | |||||
| instance = SST2(root_dir=config.data_path, maxlen=config.word_len, split=0.9) | |||||
| dataset = instance.create_train_dataset(batch_size=cfg.batch_size, epoch_size=cfg.epoch_size) | |||||
| dataset = instance.create_train_dataset(batch_size=config.batch_size, epoch_size=config.epoch_size) | |||||
| batch_num = dataset.get_dataset_size() | batch_num = dataset.get_dataset_size() | ||||
| base_lr = cfg.base_lr | |||||
| base_lr = float(config.base_lr) | |||||
| learning_rate = [] | learning_rate = [] | ||||
| warm_up = [base_lr / math.floor(cfg.epoch_size / 5) * (i + 1) for _ in range(batch_num) for i in | |||||
| range(math.floor(cfg.epoch_size / 5))] | |||||
| shrink = [base_lr / (16 * (i + 1)) for _ in range(batch_num) for i in range(math.floor(cfg.epoch_size * 3 / 5))] | |||||
| warm_up = [base_lr / math.floor(config.epoch_size / 5) * (i + 1) for _ in range(batch_num) for i in | |||||
| range(math.floor(config.epoch_size / 5))] | |||||
| shrink = [base_lr / (16 * (i + 1)) for _ in range(batch_num) for i in range(math.floor(config.epoch_size * 3 / 5))] | |||||
| normal_run = [base_lr for _ in range(batch_num) for i in | normal_run = [base_lr for _ in range(batch_num) for i in | ||||
| range(cfg.epoch_size - math.floor(cfg.epoch_size / 5) - math.floor(cfg.epoch_size * 2 / 5))] | |||||
| range(config.epoch_size - math.floor(config.epoch_size / 5) - math.floor(config.epoch_size * 2 / 5))] | |||||
| learning_rate = learning_rate + warm_up + normal_run + shrink | learning_rate = learning_rate + warm_up + normal_run + shrink | ||||
| net = TextCNN(vocab_len=instance.get_dict_len(), word_len=cfg.word_len, | |||||
| num_classes=cfg.num_classes, vec_length=cfg.vec_length) | |||||
| net = TextCNN(vocab_len=instance.get_dict_len(), word_len=config.word_len, | |||||
| num_classes=config.num_classes, vec_length=config.vec_length) | |||||
| # Continue training if set pre_trained to be True | # Continue training if set pre_trained to be True | ||||
| if cfg.pre_trained: | |||||
| param_dict = load_checkpoint(cfg.checkpoint_path) | |||||
| if config.pre_trained: | |||||
| param_dict = load_checkpoint(config.checkpoint_path) | |||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate=learning_rate, weight_decay=cfg.weight_decay) | |||||
| opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), \ | |||||
| learning_rate=learning_rate, weight_decay=float(config.weight_decay)) | |||||
| loss = SoftmaxCrossEntropyExpand(sparse=True) | loss = SoftmaxCrossEntropyExpand(sparse=True) | ||||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc': Accuracy()}) | model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc': Accuracy()}) | ||||
| config_ck = CheckpointConfig(save_checkpoint_steps=int(cfg.epoch_size*batch_num/2), | |||||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=int(config.epoch_size*batch_num/2), | |||||
| keep_checkpoint_max=config.keep_checkpoint_max) | |||||
| time_cb = TimeMonitor(data_size=batch_num) | time_cb = TimeMonitor(data_size=batch_num) | ||||
| ckpt_save_dir = "./ckpt_" + str(rank) + "/" | |||||
| ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path) | |||||
| ckpoint_cb = ModelCheckpoint(prefix="train_textcnn", directory=ckpt_save_dir, config=config_ck) | ckpoint_cb = ModelCheckpoint(prefix="train_textcnn", directory=ckpt_save_dir, config=config_ck) | ||||
| loss_cb = LossMonitor() | loss_cb = LossMonitor() | ||||
| model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) | |||||
| model.train(config.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) | |||||
| print("train success") | print("train success") | ||||
| if __name__ == '__main__': | |||||
| train_net() | |||||
| @@ -0,0 +1,125 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Parse arguments""" | |||||
| import os | |||||
| import ast | |||||
| import argparse | |||||
| from pprint import pprint, pformat | |||||
| import yaml | |||||
| _config_path = "./mr_config.yaml" | |||||
| class Config: | |||||
| """ | |||||
| Configuration namespace. Convert dictionary to members. | |||||
| """ | |||||
| def __init__(self, cfg_dict): | |||||
| for k, v in cfg_dict.items(): | |||||
| if isinstance(v, (list, tuple)): | |||||
| setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v]) | |||||
| else: | |||||
| setattr(self, k, Config(v) if isinstance(v, dict) else v) | |||||
| def __str__(self): | |||||
| return pformat(self.__dict__) | |||||
| def __repr__(self): | |||||
| return self.__str__() | |||||
| def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"): | |||||
| """ | |||||
| Parse command line arguments to the configuration according to the default yaml. | |||||
| Args: | |||||
| parser: Parent parser. | |||||
| cfg: Base configuration. | |||||
| helper: Helper description. | |||||
| cfg_path: Path to the default yaml config. | |||||
| """ | |||||
| parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]", | |||||
| parents=[parser]) | |||||
| helper = {} if helper is None else helper | |||||
| choices = {} if choices is None else choices | |||||
| for item in cfg: | |||||
| if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict): | |||||
| help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path) | |||||
| choice = choices[item] if item in choices else None | |||||
| if isinstance(cfg[item], bool): | |||||
| parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice, | |||||
| help=help_description) | |||||
| else: | |||||
| parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice, | |||||
| help=help_description) | |||||
| args = parser.parse_args() | |||||
| return args | |||||
| def parse_yaml(yaml_path): | |||||
| """ | |||||
| Parse the yaml config file. | |||||
| Args: | |||||
| yaml_path: Path to the yaml config. | |||||
| """ | |||||
| with open(yaml_path, 'r') as fin: | |||||
| try: | |||||
| cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader) | |||||
| cfgs = [x for x in cfgs] | |||||
| if len(cfgs) == 1: | |||||
| cfg_helper = {} | |||||
| cfg = cfgs[0] | |||||
| elif len(cfgs) == 2: | |||||
| cfg, cfg_helper = cfgs | |||||
| else: | |||||
| raise ValueError("At most 2 docs (config and help description for help) are supported in config yaml") | |||||
| print(cfg_helper) | |||||
| except: | |||||
| raise ValueError("Failed to parse yaml") | |||||
| return cfg, cfg_helper | |||||
| def merge(args, cfg): | |||||
| """ | |||||
| Merge the base config from yaml file and command line arguments. | |||||
| Args: | |||||
| args: Command line arguments. | |||||
| cfg: Base configuration. | |||||
| """ | |||||
| args_var = vars(args) | |||||
| for item in args_var: | |||||
| cfg[item] = args_var[item] | |||||
| return cfg | |||||
| def get_config(): | |||||
| """ | |||||
| Get Config according to the yaml file and cli arguments. | |||||
| """ | |||||
| parser = argparse.ArgumentParser(description="default name", add_help=False) | |||||
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |||||
| parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../mr_config.yaml"), | |||||
| help="Config file path") | |||||
| path_args, _ = parser.parse_known_args() | |||||
| default, helper = parse_yaml(path_args.config_path) | |||||
| pprint(default) | |||||
| args = parse_cli_to_yaml(parser, default, helper, path_args.config_path) | |||||
| final_config = merge(args, default) | |||||
| return Config(final_config) | |||||
| config = get_config() | |||||
| @@ -0,0 +1,27 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Device adapter for ModelArts""" | |||||
| from utils.config import config | |||||
| if config.enable_modelarts: | |||||
| from utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id | |||||
| else: | |||||
| from utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id | |||||
| __all__ = [ | |||||
| "get_device_id", "get_device_num", "get_rank_id", "get_job_id" | |||||
| ] | |||||
| @@ -0,0 +1,36 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Local adapter""" | |||||
| import os | |||||
| def get_device_id(): | |||||
| device_id = os.getenv('DEVICE_ID', '0') | |||||
| return int(device_id) | |||||
| def get_device_num(): | |||||
| device_num = os.getenv('RANK_SIZE', '1') | |||||
| return int(device_num) | |||||
| def get_rank_id(): | |||||
| global_rank_id = os.getenv('RANK_ID', '0') | |||||
| return int(global_rank_id) | |||||
| def get_job_id(): | |||||
| return "Local Job" | |||||
| @@ -0,0 +1,115 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Moxing adapter for ModelArts""" | |||||
| import os | |||||
| import functools | |||||
| from mindspore import context | |||||
| from utils.config import config | |||||
| _global_sync_count = 0 | |||||
| def get_device_id(): | |||||
| device_id = os.getenv('DEVICE_ID', '0') | |||||
| return int(device_id) | |||||
| def get_device_num(): | |||||
| device_num = os.getenv('RANK_SIZE', '1') | |||||
| return int(device_num) | |||||
| def get_rank_id(): | |||||
| global_rank_id = os.getenv('RANK_ID', '0') | |||||
| return int(global_rank_id) | |||||
| def get_job_id(): | |||||
| job_id = os.getenv('JOB_ID') | |||||
| job_id = job_id if job_id != "" else "default" | |||||
| return job_id | |||||
| def sync_data(from_path, to_path): | |||||
| """ | |||||
| Download data from remote obs to local directory if the first url is remote url and the second one is local path | |||||
| Upload data from local directory to remote obs in contrast. | |||||
| """ | |||||
| import moxing as mox | |||||
| import time | |||||
| global _global_sync_count | |||||
| sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count) | |||||
| _global_sync_count += 1 | |||||
| # Each server contains 8 devices as most. | |||||
| if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): | |||||
| print("from path: ", from_path) | |||||
| print("to path: ", to_path) | |||||
| mox.file.copy_parallel(from_path, to_path) | |||||
| print("===finish data synchronization===") | |||||
| try: | |||||
| os.mknod(sync_lock) | |||||
| except IOError: | |||||
| pass | |||||
| print("===save flag===") | |||||
| while True: | |||||
| if os.path.exists(sync_lock): | |||||
| break | |||||
| time.sleep(1) | |||||
| print("Finish sync data from {} to {}.".format(from_path, to_path)) | |||||
| def moxing_wrapper(pre_process=None, post_process=None): | |||||
| """ | |||||
| Moxing wrapper to download dataset and upload outputs. | |||||
| """ | |||||
| def wrapper(run_func): | |||||
| @functools.wraps(run_func) | |||||
| def wrapped_func(*args, **kwargs): | |||||
| # Download data from data_url | |||||
| if config.enable_modelarts: | |||||
| if config.data_url: | |||||
| sync_data(config.data_url, config.data_path) | |||||
| print("Dataset downloaded: ", os.listdir(config.data_path)) | |||||
| if config.checkpoint_url: | |||||
| sync_data(config.checkpoint_url, config.load_path) | |||||
| print("Preload downloaded: ", os.listdir(config.load_path)) | |||||
| if config.train_url: | |||||
| sync_data(config.train_url, config.output_path) | |||||
| print("Workspace downloaded: ", os.listdir(config.output_path)) | |||||
| context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id()))) | |||||
| config.device_num = get_device_num() | |||||
| config.device_id = get_device_id() | |||||
| if not os.path.exists(config.output_path): | |||||
| os.makedirs(config.output_path) | |||||
| if pre_process: | |||||
| pre_process() | |||||
| run_func(*args, **kwargs) | |||||
| # Upload data to train_url | |||||
| if config.enable_modelarts: | |||||
| if post_process: | |||||
| post_process() | |||||
| if config.train_url: | |||||
| print("Start to copy output directory") | |||||
| sync_data(config.output_path, config.train_url) | |||||
| return wrapped_func | |||||
| return wrapper | |||||