diff --git a/model_zoo/official/nlp/textcnn/README.md b/model_zoo/official/nlp/textcnn/README.md index e796d07d4c..3604636048 100644 --- a/model_zoo/official/nlp/textcnn/README.md +++ b/model_zoo/official/nlp/textcnn/README.md @@ -65,25 +65,59 @@ After installing MindSpore via the official website, you can start training and sh scripts/run_eval.sh ckpt_path ``` +If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training and evaluation as follows: + +```python +# run distributed training on modelarts example +# (1) First, Perform a or b. +# a. Set "enable_modelarts=True" on yaml file. +# Set other parameters on yaml file you need. +# b. Add "enable_modelarts=True" on the website UI interface. +# Add other parameters on the website UI interface. +# (2) Set the code directory to "/path/textcnn" on the website UI interface. +# (3) Set the startup file to "train.py" on the website UI interface. +# (4) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface. +# (5) Create your job. + +# run evaluation on modelarts example +# (1) Copy or upload your trained model to S3 bucket. +# (2) Perform a or b. +# a. Set "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on yaml file. +# Set "checkpoint_url=/The path of checkpoint in S3/" on yaml file. +# b. Add "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on the website UI interface. +# Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface. +# (3) Set the code directory to "/path/textcnn" on the website UI interface. +# (4) Set the startup file to "eval.py" on the website UI interface. +# (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface. +# (6) Create your job. +``` + # [Script Description](#contents) ## [Script and Sample Code](#contents) ```bash ├── model_zoo - ├── README.md // descriptions about all the models + ├── README.md // descriptions about all the models ├── textcnn - ├── README.md // descriptions about textcnn + ├── README.md // descriptions about textcnn ├──scripts - │ ├── run_train.sh // shell script for distributed on Ascend - │ ├── run_eval.sh // shell script for evaluation on Ascend + │ ├── run_train.sh // shell script for distributed on Ascend + │ ├── run_eval.sh // shell script for evaluation on Ascend ├── src - │ ├── dataset.py // Processing dataset - │ ├── textcnn.py // textcnn architecture - │ ├── config.py // parameter configuration - ├── train.py // training script - ├── eval.py // evaluation script - ├── export.py // export checkpoint to other format file + │ ├── dataset.py // Processing dataset + │ ├── textcnn.py // textcnn architecture + ├── utils + │ ├──device_adapter.py // device adapter + │ ├──local_adapter.py // local adapter + │ ├──moxing_adapter.py // moxing adapter + │ ├── config.py // parameter analysis + ├── mr_config.yaml // parameter configuration + ├── sst2_config.yaml // parameter configuration + ├── subj_config.yaml // parameter configuration + ├── train.py // training script + ├── eval.py // evaluation script + ├── export.py // export checkpoint to other format file ``` ## [Script Parameters](#contents) diff --git a/model_zoo/official/nlp/textcnn/eval.py b/model_zoo/official/nlp/textcnn/eval.py index c971a9563e..658911a4a6 100644 --- a/model_zoo/official/nlp/textcnn/eval.py +++ b/model_zoo/official/nlp/textcnn/eval.py @@ -16,50 +16,40 @@ ##############test textcnn example on movie review################# python eval.py """ -import argparse - import mindspore.nn as nn from mindspore.nn.metrics import Accuracy from mindspore import context from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net -from src.config import cfg_mr, cfg_subj, cfg_sst2 +from utils.moxing_adapter import moxing_wrapper +from utils.device_adapter import get_device_id +from utils.config import config from src.textcnn import TextCNN from src.dataset import MovieReview, SST2, Subjectivity -parser = argparse.ArgumentParser(description='TextCNN') -parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') -parser.add_argument('--dataset', type=str, default="MR", choices=['MR', 'SUBJ', 'SST2']) -args_opt = parser.parse_args() - -if __name__ == '__main__': - if args_opt.dataset == 'MR': - cfg = cfg_mr - instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) - elif args_opt.dataset == 'SUBJ': - cfg = cfg_subj +@moxing_wrapper() +def eval_net(): + '''eval net''' + if config.dataset == 'MR': + instance = MovieReview(root_dir=config.data_path, maxlen=config.word_len, split=0.9) + elif config.dataset == 'SUBJ': instance = Subjectivity(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) - elif args_opt.dataset == 'SST2': - cfg = cfg_sst2 - instance = SST2(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) - device_target = cfg.device_target - context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) + elif config.dataset == 'SST2': + instance = SST2(root_dir=config.data_path, maxlen=config.word_len, split=0.9) + device_target = config.device_target + context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) if device_target == "Ascend": - context.set_context(device_id=cfg.device_id) - dataset = instance.create_test_dataset(batch_size=cfg.batch_size) + context.set_context(device_id=get_device_id()) + dataset = instance.create_test_dataset(batch_size=config.batch_size) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) - net = TextCNN(vocab_len=instance.get_dict_len(), word_len=cfg.word_len, - num_classes=cfg.num_classes, vec_length=cfg.vec_length) + net = TextCNN(vocab_len=instance.get_dict_len(), word_len=config.word_len, + num_classes=config.num_classes, vec_length=config.vec_length) opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate=0.001, - weight_decay=cfg.weight_decay) + weight_decay=float(config.weight_decay)) - if args_opt.checkpoint_path is not None: - param_dict = load_checkpoint(args_opt.checkpoint_path) - print("load checkpoint from [{}].".format(args_opt.checkpoint_path)) - else: - param_dict = load_checkpoint(cfg.checkpoint_path) - print("load checkpoint from [{}].".format(cfg.checkpoint_path)) + param_dict = load_checkpoint(config.checkpoint_file_path) + print("load checkpoint from [{}].".format(config.checkpoint_file_path)) load_param_into_net(net, param_dict) net.set_train(False) @@ -67,3 +57,6 @@ if __name__ == '__main__': acc = model.eval(dataset) print("accuracy: ", acc) + +if __name__ == '__main__': + eval_net() diff --git a/model_zoo/official/nlp/textcnn/export.py b/model_zoo/official/nlp/textcnn/export.py index 6404d2f481..fe75e83e5a 100644 --- a/model_zoo/official/nlp/textcnn/export.py +++ b/model_zoo/official/nlp/textcnn/export.py @@ -16,50 +16,34 @@ ##############export checkpoint file into air, onnx, mindir models################# python export.py """ -import argparse import numpy as np from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context -from src.config import cfg_mr, cfg_subj, cfg_sst2 +from utils.config import config from src.textcnn import TextCNN from src.dataset import MovieReview, SST2, Subjectivity -parser = argparse.ArgumentParser(description='TextCNN export') -parser.add_argument("--device_id", type=int, default=0, help="device id") -parser.add_argument("--ckpt_file", type=str, required=True, help="checkpoint file path.") -parser.add_argument("--file_name", type=str, default="textcnn", help="output file name.") -parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') -parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend", - help="device target") -parser.add_argument('--dataset', type=str, default='MR', choices=['MR', 'SUBJ', 'SST2'], - help='dataset name.') - -args = parser.parse_args() - -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) -if args.device_target == "Ascend": - context.set_context(device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) +if config.device_target == "Ascend": + context.set_context(device_id=config.device_id) if __name__ == '__main__': - if args.dataset == 'MR': - cfg = cfg_mr - instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) - elif args.dataset == 'SUBJ': - cfg = cfg_subj - instance = Subjectivity(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) - elif args.dataset == 'SST2': - cfg = cfg_sst2 - instance = SST2(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) + if config.dataset == 'MR': + instance = MovieReview(root_dir=config.data_path, maxlen=config.word_len, split=0.9) + elif config.dataset == 'SUBJ': + instance = Subjectivity(root_dir=config.data_path, maxlen=config.word_len, split=0.9) + elif config.dataset == 'SST2': + instance = SST2(root_dir=config.data_path, maxlen=config.word_len, split=0.9) else: raise ValueError("dataset is not support.") - net = TextCNN(vocab_len=instance.get_dict_len(), word_len=cfg.word_len, - num_classes=cfg.num_classes, vec_length=cfg.vec_length) + net = TextCNN(vocab_len=instance.get_dict_len(), word_len=config.word_len, + num_classes=config.num_classes, vec_length=config.vec_length) - param_dict = load_checkpoint(args.ckpt_file) + param_dict = load_checkpoint(config.ckpt_file) load_param_into_net(net, param_dict) - input_arr = Tensor(np.ones([cfg.batch_size, cfg.word_len], np.int32)) - export(net, input_arr, file_name=args.file_name, file_format=args.file_format) + input_arr = Tensor(np.ones([config.batch_size, config.word_len], np.int32)) + export(net, input_arr, file_name=config.file_name, file_format=config.file_format) diff --git a/model_zoo/official/nlp/textcnn/mr_config.yaml b/model_zoo/official/nlp/textcnn/mr_config.yaml new file mode 100644 index 0000000000..575418404f --- /dev/null +++ b/model_zoo/official/nlp/textcnn/mr_config.yaml @@ -0,0 +1,57 @@ +# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) +enable_modelarts: False +# Url for modelarts +data_url: "" +train_url: "" +checkpoint_url: "" +# Path for local +data_path: "/cache/data" +output_path: "/cache/train" +load_path: "/cache/checkpoint_path/" +device_target: 'Ascend' +enable_profiling: False + +# ============================================================================== +# Training options +dataset: 'MR' +pre_trained: False +num_classes: 2 +batch_size: 64 +epoch_size: 4 +weight_decay: 3e-5 +keep_checkpoint_max: 1 +checkpoint_path: './checkpoint/' +checkpoint_file_path: 'train_textcnn-4_149.ckpt' +word_len: 51 +vec_length: 40 +base_lr: 1e-3 + +# Export options +device_id: 0 +ckpt_file: "" +file_name: "" +file_format: "" + +--- +# Help description for each configuration +enable_modelarts: 'Whether training on modelarts, default: False' +data_url: 'Dataset url for obs' +train_url: 'Training output url for obs' +checkpoint_url: 'The location of checkpoint for obs' +data_path: 'Dataset path for local' +output_path: 'Training output path for local' +load_path: 'The location of checkpoint for obs' +device_target: 'Target device type, available: [Ascend, GPU, CPU]' +enable_profiling: 'Whether enable profiling while training, default: False' +dataset: "Dataset to be trained and evaluated, choice: ['MR, SUBJ, SST2']" +train_epochs: "The number of epochs used to train." +pre_trained: 'If need load pre_trained checkpoint, default: False' +num_classes: 'Class for dataset' +batch_size: "Batch size for training and evaluation" +epoch_size: "Total training epochs." +weight_decay: "Weight decay." +keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint" +num_factors: "The Embedding size of MF model." +checkpoint_path: "The location of the checkpoint file." +eval_file_name: "Eval output file." +checkpoint_file_path: "The location of the checkpoint file." diff --git a/model_zoo/official/nlp/textcnn/scripts/run_eval.sh b/model_zoo/official/nlp/textcnn/scripts/run_eval.sh index 21ede3adc2..4ac6ba5f3e 100644 --- a/model_zoo/official/nlp/textcnn/scripts/run_eval.sh +++ b/model_zoo/official/nlp/textcnn/scripts/run_eval.sh @@ -13,14 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +BASE_PATH=$(cd ./ "`dirname $0`" || exit; pwd) dataset_type='MR' +CONFIG_FILE="${BASE_PATH}/mr_config.yaml" if [ $# == 2 ] then - if [ $2 != "MR" ] && [ $2 != "SUBJ" ] && [ $2 != "SST2" ] - then + if [ $2 == "MR" ]; then + CONFIG_FILE="${BASE_PATH}/mr_config.yaml" + elif [ $2 == "SUBJ" ]; then + CONFIG_FILE="${BASE_PATH}/subj_config.yaml" + elif [ $2 == "SST2" ]; then + CONFIG_FILE="${BASE_PATH}/sst2_config.yaml" + else echo "error: the selected dataset is not in supported set{MR, SUBJ, SST2}" exit 1 fi dataset_type=$2 fi -python eval.py --checkpoint_path="$1" --dataset=$dataset_type > eval.log 2>&1 & +python eval.py --checkpoint_file_path="$1" --dataset=$dataset_type --config_path=$CONFIG_FILE > eval.log 2>&1 & diff --git a/model_zoo/official/nlp/textcnn/scripts/run_train.sh b/model_zoo/official/nlp/textcnn/scripts/run_train.sh index cd7143fd7d..59bedc635d 100644 --- a/model_zoo/official/nlp/textcnn/scripts/run_train.sh +++ b/model_zoo/official/nlp/textcnn/scripts/run_train.sh @@ -13,15 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +BASE_PATH=$(cd ./ "`dirname $0`" || exit; pwd) dataset_type='MR' +CONFIG_FILE="${BASE_PATH}/mr_config.yaml" if [ $# == 1 ] then - if [ $1 != "MR" ] && [ $1 != "SUBJ" ] && [ $1 != "SST2" ] - then + if [ $1 == "MR" ]; then + CONFIG_FILE="${BASE_PATH}/mr_config.yaml" + elif [ $1 == "SUBJ" ]; then + CONFIG_FILE="${BASE_PATH}/subj_config.yaml" + elif [ $1 == "SST2" ]; then + CONFIG_FILE="${BASE_PATH}/sst2_config.yaml" + else echo "error: the selected dataset is not in supported set{MR, SUBJ, SST2}" exit 1 fi dataset_type=$1 fi rm ./ckpt_0 -rf -python train.py --dataset=$dataset_type > train.log 2>&1 & +python train.py --dataset=$dataset_type --config_path=$CONFIG_FILE --output_path './output' > train.log 2>&1 & diff --git a/model_zoo/official/nlp/textcnn/src/config.py b/model_zoo/official/nlp/textcnn/src/config.py deleted file mode 100644 index b94a4960e4..0000000000 --- a/model_zoo/official/nlp/textcnn/src/config.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -network config setting, will be used in main.py -""" -from easydict import EasyDict as edict - -cfg_mr = edict({ - 'name': 'movie review', - 'pre_trained': False, - 'num_classes': 2, - 'batch_size': 64, - 'epoch_size': 4, - 'weight_decay': 3e-5, - 'data_path': './data/', - 'device_target': 'Ascend', - 'device_id': 7, - 'keep_checkpoint_max': 1, - 'checkpoint_path': './ckpt/train_textcnn-4_149.ckpt', - 'word_len': 51, - 'vec_length': 40, - 'base_lr': 1e-3 -}) - -cfg_subj = edict({ - 'name': 'subjectivity', - 'pre_trained': False, - 'num_classes': 2, - 'batch_size': 64, - 'epoch_size': 5, - 'weight_decay': 3e-5, - 'data_path': './Subj/', - 'device_target': 'Ascend', - 'device_id': 7, - 'keep_checkpoint_max': 1, - 'checkpoint_path': './ckpt/train_textcnn-4_149.ckpt', - 'word_len': 51, - 'vec_length': 40, - 'base_lr': 8e-4 -}) - -cfg_sst2 = edict({ - 'name': 'SST2', - 'pre_trained': False, - 'num_classes': 2, - 'batch_size': 64, - 'epoch_size': 4, - 'weight_decay': 3e-5, - 'data_path': './SST-2/', - 'device_target': 'Ascend', - 'device_id': 7, - 'keep_checkpoint_max': 1, - 'checkpoint_path': './ckpt/train_textcnn-4_149.ckpt', - 'word_len': 51, - 'vec_length': 40, - 'base_lr': 5e-3 -}) diff --git a/model_zoo/official/nlp/textcnn/sst2_config.yaml b/model_zoo/official/nlp/textcnn/sst2_config.yaml new file mode 100644 index 0000000000..8231052ea3 --- /dev/null +++ b/model_zoo/official/nlp/textcnn/sst2_config.yaml @@ -0,0 +1,57 @@ +# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) +enable_modelarts: False +# Url for modelarts +data_url: "" +train_url: "" +checkpoint_url: "" +# Path for local +data_path: "/cache/data" +output_path: "/cache/train" +load_path: "/cache/checkpoint_path/" +device_target: 'Ascend' +enable_profiling: False + +# ============================================================================== +# Training options +dataset: 'SST2' +pre_trained: False +num_classes: 2 +batch_size: 64 +epoch_size: 4 +weight_decay: 3e-5 +keep_checkpoint_max: 1 +checkpoint_path: './checkpoint/' +checkpoint_file_path: 'train_textcnn-4_149.ckpt' +word_len: 51 +vec_length: 40 +base_lr: 1e-3 + +# Export options +device_id: 0 +ckpt_file: "" +file_name: "" +file_format: "" + +--- +# Help description for each configuration +enable_modelarts: 'Whether training on modelarts, default: False' +data_url: 'Dataset url for obs' +train_url: 'Training output url for obs' +checkpoint_url: 'The location of checkpoint for obs' +data_path: 'Dataset path for local' +output_path: 'Training output path for local' +load_path: 'The location of checkpoint for obs' +device_target: 'Target device type, available: [Ascend, GPU, CPU]' +enable_profiling: 'Whether enable profiling while training, default: False' +dataset: "Dataset to be trained and evaluated, choice: ['MR, SUBJ, SST2']" +train_epochs: "The number of epochs used to train." +pre_trained: 'If need load pre_trained checkpoint, default: False' +num_classes: 'Class for dataset' +batch_size: "Batch size for training and evaluation" +epoch_size: "Total training epochs." +weight_decay: "Weight decay." +keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint" +num_factors: "The Embedding size of MF model." +checkpoint_path: "The location of the checkpoint file." +eval_file_name: "Eval output file." +checkpoint_file_path: "The location of the checkpoint file." diff --git a/model_zoo/official/nlp/textcnn/subj_config.yaml b/model_zoo/official/nlp/textcnn/subj_config.yaml new file mode 100644 index 0000000000..d65fd72296 --- /dev/null +++ b/model_zoo/official/nlp/textcnn/subj_config.yaml @@ -0,0 +1,58 @@ +# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) +enable_modelarts: False +# Url for modelarts +data_url: "" +train_url: "" +checkpoint_url: "" +# Path for local +data_path: "/cache/data" +output_path: "/cache/train" +load_path: "/cache/checkpoint_path/" +device_target: 'Ascend' +enable_profiling: False + +# ============================================================================== +# Training options +dataset: 'SUBJ' +pre_trained: False +num_classes: 2 +batch_size: 64 +epoch_size: 4 +weight_decay: 3e-5 +keep_checkpoint_max: 1 +checkpoint_path: './checkpoint/' +checkpoint_file_path: 'train_textcnn-4_149.ckpt' +word_len: 51 +vec_length: 40 +base_lr: 1e-3 + +# Export options +device_id: 0 +ckpt_file: "" +file_name: "" +file_format: "" + +--- +# Help description for each configuration +enable_modelarts: 'Whether training on modelarts, default: False' +data_url: 'Dataset url for obs' +train_url: 'Training output url for obs' +checkpoint_url: 'The location of checkpoint for obs' +data_path: 'Dataset path for local' +output_path: 'Training output path for local' +load_path: 'The location of checkpoint for obs' +device_target: 'Target device type, available: [Ascend, GPU, CPU]' +enable_profiling: 'Whether enable profiling while training, default: False' +dataset: "Dataset to be trained and evaluated, choice: ['MR, SUBJ, SST2']" +train_epochs: "The number of epochs used to train." +pre_trained: 'If need load pre_trained checkpoint, default: False' +num_classes: 'Class for dataset' +batch_size: "Batch size for training and evaluation" +epoch_size: "Total training epochs." +weight_decay: "Weight decay." +keep_checkpoint_max: "keep the last keep_checkpoint_max checkpoint" +num_factors: "The Embedding size of MF model." +checkpoint_path: "The location of the checkpoint file." +eval_file_name: "Eval output file." +checkpoint_file_path: "The location of the checkpoint file." + diff --git a/model_zoo/official/nlp/textcnn/train.py b/model_zoo/official/nlp/textcnn/train.py index 29dda564b8..0ff4a28987 100644 --- a/model_zoo/official/nlp/textcnn/train.py +++ b/model_zoo/official/nlp/textcnn/train.py @@ -16,7 +16,7 @@ #################train textcnn example on movie review######################## python train.py """ -import argparse +import os import math import mindspore.nn as nn @@ -26,62 +26,62 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMoni from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net -from src.config import cfg_mr, cfg_subj, cfg_sst2 +from utils.moxing_adapter import moxing_wrapper +from utils.device_adapter import get_device_id, get_rank_id +from utils.config import config from src.textcnn import TextCNN from src.textcnn import SoftmaxCrossEntropyExpand from src.dataset import MovieReview, SST2, Subjectivity -parser = argparse.ArgumentParser(description='TextCNN') -parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], - help='device where the code will be implemented (default: Ascend)') -parser.add_argument('--device_id', type=int, default=5, help='device id of GPU or Ascend.') -parser.add_argument('--dataset', type=str, default="MR", choices=['MR', 'SUBJ', 'SST2']) -args_opt = parser.parse_args() +def modelarts_pre_process(): + config.checkpoint_path = os.path.join(config.output_path, str(get_rank_id()), config.checkpoint_path) -if __name__ == '__main__': - rank = 0 +@moxing_wrapper(pre_process=modelarts_pre_process) +def train_net(): + '''train net''' # set context - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) - context.set_context(device_id=args_opt.device_id) - if args_opt.dataset == 'MR': - cfg = cfg_mr - instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) - elif args_opt.dataset == 'SUBJ': - cfg = cfg_subj - instance = Subjectivity(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) - elif args_opt.dataset == 'SST2': - cfg = cfg_sst2 - instance = SST2(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9) + context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) + context.set_context(device_id=get_device_id()) + if config.dataset == 'MR': + instance = MovieReview(root_dir=config.data_path, maxlen=config.word_len, split=0.9) + elif config.dataset == 'SUBJ': + instance = Subjectivity(root_dir=config.data_path, maxlen=config.word_len, split=0.9) + elif config.dataset == 'SST2': + instance = SST2(root_dir=config.data_path, maxlen=config.word_len, split=0.9) - dataset = instance.create_train_dataset(batch_size=cfg.batch_size, epoch_size=cfg.epoch_size) + dataset = instance.create_train_dataset(batch_size=config.batch_size, epoch_size=config.epoch_size) batch_num = dataset.get_dataset_size() - base_lr = cfg.base_lr + base_lr = float(config.base_lr) learning_rate = [] - warm_up = [base_lr / math.floor(cfg.epoch_size / 5) * (i + 1) for _ in range(batch_num) for i in - range(math.floor(cfg.epoch_size / 5))] - shrink = [base_lr / (16 * (i + 1)) for _ in range(batch_num) for i in range(math.floor(cfg.epoch_size * 3 / 5))] + warm_up = [base_lr / math.floor(config.epoch_size / 5) * (i + 1) for _ in range(batch_num) for i in + range(math.floor(config.epoch_size / 5))] + shrink = [base_lr / (16 * (i + 1)) for _ in range(batch_num) for i in range(math.floor(config.epoch_size * 3 / 5))] normal_run = [base_lr for _ in range(batch_num) for i in - range(cfg.epoch_size - math.floor(cfg.epoch_size / 5) - math.floor(cfg.epoch_size * 2 / 5))] + range(config.epoch_size - math.floor(config.epoch_size / 5) - math.floor(config.epoch_size * 2 / 5))] learning_rate = learning_rate + warm_up + normal_run + shrink - net = TextCNN(vocab_len=instance.get_dict_len(), word_len=cfg.word_len, - num_classes=cfg.num_classes, vec_length=cfg.vec_length) + net = TextCNN(vocab_len=instance.get_dict_len(), word_len=config.word_len, + num_classes=config.num_classes, vec_length=config.vec_length) # Continue training if set pre_trained to be True - if cfg.pre_trained: - param_dict = load_checkpoint(cfg.checkpoint_path) + if config.pre_trained: + param_dict = load_checkpoint(config.checkpoint_path) load_param_into_net(net, param_dict) - opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate=learning_rate, weight_decay=cfg.weight_decay) + opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), \ + learning_rate=learning_rate, weight_decay=float(config.weight_decay)) loss = SoftmaxCrossEntropyExpand(sparse=True) model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc': Accuracy()}) - config_ck = CheckpointConfig(save_checkpoint_steps=int(cfg.epoch_size*batch_num/2), - keep_checkpoint_max=cfg.keep_checkpoint_max) + config_ck = CheckpointConfig(save_checkpoint_steps=int(config.epoch_size*batch_num/2), + keep_checkpoint_max=config.keep_checkpoint_max) time_cb = TimeMonitor(data_size=batch_num) - ckpt_save_dir = "./ckpt_" + str(rank) + "/" + ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path) ckpoint_cb = ModelCheckpoint(prefix="train_textcnn", directory=ckpt_save_dir, config=config_ck) loss_cb = LossMonitor() - model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) + model.train(config.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) print("train success") + +if __name__ == '__main__': + train_net() diff --git a/model_zoo/official/nlp/textcnn/utils/config.py b/model_zoo/official/nlp/textcnn/utils/config.py new file mode 100644 index 0000000000..31c47f185c --- /dev/null +++ b/model_zoo/official/nlp/textcnn/utils/config.py @@ -0,0 +1,125 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Parse arguments""" + +import os +import ast +import argparse +from pprint import pprint, pformat +import yaml + +_config_path = "./mr_config.yaml" + +class Config: + """ + Configuration namespace. Convert dictionary to members. + """ + def __init__(self, cfg_dict): + for k, v in cfg_dict.items(): + if isinstance(v, (list, tuple)): + setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v]) + else: + setattr(self, k, Config(v) if isinstance(v, dict) else v) + + def __str__(self): + return pformat(self.__dict__) + + def __repr__(self): + return self.__str__() + + +def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"): + """ + Parse command line arguments to the configuration according to the default yaml. + + Args: + parser: Parent parser. + cfg: Base configuration. + helper: Helper description. + cfg_path: Path to the default yaml config. + """ + parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]", + parents=[parser]) + helper = {} if helper is None else helper + choices = {} if choices is None else choices + for item in cfg: + if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict): + help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path) + choice = choices[item] if item in choices else None + if isinstance(cfg[item], bool): + parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice, + help=help_description) + else: + parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice, + help=help_description) + args = parser.parse_args() + return args + + +def parse_yaml(yaml_path): + """ + Parse the yaml config file. + + Args: + yaml_path: Path to the yaml config. + """ + with open(yaml_path, 'r') as fin: + try: + cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader) + cfgs = [x for x in cfgs] + if len(cfgs) == 1: + cfg_helper = {} + cfg = cfgs[0] + elif len(cfgs) == 2: + cfg, cfg_helper = cfgs + else: + raise ValueError("At most 2 docs (config and help description for help) are supported in config yaml") + print(cfg_helper) + except: + raise ValueError("Failed to parse yaml") + return cfg, cfg_helper + + +def merge(args, cfg): + """ + Merge the base config from yaml file and command line arguments. + + Args: + args: Command line arguments. + cfg: Base configuration. + """ + args_var = vars(args) + for item in args_var: + cfg[item] = args_var[item] + return cfg + + +def get_config(): + """ + Get Config according to the yaml file and cli arguments. + """ + parser = argparse.ArgumentParser(description="default name", add_help=False) + current_dir = os.path.dirname(os.path.abspath(__file__)) + parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../mr_config.yaml"), + help="Config file path") + path_args, _ = parser.parse_known_args() + default, helper = parse_yaml(path_args.config_path) + pprint(default) + args = parse_cli_to_yaml(parser, default, helper, path_args.config_path) + final_config = merge(args, default) + return Config(final_config) + +config = get_config() diff --git a/model_zoo/official/nlp/textcnn/utils/device_adapter.py b/model_zoo/official/nlp/textcnn/utils/device_adapter.py new file mode 100644 index 0000000000..92439de46b --- /dev/null +++ b/model_zoo/official/nlp/textcnn/utils/device_adapter.py @@ -0,0 +1,27 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Device adapter for ModelArts""" + +from utils.config import config + +if config.enable_modelarts: + from utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id +else: + from utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id + +__all__ = [ + "get_device_id", "get_device_num", "get_rank_id", "get_job_id" +] diff --git a/model_zoo/official/nlp/textcnn/utils/local_adapter.py b/model_zoo/official/nlp/textcnn/utils/local_adapter.py new file mode 100644 index 0000000000..769fa6dc78 --- /dev/null +++ b/model_zoo/official/nlp/textcnn/utils/local_adapter.py @@ -0,0 +1,36 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Local adapter""" + +import os + +def get_device_id(): + device_id = os.getenv('DEVICE_ID', '0') + return int(device_id) + + +def get_device_num(): + device_num = os.getenv('RANK_SIZE', '1') + return int(device_num) + + +def get_rank_id(): + global_rank_id = os.getenv('RANK_ID', '0') + return int(global_rank_id) + + +def get_job_id(): + return "Local Job" diff --git a/model_zoo/official/nlp/textcnn/utils/moxing_adapter.py b/model_zoo/official/nlp/textcnn/utils/moxing_adapter.py new file mode 100644 index 0000000000..420d4808f0 --- /dev/null +++ b/model_zoo/official/nlp/textcnn/utils/moxing_adapter.py @@ -0,0 +1,115 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Moxing adapter for ModelArts""" + +import os +import functools +from mindspore import context +from utils.config import config + +_global_sync_count = 0 + +def get_device_id(): + device_id = os.getenv('DEVICE_ID', '0') + return int(device_id) + + +def get_device_num(): + device_num = os.getenv('RANK_SIZE', '1') + return int(device_num) + + +def get_rank_id(): + global_rank_id = os.getenv('RANK_ID', '0') + return int(global_rank_id) + + +def get_job_id(): + job_id = os.getenv('JOB_ID') + job_id = job_id if job_id != "" else "default" + return job_id + +def sync_data(from_path, to_path): + """ + Download data from remote obs to local directory if the first url is remote url and the second one is local path + Upload data from local directory to remote obs in contrast. + """ + import moxing as mox + import time + global _global_sync_count + sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count) + _global_sync_count += 1 + + # Each server contains 8 devices as most. + if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): + print("from path: ", from_path) + print("to path: ", to_path) + mox.file.copy_parallel(from_path, to_path) + print("===finish data synchronization===") + try: + os.mknod(sync_lock) + except IOError: + pass + print("===save flag===") + + while True: + if os.path.exists(sync_lock): + break + time.sleep(1) + + print("Finish sync data from {} to {}.".format(from_path, to_path)) + + +def moxing_wrapper(pre_process=None, post_process=None): + """ + Moxing wrapper to download dataset and upload outputs. + """ + def wrapper(run_func): + @functools.wraps(run_func) + def wrapped_func(*args, **kwargs): + # Download data from data_url + if config.enable_modelarts: + if config.data_url: + sync_data(config.data_url, config.data_path) + print("Dataset downloaded: ", os.listdir(config.data_path)) + if config.checkpoint_url: + sync_data(config.checkpoint_url, config.load_path) + print("Preload downloaded: ", os.listdir(config.load_path)) + if config.train_url: + sync_data(config.train_url, config.output_path) + print("Workspace downloaded: ", os.listdir(config.output_path)) + + context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id()))) + config.device_num = get_device_num() + config.device_id = get_device_id() + if not os.path.exists(config.output_path): + os.makedirs(config.output_path) + + if pre_process: + pre_process() + + run_func(*args, **kwargs) + + # Upload data to train_url + if config.enable_modelarts: + if post_process: + post_process() + + if config.train_url: + print("Start to copy output directory") + sync_data(config.output_path, config.train_url) + return wrapped_func + return wrapper