| @@ -119,19 +119,27 @@ After installing MindSpore via the official website, you can start training and | |||
| Note: 1.the first run of training will generate the mindrecord file, which will take a long time. | |||
| 2.MINDRECORD_DATASET_PATH is the mindrecord dataset directory. | |||
| 3.LOAD_CHECKPOINT_PATH is the pretrained checkpoint file directory, if no just set "" | |||
| 4.RUN_MODE support validation and testing, set to be "val"/"test" | |||
| ```shell | |||
| # create dataset in mindrecord format | |||
| bash scripts/convert_dataset_to_mindrecord.sh | |||
| # standalone training | |||
| bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [EPOCH_SIZE] | |||
| # standalone training on Ascend | |||
| bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH] | |||
| # distributed training | |||
| bash scripts/run_distributed_train_ascend.sh [MINDRECORD_DATASET_PATH] [RANK_TABLE_FILE] | |||
| # standalone training on CPU | |||
| bash scripts/run_standalone_train_cpu.sh [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH] | |||
| # eval | |||
| bash scripts/run_standalone_eval_ascend.sh [DEVICE_ID] | |||
| # distributed training on Ascend | |||
| bash scripts/run_distributed_train_ascend.sh [MINDRECORD_DATASET_PATH] [LOAD_CHECKPOINT_PATH] [RANK_TABLE_FILE] | |||
| # eval on Ascend | |||
| bash scripts/run_standalone_eval_ascend.sh [DEVICE_ID] [RUN_MODE] [DATA_DIR] [LOAD_CHECKPOINT_PATH] | |||
| # eval on CPU | |||
| bash scripts/run_standalone_eval_cpu.sh [RUN_MODE] [DATA_DIR] [LOAD_CHECKPOINT_PATH] | |||
| ``` | |||
| # [Script Description](#contents) | |||
| @@ -153,9 +161,11 @@ bash scripts/run_standalone_eval_ascend.sh [DEVICE_ID] | |||
| │ │ ├──get_distribute_pretrain_cmd.py // script for distributed pretraining | |||
| │ │ ├──README.md | |||
| │ ├──convert_dataset_to_mindrecord.sh // shell script for converting coco type dataset to mindrecord | |||
| │ ├──run_standalone_train_ascend.sh // shell script for standalone pretrain on ascend | |||
| │ ├──run_distributed_train_ascend.sh // shell script for distributed pretrain on ascend | |||
| │ ├──run_standalone_train_ascend.sh // shell script for standalone training on ascend | |||
| │ ├──run_distributed_train_ascend.sh // shell script for distributed training on ascend | |||
| │ ├──run_standalone_eval_ascend.sh // shell script for standalone evaluation on ascend | |||
| │ ├──run_standalone_train_cpu.sh // shell script for standalone training on cpu | |||
| │ ├──run_standalone_eval_cpu.sh // shell script for standalone evaluation on cpu | |||
| └── src | |||
| ├──__init__.py | |||
| ├──centernet_pose.py // centernet networks, training entry | |||
| @@ -259,7 +269,6 @@ config for training. | |||
| ```text | |||
| config for evaluation. | |||
| flip_test whether to use flip test: True | False, default is False | |||
| soft_nms nms after decode: True | False, default is True | |||
| keep_res keep original or fix resolution: True | False, default is False | |||
| multi_scales use multi-scales of image: List, default is [1.0] | |||
| @@ -350,12 +359,12 @@ bash scripts/convert_dataset_to_mindrecord.sh | |||
| The command above will run in the background, after converting mindrecord files will be located in path specified by yourself. | |||
| ### Training | |||
| ### Standalone Training | |||
| #### Running on Ascend | |||
| ```bash | |||
| bash scripts/run_standalone_pretrain_ascend.sh 0 1 | |||
| bash scripts/run_standalone_train_ascend.sh device_id /path/mindrecord_dataset /path/load_ckpt | |||
| ``` | |||
| The command above will run in the background, you can view training logs in training_log.txt. After training finished, you will get some checkpoint files under the script folder by default. The loss values will be displayed as follows: | |||
| @@ -368,12 +377,31 @@ epoch: 349.0, current epoch percent: 1.00, step: 87500, outputs are (Tensor(shap | |||
| ... | |||
| ``` | |||
| #### Running on CPU | |||
| ```bash | |||
| bash scripts/run_standalone_train_cpu.sh /path/mindrecord_dataset /path/load_ckpt | |||
| ``` | |||
| The command above will run in the background, you can view training logs in training_log.txt. After training finished, you will get some checkpoint files under the script folder by default. The loss values will be displayed as follows (rusume from pretrained checkpoint and batch_size was set to be 8): | |||
| ```text | |||
| # grep "epoch" training_log.txt | |||
| ... | |||
| epoch: 0.0, current epoch percent: 0.00, step: 1, time of per steps: 66.693 s, outputs are 3.645 | |||
| epoch: 0.0, current epoch percent: 0.00, step: 2, time of per steps: 46.594 s, outputs are 4.862 | |||
| epoch: 0.0, current epoch percent: 0.00, step: 3, time of per steps: 44.718 s, outputs are 3.927 | |||
| epoch: 0.0, current epoch percent: 0.00, step: 4, time of per steps: 45.113 s, outputs are 3.910 | |||
| epoch: 0.0, current epoch percent: 0.00, step: 5, time of per steps: 45.213 s, outputs are 3.749 | |||
| ... | |||
| ``` | |||
| ### Distributed Training | |||
| #### Running on Ascend | |||
| ```bash | |||
| bash scripts/run_distributed_pretrain_ascend.sh /path/coco2017 /path/mindrecord /path/hccl.json | |||
| bash scripts/run_distributed_pretrain_ascend.sh /path/mindrecord_dataset /path/load_ckpt /path/hccl.json | |||
| ``` | |||
| The command above will run in the background, you can view training logs in LOG*/training_log.txt and LOG*/ms_log/. After training finished, you will get some checkpoint files under the LOG*/ckpt_0 folder by default. The loss value will be displayed as follows: | |||
| @@ -394,7 +422,11 @@ epoch: 0.0, current epoch percent: 0.002, step: 200, outputs are (Tensor(shape=[ | |||
| ```bash | |||
| # Evaluation base on validation dataset will be done automatically, while for test or test-dev dataset, the accuracy should be upload to the CodaLab official website(https://competitions.codalab.org). | |||
| bash scripts/run_standalone_eval_ascend.sh [DEVICE_ID] | |||
| # On Ascend | |||
| bash scripts/run_standalone_eval_ascend.sh device_id val(or test) /path/coco_dataset /path/load_ckpt | |||
| # On CPU | |||
| bash scripts/run_standalone_eval_cpu.sh val(or test) /path/coco_dataset /path/load_ckpt | |||
| ``` | |||
| you can see the MAP result below as below: | |||
| @@ -439,7 +471,7 @@ python export.py [DEVICE_ID] | |||
| ## [Performance](#contents) | |||
| ### Training Performance | |||
| ### Training Performance On Ascend | |||
| CenterNet on 11.8K images(The annotation and data format must be the same as coco) | |||
| @@ -460,7 +492,7 @@ CenterNet on 11.8K images(The annotation and data format must be the same as coc | |||
| | Checkpoint | 242M (.ckpt file) | | |||
| | Scripts | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/centernet> | | |||
| ### Inference Performance | |||
| ### Inference Performance On Ascend | |||
| CenterNet on validation(5K images) and test-dev(40K images) | |||
| @@ -36,6 +36,8 @@ from src.config import dataset_config, net_config, eval_config | |||
| _current_dir = os.path.dirname(os.path.realpath(__file__)) | |||
| parser = argparse.ArgumentParser(description='CenterNet evaluation') | |||
| parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'CPU'], | |||
| help='device where the code will be implemented. (Default: Ascend)') | |||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | |||
| parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") | |||
| parser.add_argument("--data_dir", type=str, default="", help="Dataset directory, " | |||
| @@ -52,15 +54,20 @@ def predict(): | |||
| ''' | |||
| Predict function | |||
| ''' | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | |||
| if args_opt.device_target == "Ascend": | |||
| context.set_context(device_id=args_opt.device_id) | |||
| enable_nms_fp16 = True | |||
| else: | |||
| enable_nms_fp16 = False | |||
| logger.info("Begin creating {} dataset".format(args_opt.run_mode)) | |||
| coco = COCOHP(dataset_config, run_mode=args_opt.run_mode, net_opt=net_config, | |||
| enable_visual_image=(args_opt.visual_image == "true"), save_path=args_opt.save_result_dir,) | |||
| coco.init(args_opt.data_dir, keep_res=eval_config.keep_res, flip_test=eval_config.flip_test) | |||
| coco.init(args_opt.data_dir, keep_res=eval_config.keep_res) | |||
| dataset = coco.create_eval_dataset() | |||
| net_for_eval = CenterNetMultiPoseEval(net_config, eval_config.flip_test, eval_config.K) | |||
| net_for_eval = CenterNetMultiPoseEval(net_config, eval_config.K, enable_nms_fp16) | |||
| net_for_eval.set_train(False) | |||
| param_dict = load_checkpoint(args_opt.load_checkpoint_path) | |||
| @@ -103,9 +110,7 @@ def predict(): | |||
| print("Image {}/{} id: {} cost time {} ms".format(index, total_nums, image_id, (end - start) * 1000.)) | |||
| # post-process | |||
| soft_nms = eval_config.soft_nms or len(eval_config.multi_scales) > 0 | |||
| detections = merge_outputs(detections, soft_nms) | |||
| detections = merge_outputs(detections, eval_config.soft_nms) | |||
| # get prediction result | |||
| pred_json = convert_eval_format(detections, image_id) | |||
| gt_image_info = coco.coco.loadImgs([image_id]) | |||
| @@ -31,7 +31,7 @@ args = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id) | |||
| if __name__ == '__main__': | |||
| net = CenterNetMultiPoseEval(net_config, eval_config.flip_test, eval_config.K) | |||
| net = CenterNetMultiPoseEval(net_config, eval_config.K) | |||
| net.set_train(False) | |||
| param_dict = load_checkpoint(export_config.ckpt_file) | |||
| @@ -39,8 +39,7 @@ def parse_args(): | |||
| parser.add_argument("--hyper_parameter_config_dir", type=str, default="", | |||
| help="Hyper Parameter config path, it is better to use absolute path") | |||
| parser.add_argument("--mindrecord_dir", type=str, default="", help="Mindrecord dataset directory") | |||
| parser.add_argument("--mindrecord_prefix", type=str, default="coco_hp.train.mind", | |||
| help="Prefix of MindRecord dataset filename.") | |||
| parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") | |||
| parser.add_argument("--hccl_config_dir", type=str, default="", | |||
| help="Hccl config path, it is better to use absolute path") | |||
| parser.add_argument("--cmd_file", type=str, default="distributed_cmd.sh", | |||
| @@ -72,7 +71,7 @@ def distribute_train(): | |||
| run_script = args.run_script_dir | |||
| mindrecord_dir = args.mindrecord_dir | |||
| mindrecord_prefix = args.mindrecord_prefix | |||
| load_checkpoint_path = args.load_checkpoint_path | |||
| cf = configparser.ConfigParser() | |||
| cf.read(args.hyper_parameter_config_dir) | |||
| cfg = dict(cf.items("config")) | |||
| @@ -151,7 +150,7 @@ def distribute_train(): | |||
| " 'device_num' or 'mindrecord_dir'! ") | |||
| run_cmd += opt | |||
| run_cmd += " --mindrecord_dir=" + mindrecord_dir | |||
| run_cmd += " --mindrecord_prefix=" + mindrecord_prefix | |||
| run_cmd += " --load_checkpoint_path=" + load_checkpoint_path | |||
| run_cmd += ' --device_id=' + str(device_id) + ' --device_num=' \ | |||
| + str(rank_size) + ' >./training_log.txt 2>&1 &' | |||
| @@ -5,7 +5,6 @@ enable_save_ckpt=true | |||
| do_shuffle=true | |||
| enable_data_sink=true | |||
| data_sink_steps=50 | |||
| load_checkpoint_path="" | |||
| save_checkpoint_path=./ | |||
| save_checkpoint_steps=3000 | |||
| save_checkpoint_num=1 | |||
| @@ -14,21 +14,26 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| echo "==============================================================================================================" | |||
| echo "Please run the scipt as: " | |||
| echo "bash run_distributed_train_ascend.sh DATA_DIR MINDRECORD_DIR RANK_TABLE_FILE" | |||
| echo "for example: bash run_distributed_train_ascend.sh /path/dataset /path/mindrecord /path/hccl.json" | |||
| echo "It is better to use absolute path." | |||
| echo "================================================================================================================" | |||
| echo "Please run the script as: " | |||
| echo "bash run_distributed_train_ascend.sh MINDRECORD_DIR LOAD_CHECKPOINT_PATH RANK_TABLE_FILE" | |||
| echo "for example: bash run_distributed_train_ascend.sh /path/mindrecord_dataset /path/load_ckpt /path/hccl.json" | |||
| echo "if no ckpt, just run: bash run_distributed_train_ascend.sh /path/mindrecord_dataset \"\" /path/hccl.json" | |||
| echo "It is better to use the absolute path." | |||
| echo "For hyper parameter, please note that you should customize the scripts: | |||
| '{CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini' " | |||
| echo "==============================================================================================================" | |||
| echo "================================================================================================================" | |||
| CUR_DIR=`pwd` | |||
| MINDRECORD_DIR=$1 | |||
| LOAD_CHECKPOINT_PATH=$2 | |||
| HCCL_RANK_FILE=$3 | |||
| python ${CUR_DIR}/scripts/ascend_distributed_launcher/get_distribute_train_cmd.py \ | |||
| --run_script_dir=${CUR_DIR}/train.py \ | |||
| --hyper_parameter_config_dir=${CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini \ | |||
| --mindrecord_dir=$1 \ | |||
| --hccl_config_dir=$2 \ | |||
| --mindrecord_dir=$MINDRECORD_DIR \ | |||
| --load_checkpoint_path=$LOAD_CHECKPOINT_PATH \ | |||
| --hccl_config_dir=$HCCL_RANK_FILE \ | |||
| --hccl_time_out=1200 \ | |||
| --cmd_file=distributed_cmd.sh | |||
| @@ -16,11 +16,14 @@ | |||
| echo "==============================================================================================================" | |||
| echo "Please run the scipt as: " | |||
| echo "bash run_standalone_eval_ascend.sh DEVICE_ID" | |||
| echo "for example: bash run_standalone_eval_ascend.sh 0" | |||
| echo "bash run_standalone_eval_ascend.sh DEVICE_ID RUN_MODE DATA_DIR LOAD_CHECKPOINT_PATH" | |||
| echo "for example of validation: bash run_standalone_eval_ascend.sh 0 val /path/coco_dataset /path/load_ckpt" | |||
| echo "for example of test: bash run_standalone_eval_ascend.sh 0 test /path/coco_dataset /path/load_ckpt" | |||
| echo "==============================================================================================================" | |||
| DEVICE_ID=$1 | |||
| RUN_MODE=$2 | |||
| DATA_DIR=$3 | |||
| LOAD_CHECKPOINT_PATH=$4 | |||
| mkdir -p ms_log | |||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||
| CUR_DIR=`pwd` | |||
| @@ -42,10 +45,11 @@ else | |||
| fi | |||
| python ${PROJECT_DIR}/../eval.py \ | |||
| --device_target=Ascend \ | |||
| --device_id=$DEVICE_ID \ | |||
| --load_checkpoint_path="" \ | |||
| --data_dir="" \ | |||
| --load_checkpoint_path=$LOAD_CHECKPOINT_PATH \ | |||
| --data_dir=$DATA_DIR \ | |||
| --run_mode=$RUN_MODE \ | |||
| --visual_image=true \ | |||
| --enable_eval=true \ | |||
| --save_result_dir="" \ | |||
| --run_mode=val > eval_log.txt 2>&1 & | |||
| --save_result_dir=./ > eval_log.txt 2>&1 & | |||
| @@ -0,0 +1,53 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| echo "==============================================================================================================" | |||
| echo "Please run the scipt as: " | |||
| echo "bash run_standalone_eval_cpu.sh RUN_MODE DATA_DIR LOAD_CHECKPOINT_PATH" | |||
| echo "for example of validation: bash run_standalone_eval_cpu.sh val /path/coco_dataset /path/load_ckpt" | |||
| echo "for example of test: bash run_standalone_eval_cpu.sh test /path/coco_dataset /path/load_ckpt" | |||
| echo "==============================================================================================================" | |||
| RUN_MODE=$1 | |||
| DATA_DIR=$2 | |||
| LOAD_CHECKPOINT_PATH=$3 | |||
| mkdir -p ms_log | |||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||
| CUR_DIR=`pwd` | |||
| export GLOG_log_dir=${CUR_DIR}/ms_log | |||
| export GLOG_logtostderr=0 | |||
| # install nms module from third party | |||
| if python -c "import nms" > /dev/null 2>&1 | |||
| then | |||
| echo "NMS module already exits, no need reinstall." | |||
| else | |||
| echo "NMS module was not found, install it now..." | |||
| git clone https://github.com/xingyizhou/CenterNet.git | |||
| cd CenterNet/src/lib/external/ | |||
| make | |||
| python setup.py install | |||
| cd - | |||
| rm -rf CenterNet | |||
| fi | |||
| python ${PROJECT_DIR}/../eval.py \ | |||
| --device_target=CPU \ | |||
| --load_checkpoint_path=$LOAD_CHECKPOINT_PATH \ | |||
| --data_dir=$DATA_DIR \ | |||
| --run_mode=$RUN_MODE \ | |||
| --visual_image=true \ | |||
| --enable_eval=true \ | |||
| --save_result_dir=./ > eval_log.txt 2>&1 & | |||
| @@ -16,12 +16,14 @@ | |||
| echo "==============================================================================================================" | |||
| echo "Please run the scipt as: " | |||
| echo "bash run_standalone_pretrain_ascend.sh DEVICE_ID EPOCH_SIZE" | |||
| echo "for example: bash run_standalone_pretrain_ascend.sh 0 350" | |||
| echo "bash run_standalone_train_ascend.sh DEVICE_ID MINDRECORD_DIR LOAD_CHECKPOINT_PATH" | |||
| echo "for example: bash run_standalone_train_ascend.sh 0 /path/mindrecord_dataset /path/load_ckpt" | |||
| echo "if no ckpt, just run: bash run_standalone_train_ascend.sh 0 /path/mindrecord_dataset \"\" " | |||
| echo "==============================================================================================================" | |||
| DEVICE_ID=$1 | |||
| EPOCH_SIZE=$2 | |||
| MINDRECORD_DIR=$2 | |||
| LOAD_CHECKPOINT_PATH=$3 | |||
| mkdir -p ms_log | |||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||
| @@ -33,16 +35,16 @@ python ${PROJECT_DIR}/../train.py \ | |||
| --distribute=false \ | |||
| --need_profiler=false \ | |||
| --profiler_path=./profiler \ | |||
| --epoch_size=$EPOCH_SIZE \ | |||
| --device_id=$DEVICE_ID \ | |||
| --enable_save_ckpt=true \ | |||
| --do_shuffle=true \ | |||
| --enable_data_sink=true \ | |||
| --data_sink_steps=50 \ | |||
| --load_checkpoint_path="" \ | |||
| --epoch_size=350 \ | |||
| --load_checkpoint_path=$LOAD_CHECKPOINT_PATH \ | |||
| --save_checkpoint_steps=10000 \ | |||
| --save_checkpoint_num=1 \ | |||
| --mindrecord_dir="" \ | |||
| --mindrecord_dir=$MINDRECORD_DIR \ | |||
| --mindrecord_prefix="coco_hp.train.mind" \ | |||
| --visual_image=false \ | |||
| --save_result_dir="" > training_log.txt 2>&1 & | |||
| @@ -0,0 +1,44 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| echo "==============================================================================================================" | |||
| echo "Please run the scipt as: " | |||
| echo "bash run_standalone_train_cpu.sh MINDRECORD_DIR LOAD_CHECKPOINT_PATH" | |||
| echo "for example: bash run_standalone_train_cpu.sh /path/mindrecord_dataset /path/load_ckpt" | |||
| echo "if no ckpt, just run: bash run_standalone_train_cpu.sh /path/mindrecord_dataset \"\" " | |||
| echo "==============================================================================================================" | |||
| MINDRECORD_DIR=$1 | |||
| LOAD_CHECKPOINT_PATH=$2 | |||
| mkdir -p ms_log | |||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||
| CUR_DIR=`pwd` | |||
| export GLOG_log_dir=${CUR_DIR}/ms_log | |||
| export GLOG_logtostderr=0 | |||
| python ${PROJECT_DIR}/../train.py \ | |||
| --device_target=CPU \ | |||
| --enable_save_ckpt=true \ | |||
| --do_shuffle=true \ | |||
| --epoch_size=1 \ | |||
| --load_checkpoint_path=$LOAD_CHECKPOINT_PATH \ | |||
| --save_checkpoint_steps=1000 \ | |||
| --save_checkpoint_num=1 \ | |||
| --mindrecord_dir=$MINDRECORD_DIR \ | |||
| --mindrecord_prefix="coco_hp.train.mind" \ | |||
| --visual_image=false \ | |||
| --save_result_dir="" > training_log.txt 2>&1 & | |||
| @@ -15,7 +15,7 @@ | |||
| """CenterNet Init.""" | |||
| from .centernet_pose import GatherMultiPoseFeatureCell, CenterNetMultiPoseLossCell, \ | |||
| CenterNetWithLossScaleCell, CenterNetMultiPoseEval | |||
| CenterNetWithLossScaleCell, CenterNetMultiPoseEval, CenterNetWithoutLossScaleCell | |||
| from .dataset import COCOHP | |||
| from .visual import visual_allimages, visual_image | |||
| from .decode import MultiPoseDecode | |||
| @@ -23,6 +23,7 @@ from .post_process import convert_eval_format, to_float, resize_detection, post_ | |||
| __all__ = [ | |||
| "GatherMultiPoseFeatureCell", "CenterNetMultiPoseLossCell", "CenterNetWithLossScaleCell", \ | |||
| "CenterNetMultiPoseEval", "COCOHP", "visual_allimages", "visual_image", "MultiPoseDecode", \ | |||
| "convert_eval_format", "to_float", "resize_detection", "post_process", "merge_outputs" | |||
| "CenterNetMultiPoseEval", "CenterNetWithoutLossScaleCell", "COCOHP", "visual_allimages", \ | |||
| "visual_image", "MultiPoseDecode", "convert_eval_format", "to_float", "resize_detection", \ | |||
| "post_process", "merge_outputs" | |||
| ] | |||
| @@ -197,6 +197,46 @@ class CenterNetMultiPoseLossCell(nn.Cell): | |||
| return total_loss | |||
| class CenterNetWithoutLossScaleCell(nn.Cell): | |||
| """ | |||
| Encapsulation class of centernet training. | |||
| Append an optimizer to the training network after that the construct | |||
| function can be called to create the backward graph. | |||
| Args: | |||
| network (Cell): The training network. Note that loss function should have been added. | |||
| optimizer (Optimizer): Optimizer for updating the weights. | |||
| Returns: | |||
| Tuple of Tensors, the loss, overflow flag and scaling sens of the network. | |||
| """ | |||
| def __init__(self, network, optimizer): | |||
| super(CenterNetWithoutLossScaleCell, self).__init__(auto_prefix=False) | |||
| self.image = ImagePreProcess() | |||
| self.network = network | |||
| self.network.set_grad() | |||
| self.weights = optimizer.parameters | |||
| self.optimizer = optimizer | |||
| self.grad = ops.GradOperation(get_by_list=True, sens_param=False) | |||
| @ops.add_flags(has_effect=True) | |||
| def construct(self, image, hm, reg_mask, ind, wh, kps, kps_mask, reg, | |||
| hm_hp, hp_offset, hp_ind, hp_mask): | |||
| """Defines the computation performed.""" | |||
| image = self.image(image) | |||
| weights = self.weights | |||
| loss = self.network(image, hm, reg_mask, ind, wh, kps, kps_mask, reg, | |||
| hm_hp, hp_offset, hp_ind, hp_mask) | |||
| grads = self.grad(self.network, weights)(image, hm, reg_mask, ind, wh, kps, | |||
| kps_mask, reg, hm_hp, hp_offset, | |||
| hp_ind, hp_mask) | |||
| succ = self.optimizer(grads) | |||
| ret = loss | |||
| return ops.depend(ret, succ) | |||
| class CenterNetWithLossScaleCell(nn.Cell): | |||
| """ | |||
| Encapsulation class of centernet training. | |||
| @@ -279,17 +319,16 @@ class CenterNetMultiPoseEval(nn.Cell): | |||
| Args: | |||
| net_config: The config info of CenterNet network. | |||
| flip_test(bool): Flip data augmentation or not. Default: False. | |||
| K(number): Max number of output objects. Default: 100. | |||
| enable_nms_fp16(bool): Use float16 data for max_pool, adaption for CPU. Default: True. | |||
| Returns: | |||
| Tensor, detection of images(bboxes, score, keypoints and category id of each objects) | |||
| """ | |||
| def __init__(self, net_config, flip_test=False, K=100): | |||
| def __init__(self, net_config, K=100, enable_nms_fp16=True): | |||
| super(CenterNetMultiPoseEval, self).__init__() | |||
| self.network = GatherMultiPoseFeatureCell(net_config) | |||
| self.decode = MultiPoseDecode(net_config, flip_test, K) | |||
| self.flip_test = flip_test | |||
| self.decode = MultiPoseDecode(net_config, K, enable_nms_fp16) | |||
| self.shape = ops.Shape() | |||
| self.reshape = ops.Reshape() | |||
| @@ -104,8 +104,7 @@ train_config = edict({ | |||
| eval_config = edict({ | |||
| 'flip_test': False, | |||
| 'soft_nms': False, | |||
| 'soft_nms': True, | |||
| 'keep_res': True, | |||
| 'multi_scales': [1.0], | |||
| 'pad': 31, | |||
| @@ -17,7 +17,6 @@ Data operations, will be used in train.py | |||
| """ | |||
| import os | |||
| import copy | |||
| import math | |||
| import argparse | |||
| import cv2 | |||
| @@ -66,7 +65,7 @@ class COCOHP(ds.Dataset): | |||
| if not os.path.exists(self.save_path): | |||
| os.makedirs(self.save_path) | |||
| def init(self, data_dir, keep_res=False, flip_test=False): | |||
| def init(self, data_dir, keep_res=False): | |||
| """initailize additional info""" | |||
| logger.info('Initializing coco 2017 {} data.'.format(self.run_mode)) | |||
| if not os.path.isdir(data_dir): | |||
| @@ -94,7 +93,6 @@ class COCOHP(ds.Dataset): | |||
| self.images = image_ids | |||
| self.num_samples = len(self.images) | |||
| self.keep_res = keep_res | |||
| self.flip_test = flip_test | |||
| if self.run_mode != "train": | |||
| self.pad = 31 | |||
| logger.info('Loaded {} {} samples'.format(self.run_mode, self.num_samples)) | |||
| @@ -167,7 +165,7 @@ class COCOHP(ds.Dataset): | |||
| ret = (img, image_id) | |||
| return ret | |||
| def pre_process_for_test(self, image, img_id, scale, meta=None): | |||
| def pre_process_for_test(self, image, img_id, scale): | |||
| """image pre-process for evaluation""" | |||
| b, h, w, ch = image.shape | |||
| assert b == 1, "only single image was supported here" | |||
| @@ -191,17 +189,8 @@ class COCOHP(ds.Dataset): | |||
| flags=cv2.INTER_LINEAR) | |||
| inp_img = (inp_image.astype(np.float32) / 255. - self.data_opt.mean) / self.data_opt.std | |||
| h, w, ch = inp_img.shape | |||
| images = copy.deepcopy(inp_img) | |||
| if self.flip_test: | |||
| flip_image = inp_img[:, ::-1, :] | |||
| inp_img = inp_img.reshape((1, h, w, ch)) | |||
| flip_image = flip_image.reshape((1, h, w, ch)) | |||
| # (2, h, w, c) | |||
| images = np.concatenate((inp_img, flip_image), axis=0) | |||
| else: | |||
| images = images.reshape((1, h, w, ch)) | |||
| images = images.transpose(0, 3, 1, 2) | |||
| eval_image = inp_img.reshape((1,) + inp_img.shape) | |||
| eval_image = eval_image.transpose(0, 3, 1, 2) | |||
| meta = {'c': c, 's': s, | |||
| 'out_height': inp_height // self.net_opt.down_ratio, | |||
| @@ -244,7 +233,7 @@ class COCOHP(ds.Dataset): | |||
| image_name = "gt_" + self.run_mode + "_image_" + str(img_id) + "_scale_" + str(scale) + ".png" | |||
| cv2.imwrite("{}/{}".format(self.save_path, image_name), inp_image) | |||
| return images, meta | |||
| return eval_image, meta | |||
| def preprocess_fn(self, img, num_objects, keypoints, bboxes, category_id): | |||
| """image pre-process and augmentation""" | |||
| @@ -30,25 +30,32 @@ class NMS(nn.Cell): | |||
| Args: | |||
| kernel(int): Maxpooling kernel size. Default: 3. | |||
| enable_nms_fp16(bool): Use float16 data for max_pool, adaption for CPU. Default: True. | |||
| Returns: | |||
| Tensor, heatmap after non-maximum suppression. | |||
| """ | |||
| def __init__(self, kernel=3): | |||
| def __init__(self, kernel=3, enable_nms_fp16=True): | |||
| super(NMS, self).__init__() | |||
| self.pad = (kernel - 1) // 2 | |||
| self.cast = ops.Cast() | |||
| self.dtype = ops.DType() | |||
| self.equal = ops.Equal() | |||
| self.max_pool = nn.MaxPool2d(kernel, stride=1, pad_mode="same") | |||
| self.enable_fp16 = enable_nms_fp16 | |||
| def construct(self, heat): | |||
| """Non-maximum suppression""" | |||
| dtype = self.dtype(heat) | |||
| heat = self.cast(heat, mstype.float16) | |||
| heat_max = self.max_pool(heat) | |||
| keep = self.equal(heat, heat_max) | |||
| keep = self.cast(keep, dtype) | |||
| heat = self.cast(heat, dtype) | |||
| if self.enable_fp16: | |||
| heat = self.cast(heat, mstype.float16) | |||
| heat_max = self.max_pool(heat) | |||
| keep = self.equal(heat, heat_max) | |||
| keep = self.cast(keep, dtype) | |||
| heat = self.cast(heat, dtype) | |||
| else: | |||
| heat_max = self.max_pool(heat) | |||
| keep = self.equal(heat, heat_max) | |||
| heat = heat * keep | |||
| return heat | |||
| @@ -127,18 +134,24 @@ class GatherFeatureByInd(nn.Cell): | |||
| """ | |||
| Gather features by index | |||
| Args: None | |||
| Args: | |||
| enable_cpu_gather (bool): Use cpu operator GatherD to gather feature or not, adaption for CPU. Default: True. | |||
| Returns: | |||
| Tensor | |||
| """ | |||
| def __init__(self): | |||
| def __init__(self, enable_cpu_gatherd=True): | |||
| super(GatherFeatureByInd, self).__init__() | |||
| self.tile = ops.Tile() | |||
| self.shape = ops.Shape() | |||
| self.concat = ops.Concat(axis=1) | |||
| self.reshape = ops.Reshape() | |||
| self.gather_nd = ops.GatherNd() | |||
| self.enable_cpu_gatherd = enable_cpu_gatherd | |||
| if self.enable_cpu_gatherd: | |||
| self.gather_nd = ops.GatherD() | |||
| self.expand_dims = ops.ExpandDims() | |||
| else: | |||
| self.gather_nd = ops.GatherNd() | |||
| def construct(self, feat, ind): | |||
| """gather by index""" | |||
| @@ -147,18 +160,24 @@ class GatherFeatureByInd(nn.Cell): | |||
| b, J, K = self.shape(ind) | |||
| feat = self.reshape(feat, (b, J, K, -1)) | |||
| _, _, _, N = self.shape(feat) | |||
| ind = self.reshape(ind, (-1, 1)) | |||
| ind_b = nn.Range(0, b * J, 1)() | |||
| ind_b = self.reshape(ind_b, (-1, 1)) | |||
| ind_b = self.tile(ind_b, (1, K)) | |||
| ind_b = self.reshape(ind_b, (-1, 1)) | |||
| index = self.concat((ind_b, ind)) | |||
| # (b, N, 2) | |||
| index = self.reshape(index, (-1, K, 2)) | |||
| # (b, N, c) | |||
| feat = self.reshape(feat, (-1, K, N)) | |||
| feat = self.gather_nd(feat, index) | |||
| feat = self.reshape(feat, (b, J, K, -1)) | |||
| if self.enable_cpu_gatherd: | |||
| # (b, J, K, N) | |||
| index = self.expand_dims(ind, -1) | |||
| index = self.tile(index, (1, 1, 1, N)) | |||
| feat = self.gather_nd(feat, 2, index) | |||
| else: | |||
| ind = self.reshape(ind, (-1, 1)) | |||
| ind_b = nn.Range(0, b * J, 1)() | |||
| ind_b = self.reshape(ind_b, (-1, 1)) | |||
| ind_b = self.tile(ind_b, (1, K)) | |||
| ind_b = self.reshape(ind_b, (-1, 1)) | |||
| index = self.concat((ind_b, ind)) | |||
| # (b*J, K, 2) | |||
| index = self.reshape(index, (-1, K, 2)) | |||
| # (b*J, K) | |||
| feat = self.reshape(feat, (-1, K, N)) | |||
| feat = self.gather_nd(feat, index) | |||
| feat = self.reshape(feat, (b, J, K, -1)) | |||
| return feat | |||
| @@ -285,17 +304,16 @@ class MultiPoseDecode(nn.Cell): | |||
| Args: | |||
| net_config(edict): config info for CenterNet network. | |||
| flip_test(bool): flip test of not. Default: False. | |||
| K(int): maximum objects number. Default: 100. | |||
| enable_nms_fp16(bool): Use float16 data for max_pool, adaption for CPU. Default: True. | |||
| Returns: | |||
| Tensor, multi-objects detections. | |||
| """ | |||
| def __init__(self, net_config, flip_test=False, K=100): | |||
| def __init__(self, net_config, K=100, enable_nms_fp16=True): | |||
| super(MultiPoseDecode, self).__init__() | |||
| self.K = K | |||
| self.flip_test = flip_test | |||
| self.nms = NMS() | |||
| self.nms = NMS(enable_nms_fp16=enable_nms_fp16) | |||
| self.shape = ops.Shape() | |||
| self.gather_topk = GatherTopK() | |||
| self.gather_topk_channel = GatherTopKChannel() | |||
| @@ -336,8 +354,6 @@ class MultiPoseDecode(nn.Cell): | |||
| def construct(self, feature): | |||
| """gather detections""" | |||
| heat = feature[0] | |||
| if self.flip_test: | |||
| heat = self.flip_tensor(heat) | |||
| K = self.K | |||
| b, _, _, _ = self.shape(heat) | |||
| heat = self.nms(heat) | |||
| @@ -346,8 +362,6 @@ class MultiPoseDecode(nn.Cell): | |||
| xs = self.reshape(xs, (b, K, 1)) | |||
| kps = feature[1] | |||
| if self.flip_test: | |||
| kps = self.flip_lr_off(kps) | |||
| num_joints = self.shape(kps)[1] / 2 | |||
| # (b, K, num_joints*2) | |||
| kps = self.trans_gather_feature(kps, inds) | |||
| @@ -365,15 +379,11 @@ class MultiPoseDecode(nn.Cell): | |||
| kps = self.reshape(kps, (b, K, num_joints * 2)) | |||
| wh = feature[2] | |||
| if self.flip_test: | |||
| wh = self.flip_tensor(wh) | |||
| wh = self.trans_gather_feature(wh, inds) | |||
| ws, hs = self.half(wh) | |||
| if self.reg_offset: | |||
| reg = feature[self.reg_ind] | |||
| if self.flip_test: | |||
| reg, _ = self.half_first(reg) | |||
| reg = self.trans_gather_feature(reg, inds) | |||
| reg = self.reshape(reg, (b, K, 2)) | |||
| reg_w, reg_h = self.half(reg) | |||
| @@ -387,16 +397,12 @@ class MultiPoseDecode(nn.Cell): | |||
| if self.hm_hp: | |||
| hm_hp = feature[self.hm_hp_ind] | |||
| if self.flip_test: | |||
| hm_hp = self.flip_lr(hm_hp) | |||
| hm_hp = self.nms(hm_hp) | |||
| # (b, num_joints, K) | |||
| hm_score, hm_inds, hm_ys, hm_xs = self.gather_topk_channel(hm_hp, K=K) | |||
| if self.reg_hp_offset: | |||
| hp_offset = feature[self.reg_hp_ind] | |||
| if self.flip_test: | |||
| hp_offset, _ = self.half_first(hp_offset) | |||
| hp_offset = self.trans_gather_feature(hp_offset, self.reshape(hm_inds, (b, -1))) | |||
| hp_offset = self.reshape(hp_offset, (b, num_joints, K, 2)) | |||
| hp_ws, hp_hs = self.half(hp_offset) | |||
| @@ -17,6 +17,7 @@ Functional Cells to be used. | |||
| """ | |||
| import math | |||
| import time | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| import mindspore.ops as ops | |||
| @@ -119,33 +120,46 @@ class GatherFeature(nn.Cell): | |||
| """ | |||
| Gather feature at specified position | |||
| Args: None | |||
| Args: | |||
| enable_cpu_gather (bool): Use cpu operator GatherD to gather feature or not, adaption for CPU. Default: True. | |||
| Returns: | |||
| Tensor, feature at spectified position | |||
| """ | |||
| def __init__(self): | |||
| def __init__(self, enable_cpu_gather=True): | |||
| super(GatherFeature, self).__init__() | |||
| self.tile = ops.Tile() | |||
| self.shape = ops.Shape() | |||
| self.concat = ops.Concat(axis=1) | |||
| self.reshape = ops.Reshape() | |||
| self.gather_nd = ops.GatherNd() | |||
| self.enable_cpu_gather = enable_cpu_gather | |||
| if self.enable_cpu_gather: | |||
| self.gather_nd = ops.GatherD() | |||
| self.expand_dims = ops.ExpandDims() | |||
| else: | |||
| self.gather_nd = ops.GatherND() | |||
| def construct(self, feat, ind): | |||
| """gather by specified index""" | |||
| # (b, N)->(b*N, 1) | |||
| b, N = self.shape(ind) | |||
| ind = self.reshape(ind, (-1, 1)) | |||
| ind_b = nn.Range(0, b, 1)() | |||
| ind_b = self.reshape(ind_b, (-1, 1)) | |||
| ind_b = self.tile(ind_b, (1, N)) | |||
| ind_b = self.reshape(ind_b, (-1, 1)) | |||
| index = self.concat((ind_b, ind)) | |||
| # (b, N, 2) | |||
| index = self.reshape(index, (b, N, -1)) | |||
| # (b, N, c) | |||
| feat = self.gather_nd(feat, index) | |||
| if self.enable_cpu_gather: | |||
| _, _, c = self.shape(feat) | |||
| # (b, N, c) | |||
| index = self.expand_dims(ind, -1) | |||
| index = self.tile(index, (1, 1, c)) | |||
| feat = self.gather_nd(feat, 1, index) | |||
| else: | |||
| # (b, N)->(b*N, 1) | |||
| b, N = self.shape(ind) | |||
| ind = self.reshape(ind, (-1, 1)) | |||
| ind_b = nn.Range(0, b, 1)() | |||
| ind_b = self.reshape(ind_b, (-1, 1)) | |||
| ind_b = self.tile(ind_b, (1, N)) | |||
| ind_b = self.reshape(ind_b, (-1, 1)) | |||
| index = self.concat((ind_b, ind)) | |||
| # (b, N, 2) | |||
| index = self.reshape(index, (b, N, -1)) | |||
| # (b, N, c) | |||
| feat = self.gather_nd(feat, index) | |||
| return feat | |||
| @@ -477,11 +491,19 @@ class LossCallBack(Callback): | |||
| Args: | |||
| dataset_size (int): Dataset size. Default: -1. | |||
| enable_static_time (bool): enable static time cost, adaption for CPU. Default: False. | |||
| """ | |||
| def __init__(self, dataset_size=-1): | |||
| def __init__(self, dataset_size=-1, enable_static_time=False): | |||
| super(LossCallBack, self).__init__() | |||
| self._dataset_size = dataset_size | |||
| self._enable_static_time = enable_static_time | |||
| def step_begin(self, run_context): | |||
| """ | |||
| Get begining time of each step | |||
| """ | |||
| self._begin_time = time.time() | |||
| def step_end(self, run_context): | |||
| """ | |||
| @@ -493,11 +515,19 @@ class LossCallBack(Callback): | |||
| if percent == 0: | |||
| percent = 1 | |||
| epoch_num -= 1 | |||
| print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}" | |||
| .format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs))) | |||
| if self._enable_static_time: | |||
| cur_time = time.time() | |||
| time_per_step = cur_time - self._begin_time | |||
| print("epoch: {}, current epoch percent: {}, step: {}, time per step: {} s, outputs are {}" | |||
| .format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, "%.3f" % time_per_step, | |||
| str(cb_params.net_outputs)), flush=True) | |||
| else: | |||
| print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}" | |||
| .format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, | |||
| str(cb_params.net_outputs)), flush=True) | |||
| else: | |||
| print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, | |||
| str(cb_params.net_outputs))) | |||
| str(cb_params.net_outputs)), flush=True) | |||
| class CenterNetPolynomialDecayLR(LearningRateSchedule): | |||
| @@ -31,12 +31,15 @@ from mindspore.common import set_seed | |||
| from mindspore.profiler import Profiler | |||
| from src.dataset import COCOHP | |||
| from src import CenterNetMultiPoseLossCell, CenterNetWithLossScaleCell | |||
| from src import CenterNetWithoutLossScaleCell | |||
| from src.utils import LossCallBack, CenterNetPolynomialDecayLR, CenterNetMultiEpochsDecayLR | |||
| from src.config import dataset_config, net_config, train_config | |||
| _current_dir = os.path.dirname(os.path.realpath(__file__)) | |||
| parser = argparse.ArgumentParser(description='CenterNet training') | |||
| parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'CPU'], | |||
| help='device where the code will be implemented. (Default: Ascend)') | |||
| parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"], | |||
| help="Run distribute, default is false.") | |||
| parser.add_argument("--need_profiler", type=str, default="false", choices=["true", "false"], | |||
| @@ -125,26 +128,32 @@ def _get_optimizer(network, dataset_size): | |||
| def train(): | |||
| """training CenterNet""" | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||
| context.set_context(enable_auto_mixed_precision=False) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | |||
| context.set_context(reserve_class_name_in_scope=False) | |||
| context.set_context(save_graphs=False) | |||
| ckpt_save_dir = args_opt.save_checkpoint_path | |||
| if args_opt.distribute == "true": | |||
| D.init() | |||
| device_num = args_opt.device_num | |||
| rank = args_opt.device_id % device_num | |||
| ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(get_rank()) + '/' | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, | |||
| device_num=device_num) | |||
| _set_parallel_all_reduce_split() | |||
| else: | |||
| rank = 0 | |||
| device_num = 1 | |||
| rank = 0 | |||
| device_num = 1 | |||
| num_workers = 8 | |||
| if args_opt.device_target == "Ascend": | |||
| context.set_context(enable_auto_mixed_precision=False) | |||
| context.set_context(device_id=args_opt.device_id) | |||
| if args_opt.distribute == "true": | |||
| D.init() | |||
| device_num = args_opt.device_num | |||
| rank = args_opt.device_id % device_num | |||
| ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(get_rank()) + '/' | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, | |||
| device_num=device_num) | |||
| _set_parallel_all_reduce_split() | |||
| else: | |||
| args_opt.distribute = "false" | |||
| args_opt.need_profiler = "false" | |||
| args_opt.enable_data_sink = "false" | |||
| # Start create dataset! | |||
| # mindrecord files will be generated at args_opt.mindrecord_dir such as centernet.mindrecord0, 1, ... file_num. | |||
| logger.info("Begin creating dataset for CenterNet") | |||
| @@ -167,7 +176,8 @@ def train(): | |||
| optimizer = _get_optimizer(net_with_loss, dataset_size) | |||
| callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack(dataset_size)] | |||
| enable_static_time = args_opt.device_target == "CPU" | |||
| callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack(dataset_size, enable_static_time)] | |||
| if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(8, device_num) == 0: | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps, | |||
| keep_checkpoint_max=args_opt.save_checkpoint_num) | |||
| @@ -178,12 +188,13 @@ def train(): | |||
| if args_opt.load_checkpoint_path: | |||
| param_dict = load_checkpoint(args_opt.load_checkpoint_path) | |||
| load_param_into_net(net_with_loss, param_dict) | |||
| net_with_grads = CenterNetWithLossScaleCell(net_with_loss, optimizer=optimizer, | |||
| sens=train_config.loss_scale_value) | |||
| if args_opt.device_target == "Ascend": | |||
| net_with_grads = CenterNetWithLossScaleCell(net_with_loss, optimizer=optimizer, | |||
| sens=train_config.loss_scale_value) | |||
| else: | |||
| net_with_grads = CenterNetWithoutLossScaleCell(net_with_loss, optimizer=optimizer) | |||
| model = Model(net_with_grads) | |||
| model.train(new_repeat_count, dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"), | |||
| sink_size=args_opt.data_sink_steps) | |||