Browse Source

add incremental learn fun

tags/v1.0.0
Payne 5 years ago
parent
commit
d7e2bf974f
12 changed files with 868 additions and 470 deletions
  1. +75
    -49
      model_zoo/official/cv/mobilenetv2/Readme.md
  2. +25
    -44
      model_zoo/official/cv/mobilenetv2/eval.py
  3. +108
    -29
      model_zoo/official/cv/mobilenetv2/scripts/run_infer.sh
  4. +39
    -8
      model_zoo/official/cv/mobilenetv2/scripts/run_train.sh
  5. +65
    -0
      model_zoo/official/cv/mobilenetv2/src/args.py
  6. +74
    -34
      model_zoo/official/cv/mobilenetv2/src/config.py
  7. +40
    -7
      model_zoo/official/cv/mobilenetv2/src/dataset.py
  8. +4
    -37
      model_zoo/official/cv/mobilenetv2/src/launch.py
  9. +121
    -29
      model_zoo/official/cv/mobilenetv2/src/mobilenetV2.py
  10. +138
    -0
      model_zoo/official/cv/mobilenetv2/src/models.py
  11. +93
    -0
      model_zoo/official/cv/mobilenetv2/src/utils.py
  12. +86
    -233
      model_zoo/official/cv/mobilenetv2/train.py

+ 75
- 49
model_zoo/official/cv/mobilenetv2/Readme.md View File

