From: @wuxuejian Reviewed-by: @oacjiewen,@liangchenghui Signed-off-by: @liangchenghuipull/14739/MERGE
| @@ -99,6 +99,7 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil | |||||
| You can start training using python or shell scripts. The usage of shell scripts as follows: | You can start training using python or shell scripts. The usage of shell scripts as follows: | ||||
| - Ascend: sh run_distribute_train.sh [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH] (optional) | - Ascend: sh run_distribute_train.sh [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH] (optional) | ||||
| - CPU: sh run_train_CPU.sh [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH] (optional) | |||||
| For distributed training, a hccl configuration file with JSON format needs to be created in advance. | For distributed training, a hccl configuration file with JSON format needs to be created in advance. | ||||
| @@ -109,10 +110,12 @@ Please follow the instructions in the link [hccn_tools](https://gitee.com/mindsp | |||||
| ```shell | ```shell | ||||
| # training example | # training example | ||||
| python: | python: | ||||
| Ascend: python train.py --platform Ascend --dataset_path [TRAIN_DATASET_PATH] | |||||
| Ascend: python train.py --device_target Ascend --dataset_path [TRAIN_DATASET_PATH] | |||||
| CPU: python train.py --device_target CPU --dataset_path [TRAIN_DATASET_PATH] | |||||
| shell: | shell: | ||||
| Ascend: sh run_distribute_train.sh [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) | Ascend: sh run_distribute_train.sh [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) | ||||
| CPU: sh run_train_CPU.sh [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) | |||||
| ``` | ``` | ||||
| ### Result | ### Result | ||||
| @@ -133,6 +136,7 @@ Epoch time: 150950.623, per step time: 120.664 | |||||
| You can start training using python or shell scripts.If the train method is train or fine tune, should not input the `[CHECKPOINT_PATH]` The usage of shell scripts as follows: | You can start training using python or shell scripts.If the train method is train or fine tune, should not input the `[CHECKPOINT_PATH]` The usage of shell scripts as follows: | ||||
| - Ascend: sh run_eval.sh [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH] | - Ascend: sh run_eval.sh [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH] | ||||
| - CPU: sh run_eval_CPU.sh [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH] | |||||
| ### Launch | ### Launch | ||||
| @@ -140,9 +144,11 @@ You can start training using python or shell scripts.If the train method is trai | |||||
| # eval example | # eval example | ||||
| python: | python: | ||||
| Ascend: python eval.py --dataset [cifar10|imagenet2012] --dataset_path [VAL_DATASET_PATH] --pretrain_ckpt [CHECKPOINT_PATH] | Ascend: python eval.py --dataset [cifar10|imagenet2012] --dataset_path [VAL_DATASET_PATH] --pretrain_ckpt [CHECKPOINT_PATH] | ||||
| CPU: python eval.py --dataset [cifar10|imagenet2012] --dataset_path [VAL_DATASET_PATH] --pretrain_ckpt [CHECKPOINT_PATH] --device_target CPU | |||||
| shell: | shell: | ||||
| Ascend: sh run_eval.sh [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH] | Ascend: sh run_eval.sh [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH] | ||||
| CPU: sh run_eval_CPU.sh [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH] | |||||
| ``` | ``` | ||||
| > checkpoint can be produced in training process. | > checkpoint can be produced in training process. | ||||
| @@ -45,7 +45,7 @@ if __name__ == '__main__': | |||||
| # init context | # init context | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) | context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) | ||||
| if target != "GPU": | |||||
| if target == "Ascend": | |||||
| device_id = int(os.getenv('DEVICE_ID')) | device_id = int(os.getenv('DEVICE_ID')) | ||||
| context.set_context(device_id=device_id) | context.set_context(device_id=device_id) | ||||
| @@ -0,0 +1,64 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 3 ] | |||||
| then | |||||
| echo "Usage: bash run_eval_cpu.sh [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]" | |||||
| exit 1 | |||||
| fi | |||||
| if [ $1 != "cifar10" ] && [ $1 != "imagenet2012" ] | |||||
| then | |||||
| echo "error: the selected dataset is neither cifar10 nor imagenet2012" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path(){ | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $2) | |||||
| PATH2=$(get_real_path $3) | |||||
| if [ ! -d $PATH1 ] | |||||
| then | |||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| if [ ! -f $PATH2 ] | |||||
| then | |||||
| echo "error: CHECKPOINT_PATH=$PATH2 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| if [ -d "eval" ]; | |||||
| then | |||||
| rm -rf ./eval | |||||
| fi | |||||
| mkdir ./eval | |||||
| cp ../*.py ./eval | |||||
| cp *.sh ./eval | |||||
| cp -r ../src ./eval | |||||
| cd ./eval || exit | |||||
| env > env.log | |||||
| python eval.py --dataset=$1 --dataset_path=$PATH1 --checkpoint_path=$PATH2 --device_target=CPU &> log & | |||||
| cd .. | |||||
| @@ -0,0 +1,75 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 2 ] && [ $# != 3 ] | |||||
| then | |||||
| echo "Usage: bash run_train_cpu.sh [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" | |||||
| exit 1 | |||||
| fi | |||||
| if [ $1 != "cifar10" ] && [ $1 != "imagenet2012" ] | |||||
| then | |||||
| echo "error: the selected dataset is neither cifar10 nor imagenet2012" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path(){ | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $2) | |||||
| if [ $# == 3 ] | |||||
| then | |||||
| PATH2=$(get_real_path $3) | |||||
| fi | |||||
| if [ ! -d $PATH1 ] | |||||
| then | |||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| if [ $# == 3 ] && [ ! -f $PATH2 ] | |||||
| then | |||||
| echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| if [ -d "train" ]; | |||||
| then | |||||
| rm -rf ./train | |||||
| fi | |||||
| mkdir ./train | |||||
| cp ../*.py ./train | |||||
| cp *.sh ./train | |||||
| cp -r ../src ./train | |||||
| cd ./train || exit | |||||
| env > env.log | |||||
| if [ $# == 2 ] | |||||
| then | |||||
| python train.py --dataset=$1 --dataset_path=$PATH1 --device_target=CPU &> log & | |||||
| fi | |||||
| if [ $# == 3 ] | |||||
| then | |||||
| python train.py --dataset=$1 --dataset_path=$PATH1 --pre_trained=$PATH2 --device_target=CPU &> log & | |||||
| fi | |||||
| cd .. | |||||
| @@ -16,12 +16,15 @@ | |||||
| create train or eval dataset. | create train or eval dataset. | ||||
| """ | """ | ||||
| import os | import os | ||||
| from multiprocessing import cpu_count | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.vision.c_transforms as C | import mindspore.dataset.vision.c_transforms as C | ||||
| import mindspore.dataset.transforms.c_transforms as C2 | import mindspore.dataset.transforms.c_transforms as C2 | ||||
| from mindspore.communication.management import init, get_rank, get_group_size | from mindspore.communication.management import init, get_rank, get_group_size | ||||
| THREAD_NUM = 12 if cpu_count() >= 12 else 8 | |||||
| def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"): | def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"): | ||||
| """ | """ | ||||
| @@ -38,15 +41,17 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target= | |||||
| """ | """ | ||||
| if target == "Ascend": | if target == "Ascend": | ||||
| device_num, rank_id = _get_rank_info() | device_num, rank_id = _get_rank_info() | ||||
| else: | |||||
| elif target == "GPU": | |||||
| init() | init() | ||||
| rank_id = get_rank() | rank_id = get_rank() | ||||
| device_num = get_group_size() | device_num = get_group_size() | ||||
| else: | |||||
| device_num = 1 | |||||
| if device_num == 1: | if device_num == 1: | ||||
| data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=12, shuffle=True) | |||||
| data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=THREAD_NUM, shuffle=True) | |||||
| else: | else: | ||||
| data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=12, shuffle=True, | |||||
| data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=THREAD_NUM, shuffle=True, | |||||
| num_shards=device_num, shard_id=rank_id) | num_shards=device_num, shard_id=rank_id) | ||||
| # define map operations | # define map operations | ||||
| @@ -66,8 +71,8 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target= | |||||
| type_cast_op = C2.TypeCast(mstype.int32) | type_cast_op = C2.TypeCast(mstype.int32) | ||||
| data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=12) | |||||
| data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=12) | |||||
| data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=THREAD_NUM) | |||||
| data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=THREAD_NUM) | |||||
| # apply batch operations | # apply batch operations | ||||
| data_set = data_set.batch(batch_size, drop_remainder=True) | data_set = data_set.batch(batch_size, drop_remainder=True) | ||||
| @@ -99,9 +104,9 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target= | |||||
| device_num = get_group_size() | device_num = get_group_size() | ||||
| if device_num == 1: | if device_num == 1: | ||||
| data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True) | |||||
| data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=THREAD_NUM, shuffle=True) | |||||
| else: | else: | ||||
| data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True, | |||||
| data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=THREAD_NUM, shuffle=True, | |||||
| num_shards=device_num, shard_id=rank_id) | num_shards=device_num, shard_id=rank_id) | ||||
| image_size = 224 | image_size = 224 | ||||
| @@ -127,8 +132,8 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target= | |||||
| type_cast_op = C2.TypeCast(mstype.int32) | type_cast_op = C2.TypeCast(mstype.int32) | ||||
| data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=12) | |||||
| data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=12) | |||||
| data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=THREAD_NUM) | |||||
| data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=THREAD_NUM) | |||||
| # apply batch operations | # apply batch operations | ||||
| data_set = data_set.batch(batch_size, drop_remainder=True) | data_set = data_set.batch(batch_size, drop_remainder=True) | ||||
| @@ -116,38 +116,28 @@ if __name__ == '__main__': | |||||
| else: | else: | ||||
| no_decayed_params.append(param) | no_decayed_params.append(param) | ||||
| group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, | |||||
| {'params': no_decayed_params}, | |||||
| {'order_params': net.trainable_params()}] | |||||
| opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale) | |||||
| # define loss, model | |||||
| if target == "Ascend": | if target == "Ascend": | ||||
| if args_opt.dataset == "imagenet2012": | |||||
| if not config.use_label_smooth: | |||||
| config.label_smooth_factor = 0.0 | |||||
| loss = CrossEntropySmooth(sparse=True, reduction="mean", | |||||
| smooth_factor=config.label_smooth_factor, num_classes=config.class_num) | |||||
| else: | |||||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | |||||
| loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) | |||||
| model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, | |||||
| amp_level="O2", keep_batchnorm_fp32=False) | |||||
| group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, | |||||
| {'params': no_decayed_params}, | |||||
| {'order_params': net.trainable_params()}] | |||||
| opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale) | |||||
| else: | else: | ||||
| # GPU target | |||||
| if args_opt.dataset == "imagenet2012": | |||||
| if not config.use_label_smooth: | |||||
| config.label_smooth_factor = 0.0 | |||||
| loss = CrossEntropySmooth(sparse=True, reduction="mean", | |||||
| smooth_factor=config.label_smooth_factor, num_classes=config.class_num) | |||||
| else: | |||||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay, | opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay, | ||||
| config.loss_scale) | config.loss_scale) | ||||
| loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) | |||||
| # Mixed precision | |||||
| # define loss, model | |||||
| if args_opt.dataset == "imagenet2012": | |||||
| if not config.use_label_smooth: | |||||
| config.label_smooth_factor = 0.0 | |||||
| loss = CrossEntropySmooth(sparse=True, reduction="mean", | |||||
| smooth_factor=config.label_smooth_factor, num_classes=config.class_num) | |||||
| else: | |||||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | |||||
| loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) | |||||
| if target != "CPU": | |||||
| model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, | model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, | ||||
| amp_level="O2", keep_batchnorm_fp32=False) | amp_level="O2", keep_batchnorm_fp32=False) | ||||
| else: | |||||
| model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}) | |||||
| # define callbacks | # define callbacks | ||||
| time_cb = TimeMonitor(data_size=step_size) | time_cb = TimeMonitor(data_size=step_size) | ||||