From: @zhao_ting_v Reviewed-by: @oacjiewen,@wuxuejian Signed-off-by: @wuxuejianpull/15645/MERGE
| @@ -69,8 +69,8 @@ We use about 13K images as training dataset and 3K as evaluating dataset in this | |||
| # [Environment Requirements](#contents) | |||
| - Hardware(Ascend) | |||
| - Prepare hardware environment with Ascend processor. | |||
| - Hardware(Ascend, CPU) | |||
| - Prepare hardware environment with Ascend or CPU processor. | |||
| - Framework | |||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||
| - For more information, please check the resources below: | |||
| @@ -120,45 +120,45 @@ The entire code structure is as following: | |||
| ```bash | |||
| cd ./scripts | |||
| sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID] | |||
| bash run_standalone_train.sh [PLATFORM] [MINDRECORD_FILE] [USE_DEVICE_ID] | |||
| ``` | |||
| or (fine-tune) | |||
| ```bash | |||
| cd ./scripts | |||
| sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE] | |||
| bash run_standalone_train.sh [PLATFORM] [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE] | |||
| ``` | |||
| for example: | |||
| ```bash | |||
| cd ./scripts | |||
| sh run_standalone_train.sh /home/train.mindrecord 0 /home/a.ckpt | |||
| bash run_standalone_train.sh CPU /home/train.mindrecord 0 /home/a.ckpt | |||
| ``` | |||
| - Distribute mode (recommended) | |||
| ```bash | |||
| cd ./scripts | |||
| sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE] | |||
| bash run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE] | |||
| ``` | |||
| or (fine-tune) | |||
| ```bash | |||
| cd ./scripts | |||
| sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE] [PRETRAINED_BACKBONE] | |||
| bash run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE] [PRETRAINED_BACKBONE] | |||
| ``` | |||
| for example: | |||
| ```bash | |||
| cd ./scripts | |||
| sh run_distribute_train.sh /home/train.mindrecord ./rank_table_8p.json /home/a.ckpt | |||
| bash run_distribute_train.sh /home/train.mindrecord ./rank_table_8p.json /home/a.ckpt | |||
| ``` | |||
| You will get the loss value of each step as following in "./output/[TIME]/[TIME].log" or "./scripts/device0/train.log": | |||
| *Distribute mode doesn't support running on CPU*. You will get the loss value of each step as following in "./output/[TIME]/[TIME].log" or "./scripts/device0/train.log": | |||
| ```python | |||
| rank[0], iter[0], loss[318555.8], overflow:False, loss_scale:1024.0, lr:6.24999984211172e-06, batch_images:(64, 3, 448, 768), batch_labels:(64, 200, 6) | |||
| @@ -177,14 +177,14 @@ rank[0], iter[62499], loss[4294.194], overflow:False, loss_scale:256.0, lr:6.249 | |||
| ```bash | |||
| cd ./scripts | |||
| sh run_eval.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE] | |||
| bash run_eval.sh [PLATFORM] [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE] | |||
| ``` | |||
| for example: | |||
| ```bash | |||
| cd ./scripts | |||
| sh run_eval.sh /home/eval.mindrecord 0 /home/a.ckpt | |||
| bash run_eval.sh Ascend /home/eval.mindrecord 0 /home/a.ckpt | |||
| ``` | |||
| You will get the result as following in "./scripts/device0/eval.log": | |||
| @@ -202,7 +202,7 @@ If you want to infer the network on Ascend 310, you should convert the model to | |||
| ```bash | |||
| cd ./scripts | |||
| sh run_export.sh [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE] | |||
| bash run_export.sh [PLATFORM] [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE] | |||
| ``` | |||
| # [Model Description](#contents) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -35,14 +35,12 @@ from src.network_define import BuildTestNetwork, get_bounding_boxes, tensor_to_b | |||
| parse_gt_from_anno, parse_rets, calc_recall_precision_ap | |||
| plt.switch_backend('agg') | |||
| devid = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid) | |||
| def parse_args(): | |||
| '''parse_args''' | |||
| parser = argparse.ArgumentParser('Yolov3 Face Detection') | |||
| parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "CPU"), | |||
| help="run platform, support Ascend and CPU.") | |||
| parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord') | |||
| parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load') | |||
| parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed') | |||
| @@ -55,7 +53,8 @@ def parse_args(): | |||
| if __name__ == "__main__": | |||
| args = parse_args() | |||
| devid = int(os.getenv('DEVICE_ID', '0')) if args.run_platform != 'CPU' else 0 | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.run_platform, save_graphs=False, device_id=devid) | |||
| print('=============yolov3 start evaluating==================') | |||
| # logger | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -24,14 +24,11 @@ from mindspore.train.serialization import export, load_checkpoint, load_param_in | |||
| from src.FaceDetection.yolov3 import HwYolov3 as backbone_HwYolov3 | |||
| from src.config import config | |||
| devid = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid) | |||
| def save_air(args): | |||
| '''save air''' | |||
| print('============= yolov3 start save air ==================') | |||
| devid = int(os.getenv('DEVICE_ID', '0')) if args.run_platform != 'CPU' else 0 | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.run_platform, save_graphs=False, device_id=devid) | |||
| num_classes = config.num_classes | |||
| anchors_mask = config.anchors_mask | |||
| @@ -63,6 +60,8 @@ def save_air(args): | |||
| if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser(description='Convert ckpt to air') | |||
| parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "CPU"), | |||
| help="run platform, support Ascend and CPU.") | |||
| parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load') | |||
| parser.add_argument('--batch_size', type=int, default=8, help='batch size') | |||
| @@ -16,8 +16,8 @@ | |||
| if [ $# != 2 ] && [ $# != 3 ] | |||
| then | |||
| echo "Usage: sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE] [PRETRAINED_BACKBONE]" | |||
| echo " or: sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE]" | |||
| echo "Usage: bash run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE] [PRETRAINED_BACKBONE]" | |||
| echo " or: bash run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE]" | |||
| exit 1 | |||
| fi | |||
| @@ -14,9 +14,9 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# != 3 ] | |||
| if [ $# != 4 ] | |||
| then | |||
| echo "Usage: sh run_eval.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]" | |||
| echo "Usage: bash run_eval.sh [PLATFORM] [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]" | |||
| exit 1 | |||
| fi | |||
| @@ -42,9 +42,10 @@ SCRIPT_NAME='eval.py' | |||
| ulimit -c unlimited | |||
| MINDRECORD_FILE=$(get_real_path $1) | |||
| USE_DEVICE_ID=$2 | |||
| PRETRAINED_BACKBONE=$(get_real_path $3) | |||
| PLATFORM=$1 | |||
| MINDRECORD_FILE=$(get_real_path $2) | |||
| USE_DEVICE_ID=$3 | |||
| PRETRAINED_BACKBONE=$(get_real_path $4) | |||
| if [ ! -f $PRETRAINED_BACKBONE ] | |||
| then | |||
| @@ -52,6 +53,7 @@ if [ ! -f $PRETRAINED_BACKBONE ] | |||
| exit 1 | |||
| fi | |||
| echo $PLATFORM | |||
| echo $MINDRECORD_FILE | |||
| echo $USE_DEVICE_ID | |||
| echo $PRETRAINED_BACKBONE | |||
| @@ -65,6 +67,7 @@ cd ${current_exec_path}/device$USE_DEVICE_ID || exit | |||
| dev=`expr $USE_DEVICE_ID + 0` | |||
| export DEVICE_ID=$dev | |||
| python ${dirname_path}/${SCRIPT_NAME} \ | |||
| --run_platform=$PLATFORM \ | |||
| --mindrecord_path=$MINDRECORD_FILE \ | |||
| --pretrained=$PRETRAINED_BACKBONE > eval.log 2>&1 & | |||
| @@ -14,9 +14,9 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# != 3 ] | |||
| if [ $# != 4 ] | |||
| then | |||
| echo "Usage: sh run_export.sh [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]" | |||
| echo "Usage: bash run_export.sh [PLATFORM] [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]" | |||
| exit 1 | |||
| fi | |||
| @@ -42,9 +42,10 @@ SCRIPT_NAME='export.py' | |||
| ulimit -c unlimited | |||
| BATCH_SIZE=$1 | |||
| USE_DEVICE_ID=$2 | |||
| PRETRAINED_BACKBONE=$(get_real_path $3) | |||
| PLATFORM=$1 | |||
| BATCH_SIZE=$2 | |||
| USE_DEVICE_ID=$3 | |||
| PRETRAINED_BACKBONE=$(get_real_path $4) | |||
| if [ ! -f $PRETRAINED_BACKBONE ] | |||
| then | |||
| @@ -52,6 +53,7 @@ if [ ! -f $PRETRAINED_BACKBONE ] | |||
| exit 1 | |||
| fi | |||
| echo $PLATFORM | |||
| echo $BATCH_SIZE | |||
| echo $USE_DEVICE_ID | |||
| echo $PRETRAINED_BACKBONE | |||
| @@ -65,6 +67,7 @@ cd ${current_exec_path}/device$USE_DEVICE_ID || exit | |||
| dev=`expr $USE_DEVICE_ID + 0` | |||
| export DEVICE_ID=$dev | |||
| python ${dirname_path}/${SCRIPT_NAME} \ | |||
| --run_platform=$PLATFORM \ | |||
| --batch_size=$BATCH_SIZE \ | |||
| --pretrained=$PRETRAINED_BACKBONE > convert.log 2>&1 & | |||
| @@ -14,10 +14,10 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# != 2 ] && [ $# != 3 ] | |||
| if [ $# != 3 ] && [ $# != 4 ] | |||
| then | |||
| echo "Usage: sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]" | |||
| echo " or: sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID]" | |||
| echo "Usage: bash run_standalone_train.sh [PLATFORM] [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]" | |||
| echo " or: bash run_standalone_train.sh [PLATFORM] [MINDRECORD_FILE] [USE_DEVICE_ID]" | |||
| exit 1 | |||
| fi | |||
| @@ -43,13 +43,14 @@ SCRIPT_NAME='train.py' | |||
| ulimit -c unlimited | |||
| MINDRECORD_FILE=$(get_real_path $1) | |||
| USE_DEVICE_ID=$2 | |||
| PLATFORM=$1 | |||
| MINDRECORD_FILE=$(get_real_path $2) | |||
| USE_DEVICE_ID=$3 | |||
| PRETRAINED_BACKBONE='' | |||
| if [ $# == 3 ] | |||
| if [ $# == 4 ] | |||
| then | |||
| PRETRAINED_BACKBONE=$(get_real_path $3) | |||
| PRETRAINED_BACKBONE=$(get_real_path $4) | |||
| if [ ! -f $PRETRAINED_BACKBONE ] | |||
| then | |||
| echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file" | |||
| @@ -57,6 +58,7 @@ then | |||
| fi | |||
| fi | |||
| echo $PLATFORM | |||
| echo $MINDRECORD_FILE | |||
| echo $USE_DEVICE_ID | |||
| echo $PRETRAINED_BACKBONE | |||
| @@ -70,6 +72,7 @@ cd ${current_exec_path}/device$USE_DEVICE_ID || exit | |||
| dev=`expr $USE_DEVICE_ID + 0` | |||
| export DEVICE_ID=$dev | |||
| python ${dirname_path}/${SCRIPT_NAME} \ | |||
| --run_platform=$PLATFORM \ | |||
| --world_size=1 \ | |||
| --mindrecord_path=$MINDRECORD_FILE \ | |||
| --pretrained=$PRETRAINED_BACKBONE > train.log 2>&1 & | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -16,6 +16,7 @@ | |||
| import numpy as np | |||
| import mindspore.dataset.vision.py_transforms as P | |||
| import mindspore.dataset as de | |||
| from src.transforms import RandomCropLetterbox, RandomFlip, HSVShift, ResizeLetterbox | |||
| from src.config import config | |||
| @@ -240,5 +241,33 @@ def preprocess_fn(image, annotation): | |||
| t_cls_1, gt_list_1, coord_mask_2, conf_pos_mask_2, conf_neg_mask_2, cls_mask_2, t_coord_2, t_conf_2, \ | |||
| t_cls_2, gt_list_2 | |||
| compose_map_func = (preprocess_fn) | |||
| def create_dataset(args): | |||
| """Create dataset object.""" | |||
| args.logger.info('start create dataloader') | |||
| ds = de.MindDataset(args.mindrecord_path + "0", columns_list=["image", "annotation"], num_shards=args.world_size, | |||
| shard_id=args.local_rank) | |||
| ds = ds.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation", 'coord_mask_0', 'conf_pos_mask_0', 'conf_neg_mask_0', | |||
| 'cls_mask_0', 't_coord_0', 't_conf_0', 't_cls_0', 'gt_list_0', 'coord_mask_1', | |||
| 'conf_pos_mask_1', 'conf_neg_mask_1', 'cls_mask_1', 't_coord_1', 't_conf_1', | |||
| 't_cls_1', 'gt_list_1', 'coord_mask_2', 'conf_pos_mask_2', 'conf_neg_mask_2', | |||
| 'cls_mask_2', 't_coord_2', 't_conf_2', 't_cls_2', 'gt_list_2'], | |||
| column_order=["image", "annotation", 'coord_mask_0', 'conf_pos_mask_0', 'conf_neg_mask_0', | |||
| 'cls_mask_0', 't_coord_0', 't_conf_0', 't_cls_0', 'gt_list_0', 'coord_mask_1', | |||
| 'conf_pos_mask_1', 'conf_neg_mask_1', 'cls_mask_1', 't_coord_1', 't_conf_1', | |||
| 't_cls_1', 'gt_list_1', 'coord_mask_2', 'conf_pos_mask_2', 'conf_neg_mask_2', | |||
| 'cls_mask_2', 't_coord_2', 't_conf_2', 't_cls_2', 'gt_list_2'], | |||
| operations=compose_map_func, num_parallel_workers=16, python_multiprocessing=True) | |||
| ds = ds.batch(args.batch_size, drop_remainder=True, num_parallel_workers=8) | |||
| ds = ds.repeat(args.max_epoch) | |||
| args.steps_per_epoch = ds.get_dataset_size() | |||
| args.logger.info('args.steps_per_epoch:{}'.format(args.steps_per_epoch)) | |||
| args.logger.info('args.world_size:{}'.format(args.world_size)) | |||
| args.logger.info('args.local_rank:{}'.format(args.local_rank)) | |||
| args.logger.info('end create dataloader') | |||
| args.logger.save_args(args) | |||
| return ds | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Face detection network wrapper.""" | |||
| import os | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| @@ -27,9 +28,12 @@ from mindspore.ops import functional as F | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.parameter import Parameter, ParameterTuple | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.FaceDetection.yolo_postprocess import YoloPostProcess | |||
| from src.FaceDetection.yolov3 import HwYolov3 as backbone_HwYolov3 | |||
| from src.FaceDetection.yolo_loss import YoloLoss | |||
| from src.lrsche_factory import warmup_step_new | |||
| _grad_scale = C.MultitypeFuncGraph("grad_scale") | |||
| reciprocal = P.Reciprocal() | |||
| @@ -634,3 +638,50 @@ def calc_recall_precision_ap(ground_truth, ret_list, iou_thr=0.5): | |||
| evaluate[cls] = {'recall': recall, 'precision': precision, 'ap': ap} | |||
| return evaluate | |||
| def define_network(args): | |||
| """Define train network with TrainOneStepCell.""" | |||
| # backbone and loss | |||
| num_classes = args.num_classes | |||
| num_anchors_list = args.num_anchors_list | |||
| anchors = args.anchors | |||
| anchors_mask = args.anchors_mask | |||
| momentum = args.momentum | |||
| args.logger.info('train opt momentum:{}'.format(momentum)) | |||
| weight_decay = args.weight_decay * float(args.batch_size) | |||
| args.logger.info('real weight_decay:{}'.format(weight_decay)) | |||
| lr_scale = args.world_size / 8 | |||
| args.logger.info('lr_scale:{}'.format(lr_scale)) | |||
| args.lr = warmup_step_new(args, lr_scale=lr_scale) | |||
| network = backbone_HwYolov3(num_classes, num_anchors_list, args) | |||
| criterion0 = YoloLoss(num_classes, anchors, anchors_mask[0], 64, 0, head_idx=0.0) | |||
| criterion1 = YoloLoss(num_classes, anchors, anchors_mask[1], 32, 0, head_idx=1.0) | |||
| criterion2 = YoloLoss(num_classes, anchors, anchors_mask[2], 16, 0, head_idx=2.0) | |||
| # load pretrain model | |||
| if os.path.isfile(args.pretrained): | |||
| param_dict = load_checkpoint(args.pretrained) | |||
| param_dict_new = {} | |||
| for key, values in param_dict.items(): | |||
| if key.startswith('moments.'): | |||
| continue | |||
| elif key.startswith('network.'): | |||
| param_dict_new[key[8:]] = values | |||
| else: | |||
| param_dict_new[key] = values | |||
| load_param_into_net(network, param_dict_new) | |||
| args.logger.info('load model {} success'.format(args.pretrained)) | |||
| train_net = BuildTrainNetworkV2(network, criterion0, criterion1, criterion2, args) | |||
| # optimizer | |||
| opt = nn.Momentum(params=train_net.trainable_params(), learning_rate=Tensor(args.lr), momentum=momentum, | |||
| weight_decay=weight_decay) | |||
| # package training process | |||
| if args.use_loss_scale: | |||
| train_net = TrainOneStepWithLossScaleCell(train_net, opt) | |||
| else: | |||
| train_net = nn.TrainOneStepCell(train_net, opt) | |||
| if args.world_size != 1: | |||
| train_net.set_broadcast_flag() | |||
| return train_net | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================ | |||
| """Face detection train.""" | |||
| import os | |||
| import ast | |||
| import time | |||
| import datetime | |||
| import argparse | |||
| @@ -22,163 +23,78 @@ import numpy as np | |||
| from mindspore import context | |||
| from mindspore.train.loss_scale_manager import DynamicLossScaleManager | |||
| from mindspore import Tensor | |||
| from mindspore.nn import Momentum | |||
| from mindspore.communication.management import init, get_rank, get_group_size | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.train.callback import ModelCheckpoint, RunContext | |||
| from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.common import dtype as mstype | |||
| import mindspore.dataset as de | |||
| from src.FaceDetection.yolov3 import HwYolov3 as backbone_HwYolov3 | |||
| from src.FaceDetection.yolo_loss import YoloLoss | |||
| from src.network_define import BuildTrainNetworkV2, TrainOneStepWithLossScaleCell | |||
| from src.lrsche_factory import warmup_step_new | |||
| from src.logging import get_logger | |||
| from src.data_preprocess import compose_map_func | |||
| from src.data_preprocess import create_dataset | |||
| from src.config import config | |||
| devid = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid) | |||
| from src.network_define import define_network | |||
| def parse_args(): | |||
| '''parse_args''' | |||
| parser = argparse.ArgumentParser('Yolov3 Face Detection') | |||
| parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend", "CPU"), | |||
| help="run platform, support Ascend and CPU.") | |||
| parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord') | |||
| parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load') | |||
| parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed') | |||
| parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed') | |||
| parser.add_argument("--use_loss_scale", type=ast.literal_eval, default=True, | |||
| help="Whether use dynamic loss scale, default is True.") | |||
| args, _ = parser.parse_known_args() | |||
| return args | |||
| def train(args): | |||
| '''train''' | |||
| print('=============yolov3 start trainging==================') | |||
| # init distributed | |||
| if args.world_size != 1: | |||
| init() | |||
| args.local_rank = get_rank() | |||
| args.world_size = get_group_size() | |||
| args.batch_size = config.batch_size | |||
| args.warmup_lr = config.warmup_lr | |||
| args.lr_rates = config.lr_rates | |||
| args.lr_steps = config.lr_steps | |||
| if args.run_platform == "CPU": | |||
| args.use_loss_scale = False | |||
| args.world_size = 1 | |||
| args.local_rank = 0 | |||
| if args.world_size != 8: | |||
| args.lr_steps = [i * 8 // args.world_size for i in config.lr_steps] | |||
| else: | |||
| args.lr_steps = config.lr_steps | |||
| args.gamma = config.gamma | |||
| args.weight_decay = config.weight_decay | |||
| args.weight_decay = config.weight_decay if args.world_size != 1 else 0. | |||
| args.momentum = config.momentum | |||
| args.max_epoch = config.max_epoch | |||
| args.log_interval = config.log_interval | |||
| args.ckpt_path = config.ckpt_path | |||
| args.ckpt_interval = config.ckpt_interval | |||
| args.outputs_dir = os.path.join(args.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | |||
| print('args.outputs_dir', args.outputs_dir) | |||
| args.num_classes = config.num_classes | |||
| args.anchors = config.anchors | |||
| args.anchors_mask = config.anchors_mask | |||
| args.num_anchors_list = [len(x) for x in args.anchors_mask] | |||
| return args | |||
| args.logger = get_logger(args.outputs_dir, args.local_rank) | |||
| if args.world_size != 8: | |||
| args.lr_steps = [i * 8 // args.world_size for i in args.lr_steps] | |||
| if args.world_size == 1: | |||
| args.weight_decay = 0. | |||
| def train(args): | |||
| '''train''' | |||
| print('=============yolov3 start trainging==================') | |||
| devid = int(os.getenv('DEVICE_ID', '0')) if args.run_platform != 'CPU' else 0 | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.run_platform, save_graphs=False, device_id=devid) | |||
| # init distributed | |||
| if args.world_size != 1: | |||
| parallel_mode = ParallelMode.DATA_PARALLEL | |||
| else: | |||
| parallel_mode = ParallelMode.STAND_ALONE | |||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.world_size, gradients_mean=True) | |||
| mindrecord_path = args.mindrecord_path | |||
| num_classes = config.num_classes | |||
| anchors = config.anchors | |||
| anchors_mask = config.anchors_mask | |||
| num_anchors_list = [len(x) for x in anchors_mask] | |||
| momentum = args.momentum | |||
| args.logger.info('train opt momentum:{}'.format(momentum)) | |||
| weight_decay = args.weight_decay * float(args.batch_size) | |||
| args.logger.info('real weight_decay:{}'.format(weight_decay)) | |||
| lr_scale = args.world_size / 8 | |||
| args.logger.info('lr_scale:{}'.format(lr_scale)) | |||
| init() | |||
| args.local_rank = get_rank() | |||
| args.world_size = get_group_size() | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, device_num=args.world_size, | |||
| gradients_mean=True) | |||
| args.logger = get_logger(args.outputs_dir, args.local_rank) | |||
| # dataloader | |||
| args.logger.info('start create dataloader') | |||
| epoch = args.max_epoch | |||
| ds = de.MindDataset(mindrecord_path + "0", columns_list=["image", "annotation"], num_shards=args.world_size, | |||
| shard_id=args.local_rank) | |||
| ds = create_dataset(args) | |||
| ds = ds.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "annotation", 'coord_mask_0', 'conf_pos_mask_0', 'conf_neg_mask_0', | |||
| 'cls_mask_0', 't_coord_0', 't_conf_0', 't_cls_0', 'gt_list_0', 'coord_mask_1', | |||
| 'conf_pos_mask_1', 'conf_neg_mask_1', 'cls_mask_1', 't_coord_1', 't_conf_1', | |||
| 't_cls_1', 'gt_list_1', 'coord_mask_2', 'conf_pos_mask_2', 'conf_neg_mask_2', | |||
| 'cls_mask_2', 't_coord_2', 't_conf_2', 't_cls_2', 'gt_list_2'], | |||
| column_order=["image", "annotation", 'coord_mask_0', 'conf_pos_mask_0', 'conf_neg_mask_0', | |||
| 'cls_mask_0', 't_coord_0', 't_conf_0', 't_cls_0', 'gt_list_0', 'coord_mask_1', | |||
| 'conf_pos_mask_1', 'conf_neg_mask_1', 'cls_mask_1', 't_coord_1', 't_conf_1', | |||
| 't_cls_1', 'gt_list_1', 'coord_mask_2', 'conf_pos_mask_2', 'conf_neg_mask_2', | |||
| 'cls_mask_2', 't_coord_2', 't_conf_2', 't_cls_2', 'gt_list_2'], | |||
| operations=compose_map_func, num_parallel_workers=16, python_multiprocessing=True) | |||
| ds = ds.batch(args.batch_size, drop_remainder=True, num_parallel_workers=8) | |||
| args.steps_per_epoch = ds.get_dataset_size() | |||
| lr = warmup_step_new(args, lr_scale=lr_scale) | |||
| ds = ds.repeat(epoch) | |||
| args.logger.info('args.steps_per_epoch:{}'.format(args.steps_per_epoch)) | |||
| args.logger.info('args.world_size:{}'.format(args.world_size)) | |||
| args.logger.info('args.local_rank:{}'.format(args.local_rank)) | |||
| args.logger.info('end create dataloader') | |||
| args.logger.save_args(args) | |||
| args.logger.important_info('start create network') | |||
| create_network_start = time.time() | |||
| # backbone and loss | |||
| network = backbone_HwYolov3(num_classes, num_anchors_list, args) | |||
| criterion0 = YoloLoss(num_classes, anchors, anchors_mask[0], 64, 0, head_idx=0.0) | |||
| criterion1 = YoloLoss(num_classes, anchors, anchors_mask[1], 32, 0, head_idx=1.0) | |||
| criterion2 = YoloLoss(num_classes, anchors, anchors_mask[2], 16, 0, head_idx=2.0) | |||
| # load pretrain model | |||
| if os.path.isfile(args.pretrained): | |||
| param_dict = load_checkpoint(args.pretrained) | |||
| param_dict_new = {} | |||
| for key, values in param_dict.items(): | |||
| if key.startswith('moments.'): | |||
| continue | |||
| elif key.startswith('network.'): | |||
| param_dict_new[key[8:]] = values | |||
| else: | |||
| param_dict_new[key] = values | |||
| load_param_into_net(network, param_dict_new) | |||
| args.logger.info('load model {} success'.format(args.pretrained)) | |||
| train_net = BuildTrainNetworkV2(network, criterion0, criterion1, criterion2, args) | |||
| # optimizer | |||
| opt = Momentum(params=train_net.trainable_params(), learning_rate=Tensor(lr), momentum=momentum, | |||
| weight_decay=weight_decay) | |||
| # package training process | |||
| train_net = TrainOneStepWithLossScaleCell(train_net, opt) | |||
| train_net.set_broadcast_flag() | |||
| train_net = define_network(args) | |||
| # checkpoint | |||
| ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval | |||
| @@ -196,86 +112,30 @@ def train(args): | |||
| t_epoch = time.time() | |||
| old_progress = -1 | |||
| i = 0 | |||
| scale_manager = DynamicLossScaleManager(init_loss_scale=2 ** 10, scale_factor=2, scale_window=2000) | |||
| if args.use_loss_scale: | |||
| scale_manager = DynamicLossScaleManager(init_loss_scale=2 ** 10, scale_factor=2, scale_window=2000) | |||
| for data in ds.create_tuple_iterator(output_numpy=True): | |||
| batch_images = data[0] | |||
| batch_labels = data[1] | |||
| coord_mask_0 = data[2] | |||
| conf_pos_mask_0 = data[3] | |||
| conf_neg_mask_0 = data[4] | |||
| cls_mask_0 = data[5] | |||
| t_coord_0 = data[6] | |||
| t_conf_0 = data[7] | |||
| t_cls_0 = data[8] | |||
| gt_list_0 = data[9] | |||
| coord_mask_1 = data[10] | |||
| conf_pos_mask_1 = data[11] | |||
| conf_neg_mask_1 = data[12] | |||
| cls_mask_1 = data[13] | |||
| t_coord_1 = data[14] | |||
| t_conf_1 = data[15] | |||
| t_cls_1 = data[16] | |||
| gt_list_1 = data[17] | |||
| coord_mask_2 = data[18] | |||
| conf_pos_mask_2 = data[19] | |||
| conf_neg_mask_2 = data[20] | |||
| cls_mask_2 = data[21] | |||
| t_coord_2 = data[22] | |||
| t_conf_2 = data[23] | |||
| t_cls_2 = data[24] | |||
| gt_list_2 = data[25] | |||
| img_tensor = Tensor(batch_images, mstype.float32) | |||
| coord_mask_tensor_0 = Tensor(coord_mask_0.astype(np.float32)) | |||
| conf_pos_mask_tensor_0 = Tensor(conf_pos_mask_0.astype(np.float32)) | |||
| conf_neg_mask_tensor_0 = Tensor(conf_neg_mask_0.astype(np.float32)) | |||
| cls_mask_tensor_0 = Tensor(cls_mask_0.astype(np.float32)) | |||
| t_coord_tensor_0 = Tensor(t_coord_0.astype(np.float32)) | |||
| t_conf_tensor_0 = Tensor(t_conf_0.astype(np.float32)) | |||
| t_cls_tensor_0 = Tensor(t_cls_0.astype(np.float32)) | |||
| gt_list_tensor_0 = Tensor(gt_list_0.astype(np.float32)) | |||
| coord_mask_tensor_1 = Tensor(coord_mask_1.astype(np.float32)) | |||
| conf_pos_mask_tensor_1 = Tensor(conf_pos_mask_1.astype(np.float32)) | |||
| conf_neg_mask_tensor_1 = Tensor(conf_neg_mask_1.astype(np.float32)) | |||
| cls_mask_tensor_1 = Tensor(cls_mask_1.astype(np.float32)) | |||
| t_coord_tensor_1 = Tensor(t_coord_1.astype(np.float32)) | |||
| t_conf_tensor_1 = Tensor(t_conf_1.astype(np.float32)) | |||
| t_cls_tensor_1 = Tensor(t_cls_1.astype(np.float32)) | |||
| gt_list_tensor_1 = Tensor(gt_list_1.astype(np.float32)) | |||
| coord_mask_tensor_2 = Tensor(coord_mask_2.astype(np.float32)) | |||
| conf_pos_mask_tensor_2 = Tensor(conf_pos_mask_2.astype(np.float32)) | |||
| conf_neg_mask_tensor_2 = Tensor(conf_neg_mask_2.astype(np.float32)) | |||
| cls_mask_tensor_2 = Tensor(cls_mask_2.astype(np.float32)) | |||
| t_coord_tensor_2 = Tensor(t_coord_2.astype(np.float32)) | |||
| t_conf_tensor_2 = Tensor(t_conf_2.astype(np.float32)) | |||
| t_cls_tensor_2 = Tensor(t_cls_2.astype(np.float32)) | |||
| gt_list_tensor_2 = Tensor(gt_list_2.astype(np.float32)) | |||
| scaling_sens = Tensor(scale_manager.get_loss_scale(), dtype=mstype.float32) | |||
| loss0, overflow, _ = train_net(img_tensor, coord_mask_tensor_0, conf_pos_mask_tensor_0, | |||
| conf_neg_mask_tensor_0, cls_mask_tensor_0, t_coord_tensor_0, | |||
| t_conf_tensor_0, t_cls_tensor_0, gt_list_tensor_0, | |||
| coord_mask_tensor_1, conf_pos_mask_tensor_1, conf_neg_mask_tensor_1, | |||
| cls_mask_tensor_1, t_coord_tensor_1, t_conf_tensor_1, | |||
| t_cls_tensor_1, gt_list_tensor_1, coord_mask_tensor_2, | |||
| conf_pos_mask_tensor_2, conf_neg_mask_tensor_2, | |||
| cls_mask_tensor_2, t_coord_tensor_2, t_conf_tensor_2, | |||
| t_cls_tensor_2, gt_list_tensor_2, scaling_sens) | |||
| overflow = np.all(overflow.asnumpy()) | |||
| if overflow: | |||
| scale_manager.update_loss_scale(overflow) | |||
| input_list = [Tensor(batch_images, mstype.float32)] | |||
| for idx in range(2, 26): | |||
| input_list.append(Tensor(data[idx], mstype.float32)) | |||
| if args.use_loss_scale: | |||
| scaling_sens = Tensor(scale_manager.get_loss_scale(), dtype=mstype.float32) | |||
| loss0, overflow, _ = train_net(*input_list, scaling_sens) | |||
| overflow = np.all(overflow.asnumpy()) | |||
| if overflow: | |||
| scale_manager.update_loss_scale(overflow) | |||
| else: | |||
| scale_manager.update_loss_scale(False) | |||
| args.logger.info('rank[{}], iter[{}], loss[{}], overflow:{}, loss_scale:{}, lr:{}, batch_images:{}, ' | |||
| 'batch_labels:{}'.format(args.local_rank, i, loss0, overflow, scaling_sens, args.lr[i], | |||
| batch_images.shape, batch_labels.shape)) | |||
| else: | |||
| scale_manager.update_loss_scale(False) | |||
| args.logger.info('rank[{}], iter[{}], loss[{}], overflow:{}, loss_scale:{}, lr:{}, batch_images:{}, ' | |||
| 'batch_labels:{}'.format(args.local_rank, i, loss0, overflow, scaling_sens, lr[i], | |||
| batch_images.shape, batch_labels.shape)) | |||
| loss0 = train_net(*input_list) | |||
| args.logger.info('rank[{}], iter[{}], loss[{}], lr:{}, batch_images:{}, ' | |||
| 'batch_labels:{}'.format(args.local_rank, i, loss0, args.lr[i], | |||
| batch_images.shape, batch_labels.shape)) | |||
| # save ckpt | |||
| cb_params.cur_step_num = i + 1 # current step number | |||
| cb_params.batch_num = i + 2 | |||