@@ -4,23 +4,22 @@
- [Model Architecture](#model-architecture) - [Model Architecture](#model-architecture)
- [Dataset](#dataset) - [Dataset](#dataset)
- [Features](#features) - [Features](#features)
- [Mixed Precision](#mixed-precision)
- [Mixed Precision](#mixed-precision)
- [Environment Requirements](#environment-requirements) - [Environment Requirements](#environment-requirements)
- [Script Description](#script-description) - [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script and Sample Code](#script-and-sample-code)
- [Training Process](#training-process) - [Training Process](#training-process)
- [Evaluation Process](#evaluation-process) - [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Evaluation](#evaluation)
- [Model Description](#model-description) - [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [Performance](#performance)
- [Training Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation) - [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage) - [ModelZoo Homepage](#modelzoo-homepage)


# [MobileNetV2 Description](#contents) # [MobileNetV2 Description](#contents)



MobileNetV2 is tuned to mobile phone CPUs through a combination of hardware- aware network architecture search (NAS) complemented by the NetAdapt algorithm and then subsequently improved through novel architecture advances.Nov 20, 2019. MobileNetV2 is tuned to mobile phone CPUs through a combination of hardware- aware network architecture search (NAS) complemented by the NetAdapt algorithm and then subsequently improved through novel architecture advances.Nov 20, 2019.


[Paper](https://arxiv.org/pdf/1905.02244) Howard, Andrew, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang et al. "Searching for MobileNetV2." In Proceedings of the IEEE International Conference on Computer Vision, pp. 1314-1324. 2019. [Paper](https://arxiv.org/pdf/1905.02244) Howard, Andrew, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang et al. "Searching for MobileNetV2." In Proceedings of the IEEE International Conference on Computer Vision, pp. 1314-1324. 2019.
@@ -36,78 +35,103 @@ The overall network architecture of MobileNetV2 is show below:
Dataset used: [imagenet](http://www.image-net.org/) Dataset used: [imagenet](http://www.image-net.org/)


- Dataset size: ~125G, 1.2W colorful images in 1000 classes - Dataset size: ~125G, 1.2W colorful images in 1000 classes
- Train: 120G, 1.2W images
- Test: 5G, 50000 images
- Train: 120G, 1.2W images
- Test: 5G, 50000 images
- Data format: RGB images. - Data format: RGB images.
- Note: Data will be processed in src/dataset.py

- Note: Data will be processed in src/dataset.py


# [Features](#contents) # [Features](#contents)


## [Mixed Precision(Ascend)](#contents) ## [Mixed Precision(Ascend)](#contents)


The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’. For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’.


# [Environment Requirements](#contents) # [Environment Requirements](#contents)


- Hardware(Ascend/GPU)
- Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Hardware(Ascend/GPU/CPU
- Prepare hardware environment with Ascend、GPU or CPU processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework - Framework
- [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/) - [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/)
- For more information, please check the resources below: - For more information, please check the resources below:
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)



# [Script description](#contents) # [Script description](#contents)


## [Script and sample code](#contents) ## [Script and sample code](#contents)


```python ```python
├── MobileNetV2
├── Readme.md # descriptions about MobileNetV2
├── scripts
│ ├──run_train.sh # shell script for train
│ ├──run_eval.sh # shell script for evaluation
├── src
│ ├──config.py # parameter configuration
├── MobileNetV2
├── Readme.md # descriptions about MobileNetV2
├── scripts
│ ├──run_train.sh # shell script for train, fine_tune or incremental learn with CPU, GPU or Ascend
│ ├──run_eval.sh # shell script for evaluation with CPU, GPU or Ascend
├── src
│ ├──args.py # parse args
│ ├──config.py # parameter configuration
│ ├──dataset.py # creating dataset │ ├──dataset.py # creating dataset
│ ├──launch.py # start python script │ ├──launch.py # start python script
│ ├──lr_generator.py # learning rate config
│ ├──lr_generator.py # learning rate config
│ ├──mobilenetV2.py # MobileNetV2 architecture │ ├──mobilenetV2.py # MobileNetV2 architecture
│ ├──models.py # contain define_net and Loss, Monitor
│ ├──utils.py # utils to load ckpt_file for fine tune or incremental learn
├── train.py # training script ├── train.py # training script
├── eval.py # evaluation script
├── eval.py # evaluation script
``` ```


## [Training process](#contents) ## [Training process](#contents)


### Usage ### Usage



You can start training using python or shell scripts. The usage of shell scripts as follows: You can start training using python or shell scripts. The usage of shell scripts as follows:


- Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH]
- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]
- Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH]
- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH]
- CPU: sh run_trian.sh CPU [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH]


### Launch ### Launch


```
```
# training example # training example
python: python:
Ascend: python train.py --dataset_path ~/imagenet/train/ --device_targe Ascend
GPU: python train.py --dataset_path ~/imagenet/train/ --device_targe GPU
Ascend: python train.py --dataset_path ~/imagenet/train/ --platform Ascend --train_method train
GPU: python train.py --dataset_path ~/imagenet/train/ --platform GPU --train_method train
CPU: python train.py --dataset_path ~/imagenet/train/ --platform CPU --train_method train

shell:
Ascend: sh run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json ~/imagenet/train/ train
GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ train
CPU: sh run_train.sh CPU ~/imagenet/train/ train

# fine tune example
python:
Ascend: python train.py --dataset_path ~/imagenet/train/ --platform Ascend --train_method fine_tune ./pretrain_checkpoint/mobilenetv2_199.ckpt
GPU: python train.py --dataset_path ~/imagenet/train/ --platform GPU --train_method fine_tune ./pretrain_checkpoint/mobilenetv2_199.ckpt
CPU: python train.py --dataset_path ~/imagenet/train/ --platform CPU --train_method fine_tune ./pretrain_checkpoint/mobilenetv2_199.ckpt


shell: shell:
Ascend: sh run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json ~/imagenet/train/ mobilenet_199.ckpt
GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/
Ascend: sh run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json ~/imagenet/train/ fine_tune ./pretrain_checkpoint/mobilenetv2_199.ckpt
GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ fine_tune ./pretrain_checkpoint/mobilenetv2_199.ckpt
CPU: sh run_train.sh CPU ~/imagenet/train/ fine_tune ./pretrain_checkpoint/mobilenetv2_199.ckpt

# incremental learn example
python:
Ascend: python train.py --dataset_path ~/imagenet/train/ --platform Ascend --train_method incremental_learn ./pretrain_checkpoint/mobilenetv2_199.ckpt
GPU: python train.py --dataset_path ~/imagenet/train/ --platform GPU --train_method incremental_learn ./pretrain_checkpoint/mobilenetv2_199.ckpt
CPU: python train.py --dataset_path ~/imagenet/train/ --platform CPU --train_method incremental_learn ./pretrain_checkpoint/mobilenetv2_199.ckpt

shell:
Ascend: sh run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json ~/imagenet/train/ incremental_learn ./pretrain_checkpoint/mobilenetv2_199.ckpt
GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ incremental_learn ./pretrain_checkpoint/mobilenetv2_199.ckpt
CPU: sh run_train.sh CPU ~/imagenet/train/ incremental_learn ./pretrain_checkpoint/mobilenetv2_199.ckpt
``` ```


### Result ### Result


Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings.
Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings.


```
```
epoch: [ 0/200], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100] epoch: [ 0/200], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100]
epoch time: 140522.500, per step time: 224.836, avg loss: 5.258 epoch time: 140522.500, per step time: 224.836, avg loss: 5.258
epoch: [ 1/200], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200] epoch: [ 1/200], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200]
@@ -120,29 +144,32 @@ epoch time: 138331.250, per step time: 221.330, avg loss: 3.917


You can start training using python or shell scripts. The usage of shell scripts as follows: You can start training using python or shell scripts. The usage of shell scripts as follows:


- Ascend: sh run_infer.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH]
- GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH]
- Ascend: sh run_infer.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH] [HEAD_CKPT_PATH]
- GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] [HEAD_CKPT_PATH]
- CPU: sh run_infer.sh CPU [DATASET_PATH] [BACKBONE_CKPT_PATH] [HEAD_CKPT_PATH]


### Launch ### Launch


```
```
# infer example # infer example
python: python:
Ascend: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe Ascend
GPU: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe GPU
Ascend: python eval.py --dataset_path ~/imagenet/val/ --pretrain_ckpt ~/train/mobilenet-200_625.ckpt --platform Ascend --head_ckpt ./checkpoint/mobilenetv2_199.ckpt
GPU: python eval.py --dataset_path ~/imagenet/val/ --pretrain_ckpt ~/train/mobilenet-200_625.ckpt --platform GPU --head_ckpt ./checkpoint/mobilenetv2_199.ckpt
CPU: python eval.py --dataset_path ~/imagenet/val/ --pretrain_ckpt ~/train/mobilenet-200_625.ckpt --platform CPU --head_ckpt ./checkpoint/mobilenetv2_199.ckpt


shell: shell:
Ascend: sh run_infer.sh Ascend ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt
GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt
Ascend: sh run_infer.sh Ascend ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt ./checkpoint/mobilenetv2_199.ckpt
GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt ./checkpoint/mobilenetv2_199.ckpt
CPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt ./checkpoint/mobilenetv2_199.ckpt
``` ```


> checkpoint can be produced in training process.
> checkpoint can be produced in training process.


### Result ### Result


Inference result will be stored in the example path, you can find result like the followings in `val.log`.
Inference result will be stored in the example path, you can find result like the followings in `val.log`.


```
```
result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt
``` ```


@@ -177,7 +204,7 @@ result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.
| Model Version | V1 | | | | Model Version | V1 | | |
| Resource | Ascend 910 | NV SMX2 V100-32G | Ascend 310 | | Resource | Ascend 910 | NV SMX2 V100-32G | Ascend 310 |
| uploaded Date | 05/06/2020 | 05/22/2020 | | | uploaded Date | 05/06/2020 | 05/22/2020 | |
| MindSpore Version | 0.2.0 | 0.2.0 | 0.2.0 |
| MindSpore Version | 0.2.0 | 0.2.0 | 0.2.0 |
| Dataset | ImageNet, 1.2W | ImageNet, 1.2W | ImageNet, 1.2W | | Dataset | ImageNet, 1.2W | ImageNet, 1.2W | ImageNet, 1.2W |
| batch_size | | 130(8P) | | | batch_size | | 130(8P) | |
| outputs | | | | | outputs | | | |
@@ -191,6 +218,5 @@ result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py. In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.


# [ModelZoo Homepage](#contents) # [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

+ 25
- 44
model_zoo/official/cv/mobilenetv2/eval.py View File

@@ -15,62 +15,43 @@
""" """
eval. eval.
""" """
import os
import argparse
from mindspore import context
from mindspore import nn from mindspore import nn
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from src.dataset import create_dataset
from src.config import config_ascend, config_gpu
from src.mobilenetV2 import mobilenet_v2


from src.dataset import create_dataset
from src.config import set_config
from src.mobilenetV2 import MobileNetV2Backbone, MobileNetV2Head, mobilenet_v2
from src.args import eval_parse_args
from src.models import load_ckpt
from src.utils import switch_precision, set_context


parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default=None, help='run device_target')
args_opt = parser.parse_args()
if __name__ == '__main__':
args_opt = eval_parse_args()
config = set_config(args_opt)


backbone_net = MobileNetV2Backbone(platform=args_opt.platform)
head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes)
net = mobilenet_v2(feature_net, head_net)


if __name__ == '__main__':
config = None
net = None
if args_opt.device_target == "Ascend":
config = config_ascend
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
device_id=device_id, save_graphs=False)
net = mobilenet_v2(num_classes=config.num_classes, device_target="Ascend")
elif args_opt.device_target == "GPU":
config = config_gpu
context.set_context(mode=context.GRAPH_MODE,
device_target="GPU", save_graphs=False)
net = mobilenet_v2(num_classes=config.num_classes, device_target="GPU")
#load the trained checkpoint file to the net for evaluation
if args_opt.head_ckpt:
load_ckpt(backbone_net, args_opt.pretrain_ckpt)
load_ckpt(head_net, args_opt.head_ckpt)
else: else:
raise ValueError("Unsupported device_target.")
load_ckpt(net, args_opt.pretrain_ckpt)


loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
set_context(config)
switch_precision(net, mstype.float16, config)


if args_opt.device_target == "Ascend":
net.to_float(mstype.float16)
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Dense):
cell.to_float(mstype.float32)

dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=False,
config=config,
device_target=args_opt.device_target,
batch_size=config.batch_size)
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, config=config)
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()

if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False) net.set_train(False)


loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
model = Model(net, loss_fn=loss, metrics={'acc'}) model = Model(net, loss_fn=loss, metrics={'acc'})

res = model.eval(dataset) res = model.eval(dataset)
print("result:", res, "ckpt=", args_opt.checkpoint_path)
print(f"result:{res}\npretrain_ckpt={args_opt.pretrain_ckpt}")
if args_opt.head_ckpt:
print(f"head_ckpt={args_opt.head_ckpt}")

+ 108
- 29
model_zoo/official/cv/mobilenetv2/scripts/run_infer.sh View File

@@ -13,10 +13,106 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
if [ $# != 3 ]



run_ascend()
{
# check checkpoint file
if [ ! -f $3 ]
then
echo "error: CHECKPOINT_PATH=$3 is not a file"
exit 1
fi

# set environment
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
if [ -d "../eval" ];
then
rm -rf ../eval
fi
mkdir ../eval
cd ../eval || exit

# launch
python ${BASEPATH}/../eval.py \
--platform=$1 \
--dataset_path=$2 \
--pretrain_ckpt=$3 \
--head_ckpt=$4 \
&> ../infer.log & # dataset val folder path
}

run_gpu()
{
# check checkpoint file
if [ ! -f $3 ]
then
echo "error: CHECKPOINT_PATH=$3 is not a file"
exit 1
fi

BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "../eval" ];
then
rm -rf ../eval
fi
mkdir ../eval
cd ../eval || exit

python ${BASEPATH}/../eval.py \
--platform=$1 \
--dataset_path=$2 \
--pretrain_ckpt=$3 \
--head_ckpt=$4 \
&> ../infer.log & # dataset train folder
}

run_cpu()
{
# check checkpoint file
if [ ! -f $3 ]
then
echo "error: BACKBONE_CKPT=$3 is not a file"
exit 1
fi

# check checkpoint file
if [ ! -f $4 ]
then
echo "error: HEAD_CKPT=$4 is not a file"
exit 1
fi


BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "../eval" ];
then
rm -rf ../eval
fi
mkdir ../eval
cd ../eval || exit

python ${BASEPATH}/../eval.py \
--platform=$1 \
--dataset_path=$2 \
--pretrain_ckpt=$3 \
--head_ckpt=$4 \
&> ../infer.log & # dataset train folder
}


if [ $# -gt 4 ] || [ $# -lt 3 ]
then then
echo "Ascend: sh run_infer.sh [DEVICE_TARGET] [DATASET_PATH] [CHECKPOINT_PATH] \
GPU: sh run_infer.sh [DEVICE_TARGET] [DATASET_PATH] [CHECKPOINT_PATH]"
echo "Ascend: sh run_infer.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT] \
GPU: sh run_infer.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT]
CPU: sh run_infer.sh [PLATFORM] [DATASET_PATH] [BACKBONE_CKPT] [HEAD_CKPT]"
exit 1 exit 1
fi fi


@@ -27,29 +123,12 @@ then
exit 1 exit 1
fi fi


# check checkpoint file
if [ ! -f $3 ]
then
echo "error: CHECKPOINT_PATH=$3 is not a file"
exit 1
fi

# set environment
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
if [ -d "../eval" ];
then
rm -rf ../eval
fi
mkdir ../eval
cd ../eval || exit

# launch
python ${BASEPATH}/../eval.py \
--device_target=$1 \
--dataset_path=$2 \
--checkpoint_path=$3 \
&> ../infer.log & # dataset val folder path
if [ $1 = "CPU" ] ; then
run_cpu "$@"
elif [ $1 = "GPU" ] ; then
run_gpu "$@"
elif [ $1 = "Ascend" ] ; then
run_ascend "$@"
else
echo "Unsupported device_target."
fi;

+ 39
- 8
model_zoo/official/cv/mobilenetv2/scripts/run_train.sh View File

@@ -38,12 +38,14 @@ run_ascend()
mkdir ../train mkdir ../train
cd ../train || exit cd ../train || exit
python ${BASEPATH}/../src/launch.py \ python ${BASEPATH}/../src/launch.py \
--platform=$1 \
--nproc_per_node=$2 \ --nproc_per_node=$2 \
--visible_devices=$3 \ --visible_devices=$3 \
--training_script=${BASEPATH}/../train.py \ --training_script=${BASEPATH}/../train.py \
--dataset_path=$5 \ --dataset_path=$5 \
--pre_trained=$6 \
--device_target=$1 &> ../train.log & # dataset train folder
--train_method=$6 \
--pretrain_ckpt=$7 \
&> ../train.log & # dataset train folder
} }


run_gpu() run_gpu()
@@ -72,17 +74,45 @@ run_gpu()
export CUDA_VISIBLE_DEVICES="$3" export CUDA_VISIBLE_DEVICES="$3"
mpirun -n $2 --allow-run-as-root \ mpirun -n $2 --allow-run-as-root \
python ${BASEPATH}/../train.py \ python ${BASEPATH}/../train.py \
--platform=$1 \
--dataset_path=$4 \ --dataset_path=$4 \
--pre_trained=$5 \
--device_target=$1 \
--train_method=$5 \
--pretrain_ckpt=$6 \
&> ../train.log & # dataset train folder &> ../train.log & # dataset train folder
} }


if [ $# -gt 6 ] || [ $# -lt 4 ]
run_cpu()
{

if [ ! -d $4 ]
then
echo "error: DATASET_PATH=$4 is not a directory"
exit 1
fi

BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "../train" ];
then
rm -rf ../train
fi
mkdir ../train
cd ../train || exit

python ${BASEPATH}/../train.py \
--platform=$1 \
--dataset_path=$2 \
--train_method=$3 \
--pretrain_ckpt=$4 \
&> ../train.log & # dataset train folder
}

if [ $# -gt 7 ] || [ $# -lt 4 ]
then then
echo "Usage:\n \ echo "Usage:\n \
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH]\n \
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH] \n \
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH]\n \
CPU: sh run_train.sh CPU [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH]\n \
" "
exit 1 exit 1
fi fi
@@ -91,7 +121,8 @@ if [ $1 = "Ascend" ] ; then
run_ascend "$@" run_ascend "$@"
elif [ $1 = "GPU" ] ; then elif [ $1 = "GPU" ] ; then
run_gpu "$@" run_gpu "$@"
elif [ $1 = "CPU" ] ; then
run_cpu "$@"
else else
echo "Unsupported device_target." echo "Unsupported device_target."
fi; fi;


+ 65
- 0
model_zoo/official/cv/mobilenetv2/src/args.py View File

@@ -0,0 +1,65 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import argparse


def launch_parse_args():

launch_parser = argparse.ArgumentParser(description="mindspore distributed training launch helper utilty \
that will spawn up multiple distributed processes")
launch_parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), \
help='run platform, only support GPU, CPU and Ascend')
launch_parser.add_argument("--nproc_per_node", type=int, default=1, choices=(0, 1, 2, 3, 4, 5, 6, 7), \
help="The number of processes to launch on each node, for D training, this is recommended to be set \
to the number of D in your system so that each process can be bound to a single D.")
launch_parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", help="will use the \
visible devices sequentially")
launch_parser.add_argument("--training_script", type=str, default="./train.py", help="The full path to \
the single D training program/script to be launched in parallel, followed by all the arguments for \
the training script")

launch_args, unknown = launch_parser.parse_known_args()
launch_args.train_script_args = unknown
launch_args.training_script_args += ["--platform", launch_args.platform]
return launch_args

def train_parse_args():
train_parser = argparse.ArgumentParser(description='Image classification trian')
train_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path')
train_parser.add_argument('--platform', type=str, default="Ascend", choices=("CPU", "GPU", "Ascend"), \
help='run platform, only support CPU, GPU and Ascend')
train_parser.add_argument('--pretrain_ckpt', type=str, default=None, help='Pretrained checkpoint path \
for fine tune or incremental learning')
train_parser.add_argument('--train_method', type=str, choices=("train", "fine_tune", "incremental_learn"), \
help="\"fine_tune\"or \"incremental_learn\" if to fine tune the net after loading the ckpt, \"train\" to \
train from initialization model")

train_args = train_parser.parse_args()
return train_args

def eval_parse_args():
eval_parser = argparse.ArgumentParser(description='Image classification eval')
eval_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path')
eval_parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), \
help='run platform, only support GPU, CPU and Ascend')
eval_parser.add_argument('--pretrain_ckpt', type=str, default=None, help='Pretrained checkpoint path \
for fine tune or incremental learning')
eval_parser.add_argument('--head_ckpt', type=str, default=None, help='Pretrained checkpoint path \
for fine tune or incremental learning')
eval_args = eval_parser.parse_args()

return eval_args

+ 74
- 34
model_zoo/official/cv/mobilenetv2/src/config.py View File

@@ -15,40 +15,80 @@
""" """
network config setting, will be used in train.py and eval.py network config setting, will be used in train.py and eval.py
""" """
import os
from easydict import EasyDict as ed from easydict import EasyDict as ed


config_ascend = ed({
"num_classes": 1000,
"image_height": 224,
"image_width": 224,
"batch_size": 256,
"epoch_size": 200,
"warmup_epochs": 4,
"lr": 0.4,
"momentum": 0.9,
"weight_decay": 4e-5,
"label_smooth": 0.1,
"loss_scale": 1024,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 200,
"save_checkpoint_path": "./checkpoint",
})
def set_config(args):
config_cpu = ed({
"num_classes": 26,
"image_height": 224,
"image_width": 224,
"batch_size": 150,
"epoch_size": 15,
"warmup_epochs": 0,
"lr_max": 0.03,
"lr_end": 0.03,
"momentum": 0.9,
"weight_decay": 4e-5,
"label_smooth": 0.1,
"loss_scale": 1024,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 20,
"save_checkpoint_path": "./checkpoint",
"platform": args.platform
})
config_gpu = ed({
"num_classes": 1000,
"image_height": 224,
"image_width": 224,
"batch_size": 150,
"epoch_size": 200,
"warmup_epochs": 0,
"lr": 0.8,
"lr_max": 0.03,
"lr_end": 0.03,
"momentum": 0.9,
"weight_decay": 4e-5,
"label_smooth": 0.1,
"loss_scale": 1024,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 200,
"save_checkpoint_path": "./checkpoint",
"platform": args.platform,
"ccl": "nccl",
})
config_ascend = ed({
"num_classes": 1000,
"image_height": 224,
"image_width": 224,
"batch_size": 256,
"epoch_size": 200,
"warmup_epochs": 4,
"lr": 0.4,
"lr_max": 0.03,
"lr_end": 0.03,
"momentum": 0.9,
"weight_decay": 4e-5,
"label_smooth": 0.1,
"loss_scale": 1024,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 200,
"save_checkpoint_path": "./checkpoint",
"platform": args.platform,
"ccl": "hccl",
"device_id": int(os.getenv('DEVICE_ID', '0')),
"rank_id": int(os.getenv('RANK_ID', '0')),
"rank_size": int(os.getenv('RANK_SIZE', '1')),
"run_distribute": int(os.getenv('RANK_SIZE', '1')) > 1.
})
config = ed({"CPU": config_cpu,
"GPU": config_gpu,
"Ascend": config_ascend})


config_gpu = ed({
"num_classes": 1000,
"image_height": 224,
"image_width": 224,
"batch_size": 150,
"epoch_size": 200,
"warmup_epochs": 0,
"lr": 0.8,
"momentum": 0.9,
"weight_decay": 4e-5,
"label_smooth": 0.1,
"loss_scale": 1024,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 200,
"save_checkpoint_path": "./checkpoint",
})
if args.platform not in config.keys():
raise ValueError("Unsupport platform.")

return config[args.platform]

+ 40
- 7
model_zoo/official/cv/mobilenetv2/src/dataset.py View File

@@ -16,25 +16,31 @@
create train or eval dataset. create train or eval dataset.
""" """
import os import os
from tqdm import tqdm
import numpy as np

from mindspore import Tensor
from mindspore.train.model import Model
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as C import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2 import mindspore.dataset.transforms.c_transforms as C2


def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, batch_size=32):
def create_dataset(dataset_path, do_train, config, repeat_num=1):
""" """
create a train or eval dataset create a train or eval dataset


Args: Args:
dataset_path(string): the path of dataset. dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval. do_train(bool): whether dataset is used for train or eval.
config(struct): the config of train and eval in diffirent platform.
repeat_num(int): the repeat times of dataset. Default: 1. repeat_num(int): the repeat times of dataset. Default: 1.
batch_size(int): the batch size of dataset. Default: 32.


Returns: Returns:
dataset dataset
""" """
if device_target == "Ascend":
if config.platform == "Ascend":
rank_size = int(os.getenv("RANK_SIZE", '1')) rank_size = int(os.getenv("RANK_SIZE", '1'))
rank_id = int(os.getenv("RANK_ID", '0')) rank_id = int(os.getenv("RANK_ID", '0'))
if rank_size == 1: if rank_size == 1:
@@ -42,15 +48,16 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1,
else: else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=rank_size, shard_id=rank_id) num_shards=rank_size, shard_id=rank_id)
elif device_target == "GPU":
elif config.platform == "GPU":
if do_train: if do_train:
from mindspore.communication.management import get_rank, get_group_size from mindspore.communication.management import get_rank, get_group_size
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank()) num_shards=get_group_size(), shard_id=get_rank())
else: else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
else:
raise ValueError("Unsupported device_target.")
elif config.platform == "CPU":
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)



resize_height = config.image_height resize_height = config.image_height
resize_width = config.image_width resize_width = config.image_width
@@ -81,9 +88,35 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1,
ds = ds.shuffle(buffer_size=buffer_size) ds = ds.shuffle(buffer_size=buffer_size)


# apply batch operations # apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.batch(config.batch_size, drop_remainder=True)


# apply dataset repeat operation # apply dataset repeat operation
ds = ds.repeat(repeat_num) ds = ds.repeat(repeat_num)


return ds return ds

def extract_features(net, dataset_path, config):
features_folder = dataset_path + '_features'
if not os.path.exists(features_folder):
os.makedirs(features_folder)
dataset = create_dataset(dataset_path=dataset_path,
do_train=False,
config=config,
repeat_num=1)
step_size = dataset.get_dataset_size()
pbar = tqdm(list(dataset.create_dict_iterator()))
model = Model(net)
i = 0
for data in pbar:
features_path = os.path.join(features_folder, f"feature_{i}.npy")
label_path = os.path.join(features_folder, f"label_{i}.npy")
if not(os.path.exists(features_path) and os.path.exists(label_path)):
image = data["image"]
label = data["label"]
features = model.predict(Tensor(image))
np.save(features_path, features.asnumpy())
np.save(label_path, label)
pbar.set_description("Process dataset batch: %d"%(i+1))
i += 1

return step_size

+ 4
- 37
model_zoo/official/cv/mobilenetv2/src/launch.py View File

@@ -17,44 +17,11 @@ import os
import sys import sys
import subprocess import subprocess
import shutil import shutil
from argparse import ArgumentParser

def parse_args():
"""
parse args .

Args:

Returns:
args.

Examples:
>>> parse_args()
"""
parser = ArgumentParser(description="mindspore distributed training launch "
"helper utilty that will spawn up "
"multiple distributed processes")
parser.add_argument("--nproc_per_node", type=int, default=1,
help="The number of processes to launch on each node, "
"for D training, this is recommended to be set "
"to the number of D in your system so that "
"each process can be bound to a single D.")
parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7",
help="will use the visible devices sequentially")
parser.add_argument("--training_script", type=str,
help="The full path to the single D training "
"program/script to be launched in parallel, "
"followed by all the arguments for the "
"training script")
# rest from the training program
args, unknown = parser.parse_known_args()
args.training_script_args = unknown
return args

from args import launch_parse_args


def main(): def main():
print("start", __file__) print("start", __file__)
args = parse_args()
args = launch_parse_args()
print(args) print(args)
visible_devices = args.visible_devices.split(',') visible_devices = args.visible_devices.split(',')
assert os.path.isfile(args.training_script) assert os.path.isfile(args.training_script)
@@ -79,8 +46,8 @@ def main():
os.mkdir(device_dir) os.mkdir(device_dir)
os.chdir(device_dir) os.chdir(device_dir)
cmd = [sys.executable, '-u'] cmd = [sys.executable, '-u']
cmd.append(args.training_script)
cmd.extend(args.training_script_args)
cmd.append(args.train_script)
cmd.extend(args.train_script_args)
log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w') log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w')
process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env) process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env)
processes.append(process) processes.append(process)


+ 121
- 29
model_zoo/official/cv/mobilenetv2/src/mobilenetV2.py View File

@@ -20,7 +20,7 @@ from mindspore.ops.operations import TensorAdd
from mindspore import Parameter, Tensor from mindspore import Parameter, Tensor
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer


__all__ = ['mobilenet_v2']
__all__ = ['MobileNetV2', 'MobileNetV2Backbone', 'MobileNetV2Head', 'mobilenet_v2']




def _make_divisible(v, divisor, min_value=None): def _make_divisible(v, divisor, min_value=None):
@@ -119,17 +119,19 @@ class ConvBNReLU(nn.Cell):
>>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
""" """


def __init__(self, device_target, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
def __init__(self, platform, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
super(ConvBNReLU, self).__init__() super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2 padding = (kernel_size - 1) // 2
if groups == 1: if groups == 1:
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding) conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding)
else: else:
if device_target == "Ascend":
if platform in ("CPU", "GPU"):
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, group=in_planes, pad_mode='pad', \
padding=padding)
elif platform == "Ascend":
conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding) conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding)
elif device_target == "GPU":
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride,
group=in_planes, pad_mode='pad', padding=padding)
else:
raise ValueError("Unsupported Device, only support CPU, GPU and Ascend.")


layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()] layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()]
self.features = nn.SequentialCell(layers) self.features = nn.SequentialCell(layers)
@@ -156,7 +158,7 @@ class InvertedResidual(nn.Cell):
>>> ResidualBlock(3, 256, 1, 1) >>> ResidualBlock(3, 256, 1, 1)
""" """


def __init__(self, device_target, inp, oup, stride, expand_ratio):
def __init__(self, platform, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__() super(InvertedResidual, self).__init__()
assert stride in [1, 2] assert stride in [1, 2]


@@ -165,10 +167,10 @@ class InvertedResidual(nn.Cell):


layers = [] layers = []
if expand_ratio != 1: if expand_ratio != 1:
layers.append(ConvBNReLU(device_target, inp, hidden_dim, kernel_size=1))
layers.append(ConvBNReLU(platform, inp, hidden_dim, kernel_size=1))
layers.extend([ layers.extend([
# dw # dw
ConvBNReLU(device_target, hidden_dim, hidden_dim,
ConvBNReLU(platform, hidden_dim, hidden_dim,
stride=stride, groups=hidden_dim), stride=stride, groups=hidden_dim),
# pw-linear # pw-linear
nn.Conv2d(hidden_dim, oup, kernel_size=1, nn.Conv2d(hidden_dim, oup, kernel_size=1,
@@ -186,8 +188,7 @@ class InvertedResidual(nn.Cell):
return self.add(identity, x) return self.add(identity, x)
return x return x



class MobileNetV2(nn.Cell):
class MobileNetV2Backbone(nn.Cell):
""" """
MobileNetV2 architecture. MobileNetV2 architecture.


@@ -204,12 +205,10 @@ class MobileNetV2(nn.Cell):
>>> MobileNetV2(num_classes=1000) >>> MobileNetV2(num_classes=1000)
""" """


def __init__(self, device_target, num_classes=1000, width_mult=1.,
has_dropout=False, inverted_residual_setting=None, round_nearest=8):
super(MobileNetV2, self).__init__()
def __init__(self, platform, width_mult=1., inverted_residual_setting=None, round_nearest=8,
input_channel=32, last_channel=1280):
super(MobileNetV2Backbone, self).__init__()
block = InvertedResidual block = InvertedResidual
input_channel = 32
last_channel = 1280
# setting of inverted residual blocks # setting of inverted residual blocks
self.cfgs = inverted_residual_setting self.cfgs = inverted_residual_setting
if inverted_residual_setting is None: if inverted_residual_setting is None:
@@ -227,28 +226,22 @@ class MobileNetV2(nn.Cell):
# building first layer # building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest) input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [ConvBNReLU(device_target, 3, input_channel, stride=2)]
features = [ConvBNReLU(platform, 3, input_channel, stride=2)]
# building inverted residual blocks # building inverted residual blocks
for t, c, n, s in self.cfgs: for t, c, n, s in self.cfgs:
output_channel = _make_divisible(c * width_mult, round_nearest) output_channel = _make_divisible(c * width_mult, round_nearest)
for i in range(n): for i in range(n):
stride = s if i == 0 else 1 stride = s if i == 0 else 1
features.append(block(device_target, input_channel, output_channel, stride, expand_ratio=t))
features.append(block(platform, input_channel, output_channel, stride, expand_ratio=t))
input_channel = output_channel input_channel = output_channel
# building last several layers # building last several layers
features.append(ConvBNReLU(device_target, input_channel, self.out_channels, kernel_size=1))
features.append(ConvBNReLU(platform, input_channel, self.out_channels, kernel_size=1))
# make it nn.CellList # make it nn.CellList
self.features = nn.SequentialCell(features) self.features = nn.SequentialCell(features)
# mobilenet head
head = ([GlobalAvgPooling(), nn.Dense(self.out_channels, num_classes, has_bias=True)] if not has_dropout else
[GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(self.out_channels, num_classes, has_bias=True)])
self.head = nn.SequentialCell(head)

self._initialize_weights() self._initialize_weights()


def construct(self, x): def construct(self, x):
x = self.features(x) x = self.features(x)
x = self.head(x)
return x return x


def _initialize_weights(self): def _initialize_weights(self):
@@ -277,16 +270,115 @@ class MobileNetV2(nn.Cell):
Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
m.beta.set_parameter_data( m.beta.set_parameter_data(
Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
elif isinstance(m, nn.Dense):

@property
def get_features(self):
return self.features

class MobileNetV2Head(nn.Cell):
"""
MobileNetV2 architecture.

Args:
class_num (Cell): number of classes.
has_dropout (bool): Is dropout used. Default is false
Returns:
Tensor, output tensor.

Examples:
>>> MobileNetV2(num_classes=1000)
"""

def __init__(self, input_channel=1280, num_classes=1000, has_dropout=False):
super(MobileNetV2Head, self).__init__()
# mobilenet head
head = ([GlobalAvgPooling(), nn.Dense(input_channel, num_classes, has_bias=True)] if not has_dropout else
[GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(input_channel, num_classes, has_bias=True)])
self.head = nn.SequentialCell(head)
self._initialize_weights()

def construct(self, x):
x = self.head(x)
return x

def _initialize_weights(self):
"""
Initialize weights.

Args:

Returns:
None.

Examples:
>>> _initialize_weights()
"""
self.init_parameters_data()
for _, m in self.cells_and_names():
if isinstance(m, nn.Dense):
m.weight.set_parameter_data(Tensor(np.random.normal( m.weight.set_parameter_data(Tensor(np.random.normal(
0, 0.01, m.weight.data.shape).astype("float32"))) 0, 0.01, m.weight.data.shape).astype("float32")))
if m.bias is not None: if m.bias is not None:
m.bias.set_parameter_data( m.bias.set_parameter_data(
Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
@property
def get_head(self):
return self.head

class MobileNetV2(nn.Cell):
"""
MobileNetV2 architecture.


Args:
backbone(nn.Cell):
head(nn.Cell):
Returns:
Tensor, output tensor.


def mobilenet_v2(**kwargs):
Examples:
>>> MobileNetV2(backbone, head)
""" """
Constructs a MobileNet V2 model

def __init__(self, platform, num_classes=1000, width_mult=1., has_dropout=False, inverted_residual_setting=None, \
round_nearest=8, input_channel=32, last_channel=1280):
super(MobileNetV2, self).__init__()
self.backbone = MobileNetV2Backbone(platform=platform, width_mult=width_mult, \
inverted_residual_setting=inverted_residual_setting, \
round_nearest=round_nearest, input_channel=input_channel, last_channel=last_channel).get_features
self.head = MobileNetV2Head(input_channel=self.backbone.out_channel, num_classes=num_classes, \
has_dropout=has_dropout).get_head

def construct(self, x):
x = self.backbone(x)
x = self.head(x)
return x

class MobileNetV2Combine(nn.Cell):
""" """
return MobileNetV2(**kwargs)
MobileNetV2 architecture.

Args:
class_num (Cell): number of classes.
width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1.
has_dropout (bool): Is dropout used. Default is false
inverted_residual_setting (list): Inverted residual settings. Default is None
round_nearest (list): Channel round to . Default is 8
Returns:
Tensor, output tensor.

Examples:
>>> MobileNetV2(num_classes=1000)
"""

def __init__(self, backbone, head):
super(MobileNetV2Combine, self).__init__()
self.backbone = backbone
self.head = head

def construct(self, x):
x = self.backbone(x)
x = self.head(x)
return x

def mobilenet_v2(backbone, head):
return MobileNetV2Combine(backbone, head)

+ 138
- 0
model_zoo/official/cv/mobilenetv2/src/models.py View File

@@ -0,0 +1,138 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import time
import numpy as np
from mindspore import Tensor
from mindspore import nn
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.train.callback import Callback
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.mobilenetV2 import MobileNetV2Backbone, MobileNetV2Head, mobilenet_v2

class CrossEntropyWithLabelSmooth(_Loss):
"""
CrossEntropyWith LabelSmooth.

Args:
smooth_factor (float): smooth factor, default=0.
num_classes (int): num classes

Returns:
None.

Examples:
>>> CrossEntropyWithLabelSmooth(smooth_factor=0., num_classes=1000)
"""

def __init__(self, smooth_factor=0., num_classes=1000):
super(CrossEntropyWithLabelSmooth, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor /
(num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
self.cast = P.Cast()

def construct(self, logit, label):
one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(logit)[1],
self.on_value, self.off_value)
out_loss = self.ce(logit, one_hot_label)
out_loss = self.mean(out_loss, 0)
return out_loss

class Monitor(Callback):
"""
Monitor loss and time.

Args:
lr_init (numpy array): train lr

Returns:
None

Examples:
>>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy())
"""

def __init__(self, lr_init=None):
super(Monitor, self).__init__()
self.lr_init = lr_init
self.lr_init_len = len(lr_init)

def epoch_begin(self, run_context):
self.losses = []
self.epoch_time = time.time()

def epoch_end(self, run_context):
cb_params = run_context.original_args()

epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / cb_params.batch_num
print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds,
per_step_mseconds,
np.mean(self.losses)))

def step_begin(self, run_context):
self.step_time = time.time()

def step_end(self, run_context):
cb_params = run_context.original_args()
step_mseconds = (time.time() - self.step_time) * 1000
step_loss = cb_params.net_outputs

if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
step_loss = step_loss[0]
if isinstance(step_loss, Tensor):
step_loss = np.mean(step_loss.asnumpy())

self.losses.append(step_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num

print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format(
cb_params.cur_epoch_num -
1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss,
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]))

def load_ckpt(network, pretrain_ckpt_path, trainable=True):
"""
incremental_learning or not
"""
param_dict = load_checkpoint(pretrain_ckpt_path)
load_param_into_net(network, param_dict)
if not trainable:
for param in network.get_parameters():
param.requires_grad = False

def define_net(args, config):
backbone_net = MobileNetV2Backbone(platform=args.platform)
head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes)
net = mobilenet_v2(backbone_net, head_net)

# load the ckpt file to the network for fine tune or incremental leaning
if args.pretrain_ckpt:
if args.train_method == "fine_tune":
load_ckpt(net, args.pretrain_ckpt)
elif args.train_method == "incremental_learn":
load_ckpt(backbone_net, args.pretrain_ckpt, trainable=False)
elif args.train_method == "train":
pass
else:
raise ValueError("must input the usage of pretrain_ckpt when the pretrain_ckpt isn't None")

return backbone_net, head_net, net

+ 93
- 0
model_zoo/official/cv/mobilenetv2/src/utils.py View File

@@ -0,0 +1,93 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import random
import numpy as np
from mindspore import context
from mindspore import nn
from mindspore.common import dtype as mstype
from mindspore.train.model import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.communication.management import get_rank, init
from mindspore.dataset import engine as de
from src.models import Monitor
def switch_precision(net, data_type, config):
if config.platform == "Ascend":
net.to_float(data_type)
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Dense):
cell.to_float(mstype.float32)
def context_device_init(config):
if config.platform == "CPU":
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False)
elif config.platform == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False)
init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
elif config.platform == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, device_id=config.device_id,
save_graphs=False)
if config.run_distribute:
context.set_auto_parallel_context(device_num=config.rank_size,
parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True, mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
else:
raise ValueError("Only support CPU, GPU and Ascend.")
def set_context(config):
if config.platform == "CPU":
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform,
save_graphs=False)
elif config.platform == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform,
device_id=config.device_id, save_graphs=False)
elif config.platform == "GPU":
context.set_context(mode=context.GRAPH_MODE,
device_target=args_opt.platform, save_graphs=False)
def config_ckpoint(config, lr, step_size):
cb = None
if config.platform in ("CPU", "GPU") or config.rank_id == 0:
cb = [Monitor(lr_init=lr.asnumpy())]
if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_save_dir = config.save_checkpoint_path
if config.platform == "GPU":
ckpt_save_dir += "ckpt_" + str(get_rank()) + "/"
ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb]
return cb
def set_random_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
de.config.set_seed(seed)

+ 86
- 233
model_zoo/official/cv/mobilenetv2/train.py View File

@@ -16,263 +16,116 @@


import os import os
import time import time
import argparse
import random import random
import numpy as np import numpy as np


from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore import nn
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn import WithLossCell, TrainOneStepCell
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init, get_group_size, get_rank
import mindspore.dataset.engine as de
from mindspore.train.serialization import _exec_save_checkpoint


from src.dataset import create_dataset
from src.dataset import create_dataset, extract_features
from src.lr_generator import get_lr from src.lr_generator import get_lr
from src.config import config_gpu, config_ascend
from src.mobilenetV2 import mobilenet_v2
from src.config import set_config


random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
from src.args import train_parse_args
from src.utils import set_random_seed, context_device_init, switch_precision, config_ckpoint
from src.models import CrossEntropyWithLabelSmooth, define_net


parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
parser.add_argument('--device_target', type=str, default=None, help='run device_target')
args_opt = parser.parse_args()
set_random_seed(1)


if args_opt.device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID', '0'))
rank_id = int(os.getenv('RANK_ID', '0'))
rank_size = int(os.getenv('RANK_SIZE', '1'))
run_distribute = rank_size > 1
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
device_id=device_id, save_graphs=False)
elif args_opt.device_target == "GPU":
context.set_context(mode=context.GRAPH_MODE,
device_target="GPU",
save_graphs=False)
init()
context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
else:
raise ValueError("Unsupported device target.")


class CrossEntropyWithLabelSmooth(_Loss):
"""
CrossEntropyWith LabelSmooth.

Args:
smooth_factor (float): smooth factor, default=0.
num_classes (int): num classes

Returns:
None.

Examples:
>>> CrossEntropyWithLabelSmooth(smooth_factor=0., num_classes=1000)
"""

def __init__(self, smooth_factor=0., num_classes=1000):
super(CrossEntropyWithLabelSmooth, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor /
(num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
self.cast = P.Cast()

def construct(self, logit, label):
one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(logit)[1],
self.on_value, self.off_value)
out_loss = self.ce(logit, one_hot_label)
out_loss = self.mean(out_loss, 0)
return out_loss


class Monitor(Callback):
"""
Monitor loss and time.

Args:
lr_init (numpy array): train lr

Returns:
None

Examples:
>>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy())
"""

def __init__(self, lr_init=None):
super(Monitor, self).__init__()
self.lr_init = lr_init
self.lr_init_len = len(lr_init)

def epoch_begin(self, run_context):
self.losses = []
self.epoch_time = time.time()

def epoch_end(self, run_context):
cb_params = run_context.original_args()

epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / cb_params.batch_num
print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds,
per_step_mseconds,
np.mean(self.losses)))

def step_begin(self, run_context):
self.step_time = time.time()

def step_end(self, run_context):
cb_params = run_context.original_args()
step_mseconds = (time.time() - self.step_time) * 1000
step_loss = cb_params.net_outputs

if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
step_loss = step_loss[0]
if isinstance(step_loss, Tensor):
step_loss = np.mean(step_loss.asnumpy())
if __name__ == '__main__':
args_opt = train_parse_args()
config = set_config(args_opt)
start = time.time()


self.losses.append(step_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
print(f"train args: {args_opt}\ncfg: {config}")


print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format(
cb_params.cur_epoch_num -
1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss,
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]))
#set context and device init
context_device_init(config)


# define network
backbone_net, head_net, net = define_net(args_opt, config)


if __name__ == '__main__':
if args_opt.device_target == "GPU":
# train on gpu
print("train args: ", args_opt)
print("cfg: ", config_gpu)
# CPU only support "incremental_learn"
if args_opt.train_method == "incremental_learn":
step_size = extract_features(backbone_net, args_opt.dataset_path, config)
net = head_net


# define network
net = mobilenet_v2(num_classes=config_gpu.num_classes, device_target="GPU")
# define loss
if config_gpu.label_smooth > 0:
loss = CrossEntropyWithLabelSmooth(smooth_factor=config_gpu.label_smooth,
num_classes=config_gpu.num_classes)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# define dataset
epoch_size = config_gpu.epoch_size
dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=True,
config=config_gpu,
device_target=args_opt.device_target,
repeat_num=1,
batch_size=config_gpu.batch_size)
elif args_opt.train_method in ("train", "fine_tune"):
if args_opt.platform == "CPU":
raise ValueError("Currently, CPU only support \"incremental_learn\", not \"fine_tune\" or \"train\".")
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, config=config)
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()
# resume
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(net, param_dict)


# get learning rate
loss_scale = FixedLossScaleManager(
config_gpu.loss_scale, drop_overflow_update=False)
lr = Tensor(get_lr(global_step=0,
lr_init=0,
lr_end=0,
lr_max=config_gpu.lr,
warmup_epochs=config_gpu.warmup_epochs,
total_epochs=epoch_size,
steps_per_epoch=step_size))
# Currently, only Ascend support switch precision.
switch_precision(net, mstype.float16, config)


# define optimization
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_gpu.momentum,
config_gpu.weight_decay, config_gpu.loss_scale)
# define model
# define loss
if config.label_smooth > 0:
loss = CrossEntropyWithLabelSmooth(
smooth_factor=config.label_smooth, num_classes=config.num_classes)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

epoch_size = config.epoch_size

# get learning rate
lr = Tensor(get_lr(global_step=0,
lr_init=0,
lr_end=config.lr_end,
lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs,
total_epochs=epoch_size,
steps_per_epoch=step_size))

if args_opt.train_method == "incremental_learn":
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay)

network = WithLossCell(net, loss)
network = TrainOneStepCell(net, opt)
network.set_train()

features_path = args_opt.dataset_path + '_features'
idx_list = list(range(step_size))

if os.path.isdir(config.save_checkpoint_path):
os.rename(config.save_checkpoint_path, "{}_{}".format(config.save_checkpoint_path, time.time()))
os.mkdir(config.save_checkpoint_path)

for epoch in range(epoch_size):
random.shuffle(idx_list)
epoch_start = time.time()
losses = []
for j in idx_list:
feature = Tensor(np.load(os.path.join(features_path, f"feature_{j}.npy")))
label = Tensor(np.load(os.path.join(features_path, f"label_{j}.npy")))
losses.append(network(feature, label).asnumpy())
epoch_mseconds = (time.time()-epoch_start) * 1000
per_step_mseconds = epoch_mseconds / step_size
# lr cause to pynative, but cpu doesn't support this mode
# print("\r epoch[{}], iter[{}] cost: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}, lr: {}"\
# .format(epoch + 1, step_step, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses)), \
# lr[(epoch+1)*step_size - 1]), end="")
print("\r epoch[{}], iter[{}] cost: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}"\
.format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))), \
end="")
if (epoch + 1) % config.save_checkpoint_epochs == 0:
_exec_save_checkpoint(network, os.path.join(config.save_checkpoint_path, \
f"mobilenetv2_head_{epoch+1}.ckpt"))
print("total cost {:5.4f} s".format(time.time() - start))

elif args_opt.train_method in ("train", "fine_tune"):
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, \
config.weight_decay, config.loss_scale)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale) model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale)


cb = config_ckpoint(config, lr, step_size)
print("============== Starting Training ==============") print("============== Starting Training ==============")
cb = [Monitor(lr_init=lr.asnumpy())]
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
if config_gpu.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config_gpu.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb]
# begin train
model.train(epoch_size, dataset, callbacks=cb) model.train(epoch_size, dataset, callbacks=cb)
print("============== End Training ==============") print("============== End Training ==============")
elif args_opt.device_target == "Ascend":
# train on ascend
print("train args: ", args_opt, "\ncfg: ", config_ascend,
"\nparallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size))

if run_distribute:
context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True, mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()

epoch_size = config_ascend.epoch_size
net = mobilenet_v2(num_classes=config_ascend.num_classes, device_target="Ascend")
net.to_float(mstype.float16)
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Dense):
cell.to_float(mstype.float32)
if config_ascend.label_smooth > 0:
loss = CrossEntropyWithLabelSmooth(
smooth_factor=config_ascend.label_smooth, num_classes=config_ascend.num_classes)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=True,
config=config_ascend,
device_target=args_opt.device_target,
repeat_num=1,
batch_size=config_ascend.batch_size)
step_size = dataset.get_dataset_size()
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(net, param_dict)

loss_scale = FixedLossScaleManager(
config_ascend.loss_scale, drop_overflow_update=False)
lr = Tensor(get_lr(global_step=0,
lr_init=0,
lr_end=0,
lr_max=config_ascend.lr,
warmup_epochs=config_ascend.warmup_epochs,
total_epochs=epoch_size,
steps_per_epoch=step_size))
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_ascend.momentum,
config_ascend.weight_decay, config_ascend.loss_scale)

model = Model(net, loss_fn=loss, optimizer=opt,
loss_scale_manager=loss_scale)

cb = None
if rank_id == 0:
cb = [Monitor(lr_init=lr.asnumpy())]
if config_ascend.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config_ascend.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config_ascend.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(
prefix="mobilenetV2", directory=config_ascend.save_checkpoint_path, config=config_ck)
cb += [ckpt_cb]
model.train(epoch_size, dataset, callbacks=cb)
else:
raise ValueError("Unsupported device_target.")

Loading…
Cancel
Save